题目描述
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意: 从点u到点v的路径上的节点包括u和v本身
输入格式:
输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来一行n个整数,第i个整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
输出格式:
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
输入样例:
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
输出样例:
4
1
2
2
10
6
5
6
5
16
说明
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
思路:
树链剖分模板题!!!
代码:
#include<cmath> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> using namespace std; const int N=100010; const int oo=1e9; int n,q,a[4*N]; int tot=0,head[N]; int sm[4*N],mx[4*N]; int tpos[N],pre[N],cnt=0; int size[N],son[N],fa[N],d[N],top[N]; struct no { int u,v,nxt; } tu[N]; void add(int u,int v) { tu[++tot].u=u; tu[tot].v=v; tu[tot].nxt=head[u]; head[u]=tot; tu[++tot].u=v; tu[tot].v=u; tu[tot].nxt=head[v]; head[v]=tot; } void dfs1(int u,int f) { size[u]=1; for (int i=head[u]; i; i=tu[i].nxt) { int v=tu[i].v; if (v==f) continue; d[v]=d[u]+1; fa[v]=u; dfs1(v,u); size[u]+=size[v]; if (size[v]>size[son[u]]) son[u]=v; } } void dfs2(int u,int topp) { tpos[u]=++cnt; pre[cnt]=u; top[u]=topp; if (son[u]) dfs2(son[u],topp); for (int i=head[u]; i; i=tu[i].nxt) { int v=tu[i].v; if (v==fa[u]||v==son[u]) continue; dfs2(v,v); } } void pushup(int rt) { sm[rt]=sm[rt*2]+sm[rt*2+1]; mx[rt]=max(mx[rt*2],mx[rt*2+1]); } void build(int rt,int l,int r) { int mid=(l+r)/2; if (l==r) { sm[rt]=mx[rt]=a[pre[l]]; return; } build(rt*2,l,mid); build(rt*2+1,mid+1,r); pushup(rt); } void update(int rt,int l,int r,int q,int v) { int mid=(l+r)/2; if (l==r) { sm[rt]=mx[rt]=v; return; } if (q<=mid) update(rt*2,l,mid,q,v); else update(rt*2+1,mid+1,r,q,v); pushup(rt); } int query1(int rt,int l,int r,int ql,int qr) { int mid=(l+r)/2,ans=0; if (ql<=l&&r<=qr) return sm[rt]; if (ql<=mid) ans+=query1(rt*2,l,mid,ql,qr); if (qr>mid) ans+=query1(rt*2+1,mid+1,r,ql,qr); pushup(rt); return ans; } int query2(int rt,int l,int r,int ql,int qr) { int mid=(l+r)/2,ans=-oo; if (ql<=l&&r<=qr) return mx[rt]; if (ql<=mid) ans=max(ans,query2(rt*2,l,mid,ql,qr)); if (qr>mid) ans=max(ans,query2(rt*2+1,mid+1,r,ql,qr)); pushup(rt); return ans; } int qs(int u,int v) { int ans=0; while (top[u]!=top[v]) { if (d[top[u]]<d[top[v]]) swap(u,v); ans+=query1(1,1,n,tpos[top[u]],tpos[u]); u=fa[top[u]]; } if (d[u]<d[v]) swap(u,v); ans+=query1(1,1,n,tpos[v],tpos[u]); return ans; } int qm(int u,int v) { int ans=-oo; while (top[u]!=top[v]) { if (d[top[u]]<d[top[v]])swap(u,v); ans=max(ans,query2(1,1,n,tpos[top[u]],tpos[u])); u=fa[top[u]]; } if (d[u]<d[v]) swap(u,v); ans=max(ans,query2(1,1,n,tpos[v],tpos[u])); return ans; } int main() { memset(head,0,sizeof(head)); memset(a,0,sizeof(a)); scanf("%d",&n); for (int i=1; i<n; i++) { int u,v; scanf("%d%d",&u,&v); add(u,v); } for (int i=1; i<=n; i++) scanf("%d",&a[i]); d[1]=1; fa[1]=1; dfs1(1,-1); dfs2(1,1); build(1,1,n); scanf("%d",&q); while (q--) { int x,y; char s[10]; scanf("%s%d%d",s,&x,&y); if (s[1]==‘H‘) update(1,1,n,tpos[x],y); if (s[1]==‘M‘) printf("%d\n",qm(x,y)); if (s[1]==‘S‘) printf("%d\n",qs(x,y)); } return 0; }
原文:https://www.cnblogs.com/mysh/p/11291415.html