1 #include<bits/stdc++.h> 2 using namespace std; 3 const int N=2e5+8; 4 int modd,n,m,rt,cnt=0,cnto=0; 5 int son[N],id[N],top[N],f[N],dep[N],siz[N],h[N<<1],w[N],wt[N]; 6 struct edge{ 7 int to,nex; 8 }e[N]; 9 struct tree{ 10 int sum,lazy; 11 }tr[N<<2]; 12 void add(int x,int y){ 13 e[++cnt]=(edge){y,h[x]}; 14 h[x]=cnt; 15 } 16 void dfs1(int u,int fa,int deep){ 17 dep[u]=deep; 18 f[u]=fa; 19 siz[u]=1; 20 int maxson=-1; 21 for(int i=h[u];i;i=e[i].nex){ 22 int v=e[i].to; 23 if(v==fa) continue; 24 dfs1(v,u,deep+1); 25 siz[u]+=siz[v]; 26 if(siz[v]>maxson){ 27 son[u]=v; 28 maxson=siz[v]; 29 } 30 } 31 } 32 void dfs2(int u,int topfa){ 33 id[u]=++cnto; 34 wt[cnto]=w[u]; 35 top[u]=topfa; 36 if(!son[u]) return; 37 dfs2(son[u],topfa); 38 for(int i=h[u];i;i=e[i].nex){ 39 int v=e[i].to; 40 if(v==f[u]||v==son[u]) continue; 41 dfs2(v,v); 42 } 43 } 44 //线段树下 45 void build(int o,int l,int r){ 46 if(l==r){ 47 tr[o].sum=wt[l]; 48 if(tr[o].sum>modd) tr[o].sum%=modd; 49 return; 50 } 51 int m=(l+r)>>1; 52 build(o<<1,l,m); 53 build(o<<1|1,m+1,r); 54 tr[o].sum=(tr[o<<1].sum+tr[o<<1|1].sum)%modd; 55 } 56 void push_down(int o,int len){ 57 tr[o<<1].lazy+=tr[o].lazy; 58 tr[o<<1|1].lazy+=tr[o].lazy; 59 tr[o<<1].sum+=tr[o].lazy*(len+1>>1); 60 tr[o<<1|1].sum+=tr[o].lazy*(len>>1); 61 if(tr[o<<1].sum>=modd) tr[o<<1].sum%=modd; 62 if(tr[o<<1|1].sum>=modd) tr[o<<1|1].sum%=modd; 63 tr[o].lazy=0; 64 } 65 int query(int o,int l,int r,int x,int y){ 66 if(x<=l && r<=y) return tr[o].sum%modd; 67 if(tr[o].lazy) push_down(o,r-l+1); 68 int m=(l+r)>>1; 69 int ans=0; 70 if(x<=m) ans+=query(o<<1,l,m,x,y); 71 if(y>m) ans+=query(o<<1|1,m+1,r,x,y); 72 return ans%modd; 73 } 74 void change(int o,int l,int r,int x,int y,int k){ 75 if(x<=l && r<=y){ 76 tr[o].lazy+=k; 77 tr[o].sum+=k*(r-l+1); 78 return; 79 } 80 int m=(l+r)>>1; 81 if(tr[o].lazy) push_down(o,r-l+1); 82 if(x<=m) change(o<<1,l,m,x,y,k); 83 if(y>m) change(o<<1|1,m+1,r,x,y,k); 84 tr[o].sum=(tr[o<<1].sum+tr[o<<1|1].sum)%modd; 85 } 86 //线段树上 87 //树链下 88 int qSon(int u){ 89 return query(1,1,n,id[u],id[u]+siz[u]-1); 90 } 91 int qRange(int u,int v){ 92 int ans=0; 93 while(top[u]!=top[v]){ 94 if(dep[top[u]]<dep[top[v]]) swap(u,v); 95 ans+=query(1,1,n,id[top[u]],id[u]); 96 ans%=modd; 97 u=f[top[u]]; 98 } 99 if(dep[u]>dep[v]) swap(u,v); 100 ans+=query(1,1,n,id[u],id[v]); 101 return ans%modd; 102 } 103 void updSon(int u,int k){ 104 change(1,1,n,id[u],id[u]+siz[u]-1,k); 105 } 106 void updRange(int u,int v,int k){ 107 k%=modd; 108 while(top[u]!=top[v]){ 109 if(dep[top[u]]<dep[top[v]]) swap(u,v); 110 change(1,1,n,id[top[u]],id[u],k); 111 u=f[top[u]]; 112 } 113 if(dep[u]>dep[v]) swap(u,v); 114 change(1,1,n,id[u],id[v],k); 115 } 116 //树链上 117 int main(){ 118 memset(h,0,sizeof(h)); 119 memset(son,0,sizeof(son)); 120 scanf("%d%d%d%d",&n,&m,&rt,&modd); 121 for(int i=1;i<=n;++i) scanf("%d",&w[i]); 122 for(int i=1;i<n;++i){ 123 int x,y;scanf("%d%d",&x,&y); 124 add(x,y);add(y,x); 125 } 126 dfs1(rt,0,1); 127 dfs2(rt,rt); 128 build(1,1,n); 129 while(m--){ 130 int op,x,y,z; 131 scanf("%d",&op); 132 if(op==1){ 133 scanf("%d%d%d",&x,&y,&z); 134 updRange(x,y,z); 135 } 136 else if(op==2){ 137 scanf("%d%d",&x,&y); 138 printf("%d\n",qRange(x,y)); 139 } 140 else if(op==3){ 141 scanf("%d%d",&x,&y); 142 updSon(x,y); 143 } 144 else if(op==4){ 145 scanf("%d",&x); 146 printf("%d\n",qSon(x)); 147 } 148 } 149 return 0; 150 }
原文:https://www.cnblogs.com/xiaobuxie/p/11373839.html