Link
首先我们有一个静态的dp。
设\(f_{u,0/1}\)表示只考虑\(u\)的子树,\(u\)不选/选的答案。
那么很显然有:
考虑利用重链剖分来进行这个过程,设\(h_u\)表示\(u\)的重儿子,\(g_{u,0/1}\)表示\(f_{u,0/1}\)在不考虑\(h_u\)子树情况下的答案。
那么有:
对于一条重链,实际上我们只关心\(f_{top}\)。
假如我们已经求出了链上所有点的\(g\),那么我们可以做一个序列dp得到\(f_{top}\)。
实际上我们可以把这个序列dp的转移写成矩阵乘法的形式。
定义\(C=AB\)为满足\(C_{i,j}=\max\limits_k(A_{i,k}+B_{k,j})\)的矩阵,那么有:
注意到新定义的矩阵乘法仍然具有结合律,因此我们可以用线段树维护每条重链上的矩阵的区间积。为了方便我们在线段树外同时记录每个点的转移矩阵。
这样做的时间复杂度为\(O(n\log n+q\log^2n)\)。
#include<cctype>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
const int N=100007,inf=1e9;
char ibuf[1<<23|1],*iS=ibuf;
int n,m,val[N],fa[N],size[N],son[N],top[N],ch[N],dfn[N],id[N],f[N][2];
std::vector<int>e[N];
struct matrix{int a[2][2];int*operator[](int x){return a[x];}}t[4*N],a[N];
matrix operator*(matrix a,matrix b)
{
matrix c;
c[0][0]=std::max(a[0][0]+b[0][0],a[0][1]+b[1][0]),c[0][1]=std::max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
c[1][0]=std::max(a[1][0]+b[0][0],a[1][1]+b[1][0]),c[1][1]=std::max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
return c;
}
int read(){int x=0,f=1;while(isspace(*iS))++iS;if(*iS==‘-‘)++iS,f=-1;while(isdigit(*iS))(x*=10)+=*iS++&15;return f*x;}
void dfs1(int u)
{
size[u]=1;
for(int v:e[u]) if(v^fa[u]) if(fa[v]=u,dfs1(v),size[u]+=size[v],size[v]>size[son[u]]) son[u]=v;
}
void dfs2(int u,int tp)
{
static int tim;id[dfn[u]=++tim]=ch[u]=u,top[u]=tp;
if(son[u]) dfs2(son[u],tp),ch[u]=ch[son[u]];
for(int v:e[u]) if(v^fa[u]&&v^son[u]) dfs2(v,v);
}
void dfs3(int u)
{
f[u][1]=val[u];
for(int v:e[u]) if(v^fa[u]) dfs3(v),f[u][0]+=std::max(f[v][0],f[v][1]),f[u][1]+=f[v][0];
}
matrix get(int u)
{
int g0=0,g1=val[u];
for(int v:e[u]) if(v^fa[u]&&v^son[u]) g0+=std::max(f[v][0],f[v][1]),g1+=f[v][0];
return {g0,g0,g1,-inf};
}
#define ls p<<1
#define rs p<<1|1
#define mid ((l+r)/2)
void pushup(int p){t[p]=t[ls]*t[rs];}
void build(int p,int l,int r)
{
if(l==r) return a[l]=t[p]=get(id[l]),void();
build(ls,l,mid),build(rs,mid+1,r),pushup(p);
}
void update(int p,int l,int r,int x)
{
if(l==r) return t[p]=a[l],void();
x<=mid? update(ls,l,mid,x):update(rs,mid+1,r,x),pushup(p);
}
matrix query(int p,int l,int r,int L,int R)
{
if(L<=l&&r<=R) return t[p];
if(R<=mid) return query(ls,l,mid,L,R);
if(L>mid) return query(rs,mid+1,r,L,R);
return query(ls,l,mid,L,R)*query(rs,mid+1,r,L,R);
}
#undef ls
#undef rs
#undef mid
void modify(int u,int w)
{
a[dfn[u]][1][0]+=w-val[u],val[u]=w;
while(u)
{
matrix p=query(1,1,n,dfn[top[u]],dfn[ch[u]]);
update(1,1,n,dfn[u]);
matrix q=query(1,1,n,dfn[top[u]],dfn[ch[u]]);
if(!(u=fa[top[u]]))break;
int x=dfn[u],g0=p[0][0],g1=p[1][0],f0=q[0][0],f1=q[1][0];
a[x][0][0]=a[x][0][1]=a[x][0][0]+std::max(f0,f1)-std::max(g0,g1),a[x][1][0]=a[x][1][0]+f0-g0;
}
}
void work()
{
int u=read(),w=read();modify(u,w);
matrix ans=query(1,1,n,dfn[1],dfn[ch[1]]);
printf("%d\n",std::max(ans[0][0],ans[1][0]));
}
int main()
{
fread(ibuf,1,1<<23,stdin);
n=read(),m=read();
for(int i=1;i<=n;++i) val[i]=read();
for(int i=1,u,v;i<n;++i) u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
dfs1(1),dfs2(1,1),dfs3(1);
build(1,1,n);
for(int i=1;i<=m;++i) work();
}
原文:https://www.cnblogs.com/cjoierShiina-Mashiro/p/12845678.html