首页 > 其他 > 详细

[2018.5.22集训]棋盘-矩阵乘法-NTT-插值

时间:2018-05-25 00:57:17      阅读:198      评论:0      收藏:0      [点我收藏+]

题目大意

给出一个 $n*m$ 的棋盘,在上面放满白棋和黑棋,问使得仅由黑棋构成的四联通块的数量为 $k$ 的棋子放置方案的数量。
$ n \leq 3,m \leq 10^5,n*m \leq 10^5,k \leq n*m$,答案对 $998244353$ 取模。

题解

考虑朴素dp,设$f[i][j][s]$表示当前在第 $i$ 列,构成了 $j$ 个黑四联通块,当前行的涂色情况为 $s$ 的方案数。
$s$ 是一个状压之后的数,当 $n \leq 2$ 时 $s < 2^n$ ,而当$n=3$,由于特殊状态的存在,$s < 2^n +1$。
这里的特殊状态是指,形如下图的情况:

01111  01111
01000  00000
01111  01111

上面的两种 $1?0?1$ 状态是不同的,前一种上下的 $1$ 之间已经联通,而后一种没有,因此要分开讨论。

在 $n$ 的值不同时分别讨论,实现三个不同的dp即可。

然后你就有 $69$ 分了。

然而还是不够。
定义生成函数 $g(i,s)=\sum_{j}f[i][j][s]x^j$,那么答案即为 $\sum_{s}[x^k]g(m,s)$。
直接求这个生成函数就是暴力dp的做法。

考虑计算一些点值,并通过这些点值解出这个多项式。
对于单个点值的计算,显然可以矩阵快速幂计算出。
问题在于,如何将点值还原成一个多项式。

考虑人民的好朋友NTT的原理:
将原多项式使用DFT带入若干个单位根化成点值,进行完奇怪的操作后,再将利用了单位根计算出的点值用IDFT还原成多项式。
也就是说,如果能获得使用单位根计算出的点值,就能使用IDFT解出目标多项式。

那么就可以矩阵快速幂求出在单位根上的点值,并$IDFT$ 出原多项式即可~

代码:

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
using namespace std;

typedef long long ll;
const int E=9;
const int N=100009;
const ll md=998244353;

int n,m,k,l;
ll w[N],a[N],rev[N];

struct matrix
{
    int lena,lenb;
    ll a[E][E];

    matrix(int _a=0,int b=0){memset(a,0,sizeof(a));lena=_a;lenb=b;}

    inline void e()
    {
        memset(a,0,sizeof(a));
        for(int i=0;i<lena;i++)
            a[i][i]=1;
    }

    matrix operator * (const matrix &o)const
    {
        matrix ret(lena,o.lenb);
        for(int i=0;i<lena;i++)
            for(int j=0;j<o.lenb;j++)
                for(int k=0;k<lenb;k++)
                    (ret.a[i][j]+=a[i][k]*o.a[k][j]%md)%=md;
        return ret;
    }
}base,mul;

inline matrix qpow(matrix a,int b)
{
    matrix ret(a.lena,a.lenb);ret.e();
    while(b)
    {
        if(b&1)ret=ret*a;
        a=a*a;b>>=1;
    }
    return ret;
}

inline ll qpow(ll a,ll b)
{
    ll ret=1;
    while(b)
    {
        if(b&1)ret=ret*a%md;
        a=a*a%md;b>>=1;
    }
    return ret;
}

inline void ntt(ll *a,int n)
{
    for(int i=0;i<n;i++)
        rev[i]=((rev[i>>1])>>1)|((i&1)*(n>>1));
    for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int h=2;h<=n;h<<=1)
    {
        ll w=qpow(qpow(3,(md-1)/h),md-2);
        for(int j=0;j<n;j+=h)
        {
            ll wn=1ll,x,y;
            for(int k=j;k<j+(h>>1);k++)
            {
                x=a[k],y=wn*a[k+(h>>1)]%md;
                a[k]=(x+y)%md;a[k+(h>>1)]=(x-y+md)%md;
                wn=wn*w%md;
            }
        }
    }
    for(int inv=qpow(n,md-2),i=0;i<n;i++)
        a[i]=a[i]*(ll)inv%md;
}

inline ll m1(ll w)
{
    ll m1[2][2]=
    {
        {1,w},
        {1,1},
    };
    int sz=2;
    base=matrix(sz,sz);base.e();
    mul=matrix(sz,sz);
    for(int i=0;i<sz;i++)
        for(int j=0;j<sz;j++)
            mul.a[i][j]=m1[i][j];
    base=base*qpow(mul,m);
    ll ret=0;
    for(int i=0;i<sz;i++)
        ret+=base.a[0][i];
    return ret;
}

inline ll m2(ll w)
{
    ll m2[4][4]=
    {
        {1,w,w,w},
        {1,1,w,1},
        {1,w,1,1},
        {1,1,1,1},
    };
    int sz=4;
    base=matrix(sz,sz);base.e();
    mul=matrix(sz,sz);
    for(int i=0;i<sz;i++)
        for(int j=0;j<sz;j++)
            mul.a[i][j]=m2[i][j];
    base=base*qpow(mul,m);
    ll ret=0;
    for(int i=0;i<sz;i++)
        ret+=base.a[0][i];
    return ret;
}

inline ll m3(ll w)
{
    ll m3[9][9]=
    {
        {1,w,w,w,w,w,w*w%md,0,w},
        {1,1,w,w,1,w,w,0,1},
        {1,w,1,w,1,1,w*w%md,0,1},
        {1,w,w,1,w,1,w,0,1},
        {1,1,1,w,1,1,w,0,1},
        {1,w,1,1,1,1,w,0,1},
        {1,1,w,1,1,1,1,0,qpow(w,md-2)},
        {1,1,w,1,1,1,0,1,1},
        {1,1,1,1,1,1,0,1,1}
    };
    int sz=9;
    base=matrix(sz,sz);base.e();
    mul=matrix(sz,sz);
    for(int i=0;i<sz;i++)
        for(int j=0;j<sz;j++)
            mul.a[i][j]=m3[i][j];
    base=base*qpow(mul,m);
    ll ret=0;
    for(int i=0;i<sz;i++)
        ret+=base.a[0][i];
    return ret;
}

int main()
{
    scanf("%d%d%d",&n,&m,&k);
    if(k>(n*m+1)/2)return puts("0"),0;
    for(l=1;l<=((n*m+1)>>1);l<<=1);

    w[0]=1;w[1]=qpow(3,(md-1)/l);
    for(int i=2;i<l;i++)
        w[i]=w[i-1]*w[1]%md;

    for(int i=0;i<l;i++)
        a[i]=(n==1?m1:(n==2?m2:m3))(w[i]);
    ntt(a,l);

    printf("%lld\n",a[k]);
    return 0;
}

[2018.5.22集训]棋盘-矩阵乘法-NTT-插值

原文:https://www.cnblogs.com/zltttt/p/9085983.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!