1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=4e4+8; 5 int n,all,mx,root,cntd=0,cnt=0; 6 ll ans=0,k; 7 int h[N],siz[N]; 8 ll q[N],dis[N]; 9 bool vis[N]; 10 struct edge{ 11 int to,nex,w; 12 }e[N<<2]; 13 void add(int u,int v,int w){ 14 e[++cnt]=(edge){v,h[u],w}; 15 h[u]=cnt; 16 } 17 void getrt(int u,int fa){ 18 siz[u]=1; 19 int num=0; 20 for(int i=h[u];i;i=e[i].nex){ 21 int v=e[i].to; 22 if(v==fa || vis[v]) continue; 23 getrt(v,u); 24 siz[u]+=siz[v]; 25 num=max(num,siz[v]); 26 } 27 num=max(num,all-siz[u]); 28 if(num<mx){ 29 mx=num; 30 root=u; 31 } 32 } 33 void getdis(int u,int fa){ 34 q[++cntd]=dis[u]; 35 for(int i=h[u];i;i=e[i].nex){ 36 int v=e[i].to; 37 if(v==fa || vis[v]) continue; 38 dis[v]=dis[u]+e[i].w; 39 getdis(v,u); 40 } 41 } 42 ll calc(int u,int len){ 43 cntd=0; 44 dis[u]=len; 45 getdis(u,0); 46 ll sum=0; 47 int l=1,r=cntd; 48 sort(q+1,q+1+cntd); 49 while(l<r){ 50 if(q[l]+q[r]<=k) sum+=r-l,++l; 51 else --r; 52 } 53 return sum; 54 } 55 void dfs(int u){ 56 ans+=calc(u,0); 57 vis[u]=1; 58 for(int i=h[u];i;i=e[i].nex){ 59 int v=e[i].to; 60 if(vis[v]) continue; 61 ans-=calc(v,e[i].w); 62 all=siz[v]; 63 mx=0x3f3f3f3f; 64 getrt(v,0); 65 dfs(root); 66 } 67 } 68 int main(){ 69 scanf("%d",&n); 70 for(int i=1;i<n;++i){ 71 int a,b,c;scanf("%d%d%d",&a,&b,&c); 72 add(a,b,c); 73 add(b,a,c); 74 } 75 scanf("%lld",&k); 76 all=n; 77 mx=0x3f3f3f3f; 78 getrt(1,0); 79 dfs(root); 80 printf("%lld",ans); 81 return 0; 82 }
原文:https://www.cnblogs.com/xiaobuxie/p/11373840.html