假设我们现在要多项式除法并且取模,\(FFT\)就会很难受了,因为它用的是复数,并且还有精度差。
这时我们需要一个能替代单位复根的东西:原根。
考虑为什么单位复根能用来做\(FFT\),因为它有很多性质,而我们会发现原根也具有这些性质:
以下的\(n\)都为\(2\)的正整数次幂。
设\(g_n=g^{\frac{p-1}{n}}\),\(p\)是质数且\(n\mid(p-1)\),\(g\)是模\(p\)意义下的原根。
1.\(\omega_n^n=\omega_n^0=1\)
\(g_n^n\equiv g^{\frac{n(p-1)}{n}}\equiv g^{p-1}\equiv 1\pmod{p}\)
2.\(\omega_n^{\frac{n}{2}}=-1\)
\(g_n^{\frac{n}{2}}\equiv g^{\frac{n(p-1)}{2n}}\equiv g^{\frac{p-1}{2}}\pmod{p}\)
又因为\((g^{\frac{p-1}{2}})^2\equiv g^{p-1}\equiv1\pmod{p}\)。
方程\(x^2\equiv 1\pmod{p}\)在\(p\)为质数时只有\(1,-1\)两种取值,而\(g^{\frac{(p-1)}{2}} \not\equiv g^{p-1} \equiv 1\pmod{p}\),因此\(g^{\frac{(p-1)}{2}}\equiv -1\pmod{p}\)。
3.\(\omega_n^k=\omega_{dn}^{dk}\)
$g_{dn}^{dk} = g^{\frac{dk(p-1)}{dn}} = g^{\frac{k(p-1)}{n}} = g_n^k $
于是单位复根有的性质原根都有
于是我们开始魔改\(FFT\)。
首先我们的的模数\(p\)要满足$ a\cdot 2^k +1$的形式,并且这个 \(2\) 的幂要大于 \(n\)。常见的有两种:
1.$1004535809 = 479 \times 2^{21} + 1 $,它的最小正原根是 \(3\)。
2.\(998244353 = 2^{23} \times 7 \times 17 + 1\),最小正原根也是 \(3\)
实现时将\(FFT\)中的单位复根换成原根即可,最后逆变换时要乘逆元。
code:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=4e6+10;
const ll mod=998244353;
const ll g=3;
const ll invg=332748118;
int n,m,lim=1,len;
int pos[maxn];
ll a[maxn],b[maxn];
inline ll power(ll x,ll k)
{
ll res=1;
while(k)
{
if(k&1)res=res*x%mod;
x=x*x%mod;k>>=1;
}
return res;
}
inline void ntt(ll* a,int op)
{
for(int i=0;i<lim;i++)if(i<pos[i])swap(a[i],a[pos[i]]);
for(int mid=1;mid<lim;mid<<=1)
{
ll wn=power((op==1)?g:invg,(mod-1)/(mid<<1));
for(int i=0,l=(mid<<1);i<lim;i+=l)
{
ll w=1;
for(int j=0;j<mid;j++,w=w*wn%mod)
{
ll x=a[i+j]%mod,y=w*a[i+mid+j]%mod;
a[i+j]=(x+y)%mod,a[i+mid+j]=(x-y+mod)%mod;
}
}
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)scanf("%lld",&a[i]),a[i]=(a[i]+mod)%mod;
for(int i=0;i<=m;i++)scanf("%lld",&b[i]),b[i]=(b[i]+mod)%mod;
while(lim<=n+m)lim<<=1,len++;
for(int i=0;i<lim;i++)pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
ntt(a,1);ntt(b,1);
//for(int i=0;i<lim;i++)cerr<<a[i]<<' '<<b[i]<<endl;
for(int i=0;i<lim;i++)a[i]=(a[i]*b[i])%mod;
ntt(a,-1);
ll inv=power(lim,mod-2);
for(int i=0;i<=n+m;i++)printf("%lld ",a[i]*inv%mod);
return 0;
}
原文:https://www.cnblogs.com/nofind/p/12118673.html