两个数列的加和移动其实可以看成是一个数列的加减和移动。(想一想,为什么)
我们设第一个数列加的值为k,
则
\[
\displaystyle \sum _{i=1}^{n}(x_i+k-y_i)^2\\=\displaystyle \sum_{i=1}^n(x_i^2+k^2+y_i^2+2kx_i-2ky_i-2x_iy_i)\\=\displaystyle \sum_{i=1}^nx_i^2+\displaystyle \sum_{i=1}^ny_i^2 +nk^2+2kSum_x-2kSum_y-2\displaystyle \sum_{i=1}^nx_iy_i\\=\displaystyle \sum_{i=1}^nx_i^2+\displaystyle \sum_{i=1}^ny_i^2 +nk^2+2(Sum_x-Sum_y)k-2\displaystyle \sum_{i=1}^nx_iy_i
\]
冷静分析.jpg
首先 \(\displaystyle \sum_{i=1}^nx_i^2+\displaystyle \sum_{i=1}^ny_i^2\)还是不会变的,所以我们先加上。
然后\(nk^2+2(Sum_x-Sum_y)k\)和怎么移动无关,通过二次函数的知识我们知道\(k\displaystyle =\frac{Sum_x-Sum_y}{n}\)时最优,注意这时候k不一定是整数,需要考虑一下。
最后就是\(2\displaystyle \sum_{i=1}^nx_iy_i\),当然我们希望这个值越大越好,并且这次和k无关。到我们发扬人类智慧的时候了,我们将x数组反转,然后再复制一倍,我们就惊奇的发现卷积后的数组从\(n+1\)到\(2n\)的每一个值就是移动相应个数后的\(2\displaystyle \sum_{i=1}^nx_iy_i\),取个max就好啦。
三个加起来就是我们要的答案啦。
时间复杂度O(n log n).
#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
int n,m;
LL ans,suma,sumb,zh_AK=-1e18,k=1e18;
const int N=400010,mod=998244353,G=3,Ginv=(mod+1)/3;
int r[N];
LL a[N],b[N];
inline int read()
{
int res=0;char ch=getchar();bool XX=false;
for(;!isdigit(ch);ch=getchar())(ch=='-')&&(XX=true);
for(; isdigit(ch);ch=getchar())res=(res<<3)+(res<<1)+(ch^48);
return XX?-res:res;
}
LL ksm(LL a,LL b,LL mod)
{
LL res=1;
for(;b;b>>=1,a=a*a%mod)
if(b&1)res=res*a%mod;
return res;
}
void NTT(LL *A,int lim,int opt)
{
for(int i=0;i<lim;++i)
r[i]=(r[i>>1]>>1)|((i&1)?(lim>>1):0);
for(int i=0;i<lim;++i)
if(i<r[i])swap(A[i],A[r[i]]);
int len;
LL wn,w,x,y;
for(int mid=1;mid<lim;mid<<=1)
{
len=mid<<1;
wn=ksm(opt==1?G:Ginv,(mod-1)/len,mod);
for(int j=0;j<lim;j+=len)
{
w=1;
for(int k=j;k<j+mid;++k,w=w*wn%mod)
{
x=A[k];y=A[k+mid]*w%mod;
A[k]=(x+y)%mod;
A[k+mid]=(x-y+mod)%mod;
}
}
}
if(opt==1)return;
int ni=ksm(lim,mod-2,mod);
for(int i=0;i<lim;++i)A[i]=A[i]*ni%mod;
}
void MUL(LL *A,int n,LL *B,int m)
{
int lim=1;
while(lim<=(n+m))lim<<=1;
NTT(A,lim,1);NTT(B,lim,1);
for(int i=0;i<lim;++i)A[i]=A[i]*B[i]%mod;
NTT(A,lim,-1);
}
signed main()
{
cin>>n>>m;
for(int i=1;i<=n;++i)a[i]=read(),suma+=a[i];
for(int i=1;i<=n;++i)b[i]=read(),sumb+=b[i];
for(int i=1;i<=n;++i)ans+=a[i]*a[i]+b[i]*b[i];
for(int i=-201;i<=201;++i)k=min(k,n*i*i+2*(suma-sumb)*i);
ans+=k;
for(int i=1;i<=n/2;++i)swap(a[i],a[n-i+1]);
for(int i=n+1;i<=n+n;++i)a[i]=a[i-n];
MUL(a,n+n,b,n);
for(int i=n+1;i<=n+n;++i)zh_AK=max(zh_AK,a[i]);
cout<<ans-2*zh_AK;
return 0;
}
原文:https://www.cnblogs.com/wljss/p/12006768.html