BZOJ5210 最大连通子块和 <树链剖分+树形DP>

Problem

最大连通子块和

Time  Limit:  20  Sec\mathrm{Time\;Limit:\;20\;Sec}
Memory  Limit:  128  MB\mathrm{Memory\;Limit:\;128\;MB}

Description

给出一棵nn个点,以11为根的有根树,点有点权。
要求支持如下两种操作:

  1. M  x  y\mathrm{M}\;x\;y:将点xx的点权改为yy
  2. Q  x\mathrm{Q}\;x:求以xx为根的子树的最大连通子块和

一棵子树的最大连通子块和指该子树所有子连通块的点权和中的最大值(本题中子连通块包括空连通块,点权和为00)。

Input

第一行两个整数nn,mm,表示树的点数以及操作的数目。
第二行nn个整数,第ii个整数wiw_i表示第ii个点的点权。
接下来的n1n-1行,每行两个整数x,yx,y,表示xxyy之间有一条边相连。
接下来的mm行,每行输入一个操作,含义如题目所述。
保证操作为M  x  y\mathrm{M}\;x\;yQ  x\mathrm{Q}\;x之一。

Output

对于每个Q\mathrm{Q}操作输出一行一个整数,表示询问子树的最大连通子块和。

Sample Input

1
2
3
4
5
6
7
8
9
10
5 4
3 -2 0 3 -1
1 2
1 3
4 2
2 5
Q 1
M 4 1
Q 1
Q 2

Sample Output

1
2
3
4
3
1

HINT

1n,m2×1051\le n,m\le2\times10^5,任意时刻wi109|w_i|\le10^9

Source

CQzhangyu&GXZlegend原创

标签:树链剖分 树上DP

Solution

经典树链剖分维护树上DP\mathrm{DP}。以下解法源自出题人CQzhangyu的博客GXZlegend的博客

首先考虑暴力DP\mathrm{DP},令f[x]f[x]表示xx子树中包含xx的连通块权值和最大值,那么f[x]=max(0,w[x]+xyGf[y])f[x]=\max(0, w[x]+\sum_{x\to y\in\mathrm{G}}f[y])。维护s[x]s[x]表示xx子树中连通块权值和最大值,则s[x]=max(f[x],maxxyGs[y])s[x]=\max(f[x],\max_{x\to y\in\mathrm{G}}s[y])。每次修改后重新DP\mathrm{DP},可做到O(n2)O(n^2)

注意到每次修改后不是所有的ff都变化。用树链剖分维护树上DP\mathrm{DP},可以每次不修改所有的ff值。然而直接维护ff的值不方便,因为递推式中的和式在线段树上不便于计算。于是引入g[x]=w[x]+f[y]g[x]=w[x]+\sum f[y],其中yyxx儿子。那么重链上的转移就变为f[x]=max(0,f[son[x]]+g[x])f[x]=\max(0,f[son[x]]+g[x])。这其实就是最大连续子段和的DP\mathrm{DP}方式,可以用线段树维护带修改最大连续子段和。于是对于每次修改,向上跳重链,在线段树上每条重链的区域内维护最大连续子段和即可。注意这里“每条重链的区域”指的是链顶到链底的距离,而非括号序列。

再考虑如何维护ss。注意到ss每次都直接取最值,带修改后,其实是可删除堆的形式。对每个点维护可删除堆来维护轻儿子的ss值最大值,对于重儿子则在线段树上维护。在线段树上修改后update\mathrm{update}时,用每个结点的ss堆顶元素更新最大子段和,即可动态维护ss

查询时,直接向上跳重链,将该重链对应区间的最大子段和取出来打擂即可。

