首先,我们要脑洞大开想到点分治。
注意到每个点都要输出对应的\(ans\),所以想到在点分树上跳,得到答案。
一个点答案的贡献要分为好几个部分。
这是最容易统计的部分。你需要得到点分树上自己的子树对自己的贡献。
显然一个dfs即可搞定。
假设当前要求的点是\(x\),枚举它的祖先为\(y\)。
我们需要得到\(y\rightarrow x\)这条路径上的贡献乘上\(size_y-size_x\),其中\(size_x\)表示x所在子树的大小,\(size_y\)表示\(y\)在点分树上子树大小。
显然这个可以在建点分树时用一个dfs搞定。
还是上面的\(x,y\),我们需要统计\(y\)子树中除去\(x\)子树的点对\(x\)的贡献。
考虑每种颜色分开考虑。
对于每个颜色,统计会出现它的路径,可以一个dfs搞定。
然后把\(x\)所在的子树的贡献删去之后就可以了。
发现第二部分和第三部分会算重,要把重复的去掉。
对于\(y\rightarrow x\)路径上出现过的颜色,要把它出现过的路径数删去。
也是建树时一个dfs可以搞定的。
你发现上面这些全是建树时就可以算出来的,于是你根本不需要建树,直接搞就好了。
细节较多,祝你好运。
不过把我的缺省源删去后其实还挺短的,只有大概110行。
#include<bits/stdc++.h>
clock_t t=clock();
namespace my_std{
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
#define MP make_pair
#define rep(i,x,y) for (int i=(x);i<=(y);i++)
#define drep(i,x,y) for (int i=(x);i>=(y);i--)
#define go(x) for (int i=head[x];i;i=edge[i].nxt)
#define templ template<typename T>
#define sz 101010
typedef long long ll;
typedef double db;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
templ inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
templ inline bool chkmax(T &x,T y){return x<y?x=y,1:0;}
templ inline bool chkmin(T &x,T y){return x>y?x=y,1:0;}
templ inline void read(T& t)
{
t=0;char f=0,ch=getchar();double d=0.1;
while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
t=(f?-t:t);
}
template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
char sr[1<<21],z[20];int C=-1,Z=0;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
inline void print(register int x)
{
if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x;
while(z[++Z]=x%10+48,x/=10);
while(sr[++C]=z[Z],--Z);sr[++C]='\n';
}
void file()
{
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
#endif
}
inline void chktime()
{
#ifndef ONLINE_JUDGE
cout<<(clock()-t)/1000.0<<'\n';
#endif
}
#ifdef mod
ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;return ret;}
ll inv(ll x){return ksm(x,mod-2);}
#else
ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;return ret;}
#endif
// inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std;
int n;
int col[sz];
struct hh{int t,nxt;}edge[sz<<1];
int head[sz],ecnt;
void make_edge(int f,int t)
{
edge[++ecnt]=(hh){t,head[f]};
head[f]=ecnt;
edge[++ecnt]=(hh){f,head[t]};
head[t]=ecnt;
}
ll ans[sz];
bool vis[sz];
int size[sz],mn,root,sum;
#define v edge[i].t
void calcsize(int x,int fa)
{
size[x]=1;
go(x)
if (v!=fa&&!vis[v])
calcsize(v,x),size[x]+=size[v];
}
void findroot(int x,int fa)
{
int S=-1;
go(x)
if (v!=fa&&!vis[v])
findroot(v,x),chkmax(S,size[v]);
chkmax(S,sum-size[x]);
if (chkmin(mn,S)) root=x;
}
void findroot(int x){calcsize(x,0);mn=1e9;sum=size[x];findroot(x,0);}
int cnt[sz]; // 每种颜色出现次数
int S1; // 出现过的颜色数量
void dfs1(int x,int u,int fa) // 统计u往下的路径
{
++cnt[col[x]];if (cnt[col[x]]==1) ++S1;
ans[u]+=S1;
go(x) if (!vis[v]&&v!=fa) dfs1(v,u,x);
--cnt[col[x]];if (cnt[col[x]]==0) --S1;
}
ll S[sz],SS; // 每种颜色出现过的路径数,\sum_i S[i]
void dfs2(int x,int fa) // 统计每种颜色出现过的路径数
{
++cnt[col[x]];if (cnt[col[x]]==1) S[col[x]]+=size[x],SS+=size[x];
go(x) if (!vis[v]&&v!=fa) dfs2(v,x);
--cnt[col[x]];
}
void clr(int x,int fa)
{
++cnt[col[x]];if (cnt[col[x]]==1) S[col[x]]-=size[x],SS-=size[x];
go(x) if (!vis[v]&&v!=fa) clr(v,x);
--cnt[col[x]];
}
ll S2; // 已经重复需要被减去的路径数
void dfs3(int x,int fa,int Size)
{
++cnt[col[x]];if (cnt[col[x]]==1) S2+=S[col[x]],++S1;
ans[x]+=SS+S1*Size-S2;
go(x) if (!vis[v]&&v!=fa) dfs3(v,x,Size);
--cnt[col[x]];if (cnt[col[x]]==0) S2-=S[col[x]],--S1;
}
void CLR(int x,int fa)
{
S[col[x]]=0;
go(x) if (v!=fa&&!vis[v]) CLR(v,x);
}
void calc(int x)
{
calcsize(x,0);
dfs1(x,x,0);
dfs2(x,0);
go(x) if (!vis[v])
{
cnt[col[x]]=1;SS-=size[v];S[col[x]]-=size[v];clr(v,x);cnt[col[x]]=0;
dfs3(v,x,size[x]-size[v]);
cnt[col[x]]=1;SS+=size[v];S[col[x]]+=size[v];dfs2(v,x);cnt[col[x]]=0;
}
CLR(x,0);SS=0;
}
void solve(int x)
{
vis[x]=1;
calc(x);
go(x)
if (!vis[v])
findroot(v),solve(root);
}
#undef v
int main()
{
file();
read(n);
int x,y;
rep(i,1,n) read(col[i]);
rep(i,1,n-1) read(x,y),make_edge(x,y);
findroot(1);solve(root);
rep(i,1,n) printf("%lld\n",ans[i]);
return 0;
}
原文:https://www.cnblogs.com/p-b-p-b/p/10482817.html