求对于有 $n$ 个点的 $e$ 个简单环。有 $k$ 个守卫,每个环至少要有一个守卫的方案数。
$1\leq k\leq n\leq 152501$
考虑对于朴素 $O(n^2)\space dp$ 的优化,简单思考后发现 $dp$ 的过程其实是一个背包卷积的过程。
考虑对每个简单环构造生成函数 $A$ ,则 $A_i=C_{num}^i$ , $num$ 表示其环中节点个数。
$B=\prod_{i=1}^e$ ,答案则为 $B_k$。
现在的问题变成了求 $e$ 个多项式的卷积,而暴力卷积时间复杂度为 $O(n^2\log n)$ ,考虑分治优化即可。
时间复杂度 $O(n\log^2 n)$ 。
#include<iostream> #include<cstring> #include<cstdio> #include<algorithm> #include<vector> #define int long long #define mod 998244353 using namespace std; inline int read(){ int f=1,ans=0;char c=getchar(); while(c<‘0‘||c>‘9‘){if(c==‘-‘)f=-1;c=getchar();} while(c>=‘0‘&&c<=‘9‘){ans=ans*10+c-‘0‘;c=getchar();} return f*ans; } const int MAXN=2402502; int T,n,k,ff[MAXN],M[MAXN],Num[MAXN]; vector<int> ve[MAXN]; int fac[MAXN],inv[MAXN],infac[MAXN]; inline void init(){ fac[0]=fac[1]=1; for(int i=2;i<=152501;i++) fac[i]=fac[i-1]*i,fac[i]%=mod; inv[1]=1;for(int i=2;i<=152501;i++) inv[i]=((mod-mod/i)*inv[mod%i])%mod; infac[0]=1; for(int i=1;i<=152501;i++) infac[i]=infac[i-1]*inv[i],infac[i]%=mod; return; } int find(int x){ if(ff[x]==x) return x; return ff[x]=find(ff[x]); } int merge(int x,int y){ int t1=find(x),t2=find(y); ff[t2]=t1; } int ksm(int a,int b){ int ans=1; while(b){ if(b&1) ans*=a,ans%=mod; a*=a,a%=mod; b>>=1; }return ans; } inline int C(int a,int b){return (((fac[a]*infac[b])%mod)*infac[a-b])%mod;} int f[MAXN],g[MAXN],N,Lim,flip[MAXN]; inline void NTT(int *f,int opt){ for(int i=0;i<N;i++) if(i<flip[i]) swap(f[i],f[flip[i]]); for(int p=2;p<=N;p<<=1){ int len=p>>1,buf=ksm(3,(mod-1)/p); if(opt==-1) buf=ksm(buf,mod-2); for(int be=0;be<N;be+=p){ int tmp=1; for(int l=be;l<be+len;l++){ int t=(f[l+len]*tmp)%mod; f[l+len]=(f[l]-t+mod)%mod,f[l]=(f[l]+t)%mod; tmp*=buf,tmp%=mod; } } }if(opt==-1){ int Inv=ksm(N,mod-2); for(int i=0;i<N;i++) f[i]*=Inv,f[i]%=mod; }return; } inline void _NTT(vector<int> &F,vector<int> G){ int sizf=F.size()-1,sizg=G.size()-1; Lim=sizf+sizg; for(N=1;N<=Lim;N<<=1); for(int i=0;i<N;i++) flip[i]=((flip[i>>1]>>1)|(i&1?N>>1:0)); for(int i=0;i<=sizf;i++) f[i]=F[i]; for(int i=0;i<=sizg;i++) g[i]=G[i]; for(int i=sizf+1;i<=N;i++) f[i]=0; for(int i=sizg+1;i<=N;i++) g[i]=0; NTT(f,1),NTT(g,1); for(int i=0;i<N;i++) f[i]*=g[i],f[i]%=mod; NTT(f,-1); F.clear(); for(int i=0;i<=Lim;i++) F.push_back(f[i]);return; } inline void cdq(int l,int r){ if(l==r) return; int mid=l+r>>1; cdq(l,mid),cdq(mid+1,r); _NTT(ve[l],ve[mid+1]); return; } void solve(){ memset(M,0,sizeof(M)),memset(Num,0,sizeof(Num)); n=read(),k=read(); for(int i=1;i<=n;i++) ff[i]=i; for(int i=1;i<=n;i++) merge(i,read()); for(int i=1;i<=n;i++) ff[i]=find(ff[i]); for(int i=1;i<=n;i++){ if(!M[ff[i]]) M[ff[i]]=++M[0]; Num[M[ff[i]]]++; } for(int i=1;i<=M[0];i++) for(int j=0;j<=Num[i];j++) { if(j) ve[i].push_back(C(Num[i],j)); else ve[i].push_back(0); } cdq(1,M[0]); int Ans1=ve[1][k],Ans2=C(n,k); printf("%lld\n",(Ans1*ksm(Ans2,mod-2))%mod); for(int i=1;i<=M[0];i++) ve[i].clear(); return; } signed main(){ freopen("bishop.in","r",stdin); freopen("bishop.out","w",stdout); T=read();init(); while(T--) solve(); return 0; }
给定一棵有 $n$ 个节点的有根带权,第 $i$ 号点有 $val_i$ 表示 $i$ 号点有多少人。
$q$ 次询问,每次询问 $(u,v)$ 简单路径上的人走到哪个点的总距离最少,求其总距离或更改 $val$ 。
$n,q\leq 152501$
降智好题,忘记了有中位数这个东西,一直在想边对答案的贡献。
考虑若我们将 $(u,v)$ 这条链摘出来,则其选择的点为其中位数(也可以说是让两边 $val$ 值最少),问题就变成了求 $(u,v)$ 到 $x$ 的总距离。
我们设 $F_i=\sum_{lca(i,j)=j} dis_j\times val_j$ ,$dis_i$ 表示 $i$ 号点到跟的距离。
我们设 $x$ 与 $u$ 在同侧,将路径拆为 $(u,x),(x,lca),(lca,v)$ 。
$$Ans=W(u,x)+W(x,lca)+W(lca,v)\\=(F_u-F_{fath_x}-dis_{fath}\times \sum_{i\in (u,x)} C_i)+(F_v-F_{fath_{lca}}-dis_{lca}\times \sum_{i\in{lca,v}} C_i)+(dis_{lca}\times \sum _{i\in {x,lca}}C_i-(F_{x}-F_{fath_{lca}}))$$
直接用线段树或者树状数组加 $dfs$ 序(因为我们发现信息都只有一条与根相连的链),维护 $F$ 与 $C$ 的信息即可。
而求中位数直接比较后倍增或者二分维护即可。
而对于树链剖分时间复杂度 $O(n\log ^3 n)$ ,面对 $n,q\leq 152501$ 的数据会 $T $ 。
利用 $dfs$ 序优化即可,时间复杂度 $O(n\log ^2 n)$ 。
#include<iostream> #include<cstring> #include<cstdio> #include<algorithm> #define int long long using namespace std; inline int read(){ int f=1,ans=0;char c=getchar(); while(c<‘0‘||c>‘9‘){if(c==‘-‘)f=-1;c=getchar();} while(c>=‘0‘&&c<=‘9‘){ans=ans*10+c-‘0‘;c=getchar();} return f*ans; } const int MAXN=200001; int n; int in[MAXN],out[MAXN],fa[MAXN][21],dep[MAXN],cnt,head[MAXN],val[MAXN],tot,dis[MAXN]; struct node{ int u,v,w,nex; }x[MAXN<<1]; struct BIT{ int sum[MAXN]; int lowbit(int x){return x&-x;} void Modify(int x,int w){ for(;x<=n;x+=lowbit(x)) sum[x]+=w; return; } inline int Query(int x){ int ans=0; for(;x;x-=lowbit(x)) ans+=sum[x]; return ans; } inline void Add(int u,int w){ Modify(in[u],w),Modify(out[u]+1,-w); return; } inline int Que(int u){return Query(in[u]);} }t1,t2; inline void add(int u,int v,int w){ x[cnt].u=u,x[cnt].v=v,x[cnt].w=w,x[cnt].nex=head[u],head[u]=cnt++; } inline void dfs(int u,int fath){ fa[u][0]=fath;dep[u]=dep[fath]+1; in[u]=++tot; for(int i=1;(1<<i)<=dep[u];i++) fa[u][i]=fa[fa[u][i-1]][i-1]; for(int i=head[u];i!=-1;i=x[i].nex){ if(x[i].v==fath) continue; dis[x[i].v]=dis[u]+x[i].w; dfs(x[i].v,u); }out[u]=tot;return; } inline int Lca(int u,int v){ if(dep[u]<dep[v]) swap(u,v); for(int i=20;i>=0;i--) if(dep[u]-(1<<i)>=dep[v]) u=fa[u][i]; if(u==v) return u; for(int i=20;i>=0;i--){ if(fa[u][i]==fa[v][i]) continue; u=fa[u][i],v=fa[v][i]; }return fa[u][0]; } void Modify(int u,int w){ t1.Add(u,w-val[u]);t2.Add(u,(w-val[u])*dis[u]); val[u]=w;return; } inline int qcnt(int u,int v){ int lca=Lca(u,v); return t1.Que(u)+t1.Que(v)-2*t1.Que(lca)+val[lca]; } inline int Q1(int u,int v){ return t2.Que(u)-t2.Que(fa[v][0])-dis[v]*qcnt(u,v); } inline int Q2(int u,int v){ return dis[u]*qcnt(u,v)-(t2.Que(u)-t2.Que(fa[v][0])); } inline int Query(int u,int v){ int lca=Lca(u,v),Num=qcnt(u,v); int res=(Num+1)/2,tmp; if(val[u]>=res) tmp=u; else if(val[v]>=res) tmp=v,swap(u,v); else{ if(qcnt(v,lca)>=res) swap(u,v); tmp=u; for(int i=20;i>=0;i--) if(qcnt(fa[tmp][i],u)<res&&dep[tmp]-(1<<i)>=dep[lca]) tmp=fa[tmp][i]; tmp=fa[tmp][0]; } int Ans=0; Ans+=Q1(u,tmp); Ans+=Q1(v,lca); Ans+=Q2(tmp,lca); int G=qcnt(v,lca)-val[lca]; Ans+=G*(dis[tmp]-dis[lca]); return Ans; } int q; signed main(){ freopen("conference.in","r",stdin); freopen("conference.out","w",stdout); memset(head,-1,sizeof(head)); n=read(); for(int i=1;i<=n;i++) val[i]=read(); for(int i=1;i<n;i++){ int u=read(),v=read(),w=read(); add(u,v,w),add(v,u,w); } dfs(1,0); for(int i=1;i<=n;i++){ int t=val[i];val[i]=0; Modify(i,t); } q=read(); for(int i=1;i<=q;i++){ int opt=read(); if(opt==1){ int u=read(),v=read(); printf("%lld\n",Query(u,v)); }else{ int u=read(),w=read(); Modify(u,w); } }return 0; }
原文:https://www.cnblogs.com/si-rui-yang/p/11296737.html