首页 > 其他 > 详细

洛谷P5050 【模板】多项式多点求值

时间:2019-03-29 12:50:05      阅读:128      评论:0      收藏:0      [点我收藏+]

https://www.luogu.org/problemnew/show/P5050

给定多项式A(x),求$A(x_l)$,$A(x_{l+1})$,..,$A(x_r)$

分治:(如果r-l+1=1,直接O(deg(A))暴力求出即可)

首先设$mid=\lfloor\frac{l+r}{2}\rfloor$,$P^{[0]}(x)=\prod_{i=l}^{mid}(x-x_i)$,$P^{[1]}(x)=\prod_{i=mid+1}^{r}(x-x_i)$

以[l,mid]的求值为例:设$A^{[0]}(x)=A(x)\,mod\,P^{[0]}(x)$

即$A(x)=P^{[0]}(x)B^{[0]}(x)+A^{[0]}(x)$($B^{[0]}$为某个多项式)

可以发现,将$x_l$,$x_{l+1}$,..,$x_{mid}$带入$P^{[0]}(x)$,值都为0

因此对于$l<=i<=mid$,$A(x_i)=A^{[0]}(x_i)$,递归下去算就行;[mid+1,r]的求值同理

这个P可以在分治过程中处理出来

时间复杂度大概是$O(n\,log^2\,n)$(未区分n=r-l+1,m=deg(A))

版本1:基于版本1,加了小范围暴力,预处理了P方便快速插值

技术分享图片
  1 #prag  2 ma GCC optimize(2)
  3 #include<cstdio>
  4 #include<algorithm>
  5 #include<cstring>
  6 #include<vector>
  7 #include<cmath>
  8 using namespace std;
  9 #define fi first
 10 #define se second
 11 #define mp make_pair
 12 #define pb push_back
 13 typedef long long ll;
 14 typedef unsigned long long ull;
 15 const int md=998244353;
 16 const int N=131072;
 17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md))
 18 inline int del(int a,int b)
 19 {
 20     a-=b;
 21     return a<0?a+md:a;
 22 }
 23 int rev[N];
 24 void init(int len)
 25 {
 26     int bit=0,i;
 27     while((1<<(bit+1))<=len)    ++bit;
 28     for(i=1;i<len;++i)
 29         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
 30 }
 31 ull poww(ull a,ull b)
 32 {
 33     ull ans=1;
 34     for(;b;b>>=1,a=a*a%md)
 35         if(b&1)
 36             ans=ans*a%md;
 37     return ans;
 38 }
 39 int inv[300011];
 40 void dft(int *a,int len,int idx)//要求len为2的幂
 41 {
 42     int i,j,k,t1,t2;ull wn,wnk;
 43     for(i=0;i<len;++i)
 44         if(i<rev[i])
 45             swap(a[i],a[rev[i]]);
 46     for(i=1;i<len;i<<=1)
 47     {
 48         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
 49         for(j=0;j<len;j+=(i<<1))
 50         {
 51             wnk=1;
 52             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
 53             {
 54                 t1=a[k];t2=a[k+i]*wnk%md;
 55                 a[k]+=t2;
 56                 (a[k]>=md)&&(a[k]-=md);
 57                 a[k+i]=t1-t2;
 58                 (a[k+i]<0)&&(a[k+i]+=md);
 59             }
 60         }
 61     }
 62     if(idx==-1)
 63     {
 64         ull ilen=inv[len];
 65         for(i=0;i<len;++i)
 66             a[i]=a[i]*ilen%md;
 67     }
 68 }
 69 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素);要求len是2的幂
 70 {
 71     static int t1[N],t2[N];
 72     g[0]=poww(f[0],md-2);
 73     for(int i=2,j;i<=len;i<<=1)
 74     {
 75         memcpy(t1,f,sizeof(int)*i);
 76         memcpy(t2,g,sizeof(int)*(i>>1));
 77         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
 78         init(i);
 79         dft(t1,i,1);dft(t2,i,1);
 80         for(j=0;j<i;++j)
 81             t1[j]=ull(t1[j])*t2[j]%md;
 82         dft(t1,i,-1);
 83         for(j=0;j<(i>>1);++j)
 84             t1[j]=t1[j+(i>>1)];
 85         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
 86         dft(t1,i,1);
 87         for(j=0;j<i;++j)
 88             t1[j]=ull(t1[j])*t2[j]%md;
 89         dft(t1,i,-1);
 90         for(j=i>>1;j<i;++j)
 91             g[j]=md-t1[j-(i>>1)];
 92     }
 93 }
 94 inline void p_de(int *f,int len)//derivative求导;f=f‘
 95 {
 96     for(int i=0;i<len-1;++i)
 97         f[i]=ull(i+1)*f[i+1]%md;
 98     f[len-1]=0;
 99 }
