求组合数的时候,如果模数p是质数,可以用卢卡斯定理解决。
但是卢卡斯定理仅仅适用于p是质数的情况。
当p不是质数的时候,我们就需要用扩展卢卡斯求解。
实际上,扩展卢卡斯=快速幂+快速乘+exgcd求逆元+质因数分解+crt合并答案+求阶乘,跟卢卡斯定理没什么关系......
如果把模数p分解成p1^k1*p2^k2*...*px^kx的形式,那么我们可以求出c(n,m)分别模每个pi^ki的结果,再用中国剩余定理合并即可。
每个pi^ki一定是互质的,所以用朴素crt就行。
根据组合数的定义,c(n,m)=(n!) / (m!*(n-m)!) ,所以我们只要能想办法求出阶乘,就能再利用exgcd求出逆元,进而求出组合数。
接下来唯一的问题就是怎么快速求出 x! 取模 pi^ki 的结果。
考虑如下的经典样例(据说来自popoqqq):(19!)%(3^2)
19!=1*2*3*4*5*6*7*8*9*10*11*12*13*14*15*16*17*18*19
先把其中的3的倍数提出来,因为求组合数的时候分子分母能约掉。
19!=(1*2*4*5*7*8)*(10*11*13*14*16*17)*(19)*(3*6*9*12*15*18)=(1*2*4*5*7*8)*(1*2*4*5*7*8)*(3*3*3*3*3*3)*(1*2*3*4*5*6)=(1*2*4*5*7*8)^2*19*(3^6)*(1*2*3*4*5*6)。
后面的6!部分可以递归求解,递归终点为0!=1。
3^6最后计算组合数的时候再处理。
那几个(1*2*4*5*7*8)显然是循环的,循环节长度小于pi^ki,可以暴力计算。
显然一共有(x/(pi^ki))个循环节,套个快速幂即可。
剩下的部分,即19,长度等于x%(pi^ki),也小于pi^ki,也可以暴力计算。
至此我们求出了阶乘。
求组合数的时候,考虑pi的倍数的影响。
分子分母分别计数相加减。
最后用crt合并即可。
1 #include<cstdio> 2 typedef long long ll; 3 4 ll n,m,p; 5 6 ll ksm(ll b,ll tp,ll mod) 7 { 8 ll ret=1; 9 while(tp) 10 { 11 if(tp&1)ret=ret*b%mod; 12 b=b*b%mod; 13 tp>>=1; 14 } 15 return ret; 16 } 17 18 ll mul(ll a,ll b,ll mod) 19 { 20 ll ret=0; 21 while(b) 22 { 23 if(b&1)ret=(ret+a)%mod; 24 a=(a+a)%mod; 25 b>>=1; 26 } 27 return ret; 28 } 29 30 ll exgcd(ll a,ll b,ll &x,ll &y) 31 { 32 if(!b) 33 { 34 x=1;y=0; 35 return a; 36 } 37 ll t=exgcd(b,a%b,y,x); 38 y-=a/b*x; 39 } 40 41 ll inv(ll x,ll mod) 42 { 43 ll a,b; 44 exgcd(x,mod,a,b); 45 return (a%mod+mod)%mod; 46 } 47 48 ll fac(ll x,ll pi,ll pk) 49 { 50 if(!x)return 1; 51 ll ans=1; 52 for(ll i=2;i<=pk;i++) 53 if(i%pi)ans=ans*i%pk; 54 ans=ksm(ans,x/pk,pk); 55 for(ll i=2;i<=x%pk;i++) 56 if(i%pi)ans=ans*i%pk; 57 return ans*fac(x/pi,pi,pk)%pk; 58 } 59 60 ll c(ll cn,ll cm,ll pi,ll pk) 61 { 62 if(cm>cn)return 0; 63 ll up=fac(cn,pi,pk),d1=fac(cm,pi,pk),d2=fac(cn-cm,pi,pk); 64 ll cnt=0; 65 for(ll i=cn;i;i/=pi)cnt+=i/pi; 66 for(ll i=cm;i;i/=pi)cnt-=i/pi; 67 for(ll i=cn-cm;i;i/=pi)cnt-=i/pi; 68 return up*inv(d1,pk)%pk*inv(d2,pk)%pk*ksm(pi,cnt,pk)%pk; 69 } 70 71 ll crt(ll a,ll pk) 72 { 73 return a*inv(p/pk,pk)%p*(p/pk)%p; 74 } 75 76 int main() 77 { 78 scanf("%lld%lld%lld",&n,&m,&p); 79 ll tp=p,ans=0; 80 for(ll i=2;i*i<=p;i++) 81 { 82 if(tp%i)continue; 83 ll pk=1; 84 while(!(tp%i))tp/=i,pk*=i; 85 ans=(ans+crt(c(n,m,i,pk),pk))%p; 86 } 87 if(tp>1)ans=(ans+crt(c(n,m,tp,tp),tp))%p; 88 printf("%lld",(ans%p+p)%p); 89 return 0; 90 }
原文:https://www.cnblogs.com/eternhope/p/9898494.html