首页 > 其他 > 详细

P3384 【模板】树链剖分

时间:2019-02-20 00:15:11      阅读:192      评论:0      收藏:0      [点我收藏+]

注意

自己理解着打了一遍,注意两个点跳跃的时候,是比较两个点对应链顶的深度

技术分享图片
#include <algorithm>
#include  <iterator>
#include  <iostream>
#include   <cstring>
#include   <cstdlib>
#include   <iomanip>
#include    <bitset>
#include    <cctype>
#include    <cstdio>
#include    <string>
#include    <vector>
#include     <stack>
#include     <cmath>
#include     <queue>
#include      <list>
#include       <map>
#include       <set>
#include   <cassert>

/*
        
⊂_ヽ
  \\ Λ_Λ  来了老弟
   \(‘?‘)
    > ⌒ヽ
   /   へ\
   /  / \\
   ? ノ   ヽ_つ
  / /
  / /|
 ( (ヽ
 | |、\
 | 丿 \ ⌒)
 | |  ) /
‘ノ )  L?

*/

using namespace std;
#define lson (l , mid , rt << 1)
#define rson (mid + 1 , r , rt << 1 | 1)
#define debug(x) cerr << #x << " = " << x << "\n";
#define pb push_back
#define pq priority_queue



typedef long long ll;
typedef unsigned long long ull;
//typedef __int128 bll;
typedef pair<ll ,ll > pll;
typedef pair<int ,int > pii;
typedef pair<int,pii> p3;

//priority_queue<int> q;//这是一个大根堆q
//priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q
#define fi first
#define se second
//#define endl ‘\n‘

#define boost ios::sync_with_stdio(false);cin.tie(0)
#define rep(a, b, c) for(int a = (b); a <= (c); ++ a)
#define max3(a,b,c) max(max(a,b), c);
#define min3(a,b,c) min(min(a,b), c);


const ll oo = 1ll<<17;
const ll mos = 0x7FFFFFFF;  //2147483647
const ll nmos = 0x80000000;  //-2147483648
const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3f; //18
const ll mod = 2147483648;
const double esp = 1e-8;
const double PI=acos(-1.0);
const double PHI=0.61803399;    //黄金分割点
const double tPHI=0.38196601;


template<typename T>
inline T read(T&x){
    x=0;int f=0;char ch=getchar();
    while (ch<0||ch>9) f|=(ch==-),ch=getchar();
    while (ch>=0&&ch<=9) x=x*10+ch-0,ch=getchar();
    return x=f?-x:x;
}

inline void cmax(int &x,int y){if(x<y)x=y;}
inline void cmax(ll &x,ll y){if(x<y)x=y;}
inline void cmin(int &x,int y){if(x>y)x=y;}
inline void cmin(ll &x,ll y){if(x>y)x=y;}

