题目描述:在一个\(n\times m\)的网格中,放\(2n\)个棋子,使每一行和每一列都不超过两个棋子。求方案数\(\mathrm{mod} \ 943718401\)。
数据范围:\(n\le m\le 2\times 10^6\)
首先你要知道这个模数是个 NTT 模数。注意到每一行都要有两个棋子。设有\(k\)列有两个棋子,则有\(2(n-k)\)列有一个棋子的方案数是\(S_k\)。
\[
Ans=\sum_{k=0}^n\binom{m}{k,2(n-k)}S_k
\]
然后考虑计算\(S_k\),转换一下,一共有\(2n\)个球,每个颜色有两个球,扔给\(k+2(n-k)\)个有区别的盒子(列),每个盒子里面的两个球颜色不同。然后上个容斥就可以了。
我们假设同一个盒子里面的两个球不同,同种颜色的两个球不同,最后除掉。
\[
S_k=\frac{1}{2^{n+k}}\sum_{i=0}^k(-1)^ii!\binom{k}{i}\binom{n}{i}2^i(2n-2i)!
\]
这是一个卷积的形式,时间复杂度\(O(n\log n)\)
code
#include<bits/stdc++.h>
#define Rint register int
using namespace std;
typedef long long LL;
const int N = 1 << 22, mod = 943718401, G = 7, Gi = 269633829;
inline int kasumi(int a, int b){
int res = 1;
while(b){
if(b & 1) res = (LL) res * a % mod;
a = (LL) a * a % mod; b >>= 1;
}
return res;
}
int rev[N];
inline int calrev(int len){
int limit = 1, L = -1;
while(limit <= len){limit <<= 1; ++ L;}
for(Rint i = 0;i < limit;i ++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
return limit;
}
inline int add(int a, int b){return (a + b >= mod) ? (a + b - mod) : (a + b);}
inline int sub(int a, int b){return (a < b) ? (a + mod - b) : (a - b);}
inline void upd(int &a, int b){a += b; if(a >= mod) a -= mod;}
inline void NTT(int *A, int limit, int type){
for(Rint i = 0;i < limit;i ++)
if(i < rev[i]) swap(A[i], A[rev[i]]);
for(Rint mid = 1;mid < limit;mid <<= 1){
int Wn = kasumi(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
for(Rint j = 0;j < limit;j += (mid << 1))
for(Rint k = 0, w = 1;k < mid;k ++, w = (LL) w * Wn % mod){
int x = A[j + k], y = (LL) w * A[j + k + mid] % mod;
A[j + k] = add(x, y); A[j + k + mid] = sub(x, y);
}
}
if(type == -1){
int inv = kasumi(limit, mod - 2);
for(Rint i = 0;i < limit;i ++)
A[i] = (LL) A[i] * inv % mod;
}
}
int n, fac[N], inv[N], A[N], B[N], ans;
LL m;
inline void init(int m){
fac[0] = 1;
for(Rint i = 1;i <= m;i ++) fac[i] = (LL) fac[i - 1] * i % mod;
inv[m] = kasumi(fac[m], mod - 2);
for(Rint i = m;i;i --) inv[i - 1] = (LL) inv[i] * i % mod;
}
int main(){
scanf("%d%lld", &n, &m); m %= mod; init(n << 1);
int limit = calrev(n << 1);
for(Rint i = 0, t = 1;i <= n;i ++, t = t * (mod - 2ll) % mod){
A[i] = (LL) t * fac[2 * (n - i)] % mod * inv[i] % mod * inv[n - i] % mod;
B[i] = inv[i];
}
NTT(A, limit, 1); NTT(B, limit, 1);
for(Rint i = 0;i < limit;i ++) A[i] = (LL) A[i] * B[i] % mod;
NTT(A, limit, -1);
int tmp = 1;
for(Rint i = 0;i < n;++ i) tmp = (LL) tmp * sub(m, i) % mod;
for(Rint i = n, t = kasumi(2, mod - 2 * n - 1);~i;tmp = (LL) tmp * sub(m, 2 * n - i) % mod, -- i, t = add(t, t))
upd(ans, (LL) A[i] * inv[2 * (n - i)] % mod * tmp % mod * t % mod);
ans = (LL) ans * fac[n] % mod;
printf("%d\n", ans);
}
原文:https://www.cnblogs.com/AThousandMoons/p/11779200.html