考虑将每种颜色构成的极小连通块缩点,然后直接跑树形dp即可,即f[i][0/1]表示子树内是否有颜色向上延伸时删边的方案数。dp时需要去除某点的贡献,最好用前后缀积的做法而不是求逆。
至于如何缩点,假装要给每种颜色建虚树,按dfs序排一下序找到所有虚树上的边,标记所有虚树上的点(包括不在虚树中但在虚树上两点的路径中)即可。然后重建树。注意标记过程中判一下无解。
(这个div3F代码长度怎么跟我div1F差不多了啊?
#include<iostream> #include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<algorithm> #include<vector> using namespace std; #define ll long long #define P 998244353 #define N 300010 char getc(){char c=getchar();while ((c<‘A‘||c>‘Z‘)&&(c<‘a‘||c>‘z‘)&&(c<‘0‘||c>‘9‘)) c=getchar();return c;} int gcd(int n,int m){return m==0?n:gcd(m,n%m);} int read() { int x=0,f=1;char c=getchar(); while (c<‘0‘||c>‘9‘) {if (c==‘-‘) f=-1;c=getchar();} while (c>=‘0‘&&c<=‘9‘) x=(x<<1)+(x<<3)+(c^48),c=getchar(); return x*f; } int n,m,a[N],p[N<<1],dfn[N],fa[N][20],deep[N],f[N<<1],F[N<<1][2],pre[N<<1],suf[N<<1],t,T,cnt,stk[N],top; vector<int> pos[N]; struct data{int to,nxt; }edge[N<<1],tree[N<<2]; void addedge(int x,int y){t++;edge[t].to=y,edge[t].nxt=p[x],p[x]=t;} void new_addedge(int x,int y){T++;tree[T].to=y,tree[T].nxt=p[x],p[x]=T;} void dfs(int k) { dfn[k]=++cnt; for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=fa[k][0]) { fa[edge[i].to][0]=k; deep[edge[i].to]=deep[k]+1; dfs(edge[i].to); } } bool cmp(const int&a,const int&b) { return dfn[a]<dfn[b]; } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); for (int j=19;~j;j--) if (deep[fa[x][j]]>=deep[y]) x=fa[x][j]; if (x==y) return x; for (int j=19;~j;j--) if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j]; return fa[x][0]; } bool paint(int x,int y,int color) { while (x!=y) { if (f[x]>n&&f[x]!=color) return 1; f[x]=color;x=fa[x][0]; } return 0; } void dp(int k,int from) { for (int i=p[k];i;i=tree[i].nxt) if (tree[i].to!=from) dp(tree[i].to,k); if (k>n) { F[k][0]=0;F[k][1]=1; for (int i=p[k];i;i=tree[i].nxt) if (tree[i].to!=from) F[k][1]=1ll*F[k][1]*(F[tree[i].to][0]+F[tree[i].to][1])%P; } else { F[k][0]=1;int cnt=0; for (int i=p[k];i;i=tree[i].nxt) if (tree[i].to!=from) { F[k][0]=1ll*F[k][0]*(F[tree[i].to][0]+F[tree[i].to][1])%P; pre[++cnt]=F[tree[i].to][0]+F[tree[i].to][1]; } for (int i=1;i<=cnt;i++) suf[i]=pre[i]; pre[0]=1;for (int i=1;i<=cnt;i++) pre[i]=1ll*pre[i-1]*pre[i]%P; suf[cnt+1]=1;for (int i=cnt;i>=1;i--) suf[i]=1ll*suf[i]*suf[i+1]%P; int t=0; for (int i=p[k];i;i=tree[i].nxt) if (tree[i].to!=from) { t++; F[k][1]=(F[k][1]+1ll*pre[t-1]*suf[t+1]%P*F[tree[i].to][1])%P; } } } signed main() { #ifndef ONLINE_JUDGE freopen("f.in","r",stdin); freopen("f.out","w",stdout); #endif n=read(),m=read(); for (int i=1;i<=n;i++) a[i]=read(); for (int i=1;i<n;i++) { int x=read(),y=read(); addedge(x,y),addedge(y,x); } fa[1][0]=1;dfs(1); for (int j=1;j<20;j++) for (int i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1]; for (int i=1;i<=n;i++) if (a[i]) f[i]=n+a[i],pos[a[i]].push_back(i);else f[i]=i; for (int i=1;i<=m;i++) if (pos[i].size()>1) { sort(pos[i].begin(),pos[i].end(),cmp); int root=(*pos[i].begin()); for (int j=0;j<pos[i].size()-1;j++) if (deep[lca(pos[i][j],pos[i][j+1])]<deep[root]) root=lca(pos[i][j],pos[i][j+1]); if (f[root]>n&&f[root]!=n+i) {cout<<0;return 0;} top=0;stk[++top]=root;f[root]=n+i; for (int j=((*pos[i].begin())==root);j<pos[i].size();j++) { int l=lca(stk[top],pos[i][j]); if (l!=stk[top]) { while (top>1&&deep[l]<=deep[stk[top-1]]) { if (paint(stk[top],stk[top-1],n+i)) {cout<<0;return 0;} top--; } if (paint(stk[top],l,n+i)) {cout<<0;return 0;} stk[top]=l; } stk[++top]=pos[i][j]; } while (top>1) { if (paint(stk[top],stk[top-1],n+i)) {cout<<0;return 0;} top--; } } //for (int i=1;i<=n;i++) cout<<f[i]<<‘ ‘;cout<<endl; memset(p,0,sizeof(p)); for (int i=1;i<=t;i+=2) if (f[edge[i].to]!=f[edge[i+1].to]) new_addedge(f[edge[i].to],f[edge[i+1].to]), new_addedge(f[edge[i+1].to],f[edge[i].to]); dp(n+1,n+1); cout<<F[n+1][1]; return 0; //NOTICE LONG LONG!!!!! }
Codeforces Round #540 Div. 3 F2
原文:https://www.cnblogs.com/Gloid/p/10409045.html