设dp[n]为n个数字排列时候的答案,那么可以得到dp方程
dp[n]=Σdp[n-i]*c(n-1,i-1)*(i-1)!*i^2(1<=i<=n)
然后上式可以化成卷积形式,分治FFT即可。复杂度O(nlogn^2)
代码
#include <iostream> #include <cstring> #include <cstdio> using namespace std; typedef long long LL; const int P = 119*(1<<23)+1; const int N = 1 << 18; const int G = 3; const int NUM = 20; LL wn[NUM],cnt[N],jc[N],dp[N],dp2[N],inv[N]; LL a[N], b[N]; char A[N], B[N]; int len,i; LL quick_mod(LL a, LL b, LL m) { LL ans = 1; a %= m; while(b) { if(b & 1) { ans = ans * a % m; b--; } b >>= 1; a = a * a % m; } return ans; } void GetWn() { for(int i=0; i<NUM; i++) { int t = 1 << i; wn[i] = quick_mod(G, (P - 1) / t, P); } } void Rader(LL a[], int len) { int j = len >> 1; for(int i=1; i<len-1; i++) { if(i < j) swap(a[i], a[j]); int k = len >> 1; while(j >= k) { j -= k; k >>= 1; } if(j < k) j += k; } } void NTT(LL a[], int len, int on) { Rader(a, len); int id = 0; for(int h = 2; h <= len; h <<= 1) { id++; for(int j = 0; j < len; j += h) { LL w = 1; for(int k = j; k < j + h / 2; k++) { LL u = a[k] % P; LL t = w * (a[k + h / 2] % P) % P; a[k] = (u + t) % P; a[k + h / 2] = ((u - t) % P + P) % P; w = w * wn[id] % P; } } } if(on == -1) { for(int i = 1; i < len / 2; i++) swap(a[i], a[len - i]); LL Inv = quick_mod(len, P - 2, P); for(int i = 0; i < len; i++) a[i] = a[i] % P * Inv % P; } } void Conv(LL a[], LL b[], int n) { NTT(a, n, 1); NTT(b, n, 1); for(int i = 0; i < n; i++) a[i] = a[i] * b[i] % P; NTT(a, n, -1); } void solve(long long l,long long r) { int m,i; if (l==r) { dp[l]=(dp[l]+jc[l-1]*l%P*l%P)%P; return; } m=(l+r)>>1; solve(l,m); len=1; while (len<=r-l+1) len<<=1; for (i=l;i<=r;i++) { if (i<=m) a[i-l]=dp[i]*inv[i]%P; else a[i-l]=0; if (i<r) b[i-l+1]=inv[i-l]*jc[i-l]%P*(i-l+1)%P*(i-l+1)%P; else b[i-l+1]=0; } for (i=r-l+1;i<=len;i++) { a[i]=0;b[i]=0; } b[0]=0; Conv(a,b,len); for (i=m+1;i<=r;i++) { if (i-1>0) dp[i]=(dp[i]+a[i-l]*jc[i-1])%P; } solve(m+1,r); } int main() { GetWn(); jc[0]=1;inv[0]=1; for (i=1;i<=100000;i++) { jc[i]=jc[i-1]*i %P; inv[i]=quick_mod(jc[i],P-2,P); } solve(1,100000); int n; while (scanf("%d",&n)==1) printf("%I64d\n",dp[n]); }
原文:http://www.cnblogs.com/fzmh/p/4738048.html