100 inline void p_in(int *f,int len)//integral积分;f=?f
101 {
102     for(int i=len-1;i>=1;--i)
103         f[i]=ull(f[i-1])*inv[i]%md;
104     f[0]=0;
105 }
106 void p_ln(int *f,int len)//要求len为2的幂,f[0]=1
107 {
108     static int t3[N];
109     p_inv(f,t3,len);p_de(f,len);
110     init(len<<1);
111     dft(f,len<<1,1);dft(t3,len<<1,1);
112     for(int i=0;i<(len<<1);++i)
113         f[i]=ull(f[i])*t3[i]%md;
114     dft(f,len<<1,-1);p_in(f,len);
115 }
116 void p_exp(int *f,int *g,int len)//要求len为2的幂,f[0]=0
117 {
118     static int t1[N],t2[N];
119     g[0]=1;
120     for(int i=2,j;i<=len;i<<=1)
121     {
122         memcpy(t1,g,sizeof(int)*(i>>1));
123         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
124         p_ln(t1,i);
125         for(j=0;j<(i>>1);++j)
126             t1[j]=del(f[j+(i>>1)],t1[j+(i>>1)]);
127         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
128         init(i);
129         dft(t1,i,1);
130         memcpy(t2,g,sizeof(int)*(i>>1));
131         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
132         dft(t2,i,1);
133         for(j=0;j<i;++j)
134             t1[j]=ull(t1[j])*t2[j]%md;
135         dft(t1,i,-1);
136         for(j=i>>1;j<i;++j)
137             g[j]=t1[j-(i>>1)];
138     }
139 }
140 void p_div(int *a,int *b,int *c,int n,int m)//c=a/b;deg(a)=n,deg(b)=m,deg(c)=n-m;a,b无前导0;n>=m
141 {
142     reverse(a,a+n+1);reverse(b,b+m+1);
143     int x=n-m+1,t=1;
144     for(;t<x;t<<=1);
145     memset(b+m+1,0,sizeof(int)*max(t-m-1,0));
146     p_inv(b,c,t);
147     memset(c+x,0,sizeof(int)*((t<<1)-x));
148     memset(a+x,0,sizeof(int)*((t<<1)-x));
149     init(t<<1);
150     dft(a,t<<1,1);dft(c,t<<1,1);
151     for(int i=0;i<(t<<1);++i)
152         c[i]=ull(c[i])*a[i]%md;
153     dft(c,t<<1,-1);
154     memset(c+(n-m+1),0,sizeof(int)*((t<<1)-n+m-1));
155     reverse(c,c+x);
156 }
157 void p_divmod(int *a,int *b,int *c,int *d,int n,int m)//c=a/b,d=a%b,deg(d)=(<=)m-1;其余同上
158 {
159     static int t1[N];
160     memcpy(d,a,sizeof(int)*(m+1));
161     int x=n+1,t=1;
162     for(;t<x;t<<=1);
163     memcpy(t1,b,sizeof(int)*(m+1));
164     memset(t1+m+1,0,sizeof(int)*max(t-m-1,0));
165     p_div(a,b,c,n,m);
166     memcpy(a,c,sizeof(int)*(n-m+1));
167     memset(a+n-m+1,0,sizeof(int)*(t-n+m-1));
168     init(t);
169     dft(a,t,1);dft(t1,t,1);
170     for(int i=0;i<t;++i)
171         t1[i]=ull(t1[i])*a[i]%md;
172     dft(t1,t,-1);
173     for(int i=0;i<=m;++i)
174         delto(d[i],t1[i]);
175 }
176 namespace P_me
177 {
178     int *ta[N];//用线段树的方法给递归的每一层一个编号,ta[i]表示编号为i的层的P函数的各项系数
179     int data[N*40],*tp;//内存池
180     int *a,*x,*y;
181 #define LC (u<<1)
182 #define RC (u<<1|1)
183     int mt1[N];
184     const int T=200;//小范围暴力阀值
185     void _p_me1(int l,int r,int u)//计算(x-x_l)(x-x_{l+1})..(x-x_r)并存下来
186     {
187         if(r-l<=T)
188         {
189             int i,j;
190             tp[0]=1;
191             for(i=l;i<=r;++i)
192             {
193                 tp[i-l+1]=tp[i-l];
194                 for(j=i-l;j>=1;--j)
195                 {
196                     tp[j]=(ull(tp[j])*(md-x[i])+tp[j-1])%md;
197                 }
198                 tp[0]=ull(tp[0])*(md-x[i])%md;
199             }
200             ta[u]=tp;tp+=r-l+2;
201             return;
202         }
203         int mid=(l+r)>>1;
204         _p_me1(l,mid,LC);_p_me1(mid+1,r,RC);
205         int x=r-l+2,t=1;//x=(mid-l+1)+(r-mid)+1
206         for(;t<x;t<<=1);
207         memcpy(mt1,ta[LC],sizeof(int)*(mid-l+2));
208         memset(mt1+mid-l+2,0,sizeof(int)*(t-mid+l-2));
209         memcpy(tp,ta[RC],sizeof(int)*(r-mid+1));
210         memset(tp+r-mid+1,0,sizeof(int)*(t-r+mid-1));
211         init(t);
212         dft(mt1,t,1);dft(tp,t,1);
213         for(int i=0;i<t;++i)
214             tp[i]=ull(tp[i])*mt1[i]%md;
215         dft(tp,t,-1);
216         ta[u]=tp;tp+=r-l+2;
217     }
218     int mt2[N],mt3[N];
219     void _p_me2(int *a,int n,int l,int r,int u)//a是A的系数,deg(A)<=n;求A(x_l)到A(x_r),放入y_l到y_r
220     {
221         if(r-l<=T)
222         {
223             int t,i,j;
224             for(i=l;i<=r;++i)
225             {
226                 t=a[n];
227                 for(j=n-1;j>=0;--j)
228                     t=(ull(t)*x[i]+a[j])%md;
229                 y[i]=t;
230             }
231             return;
232         }
233         int x=(n+1)<<1,t=1;
234         for(;t<x;t<<=1);
235         int mt4[t];//根据需要改成new?
236         int mid=(l+r)>>1,n1;
237         memcpy(mt1,a,sizeof(int)*(n+1));
238         for(n1=n;n1>=0 && mt1[n1]==0;)    --n1;
239         if(n1<0)
240         {
241             memset(y+l,0,sizeof(int)*(r-l+1));
242             return;
243         }
244         memcpy(mt2,ta[LC],sizeof(int)*(mid-l+2));
245         if(n1<mid-l+1)
246         {
247             memcpy(mt4,mt1,sizeof(int)*(n1+1));
248             _p_me2(mt4,n1,l,mid,LC);
249         }
250         else
251         {
252             p_divmod(mt1,mt2,mt3,mt4,n1,mid-l+1);
253             _p_me2(mt4,mid-l,l,mid,LC);    
254         }
255         memcpy(mt1,a,sizeof(int)*(n+1));
256         for(n1=n;n1>=0 && mt1[n1]==0;)    --n1;
257         memcpy(mt2,ta[RC],sizeof(int)*(r-mid+1));
258         if(n1<r-mid)
259         {
260             memcpy(mt4,mt1,sizeof(int)*(n1+1));
261             _p_me2(mt4,n1,mid+1,r,RC);
262         }
263         else
264         {
265             p_divmod(mt1,mt2,mt3,mt4,n1,r-mid);
266             _p_me2(mt4,r-mid-1,mid+1,r,RC);
267         }
268     }
269     void p_multieval(int *a0,int *x0,int *y0,int n,int m)//deg(a)=n,x有m个数
270     {
271         tp=data;
272         a=a0;x=x0;y=y0;
273         _p_me1(0,m-1,1);
274         _p_me2(a,n,0,m-1,1);
275     }
276 }
277 using P_me::p_multieval;
278 int a[N],x[N],y[N];
279 int n,m;
280 int main()
281 {
282     int i;
283     inv[1]=1;
284     for(i=2;i<=300000;++i)
285         inv[i]=ull(md-md/i)*inv[md%i]%md;
286     //n=100000;m=100000;
287     scanf("%d%d",&n,&m);
288     for(i=0;i<=n;++i)
289         //a[i]=rand()%md;
290         scanf("%d",a+i);
291     for(i=0;i<m;++i)
292         //x[i]=rand()%md;
293         scanf("%d",x+i);
294     p_multieval(a,x,y,n,m);
295     for(i=0;i<m;++i)
296         printf("%d\n",y[i]);
297     return 0;
298 }
View Code

 

洛谷P5050 【模板】多项式多点求值

原文:https://www.cnblogs.com/hehe54321/p/10616071.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!