点分治板题。
把树从重心分开,递归处理各个子树,并处理路径跨重心的情况:计算所有点中距离重心的距离相加为3的倍数的点对数,再减去每个子树中距离重心距离和加上二倍“重心到该子树根距离”为3的倍数的点对数。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
#define space putchar(' ')
#define enter putchar('\n')
template <class T>
void read(T &x){
char c;
bool op = 0;
while(c = getchar(), c < '0' || c > '9')
if(c == '-') op = 1;
x = c - '0';
while(c = getchar(), c >= '0' && c <= '9')
x = x * 10 + c - '0';
if(op) x = -x;
}
template <class T>
void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x >= 10) write(x / 10);
putchar('0' + x % 10);
}
const int N = 20005, INF = 0x3f3f3f3f;
int n;
int ecnt, adj[N], nxt[2*N], go[2*N], w[2*N];
int fa[N], cnt[3], dis[N], sze[N], son[N], que[N], qr;
bool vis[N];
ll ans;
void add(int u, int v, int ww){
go[++ecnt] = v;
nxt[ecnt] = adj[u];
adj[u] = ecnt;
w[ecnt] = ww;
}
int calcG(int u){
que[qr = 1] = u, fa[u] = 0;
for(int ql = 1; ql <= qr; ql++){
u = que[ql], sze[u] = 1, son[u] = 0;
for(int e = adj[u], v; e; e = nxt[e])
if(!vis[v = go[e]] && v != fa[u])
que[++qr] = v, fa[v] = u;
}
int ret, mi = INF;
for(int ql = qr; ql; ql--){
u = que[ql];
sze[fa[u]] += sze[u];
son[fa[u]] = max(son[fa[u]], sze[u]);
son[u] = max(son[u], qr - sze[u]);
if(son[u] < mi) ret = u, mi = son[u];
}
return ret;
}
ll calc(int u, int l){
memset(cnt, 0, sizeof(cnt));
que[qr = 1] = u, dis[u] = l % 3, fa[u] = 0;
for(int ql = 1; ql <= qr; ql++){
u = que[ql], cnt[dis[u]]++;
for(int e = adj[u], v; e; e = nxt[e])
if(!vis[v = go[e]] && v != fa[u])
que[++qr] = v, fa[v] = u, dis[v] = (dis[u] + w[e]) % 3;
}
return (ll)cnt[0] * cnt[0] + (ll)cnt[1] * cnt[2] * 2;
//一个是1一个是2的情况,正着反着各算一种,要乘二
}
void solve(int u){
u = calcG(u);
vis[u] = 1;
ans += calc(u, 0);
for(int e = adj[u]; e; e = nxt[e])
if(!vis[go[e]]) ans -= calc(go[e], w[e]);
for(int e = adj[u]; e; e = nxt[e])
if(!vis[go[e]]) solve(go[e]);
}
ll gcd(ll a, ll b){
return b ? gcd(b, a % b) : a;
}
int main(){
read(n);
for(int i = 1, u, v, ww; i < n; i++)
read(u), read(v), read(ww), add(u, v, ww), add(v, u, ww);
solve(1);
ll g = gcd(ans, (ll)n * n);
printf("%lld/%lld\n", ans / g, (ll)n * n / g);
return 0;
}
原文:http://www.cnblogs.com/RabbitHu/p/BZOJ2152.html