1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
| #include <iostream> #include <cstdio> #include <vector> #include <cstring> #define MAX_N 100000 #define MAX_D 15 using namespace std; int n, m, cnt, c[MAX_N+5], root[MAX_N+5]; int anc[MAX_N+5][MAX_D+5], dep[MAX_N+5]; bool vis[MAX_N+5]; vector <int> G[MAX_N+5]; struct node {int ls, rs, size;} trie[MAX_N*20+500]; void init() { root[0] = cnt = 0; memset(root, 0, sizeof(root)); memset(anc, 0, sizeof(anc)); memset(dep, 0, sizeof(dep)); memset(vis, false, sizeof(vis)); for (int i = 1; i <= n; i++) G[i].clear(); } void DFS(int u) { vis[u] = true; for (int i = 1; (1<<i) <= dep[u]; i++) anc[u][i] = anc[anc[u][i-1]][i-1]; for (int i = 0; i < G[u].size(); i++) if (!vis[G[u][i]]) anc[G[u][i]][0] = u, dep[G[u][i]] = dep[u]+1, DFS(G[u][i]); } int LCA(int a, int b) { int i, j; if (dep[a] < dep[b]) swap(a, b); for (i = 0; (1<<i) <= dep[a]; i++) ; i--; for (j = i; j >= 0; j--) if (dep[a]-(1<<j) >= dep[b]) a = anc[a][j]; if (a == b) return a; for (j = i; j >= 0; j--) if (anc[a][j] != anc[b][j]) a = anc[a][j], b = anc[b][j]; return anc[a][0]; } void insert(int v, int o, int val, int range) { trie[v] = trie[o]; if (range == 0) {trie[v].size++; return;} int x = val/range; if (x == 0) insert(trie[v].ls = ++cnt, trie[o].ls, val%range, range/2); else insert(trie[v].rs = ++cnt, trie[o].rs, val%range, range/2); trie[v].size = trie[trie[v].ls].size+trie[trie[v].rs].size; } void build(int u) { root[u] = ++cnt; insert(root[u], root[anc[u][0]], c[u], (1<<MAX_D)); for (int i = 0; i < G[u].size(); i++) if (G[u][i] != anc[u][0]) build(G[u][i]); } int query(int v1, int v2, int v3, int v4, int x, int range) { if (range == 0) return 0; int tmp1 = trie[trie[v1].ls].size+trie[trie[v2].ls].size-trie[trie[v3].ls].size-trie[trie[v4].ls].size; int tmp2 = trie[trie[v1].rs].size+trie[trie[v2].rs].size-trie[trie[v3].rs].size-trie[trie[v4].rs].size; if (x/range == 0) { if (tmp2) return range+query(trie[v1].rs, trie[v2].rs, trie[v3].rs, trie[v4].rs, x%range, range/2); else return query(trie[v1].ls, trie[v2].ls, trie[v3].ls, trie[v4].ls, x%range, range/2); } else { if (tmp1) return range+query(trie[v1].ls, trie[v2].ls, trie[v3].ls, trie[v4].ls, x%range, range/2); else return query(trie[v1].rs, trie[v2].rs, trie[v3].rs, trie[v4].rs, x%range, range/2); } } int main() { while (scanf("%d%d", &n, &m) != EOF) { init(); for (int i = 1; i <= n; i++) scanf("%d", &c[i]); for (int i = 1; i < n; i++) {int u, v; scanf("%d%d", &u, &v), G[u].push_back(v), G[v].push_back(u);} DFS(1); build(1); while (m--) { int u, v, x, lca; scanf("%d%d%d", &u, &v, &x); lca = LCA(u, v); printf("%d\n", query(root[u], root[v], root[lca], root[anc[lca][0]], x, (1<<MAX_D))); } for (int i = 1; i <= cnt; i++) trie[i].ls = trie[i].rs = trie[i].size = 0; } return 0; }
|