http://www.lydsy.com/JudgeOnline/problem.php?id=3992 (题目链接)
集合${S}$中有若干个不超过${m}$的非负整数,问由这些数组成一个长度${n}$的序列,使序列中的数的乘积对${m}$取模正好等于${x}$,问存在多少方案。
好神的题。算法还是要多复习,我连${NTT}$都忘记怎么写了T_T
这还是我的第一发原根→_→。对于原根的求法,我们枚举${rt=2 to inf}$,然后判断是否存在${t<m-1}$使${rt^t=1}$。所以我们枚举${t=2 to \sqrt{m}}$,依次进行判断即可${rt^{(m-1)/i}}$是否${=1}$就可以了。
然后我们可以很简单的列出dp方程${f_{i,j}}$表示,已经放到了第${i}$个数,它们的乘积是${j}$的方案数。转移也就很显然了:$${f[i][j]=\sum_{k=1}^{m-1}f_{i-1,j*inv[k]}}$$
复杂度${O(nm^2)}$,于是我们就可以获得10分的高分,是不是很良心啊。
考虑这个东西怎么优化,我们把每一个${j}$都写成${m}$的原根的几次方,然后乘就变成加辣,然后我们就可以卷积辣。
然后你发现${n}$有${10^9}$,我们快速幂一波,然后就AC辣。
一开始没想清没注意到还是循环卷积卧槽T_T
// bzoj3992 #include<algorithm> #include<iostream> #include<cstdlib> #include<cstring> #include<cstdio> #include<cmath> #include<ctime> #define LL long long #define inf (1ll<<30) #define MOD 1004535809 #define Pi acos(-1.0) #define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout); using namespace std; const int maxn=20010; int f[maxn],g[maxn],rev[maxn],vis[maxn]; int n,m,rt,S,X,N,L; int power(int a,int b,int c) { int res=1; while (b) { if (b&1) res=(LL)res*a%c; b>>=1;a=(LL)a*a%c; } return res; } void root(int p) { if (p==2) {rt=1;return;} for (rt=2;;rt++) { int flag=1; for (int i=2;i*i<=p;i++) if (power(rt,(p-1)/i,p)==1) {flag=0;break;} if (flag) break; } } namespace NTT { LL A[maxn],B[maxn]; void NTT(LL *a,int f) { for (int i=0;i<N;i++) if (i<rev[i]) swap(a[i],a[rev[i]]); for (int i=1;i<N;i<<=1) { LL gn=power(3,(MOD-1)/(i<<1),MOD); for (int p=i<<1,j=0;j<N;j+=p) { LL g=1; for (int k=0;k<i;k++,(g*=gn)%=MOD) { LL x=a[k+j],y=g*a[k+j+i]%MOD; a[k+j]=(x+y)%MOD,a[k+j+i]=(x-y+MOD)%MOD; } } } if (f==-1) reverse(a+1,a+N); } void Init(int *a,int *b) { for (int i=0;i<N;i++) A[i]=a[i],B[i]=b[i]; NTT(A,1);NTT(B,1); for (int i=0;i<N;i++) (A[i]*=B[i])%=MOD; NTT(A,-1); LL ev=power(N,MOD-2,MOD); for (int i=0;i<N;i++) (A[i]*=ev)%=MOD; for (int i=0;i<m-1;i++) a[i]=(A[i]+A[i+m-1])%MOD; } } using namespace NTT; int main() { scanf("%d%d%d%d",&n,&m,&X,&S); root(m); for (int x,i=1;i<=S;i++) scanf("%d",&x),vis[x]=1; for (int p=1,i=0;i<m-1;i++,(p*=rt)%=m) if (vis[p]) f[i]=1; for (N=1,L=-1;N<(m-1)*2;N<<=1) L++; for (int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1) | ((i&1)<<L); g[0]=1; while (n) { if (n&1) Init(g,f); n>>=1;Init(f,f); } for (int i=0,p=1;i<m-1;i++,(p*=rt)%=m) if (p==X) {printf("%d",g[i]);break;} return 0; }
原文:http://www.cnblogs.com/MashiroSky/p/6395794.html