这样修改复杂度O(log2n)O(\log^2n),查询复杂度O(logn)O(\log n),总复杂度O(mlog2n)O(m\log^2n)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <bits/stdc++.h>
#define MAX_N 200000
#define mid ((s+t)>>1)
using namespace std;
typedef long long lnt;
template <class T> inline void read(T &x) {
x = 0; int c = getchar(), f = 1;
for (; !isdigit(c); c = getchar()) if (c == 45) f = -1;
for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0');
}
int n, m, a[MAX_N+5]; vector <int> G[MAX_N+5];
int ind, dep[MAX_N+5], fa[MAX_N+5], son[MAX_N+5];
int sz[MAX_N+5], top[MAX_N+5], into[MAX_N+5], outo[MAX_N+5];
lnt f[MAX_N+5], g[MAX_N+5], mxf[MAX_N+5];
struct node {
lnt s, mx, lmx, rmx;
node () {s = mx = lmx = rmx = 0LL;}
inline friend node operator + (const node &a, const node &b) {
node ret;
ret.s = a.s+b.s, ret.mx = max(max(a.mx, b.mx), a.rmx+b.lmx);
ret.lmx = max(a.lmx, a.s+b.lmx), ret.rmx = max(b.rmx, a.rmx+b.s);
return ret;
}
} tr[MAX_N<<2];
struct heap {
priority_queue <lnt> i, o;
inline void push(lnt x) {i.push(x);}
inline void pop(lnt x) {o.push(x);}
inline lnt top() {
while (!o.empty() && i.top() == o.top())
i.pop(), o.pop();
return i.top();
}
} h[MAX_N+5];
void addedge(int u, int v) {G[u].push_back(v), G[v].push_back(u);}
void DFS(int u) {
sz[u] = 1;
for (int i = 0; i < (int)G[u].size(); i++) {
int v = G[u][i]; if (v == fa[u]) continue;
dep[v] = dep[u]+1, fa[v] = u, DFS(v), sz[u] += sz[v];
if (!son[u] || sz[son[u]] < sz[v]) son[u] = v;
}
}
void DFS(int u, int tp) {
top[u] = tp, g[into[u] = ++ind] = a[u];
if (son[u]) DFS(son[u], tp);
for (int i = 0, v; i < (int)G[u].size(); i++)
if (((v = G[u][i]) ^ fa[u]) && (v ^ son[u]))
DFS(v, v), g[into[u]] += f[v], h[into[u]].push(mxf[v]);
outo[u] = son[u] ? outo[son[u]] : ind;
f[u] = max(f[son[u]]+g[into[u]], 0LL);
mxf[u] = max(mxf[son[u]], max(f[u], h[into[u]].top()));
}
void build(int v, int s, int t) {
if (s == t) {
tr[v].s = g[s], tr[v].mx = max(g[s], h[s].top());
tr[v].lmx = tr[v].rmx = max(g[s], 0LL); return;
}
build(v<<1, s, mid), build(v<<1|1, mid+1, t);
tr[v] = tr[v<<1]+tr[v<<1|1];
}
void modify(int v, int s, int t, int p) {
if (s == t) {
tr[v].s = g[s], tr[v].mx = max(g[s], h[s].top());
tr[v].lmx = tr[v].rmx = max(g[s], 0LL); return;
}
if (p <= mid) modify(v<<1, s, mid, p);
if (p >= mid+1) modify(v<<1|1, mid+1, t, p);
tr[v] = tr[v<<1]+tr[v<<1|1];
}
node query(int v, int s, int t, int l, int r) {
if (s >= l && t <= r) return tr[v]; node ret;
if (l <= mid) ret = ret+query(v<<1, s, mid, l, r);
if (r >= mid+1) ret = ret+query(v<<1|1, mid+1, t, l, r);
return ret;
}
void change(int u, int val) {
node pr, cr;
pr.lmx = g[into[u]];
cr.lmx = g[into[u]]-a[u]+val;
for (int i = 0; u; i++, u = fa[top[u]]) {
g[into[u]] += cr.lmx-pr.lmx;
if (i) h[into[u]].pop(pr.mx), h[into[u]].push(cr.mx);
pr = query(1, 1, n, into[top[u]], outo[u]);
modify(1, 1, n, into[u]);
cr = query(1, 1, n, into[top[u]], outo[u]);
}
}
int main() {
read(n), read(m);
for (int i = 1; i <= n; i++) read(a[i]), h[i].push(0);
for (int i = 1, u, v; i < n; i++) read(u), read(v), addedge(u, v);
DFS(1), DFS(1, 1), build(1, 1, n);
while (m--) {
char opt[2]; int x, y; scanf("%s", opt);
if (opt[0] == 'M') read(x), read(y), change(x, y), a[x] = y;
else read(x), printf("%lld\n", query(1, 1, n, into[x], outo[x]).mx);
}
return 0;
}