1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| #include <bits/stdc++.h> #define MAX_N 500000 using namespace std; typedef long long lnt; template <class T> inline void read(T &x) { x = 0; int c = getchar(), f = 1; for (; !isdigit(c); c = getchar()) if (c == 45) f = -1; for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0'); } int n, rt, tot, cnt[(MAX_N<<1)+5]; lnt ans; int sz[MAX_N+5], w[MAX_N+5], dis[MAX_N+5]; int f[(MAX_N<<1)+5][2], g[(MAX_N<<1)+5][2]; vector <int> G[MAX_N+5], E[MAX_N+5]; bool mrk[MAX_N+5]; void insert(int u, int v, int c) {G[u].push_back(v), E[u].push_back(c);} void addedge(int u, int v, int c) {if (!c) c = -1; insert(u, v, c), insert(v, u, c);} void getrt(int u, int fa) { sz[u] = 1, w[u] = 0; for (int i = 0, v; i < (int)G[u].size(); i++) if (((v = G[u][i]) ^ fa) && !mrk[v]) getrt(v, u), sz[u] += sz[v], w[u] = max(w[u], sz[v]); if ((w[u] = max(w[u], tot-sz[u])) < w[rt]) rt = u; } int getdis(int u, int fa, int dep) { int ret = dep; f[dis[u]][cnt[dis[u]]>0]++, cnt[dis[u]]++; for (int i = 0, v; i < (int)G[u].size(); i++) if (((v = G[u][i]) ^ fa) && !mrk[v]) dis[v] = dis[u]+E[u][i], ret = max(ret, getdis(v, u, dep+1)); cnt[dis[u]]--; return ret; } void DFS(int u) { int r = 0; mrk[u] = true, g[n][0] = 1; for (int i = 0, v, d; i < (int)G[u].size(); i++) if (!mrk[v = G[u][i]]) { dis[v] = n+E[u][i], d = getdis(v, u, 1); r = max(r, d), ans += 1LL*f[n][0]*(g[n][0]-1); for (int j = -d; j <= +d; j++) ans += 1LL*f[n+j][1]*g[n-j][1], ans += 1LL*f[n+j][0]*g[n-j][1], ans += 1LL*f[n+j][1]*g[n-j][0]; for (int j = -d; j <= +d; j++) g[n+j][0] += f[n+j][0], f[n+j][0] = 0, g[n+j][1] += f[n+j][1], f[n+j][1] = 0; } for (int i = -r; i <= +r; i++) g[n+i][0] = g[n+i][1] = 0; for (int i = 0, v; i < (int)G[u].size(); i++) if (!mrk[v = G[u][i]]) w[rt = 0] = tot = sz[v], getrt(v, u), DFS(rt); } int main() { read(n); for (int i = 1, u, v, c; i < n; i++) read(u), read(v), read(c), addedge(u, v, c); w[rt = 0] = tot = n, getrt(1, 0), DFS(rt); return printf("%lld\n", ans), 0; }
|