1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
| #include <bits/stdc++.h> #define MAX_N 100000 using namespace std; typedef long long lnt; vector <int> G[MAX_N+5]; int n, m, rt, c[MAX_N+5], sz[MAX_N+5]; int dep[MAX_N+5], fa[MAX_N+5], son[MAX_N+5]; int into[MAX_N+5], outo[MAX_N+5], top[MAX_N+5], ind; lnt seg[(MAX_N<<2)+5], tag[(MAX_N<<2)+5]; void DFS(int u) { sz[u] = 1; for (int i = 0; i < G[u].size(); i++) { int v = G[u][i]; if (v == fa[u]) continue; dep[v] = dep[u]+1, fa[v] = u, DFS(v), sz[u] += sz[v]; if (!son[u] || sz[son[u]] < sz[v]) son[u] = v; } } void DFS(int u, int tp) { top[u] = tp, into[u] = ++ind; if (son[u]) DFS(son[u], tp); for (int i = 0, v; i < G[u].size(); i++) {v = G[u][i]; if ((v^fa[u]) && (v^son[u])) DFS(v, v);} outo[u] = ind; } void updata(int v) {seg[v] = seg[v<<1]+seg[v<<1|1];} void downtag(int v, int s, int t) { if (!tag[v]) return; int mid = s+t>>1; seg[v<<1] += tag[v]*(mid-s+1), seg[v<<1|1] += tag[v]*(t-mid); tag[v<<1] += tag[v], tag[v<<1|1] += tag[v], tag[v] = 0; } void modify(int v, int s, int t, int l, int r, lnt x) { if (s >= l && t <= r) {seg[v] += x*(t-s+1), tag[v] += x; return;} int mid = s+t>>1; downtag(v, s, t); if (l <= mid) modify(v<<1, s, mid, l, r, x); if (r >= mid+1) modify(v<<1|1, mid+1, t, l, r, x); updata(v); return; } lnt query(int v, int s, int t, int l, int r) { if (s >= l && t <= r) return seg[v]; int mid = s+t>>1; lnt ret = 0; downtag(v, s, t); if (l <= mid) ret += query(v<<1, s, mid, l, r); if (r >= mid+1) ret += query(v<<1|1, mid+1, t, l, r); updata(v); return ret; } void solve1(int p, lnt x) {modify(1, 1, n, into[p], into[p], x);} void solve2(int p, lnt x) {modify(1, 1, n, into[p], outo[p], x);} void solve3(int p) { lnt ans = 0; while (top[p] != rt) ans += query(1, 1, n, into[top[p]], into[p]), p = fa[top[p]]; ans += query(1, 1, n, into[rt], into[p]); printf("%lld\n", ans); } int main() { scanf("%d%d", &n, &m), rt = 1; for (int i = 1; i <= n; i++) scanf("%d", c+i); for (int i = 1, u, v; i < n; i++) scanf("%d%d", &u, &v), G[u].push_back(v), G[v].push_back(u); DFS(rt), DFS(rt, rt); for (int i = 1; i <= n; i++) modify(1, 1, n, into[i], into[i], c[i]); while (m--) { int opt; scanf("%d", &opt); if (opt == 1) {int p; lnt x; scanf("%d%lld", &p, &x), solve1(p, x);} if (opt == 2) {int p; lnt x; scanf("%d%lld", &p, &x), solve2(p, x);} if (opt == 3) {int p; scanf("%d", &p), solve3(p);} } return 0; }
|