给出一个 $n*m$ 的棋盘,在上面放满白棋和黑棋,问使得仅由黑棋构成的四联通块的数量为 $k$ 的棋子放置方案的数量。
$ n \leq 3,m \leq 10^5,n*m \leq 10^5,k \leq n*m$,答案对 $998244353$ 取模。
考虑朴素dp,设$f[i][j][s]$表示当前在第 $i$ 列,构成了 $j$ 个黑四联通块,当前行的涂色情况为 $s$ 的方案数。
$s$ 是一个状压之后的数,当 $n \leq 2$ 时 $s < 2^n$ ,而当$n=3$,由于特殊状态的存在,$s < 2^n +1$。
这里的特殊状态是指,形如下图的情况:
01111 01111
01000 00000
01111 01111
上面的两种 $1?0?1$ 状态是不同的,前一种上下的 $1$ 之间已经联通,而后一种没有,因此要分开讨论。
在 $n$ 的值不同时分别讨论,实现三个不同的dp即可。
然后你就有 $69$ 分了。
然而还是不够。
定义生成函数 $g(i,s)=\sum_{j}f[i][j][s]x^j$,那么答案即为 $\sum_{s}[x^k]g(m,s)$。
直接求这个生成函数就是暴力dp的做法。
考虑计算一些点值,并通过这些点值解出这个多项式。
对于单个点值的计算,显然可以矩阵快速幂计算出。
问题在于,如何将点值还原成一个多项式。
考虑人民的好朋友NTT的原理:
将原多项式使用DFT带入若干个单位根化成点值,进行完奇怪的操作后,再将利用了单位根计算出的点值用IDFT还原成多项式。
也就是说,如果能获得使用单位根计算出的点值,就能使用IDFT解出目标多项式。
那么就可以矩阵快速幂求出在单位根上的点值,并$IDFT$ 出原多项式即可~
代码:
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
using namespace std;
typedef long long ll;
const int E=9;
const int N=100009;
const ll md=998244353;
int n,m,k,l;
ll w[N],a[N],rev[N];
struct matrix
{
int lena,lenb;
ll a[E][E];
matrix(int _a=0,int b=0){memset(a,0,sizeof(a));lena=_a;lenb=b;}
inline void e()
{
memset(a,0,sizeof(a));
for(int i=0;i<lena;i++)
a[i][i]=1;
}
matrix operator * (const matrix &o)const
{
matrix ret(lena,o.lenb);
for(int i=0;i<lena;i++)
for(int j=0;j<o.lenb;j++)
for(int k=0;k<lenb;k++)
(ret.a[i][j]+=a[i][k]*o.a[k][j]%md)%=md;
return ret;
}
}base,mul;
inline matrix qpow(matrix a,int b)
{
matrix ret(a.lena,a.lenb);ret.e();
while(b)
{
if(b&1)ret=ret*a;
a=a*a;b>>=1;
}
return ret;
}
inline ll qpow(ll a,ll b)
{
ll ret=1;
while(b)
{
if(b&1)ret=ret*a%md;
a=a*a%md;b>>=1;
}
return ret;
}
inline void ntt(ll *a,int n)
{
for(int i=0;i<n;i++)
rev[i]=((rev[i>>1])>>1)|((i&1)*(n>>1));
for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int h=2;h<=n;h<<=1)
{
ll w=qpow(qpow(3,(md-1)/h),md-2);
for(int j=0;j<n;j+=h)
{
ll wn=1ll,x,y;
for(int k=j;k<j+(h>>1);k++)
{
x=a[k],y=wn*a[k+(h>>1)]%md;
a[k]=(x+y)%md;a[k+(h>>1)]=(x-y+md)%md;
wn=wn*w%md;
}
}
}
for(int inv=qpow(n,md-2),i=0;i<n;i++)
a[i]=a[i]*(ll)inv%md;
}
inline ll m1(ll w)
{
ll m1[2][2]=
{
{1,w},
{1,1},
};
int sz=2;
base=matrix(sz,sz);base.e();
mul=matrix(sz,sz);
for(int i=0;i<sz;i++)
for(int j=0;j<sz;j++)
mul.a[i][j]=m1[i][j];
base=base*qpow(mul,m);
ll ret=0;
for(int i=0;i<sz;i++)
ret+=base.a[0][i];
return ret;
}
inline ll m2(ll w)
{
ll m2[4][4]=
{
{1,w,w,w},
{1,1,w,1},
{1,w,1,1},
{1,1,1,1},
};
int sz=4;
base=matrix(sz,sz);base.e();
mul=matrix(sz,sz);
for(int i=0;i<sz;i++)
for(int j=0;j<sz;j++)
mul.a[i][j]=m2[i][j];
base=base*qpow(mul,m);
ll ret=0;
for(int i=0;i<sz;i++)
ret+=base.a[0][i];
return ret;
}
inline ll m3(ll w)
{
ll m3[9][9]=
{
{1,w,w,w,w,w,w*w%md,0,w},
{1,1,w,w,1,w,w,0,1},
{1,w,1,w,1,1,w*w%md,0,1},
{1,w,w,1,w,1,w,0,1},
{1,1,1,w,1,1,w,0,1},
{1,w,1,1,1,1,w,0,1},
{1,1,w,1,1,1,1,0,qpow(w,md-2)},
{1,1,w,1,1,1,0,1,1},
{1,1,1,1,1,1,0,1,1}
};
int sz=9;
base=matrix(sz,sz);base.e();
mul=matrix(sz,sz);
for(int i=0;i<sz;i++)
for(int j=0;j<sz;j++)
mul.a[i][j]=m3[i][j];
base=base*qpow(mul,m);
ll ret=0;
for(int i=0;i<sz;i++)
ret+=base.a[0][i];
return ret;
}
int main()
{
scanf("%d%d%d",&n,&m,&k);
if(k>(n*m+1)/2)return puts("0"),0;
for(l=1;l<=((n*m+1)>>1);l<<=1);
w[0]=1;w[1]=qpow(3,(md-1)/l);
for(int i=2;i<l;i++)
w[i]=w[i-1]*w[1]%md;
for(int i=0;i<l;i++)
a[i]=(n==1?m1:(n==2?m2:m3))(w[i]);
ntt(a,l);
printf("%lld\n",a[k]);
return 0;
}
原文:https://www.cnblogs.com/zltttt/p/9085983.html