BZOJ2588 Count on a tree <树上主席树>

Problem

Count on a tree

Description

给定一棵NN个节点的树,每个点有一个权值,对于MM个询问(u,v,k)(u,v,k),你需要回答u  xor  lastansu\ \ xor\ \ lastansvv这两个节点间第KK小的点权。其中lastanslastans是上一个询问的答案,初始为00,即第一个询问的uu是明文。

Input

第一行两个整数N,MN,M
第二行有NN个整数,其中第ii个整数表示点ii的权值。
后面N1N-1行每行22个整数(x,y)(x,y),表示点xx到点yy有一条边。
最后MM行每行33个整数(u,v,k)(u,v,k),表示一组询问。

Output

MM行,表示每个询问的答案。最后一个询问不输出换行符

Sample Input

1
2
3
4
5
6
7
8
9
10
11
12
13
14
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2

Sample Output

1
2
3
4
5
2
8
9
105
7

Hint

1N,M1051\le N,M\le 10^5

标签:LCA 主席树

Solution

这是一个区间第kk小的问题,所以可以很自然的想到值域主席树。但是此题将区间移到了树上,于是写树上主席树。
在解区间第kk小的时候,对于每次询问区间[a,b][a,b],我们需要找到a1a-1位置的线段树和bb位置的线段树,然后递归queryquery的时候用个数相减。对于这道题,我们把每个结点到根的那条链作为一个序列,用区间第kk小的方法存储,然后找到uuvvLCA\mathrm{LCA}(假定它为tt),递归queryquery的时候计算左区间数的个数,即

u结点对应线段树左区间数的个数+v结点数的个数t结点数的个数t的父结点数的个数u结点对应线段树左区间数的个数+v结点\cdots数的个数-t结点\cdots 数的个数-t的父结点\cdots数的个数

tmp=tr[tr[u].ls].val+tr[tr[v].ls].valtr[tr[t].ls].valtr[tr[fa[t]].ls].valtmp = tr[tr[u].ls].val+tr[tr[v].ls].val-tr[tr[t].ls].val-tr[tr[fa[t]].ls].val
写的时候注意强制在线的操作方式和读入数后先离散化。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <bits/stdc++.h>
#define LOG 20
#define MAX_N 100000
#define mid ((s+t)>>1)
using namespace std;
template <class T> inline void read(T &x) {
x = 0; int c = getchar(), f = 1;
for (; !isdigit(c); c = getchar()) if (c == 45) f = -1;
for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0');
}
int n, m, q, cnt;
int rt[MAX_N+5], fa[MAX_N+5][LOG+1], dep[MAX_N+5];
int a[MAX_N+5], b[MAX_N+5], h[MAX_N+5], rk[MAX_N+5];
vector <int> G[MAX_N+5]; struct node {int ls, rs, c;} tr[MAX_N*35+5];
bool cmp (const int &x, const int &y) {return a[x] < a[y];}
void addedge(int u, int v) {G[u].push_back(v), G[v].push_back(u);}
void DFS(int u) {
for (int i = 1; i <= LOG; i++) if (dep[u] >= (1<<i)) fa[u][i] = fa[fa[u][i-1]][i-1];
for (int i = 0, v; i < (int)G[u].size(); i++)
if ((v = G[u][i]) ^ fa[u][0]) fa[v][0] = u, dep[v] = dep[u]+1, DFS(v);
}
int LCA(int a, int b) {
if (dep[a] < dep[b]) swap(a, b);
for (int i = LOG; ~i; i--) if (dep[a]-(1<<i) >= dep[b]) a = fa[a][i]; if (a == b) return a;
for (int i = LOG; ~i; i--) if (fa[a][i]^fa[b][i]) a = fa[a][i], b = fa[b][i]; return fa[a][0];
}
void modify(int v, int o, int s, int t, int x) {
tr[v] = tr[o]; if (s == t) {tr[v].c++; return;}
if (x <= mid) modify(tr[v].ls = ++cnt, tr[o].ls, s, mid, x);
else modify(tr[v].rs = ++cnt, tr[o].rs, mid+1, t, x);
tr[v].c = tr[tr[v].ls].c+tr[tr[v].rs].c;
}
int query(int v1, int v2, int v3, int v4, int s, int t, int k) {
if (s == t) return s;
int lsz = tr[tr[v1].ls].c+tr[tr[v2].ls].c-tr[tr[v3].ls].c-tr[tr[v4].ls].c;
if (k <= lsz) return query(tr[v1].ls, tr[v2].ls, tr[v3].ls, tr[v4].ls, s, mid, k);
return query(tr[v1].rs, tr[v2].rs, tr[v3].rs, tr[v4].rs, mid+1, t, k-lsz);
}
void build(int u) {
modify(rt[u] = ++cnt, rt[fa[u][0]], 1, m, b[u]);
for (int i = 0, v; i < (int)G[u].size(); i++)
if ((v = G[u][i]) ^ fa[u][0]) build(v);
}
int main() {
read(n), read(q);
for (int i = 1; i <= n; i++) read(a[i]), rk[i] = i;
for (int i = 1, u, v; i < n; i++)
read(u), read(v), addedge(u, v);
sort(rk+1, rk+n+1, cmp);
for (int i = 1; i <= n; b[rk[i++]] = m)
if (i == 1 || (a[rk[i]]^a[rk[i-1]])) h[++m] = a[rk[i]];
DFS(1), build(1); int lst = 0;
while (q--) {
int u, v, k; read(u), read(v), read(k), u ^= lst; int anc = LCA(u, v);
printf("%d\n", lst = h[query(rt[u], rt[v], rt[anc], rt[fa[anc][0]], 1, m, k)]);
}
return 0;
}