BZOJ3697 采药人的路径 <点分治>

Problem

采药人的路径

Time  Limit:  10  Sec\mathrm{Time\;Limit:\;10\;Sec}
Memory  Limit:  128  MB\mathrm{Memory\;Limit:\;128\;MB}

Description

采药人的药田是一个树状结构,每条路径上都种植着同种药材。
采药人以自己对药材独到的见解,对每种药材进行了分类。大致分为两类,一种是阴性的,一种是阳性的。
采药人每天都要进行采药活动。他选择的路径是很有讲究的,他认为阴阳平衡是很重要的,所以他走的一定是两种药材数目相等的路径。
采药工作是很辛苦的,所以他希望他选出的路径中有一个可以作为休息站的节点(不包括起点和终点),满足起点到休息站和休息站到终点的路径也是阴阳平衡的。
他想知道他一共可以选择多少种不同的路径。

Input

11行包含一个整数NN
接下来N1N-1行,每行包含三个整数ai,bi,tia_i,b_i,t_i,表示aia_ibib_i这条路上药材的类型为tit_i

Output

输出符合采药人要求的路径数目。

Sample Input

1
2
3
4
5
6
7
7
1 2 0
3 1 1
2 4 0
5 2 0
6 3 1
5 7 1

Sample Output

1
1

HINT

对于100%100\%的数据,N105N\le10^5

标签:点分治

Solution

点分治基础题。

每次找重心作分治中心,同一子树内的路径数递归计算,只考虑经过当前分治中心的路径数。
对于当前分治中心,处理出其余未分治到的点与其的路径上有多少阴性和阳性道路。设阴性道路边权为1-1,阳性为11,那么若两个点到分治中心的路径拼起来可以构成一条合法道路,一定需要满足两个条件:

  • 路径总长为00
  • 在两条路径中一定有至少一条在路径上存在两个点,使得分治中心到这个两点的长度相同,并且这个长度不为00。特殊情况是两条路径的长度都为00也可。

f[i][0/1]f[i][0/1]表示现在枚举到的子树中,与当前分治中心距离为ii的路径上有/没有两个离分治中心距离相同的点的路径条数;用g[i][0/1]g[i][0/1]表示同样的意义,只是是在前面已枚举的子树中这样的路径条数。那么从当前子树和前面的子树各选一条路径,拼成新路径,这样对答案的贡献是f[x][0]×g[x][1]+f[x][1]×g[x][0]+f[x][1]×g[x][1]f[x][0]\times g[-x][1]+f[x][1]\times g[-x][0]+f[x][1]\times g[-x][1]。除此之外还需要加上两条不同子树中到分治中心长为00的路径组成的路径条数,即f[0][0]×g[0][0]f[0][0]\times g[0][0]

点分时每次DFS\mathrm{DFS}预处理ff,gg统计即可。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <bits/stdc++.h>
#define MAX_N 500000
using namespace std;
typedef long long lnt;
template <class T> inline void read(T &x) {
x = 0; int c = getchar(), f = 1;
for (; !isdigit(c); c = getchar()) if (c == 45) f = -1;
for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0');
}
int n, rt, tot, cnt[(MAX_N<<1)+5]; lnt ans;
int sz[MAX_N+5], w[MAX_N+5], dis[MAX_N+5];
int f[(MAX_N<<1)+5][2], g[(MAX_N<<1)+5][2];
vector <int> G[MAX_N+5], E[MAX_N+5]; bool mrk[MAX_N+5];
void insert(int u, int v, int c) {G[u].push_back(v), E[u].push_back(c);}
void addedge(int u, int v, int c) {if (!c) c = -1; insert(u, v, c), insert(v, u, c);}
void getrt(int u, int fa) {
sz[u] = 1, w[u] = 0;
for (int i = 0, v; i < (int)G[u].size(); i++)
if (((v = G[u][i]) ^ fa) && !mrk[v])
getrt(v, u), sz[u] += sz[v], w[u] = max(w[u], sz[v]);
if ((w[u] = max(w[u], tot-sz[u])) < w[rt]) rt = u;
}
int getdis(int u, int fa, int dep) {
int ret = dep; f[dis[u]][cnt[dis[u]]>0]++, cnt[dis[u]]++;
for (int i = 0, v; i < (int)G[u].size(); i++)
if (((v = G[u][i]) ^ fa) && !mrk[v])
dis[v] = dis[u]+E[u][i], ret = max(ret, getdis(v, u, dep+1));
cnt[dis[u]]--; return ret;
}
void DFS(int u) {
int r = 0; mrk[u] = true, g[n][0] = 1;
for (int i = 0, v, d; i < (int)G[u].size(); i++)
if (!mrk[v = G[u][i]]) {
dis[v] = n+E[u][i], d = getdis(v, u, 1);
r = max(r, d), ans += 1LL*f[n][0]*(g[n][0]-1);
for (int j = -d; j <= +d; j++)
ans += 1LL*f[n+j][1]*g[n-j][1],
ans += 1LL*f[n+j][0]*g[n-j][1],
ans += 1LL*f[n+j][1]*g[n-j][0];
for (int j = -d; j <= +d; j++)
g[n+j][0] += f[n+j][0], f[n+j][0] = 0,
g[n+j][1] += f[n+j][1], f[n+j][1] = 0;
}
for (int i = -r; i <= +r; i++) g[n+i][0] = g[n+i][1] = 0;
for (int i = 0, v; i < (int)G[u].size(); i++) if (!mrk[v = G[u][i]])
w[rt = 0] = tot = sz[v], getrt(v, u), DFS(rt);
}
int main() {
read(n);
for (int i = 1, u, v, c; i < n; i++)
read(u), read(v), read(c), addedge(u, v, c);
w[rt = 0] = tot = n, getrt(1, 0), DFS(rt);
return printf("%lld\n", ans), 0;
}