1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
| #include <bits/stdc++.h> #define LOG 16 #define MAX_N 100000 using namespace std; template <class T> inline void read(T &x) { x = 0; int c = getchar(), f = 1; for (; !isdigit(c); c = getchar()) if (c == 45) f = -1; for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0'); } int n, m, e, c[MAX_N+5], fa[MAX_N+5]; int w[MAX_N+5], sz[MAX_N+5], rt, tot; int anc[MAX_N+5][LOG+1], dep[MAX_N+5]; int *p0[MAX_N+5], *p1[MAX_N+5], pr[MAX_N+5]; int BIT0[MAX_N*100], BIT1[MAX_N*100]; bool mrk[MAX_N+5]; struct node {int v, nxt;} E[(MAX_N<<1)+5]; void addedge(int u, int v) {E[e] = (node){v, pr[u]}, pr[u] = e++;} void inc(int *tr, int p, int x, int l) {for (p = min(p+1, l); p <= l; p += (p&-p)) tr[p] += x;} int sum(int *tr, int p, int l) {int ret = 0; for (p = min(p+1, l); p; p -= (p&-p)) ret += tr[p]; return ret;} void init(int u) { for (int i = 1; i <= LOG; i++) anc[u][i] =anc[anc[u][i-1]][i-1]; for (int i = pr[u], v; ~i; i = E[i].nxt) if ((v = E[i].v) ^ anc[u][0]) anc[v][0] = u, dep[v] = dep[u]+1, init(v); } int LCA(int a, int b) { if (dep[a] < dep[b]) swap(a, b); for (int i = LOG; ~i; i--) if (dep[a]-(1<<i) >= dep[b]) a = anc[a][i]; if (a == b) return a; for (int i = LOG; ~i; i--) if (anc[a][i]^anc[b][i]) a = anc[a][i], b = anc[b][i]; return anc[a][0]; } int dist(int u, int v) {return dep[u]+dep[v]-2*dep[LCA(u, v)];} int getsz(int u, int f) { int ret = 1; for (int i = pr[u], v; ~i; i = E[i].nxt) if (((v = E[i].v) ^ f) && !mrk[v]) ret += getsz(v, u); return ret; } void getrt(int u, int f) { sz[u] = 1, w[u] = 0; for (int i = pr[u], v; ~i; i = E[i].nxt) if (((v = E[i].v) ^ f) && !mrk[v]) getrt(v, u), sz[u] += sz[v], w[u] = max(w[u], sz[v]); w[u] = max(w[u], tot-sz[u]); if (w[u] < w[rt]) rt = u; } void divide(int u, int f) { rt = 0, tot = getsz(u, 0), getrt(u, 0); fa[u = rt] = f, mrk[u] = true, sz[u] = tot+1; for (int i = pr[u], v; ~i; i = E[i].nxt) if (!mrk[v = E[i].v]) divide(v, u); } void modify(int x, int y) { inc(p0[x], 0, y, sz[x]); for (int u = x; fa[u]; u = fa[u]) { int dis = dist(fa[u], x); inc(p1[u], dis, y, sz[u]); inc(p0[fa[u]], dis, y, sz[fa[u]]); } } int query(int x, int y) { int ret = sum(p0[x], y, sz[x]); for (int u = x, dis; fa[u]; u = fa[u]) if ((dis = dist(fa[u], x)) <= y) ret += sum(p0[fa[u]], y-dis, sz[fa[u]])-sum(p1[u], y-dis, sz[u]); return ret; } int main() { read(n), read(m), w[0] = n; memset(pr, -1, sizeof pr); for (int i = 1; i <= n; i++) read(c[i]); for (int i = 1, u, v; i < n; i++) read(u), read(v), addedge(u, v), addedge(v, u); init(1), divide(1, 0); int cnt = 0, lst = 0; for (int i = 1; i <= n; cnt += sz[i++]+1) p0[i] = BIT0+cnt, p1[i] = BIT1+cnt; for (int i = 1; i <= n; i++) modify(i, c[i]); while (m--) { int opt, x, y; read(opt); read(x), read(y), x ^= lst, y ^= lst; if (opt) modify(x, y-c[x]), c[x] = y; else printf("%d\n", lst = query(x, y)); } return 0; }
|