比较好想的一道题,只是那个组合数比较恶心。
先说一下我最开始想的$n^4$的沙雕dp:
设f[i][j][k]为前i天给了j个,第i天给了k个,则f[i][j][k]=∑f[i-1][j-k][o];
复杂度凑起来大概是$n^4$,因为本来就是针对30%打的,没有考虑特别大的d。
观察上面的式子,发现第三维并没有什么卵用,把它干掉,f[i][j]表示前i天给j个,那么f[i][j]=∑f[i-1][k](j-m+1<=k<=j),复杂度$n^2d$,显然可以用前缀和优化,复杂度$nd$。
但是d太大还是会死,有位巨佬用矩阵快速幂优化拿到了90分%%%,貌似矩阵有什么规律(然而我并不会)。
题目中d很大,但是n却很小,所以实际有用的只有n天,所以递推到第n天后乘以$C_d^n$,这样对吗?根据以上状态定义,n天也不一定全部都给饼干,所以乘以$C_d^n$会算重,这个时候状态就要稍微修改一下,f[i][j]表示前i天共给了j个,且每天给的饼干数>0,这样就解决了重复的问题,那么最后答案就是$∑f[i][n]*C_d^i$.
然而他又爆了longlong,__int128拯救世界系列。
1 #include<iostream> 2 #include<cstring> 3 #include<cstdio> 4 #define int LL 5 #define LL __int128 6 #define mod 998244353 7 #define ma(x) memset(x,0,sizeof(x)) 8 #define min(a,b) ((a)<(b)?(a):(b)) 9 #define max(a,b) ((a)>(b)?(a):(b)) 10 using namespace std; 11 int n,m; 12 LL d; 13 inline LL read() 14 { 15 LL s=0;char a=getchar(); 16 while(a<‘0‘||a>‘9‘)a=getchar(); 17 while(a>=‘0‘&&a<=‘9‘){s=s*10+a-‘0‘;a=getchar();} 18 return s; 19 } 20 LL f2[2010][2010]; 21 LL poww(LL a,int b) 22 { 23 LL ans=1; 24 while(b) 25 { 26 if(b&1)ans=(ans*a)%mod; 27 a=(a*a)%mod; 28 b=b>>1; 29 } 30 return ans; 31 } 32 LL C(LL d,LL n) 33 { 34 LL jn=1,js=1; 35 for(int i=1;i<=n;i++)jn=(jn*i)%mod; 36 for(int i=d-n+1;i<=d;i++)js=(js*i)%mod; 37 return js*poww(jn,mod-2)%mod; 38 } 39 signed main() 40 { 41 // freopen("in.txt","r",stdin); 42 // freopen("0.out","w",stdout); 43 44 while(scanf("%lld%lld%lld",&n,&d,&m)) 45 { 46 if(!n&&!d&&!m)return 0; 47 ma(f2); 48 if(1ll*d*(m-1)<n){puts("0");continue;} 49 if(d<=n) 50 { 51 for(int i=0;i<m;i++)f2[1][i]=1; 52 for(int i=2;i<=d;i++) 53 { 54 LL sum=0; 55 for(int j=0;j<=n;j++) 56 { 57 sum=(sum+f2[i-1][j]); 58 if(j-m>=0)sum=((sum-f2[i-1][j-m])%mod+mod)%mod; 59 f2[i][j]=(f2[i][j]+sum)%mod; 60 } 61 } 62 LL ans=f2[d][n]; 63 printf("%lld\n",ans%mod); 64 } 65 else 66 { 67 for(int i=1;i<m;i++)f2[1][i]=1; 68 for(int i=2;i<=n;i++) 69 { 70 LL sum=0; 71 for(int j=max(0,i-m);j<i;j++)sum=(sum+f2[i-1][j])%mod; 72 for(int j=i;j<=n&&j<=i*(m-1);j++) 73 { 74 if(j-m>=0)sum=((sum-f2[i-1][j-m])%mod+mod)%mod; 75 f2[i][j]=(f2[i][j]+sum)%mod; 76 sum=(sum+f2[i-1][j]); 77 } 78 } 79 LL ans=0; 80 for(int i=1;i<=n;i++) 81 ans=(ans+f2[i][n]*C(d,i))%mod; 82 printf("%lld\n",ans%mod); 83 } 84 } 85 }
原文:https://www.cnblogs.com/Al-Ca/p/11219272.html