题目描述
有一个\(n\)个元素的置换,你要选择\(k\)个元素,问有多少种方案满足:对于每个轮换,你都选择了其中的一个元素。
对\(998244353\)取模。
\(k\leq n\leq 152501\)
题解
吐槽
为什么一道FFT题要把\(n\)设为\(150000\)?
解法一
先把轮换拆出来。
直接DP。
设\(f_{i,j}\)为前\(i\)个轮换选择了\(j\)个元素,且每个轮换都选择了至少一个元素的方案数。
\[
f_{i,j}=\sum_{k=1}^{a_i}f_{i-1,j-k}\binom{a_i}{k}
\]
时间复杂度为\(O(n^2)\),因为枚举的是第\(i\)组和前\(i-1\)组的配对,而任意两个元素之间最多被配对一次。
可以分治FFT做到\(O(n\log^2 n)\)
解法二
考虑容斥。
设\(m\)为轮换个数。
枚举有哪些轮换\(S\)中可能有被选中的元素,容斥系数就是\({(-1)}^{m-|S|}\)(\(sum\)为这些轮换的大小总和):
或者枚举哪些轮换\(S\)中没有被选中的元素,容斥系数就是\({(-1)}^{|S|}\):
\[
\begin{align}
s&=\sum_{S}{(-1)}^{m-|S|}\binom{sum}{k}\s&=\sum_{S}{(-1)}^{|S|}\binom{n-sum}{k}\\end{align}
\]
现在我们要对于每一个\(i\),计算\(f_i=\sum_{S,sum=i}{(-1)}^{|S|}\)。
构造生成函数\(A_i(x)=1-x^{a_i}\),那么\(F(x)=\prod_{i=1}^mA_i(x)\)。
直接做还是\(O(n\log^2n)\)的。我们需要一些优化。
\[
\begin{align}
F(x)&=\prod_{i=1}^m1-x^{a_i}\\ln(F(x))&=\sum_{i=1}^n\ln(1-x^{a_i})\\ln(F(x))&=\sum_{i=1}^n\sum_{j=a_i}-\frac{x^{ja_i}}{j}
\end{align}
\]
那么可以在\(O(n\log n)\)内算出\(\ln(F(x))\),然后\(\exp\)一下。
时间复杂度:\(O(n\log n)\)
由于常数过大,所以要用下面那条式子(因为只用计算到\(x^{n-k}\))。
解法一
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
int rd()
{
int s=0,c;
while((c=getchar())<'0'||c>'9');
s=c-'0';
while((c=getchar())>='0'&&c<='9')
s=s*10+c-'0';
return s;
}
const int p=998244353;
const int g=3;
ll fp(ll a,ll b)
{
ll s=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
s=s*a%p;
return s;
}
ll inv[200010];
ll fac[200010];
ll ifac[200010];
int a[200010];
int n,m,k;
int c[200010];
int b[200010];
ll getc(int x,int y)
{
return fac[x]*ifac[y]%p*ifac[x-y]%p;
}
ll *f[500010];
int len[500010];
int cnt;
int a1[600010];
int a2[600010];
int rev[600010];
void ntt(int *a,int n,int t)
{
for(int i=1;i<n;i++)
{
rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
if(i>rev[i])
swap(a[i],a[rev[i]]);
}
for(int i=2;i<=n;i<<=1)
{
int wn=fp(g,(p-1)/i*(t==1?1:i-1));
for(int j=0;j<n;j+=i)
{
int w=1;
for(int k=j;k<j+i/2;k++)
{
int u=a[k];
int v=(ll)a[k+i/2]*w%p;
a[k]=(u+v)%p;
a[k+i/2]=(u-v)%p;
w=(ll)w*wn%p;
}
}
}
if(t==-1)
{
int inv=fp(n,p-2);
for(int i=0;i<n;i++)
a[i]=(ll)a[i]*inv%p;
}
}
void solve(int &now,int l,int r)
{
now=++cnt;
if(l==r)
{
len[now]=min(a[l],k);
f[now]=new ll[len[now]+1];
f[now][0]=0;
for(int i=1;i<=len[now];i++)
f[now][i]=ifac[i]*ifac[a[l]-i]%p;
return;
}
int ls,rs,mid=(l+r)>>1;
solve(ls,l,mid);
solve(rs,mid+1,r);
len[now]=min(len[ls]+len[rs],k);
f[now]=new ll[len[now]+1];
int v=1;
while(v<=len[ls]+len[rs])
v<<=1;
for(int i=0;i<v;i++)
a1[i]=(i<=len[ls]?f[ls][i]:0);
for(int i=0;i<v;i++)
a2[i]=(i<=len[rs]?f[rs][i]:0);
ntt(a1,v,1);
ntt(a2,v,1);
for(int i=0;i<v;i++)
a1[i]=(ll)a1[i]*a2[i]%p;
ntt(a1,v,-1);
for(int i=0;i<=len[now];i++)
f[now][i]=a1[i];
delete [] f[ls];
delete [] f[rs];
}
void solve()
{
// scanf("%d%d",&n,&k);
n=rd();
k=rd();
for(int i=1;i<=n;i++)
c[i]=rd();
// scanf("%d",&c[i]);
if(k==n)
{
printf("1\n");
return;
}
m=0;
cnt=0;
memset(b,0,sizeof b);
memset(a,0,sizeof a);
for(int i=1;i<=n;i++)
if(!b[i])
{
m++;
for(int j=i;!b[j];j=c[j])
{
b[j]=1;
a[m]++;
}
}
if(k<m)
{
printf("0\n");
return;
}
int rt;
solve(rt,1,m);
ll ans=f[rt][k];
ans=ans*fp(getc(n,k),p-2)%p;
for(int i=1;i<=m;i++)
ans=ans*fac[a[i]]%p;
ans=(ans+p)%p;
printf("%lld\n",ans);
}
int main()
{
open("a");
inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
for(int i=2;i<=200000;i++)
{
inv[i]=-p/i*inv[p%i]%p;
fac[i]=fac[i-1]*i%p;
ifac[i]=ifac[i-1]*inv[i]%p;
}
int t;
// scanf("%d",&t);
t=rd();
while(t--)
solve();
return 0;
}
解法二
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
int rd()
{
int s=0,c;
while((c=getchar())<'0'||c>'9');
s=c-'0';
while((c=getchar())>='0'&&c<='9')
s=s*10+c-'0';
return s;
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
const int p=998244353;
const int g=3;
ll fp(ll a,ll b)
{
ll s=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
s=s*a%p;
return s;
}
ll inv[300010];
ll fac[300010];
ll ifac[300010];
namespace ntt
{
int rev[600000];
void ntt(int *a,int n,int t)
{
for(int i=1;i<n;i++)
{
rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
if(i>rev[i])
swap(a[i],a[rev[i]]);
}
for(int i=2;i<=n;i<<=1)
{
int wn=fp(g,(p-1)/i*(t==1?1:i-1));
for(int j=0;j<n;j+=i)
{
int w=1;
for(int k=j;k<j+i/2;k++)
{
int u=a[k];
int v=(ll)a[k+i/2]*w%p;
a[k]=(u+v)%p;
a[k+i/2]=(u-v)%p;
w=(ll)w*wn%p;
}
}
}
if(t==-1)
{
int inv=fp(n,p-2);
for(int i=0;i<n;i++)
a[i]=(ll)a[i]*inv%p;
}
}
void getinv(int *a,int *b,int n)
{
if(n==1)
{
b[0]=fp(a[0],p-2);
return;
}
getinv(a,b,n>>1);
static int a1[600000],a2[600000];
for(int i=0;i<n;i++)
a1[i]=a[i];
for(int i=n;i<n<<1;i++)
a1[i]=0;
for(int i=0;i<n>>1;i++)
a2[i]=b[i];
for(int i=n>>1;i<n<<1;i++)
a2[i]=0;
ntt(a1,n<<1,1);
ntt(a2,n<<1,1);
for(int i=0;i<n<<1;i++)
a1[i]=a2[i]*(2-(ll)a1[i]*a2[i]%p)%p;
ntt(a1,n<<1,-1);
for(int i=0;i<n;i++)
b[i]=a1[i];
}
void getln(int *a,int *b,int n)
{
static int a1[600000],a2[600000];
for(int i=1;i<n;i++)
a1[i-1]=(ll)a[i]*i%p;
a1[n-1]=0;
getinv(a,a2,n);
for(int i=n;i<n<<1;i++)
a1[i]=a2[i]=0;
ntt(a1,n<<1,1);
ntt(a2,n<<1,1);
for(int i=0;i<n<<1;i++)
a1[i]=(ll)a1[i]*a2[i]%p;
ntt(a1,n<<1,-1);
for(int i=1;i<n;i++)
b[i]=(ll)a1[i-1]*inv[i]%p;
b[0]=0;
}
void getexp(int *a,int *b,int n)
{
if(n==1)
{
b[0]=1;
return;
}
getexp(a,b,n>>1);
static int a1[600000],a2[600000],a3[600000];
for(int i=n>>1;i<n;i++)
b[i]=0;
getln(b,a3,n);
for(int i=0;i<n>>1;i++)
{
a1[i]=b[i];
a2[i]=(a[i+(n>>1)]-a3[i+(n>>1)])%p;
}
for(int i=n>>1;i<n;i++)
a1[i]=a2[i]=0;
ntt(a1,n,1);
ntt(a2,n,1);
for(int i=0;i<n;i++)
a1[i]=(ll)a1[i]*a2[i]%p;
ntt(a1,n,-1);
for(int i=0;i<n>>1;i++)
b[i+(n>>1)]=a1[i];
}
}
int a[200010];
int n,m,k;
int c[200010];
int b[200010];
int cnt;
ll ans;
int d[300010];
int s[300010];
int f[300010];
ll getc(int x,int y)
{
if(y>x||y<0)
return 0;
return fac[x]*ifac[y]%p*ifac[x-y]%p;
}
void dfs(int x,int y,int v)
{
if(x>m)
{
ans=(ans+v*getc(y,k))%p;
return;
}
dfs(x+1,y,v);
dfs(x+1,y+a[x],-v);
}
void solve()
{
// scanf("%d%d",&n,&k);
n=rd();
k=rd();
for(int i=1;i<=n;i++)
c[i]=rd();
// scanf("%d",&c[i]);
if(k==n)
{
printf("1\n");
return;
}
m=0;
cnt=0;
memset(b,0,sizeof b);
memset(a,0,sizeof a);
for(int i=1;i<=n;i++)
if(!b[i])
{
m++;
for(int j=i;!b[j];j=c[j])
{
b[j]=1;
a[m]++;
}
}
if(k<m)
{
printf("0\n");
return;
}
memset(d,0,sizeof d);
memset(s,0,sizeof s);
for(int i=1;i<=m;i++)
d[a[i]]++;
for(int i=1;i<=n;i++)
if(d[i])
for(int j=1;i*j<=n;j++)
s[i*j]=(s[i*j]-inv[j]*d[i])%p;
int l=1;
while(l<=n-k)
l<<=1;
s[0]=1;
ntt::getexp(s,f,l);
ans=0;
for(int i=0;i<=n-k;i++)
ans=(ans+f[i]*getc(n-i,k))%p;
// ans=(ans+f[i]*getc(i,k))%p;
ans=ans*fp(getc(n,k),p-2)%p;
// if(m&1)
// ans=-ans;
ans=(ans+p)%p;
printf("%lld\n",ans);
}
int main()
{
open("a");
inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
for(int i=2;i<=300000;i++)
{
inv[i]=-p/i*inv[p%i]%p;
fac[i]=fac[i-1]*i%p;
ifac[i]=ifac[i-1]*inv[i]%p;
}
int t;
// scanf("%d",&t);
t=rd();
while(t--)
solve();
return 0;
}