/*-----------------------showtime----------------------*/
            const int maxn = 1e6+9;
            int n,m,R,P;
            int a[maxn];
            vector<int>mp[maxn];

            int dp[maxn],fa[maxn],son[maxn],sz[maxn];

            void dfs1(int u,int f, int deep){
                int mx = 0;
                dp[u] = deep;
                sz[u] = 1;
                fa[u] = f;

                for(int i=0; i<mp[u].size(); i++){
                    int v = mp[u][i];
                    if(f == v) continue;
                    dfs1(v, u, deep + 1);
                    sz[u] += sz[v];
                    if(sz[v] > mx) son[u] = v, mx = sz[v];
                }
            }
            int b[maxn];
            int id[maxn],top[maxn];
            int cnt = 0;
            void dfs2(int u,int f,int topf){
                id[u] = ++cnt;
                top[u] = topf; 
                b[cnt] = a[u];
                if(son[u]) dfs2(son[u], u, topf);

                for(int i=0; i<mp[u].size(); i++){
                    int v = mp[u][i];
                    if(f == v || son[u] == v) continue;
                    dfs2(v, u, v);
                }
            }
            int sum[maxn<<2],lazy[maxn<<2];
            void build(int l,int r,int rt){
                if(l == r){
                    sum[rt] = b[l];
                    return;
                }
                int mid = (l + r) >> 1;
                build(l, mid, rt<<1);
                build(mid+1,r,rt<<1|1);
                sum[rt] = (sum[rt<<1] + sum[rt<<1|1]) % P;
            }
            void pushdown(int l,int r,int rt){
                int mid = (l + r) >> 1;
                lazy[rt<<1] = (lazy[rt<<1] + lazy[rt])%P;
                lazy[rt<<1|1] = (lazy[rt<<1|1] + lazy[rt])%P;
                sum[rt<<1] = (sum[rt<<1]  + (mid - l+1) * lazy[rt])%P;
                sum[rt<<1|1] = (sum[rt<<1|1] + (r - mid)*lazy[rt]) %P;
                lazy[rt] = 0;
            }
            void update(int L,int R, int val, int l,int r,int rt){
                if(l >= L && r<=R){
                    lazy[rt] =(lazy[rt] + val)%P;
                    sum[rt] = (sum[rt] + (r - l + 1) * val)%P;
                    return;  
                }
                int mid = (l + r)>>1;
                if(lazy[rt]) pushdown(l,r,rt);
                if(mid >= L) update(L,R,val, l,mid,rt<<1);
                if(mid < R) update(L,R,val, mid+1,r,rt<<1|1);
                sum[rt] = (sum[rt<<1] + sum[rt<<1|1])%P;
            }
            int query(int L,int R,int l, int r,int rt){
                if(l>=L && r<=R){
                    return sum[rt];
                }
                int res = 0;
                int mid = (l + r)>>1;
                if(lazy[rt]) pushdown(l, r, rt);
                if(mid >= L) res = (res + query(L,R,l,mid,rt<<1)) %P;
                if(mid < R) res = (res + query(L,R,mid+1,r,rt<<1|1)) %P;
                sum[rt] = (sum[rt<<1] + sum[rt<<1|1])%P;
                return res;
            }
            void add1(int x,int y,int z){
                while(top[x] != top[y]){
                    if(dp[top[x]] < dp[top[y]]) swap(x, y);
                    update(id[top[x]], id[x], z, 1, n, 1);
                    x = fa[top[x]];
                }
                if(dp[x] > dp[y]) swap(x, y);
                update(id[x], id[y], z, 1, n, 1);
            }
            int getsum1(int x,int y) {
                int res = 0;
                while(top[x] != top[y]){
                    if(dp[top[x]] < dp[top[y]]) swap(x, y);
                    res = (res + query(id[top[x]], id[x], 1, n, 1))%P;
                    x = fa[top[x]];
                }
                if(dp[x] > dp[y]) swap(x, y);
                res = (res + query(id[x], id[y], 1, n, 1))%P;
                return res;
            }
            void add2(int x,int z){
                update(id[x], id[x] + sz[x] - 1, z, 1, n, 1);
            }
            int getsum2(int x){
                return query(id[x], id[x] + sz[x] - 1, 1, n, 1);
            }
int main(){
            scanf("%d%d%d%d", &n, &m, &R, &P);
            rep(i, 1, n) scanf("%d", &a[i]);
            rep(i, 1, n-1){
                int u,v;
                scanf("%d%d", &u, &v);
                mp[u].pb(v); mp[v].pb(u);
            }
            dfs1(R, R, 1);
            dfs2(R, R, R);
            build(1, n, 1);

            while(m --){
                int op; scanf("%d", &op);
                if(op == 1) {
                    int x, y, z;
                    scanf("%d%d%d", &x, &y, &z);
                    add1(x, y, z);
                }
                else if(op == 2) {
                    int x, y;
                    scanf("%d%d", &x, &y);
                    printf("%d\n", getsum1(x, y));
                }
                else if(op == 3) {
                    int x,z;
                    scanf("%d%d", &x, &z);
                    add2(x, z);
                }
                else if(op == 4){
                    int x;
                    scanf("%d", &x);
                    printf("%d\n", getsum2(x));
                }
            }
            return 0;
}
View Code

 

P3384 【模板】树链剖分

原文:https://www.cnblogs.com/ckxkexing/p/10404100.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!