调了好久233。
大概想一想就是树分,然后考虑这样路径(u,v)的特征,以根节点(root)切开,u到root的阴阳差值,和v到root巧合互为相反数,然后考虑要有一个点可作为休息点,即u/v到root的路径中要有一点x与u/v到root的阴阳差值相同,然后维护一下就好。
注意的是阴阳差为0的特判……写挂了调好久,对拍也不好写,真是恶心。
1 #define Troy 11/23 2 #define inf 0x7fffffff 3 4 #include <bits/stdc++.h> 5 6 using namespace std; 7 8 inline int read(){ 9 int s=0,k=1;char ch=getchar(); 10 while(ch<‘0‘|ch>‘9‘) ch==‘-‘?k=-1:0,ch=getchar(); 11 while(ch>47&ch<=‘9‘) s=s*10+(ch^48),ch=getchar(); 12 return s*k; 13 } 14 15 const int N=1e5+10; 16 17 typedef long long ll; 18 19 int n; 20 21 struct edges{ 22 int v,w;edges *last; 23 }edge[N<<1],*head[N];int cnt; 24 25 inline void push(int u,int v,int w){ 26 edge[++cnt]=(edges){v,w,head[u]};head[u]=edge+cnt; 27 } 28 29 int tot,root,heavy[N],size[N],top,num,T[N<<1][2],nT[N<<1],pre[N<<1]; 30 31 ll ans,t[N<<1][2]; 32 33 bool vis[N]; 34 35 inline void dfs(int x,int fa){ 36 size[x]=1,heavy[x]=0; 37 for(edges *i=head[x];i;i=i->last)if(!vis[i->v]&&i->v!=fa){ 38 dfs(i->v,x); 39 heavy[x]=max(heavy[x],size[i->v]); 40 size[x]+=size[i->v]; 41 } 42 heavy[x]=max(heavy[x],tot-size[x]); 43 if(top>heavy[x]) 44 top=heavy[x],root=x; 45 } 46 47 #define g(s) t[s+n] 48 #define G(s) T[s+n] 49 #define f(s) pre[s+n] 50 51 inline void update(int x,int s,int fa){ 52 bool a=f(s)>0; 53 if(G(s)[a]!=num) 54 G(s)[a]=num,g(s)[a]=0; 55 ++g(s)[a]; 56 ++f(s); 57 for(edges *i=head[x];i;i=i->last)if(!vis[i->v]&&i->v!=fa) 58 update(i->v,s+(i->w?1:-1),x); 59 --f(s); 60 } 61 62 inline void calc(int x,int s,int fa){ 63 bool a=f(s)>0; 64 ++f(s); 65 if(G(s)[1]==num) 66 ans+=g(s)[1]; 67 if(a&&G(s)[0]==num){ 68 ans+=g(s)[0]; 69 if(!s&&f(s)<=2) ans-=g(s)[0]; 70 } 71 for(edges *i=head[x];i;i=i->last)if(!vis[i->v]&&i->v!=fa) 72 calc(i->v,s+(i->w?-1:1),x); 73 --f(s); 74 } 75 inline void solve(int x){ 76 top=inf; 77 dfs(x,x); 78 vis[x=root]=true; 79 G(0)[0]=++num; 80 g(0)[0]=1; 81 for(edges *i=head[x];i;i=i->last)if(!vis[i->v]){ 82 calc(i->v,i->w?-1:1,x); 83 update(i->v,i->w?1:-1,x); 84 } 85 for(edges *i=head[x];i;i=i->last)if(!vis[i->v]){ 86 tot=size[i->v]; 87 solve(i->v); 88 } 89 } 90 91 int main(){ 92 n=read(); 93 for(int i=1,u,v,w;i^n;++i){ 94 u=read(),v=read(),w=read(); 95 push(u,v,w); 96 push(v,u,w); 97 } 98 tot=n; 99 f(0)=1; 100 solve(1); 101 printf("%lld\n",ans); 102 }
原文:http://www.cnblogs.com/Troywar/p/7885695.html