无语。。开读入优化结果读入炸了。。。
看起来是时候换成快读了...
本题也是一个树链剖分模版,维护树上的单点修改和区间最值,区间和的查询。
注意有负值,在初始化比较数的时候要设成-inf。
树链剖分实际上就是一遍dfs跑出所有点的深度,父节点,子树size,重链子节点这些基本信息
然后再跑一遍dfs跑出所有点所在链的top节点,并且跑出树剖特有的的类似dfs序一样的东西存起来。我的dfsx存的是线段树里的lo/ro对应的原树的点编号,all[now].id存的是该原来树上的节点的dfs序号。
然后就可以build线段树了,用跑出来的dfs序作为区间,这样每一条链在线段树上成为一个连续的区间。然后可以维护区间信息。
require的时候可以类似lca的方式在原来的树上不断找链的top节点,注意这个时候用的是原树上的编号,然后不断地加上线段树查询的这条链的对应值即可。update的时候直接找线段树上all[now].id的位置并且更新即可。
(代码又臭又长
#include <bits/stdc++.h>
using namespace std;
#define ios ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define endl ‘\n‘
#define debugg(x) cout<<#x<<‘=‘<<x<<endl;
#define debug1(x,y,z) cout<<#x<<‘ ‘<<x<<‘ ‘<<#y<<‘ ‘<<y<<‘ ‘<<#z<<‘ ‘<<z<<endl;
#define debug cout<<endl<<"********"<<endl;
#define ll long long
#define int ll
#define ull unisgned long long
#define ld long double
#define itn int
#define pii pair<int,int>
#define rep(I, A, B) for (int I = (A); I <= (B); ++I)
#define dwn(I, A, B) for (int I = (A); I >= (B); --I)
#define mod (ll)(1e9+7)
//#define mid ((lo+ro)>>1)
void fre(){
freopen("test.in","r",stdin);
freopen("test.out","w",stdout);
}
void fc(){fclose(stdin);fclose(stdout);}
const int maxn=1e5+10;
const ll inf=0x7fffffff;
int n,q,tot;
struct node{
vector<int> son;int top;int fat;int dep;int siz;int zson;int id;
}all[maxn];
int dfsx[maxn];
struct xnode{
int lo;int ro;int sum;int zhi;
}tree[maxn*4];
int zhi[maxn];
void add(int u,int v){all[u].son.push_back(v);}
void dfs1(int now){
all[now].siz=1;
for(int i=0;i<(int)all[now].son.size();i++){
int v=all[now].son[i];
if(all[v].dep) continue;
all[v].dep=all[now].dep+1;
all[v].fat=now;
dfs1(v);
all[now].siz+=all[v].siz;
if(all[v].siz>=all[all[now].zson].siz) all[now].zson=v;
}
}
void dfs2(int now,int fat){
all[now].id=++tot;
dfsx[tot]=now;
all[now].top=fat;
if(all[now].zson) dfs2(all[now].zson,fat);
for(int i=0;i<(int)all[now].son.size();i++){
int v=all[now].son[i];
if(v==all[now].fat||v==all[now].zson) continue;
dfs2(v,v);
}
}
void pushup(int now){
tree[now].sum=tree[now*2].sum+tree[now*2+1].sum;
tree[now].zhi=max(tree[now*2].zhi,tree[now*2+1].zhi);
}
void build(int now,int lo,int ro){
tree[now].lo=lo;tree[now].ro=ro;
if(lo==ro){tree[now].zhi=zhi[dfsx[tree[now].lo]];tree[now].sum=tree[now].zhi;return;}
int mid=(lo+ro)/2;
build(now*2,lo,mid);
build(now*2+1,mid+1,ro);
pushup(now);
}
void upd(int now,int pos,int zhi){
if(tree[now].lo==tree[now].ro){tree[now].zhi=zhi;tree[now].sum=zhi;return;}
if(tree[now].lo>pos||tree[now].ro<pos) return;
int mid=(tree[now].lo+tree[now].ro)/2;
if(pos<=mid) upd(now*2,pos,zhi);
else upd(now*2+1,pos,zhi);
pushup(now);
}
ll reqs(int now,int lo,int ro){
if(tree[now].lo>ro||tree[now].ro<lo) return -inf;
if(tree[now].lo>=lo&&tree[now].ro<=ro) return tree[now].zhi;
int mid=(tree[now].lo+tree[now].ro)/2;
ll ans=-inf;
if(mid>=lo) ans=max(ans,reqs(now*2,lo,ro));
if(mid<ro) ans=max(ans,reqs(now*2+1,lo,ro));
pushup(now);
return ans;
}
ll reqmax(int now1,int now2){
ll res=-inf;
int fat1=all[now1].top,fat2=all[now2].top;
while(fat1!=fat2){
if(all[fat1].dep<all[fat2].dep) swap(fat1,fat2),swap(now1,now2);
res=max(res,reqs(1,all[fat1].id,all[now1].id));
now1=all[fat1].fat;fat1=all[now1].top;
}
if(all[now1].dep<all[now2].dep) swap(now1,now2);
res=max(res,reqs(1,all[now2].id,all[now1].id));
return res;
}
ll reqh(int now,int lo,int ro){
if(tree[now].lo>=lo&&tree[now].ro<=ro) return tree[now].sum;
if(tree[now].lo>ro||tree[now].ro<lo) return 0;
int mid=(tree[now].lo+tree[now].ro)/2;
ll ans=0;
if(mid>=lo) ans+=reqh(now*2,lo,ro);
if(mid<ro) ans+=reqh(now*2+1,lo,ro);
pushup(now);
return ans;
}
ll reqsum(int now1,int now2){
ll res=0;
int fat1=all[now1].top,fat2=all[now2].top;
while(fat1!=fat2){
if(all[fat1].dep<all[fat2].dep) swap(now1,now2),swap(fat1,fat2);
res+=reqh(1,all[fat1].id,all[now1].id);
now1=all[fat1].fat;
fat1=all[now1].top;
}
if(all[now1].dep<all[now2].dep) swap(now1,now2);
res+=reqh(1,all[now2].id,all[now1].id);
return res;
}
signed main(){
cin>>n;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;add(u,v);add(v,u);
}
for(int i=1;i<=n;i++) cin>>zhi[i];
all[1].dep=1;
dfs1(1);
dfs2(1,1);
build(1,1,n);
cin>>q;
while(q--){
int u,v;char s[10];
cin>>s>>u>>v;
if(s[1]==‘H‘) upd(1,all[u].id,v);
else if(s[1]==‘M‘) cout<<(int)reqmax(u,v)<<endl;
else if(s[1]==‘S‘) cout<<(int)reqsum(u,v)<<endl;
}
return 0;
}
原文:https://www.cnblogs.com/14long-Alex/p/14608205.html