本文涉及算法:线段树、dfs。树链剖分是码量十分巨大的数据结构,但十分有用。
一道来源不明的题:
给一棵树,每个结点都有一个点权\(a_i\),求从\(x\)到\(y\)的简单路径上的点权和。
有\(10^5\)次询问。
强行枚举从\(x\)到\(y\)进行求和。
时间复杂度\(O(n^2)\)
那有没有更优秀的算法呢?
对于\(ans_i = ans_{fa_i} + a_i\)
查询时输出\(ans_x + ans_y - ans_{lca(x,y)}\)
时间复杂度\(O(nlog_n)\)
(lca(x,y)为x和y的最近公共祖先)
树上前缀和的复杂度已经十分优秀了,完全可以解决静态树上求和。
又一道来源不明的题:
给一棵树,每个结点都有一个点权\(a_i\),可以进行从\(x\)到\(y\)的简单路径上的修改,询问从\(x\)到\(y\)的简单路径上的权和。
同上。
时间复杂度\(O(n^2)\)
对于每一次点权修改重新修改前缀和。
查询方法同上。
当有了修改操作,树上前缀和的时间复杂度貌似变为:
\(O(n^2)!\)
这个方法几乎等同于暴力。
通过仔细打表发现,每次修改树上前缀和的时间太高了,有没有不浪费的算法又资瓷修改操作的呢?
有!我会线段树!
但是我们发现,线段树不资瓷的是树上修改,那我们有没有方法将树劈成一些链来处理呢?
有,树链剖分!
为了行文方便,我们先在这定义一下一些代名词:
重儿子\(son_i\):他的父节点中子树结点数量最多的儿子。
轻儿子 : 他的父节点中不是重儿子的其他儿子。
重链 : 由重儿子连接而成的链。
轻链 :轻儿子组成连接而成的链。
树链剖分有一个主题思想就是将一棵树变成一堆链来线段树。
对于查询操作
我们先找到这棵树的重儿子,将重儿子连成重链,剩下的全部连成轻链,我们记下所有链的顶端,每次求和完成后跳到他的顶端继续游戏。
(注:轻链的顶端是他自己)
其实这个过程就是倍增LCA的过程,每次跳的是深度深的结点,因为防止跳过头。
我们可以用线段树来维护重链的值。
int fa[200005],son[200005],head[200005],size[200005];
int d[200005],rk[200005];
int top[200005],id[200005],w[200005];
struct E{
int next,to;
} edge[200005];
struct T{
int l,r,w,f;
} a[500000];
\(fa_i\) 第i个结点的父亲
\(son_i\) 第i个结点的重儿子
\(head\)、\(edge\)为前向星用的数组
\(d_i\) 为第i个结点的深度
\(size_i\) 为第i个结点子树结点的数量
\(top_i\) 为链的顶端
\(id_i\) 为第i个结点在线段树中对应的结点标号
\(rk_i\) 为线段树第i个结点中对应现实结点的标号
\(w_i\)为读入的点权
\(a\)为线段树数组
void dfs1(int x)
{
size[x] = 1;//一开始x的子树结点只有自己
d[x] = d[fa[x]] + 1;//深度等于他父亲的深度+1
for (int v,i = head[x]; i; i = edge[i].next)
if ((v = edge[i].to) != fa[x]) //找到的是他的儿子
{
fa[v] = x;//下一个结点的父亲是自己
dfs1(v);
size[x] += size[v];//合并子树
if (size[son[x]] < size[v])
son[x] = v; //如果子树节点数比max大,设为重儿子
}
}
void dfs2(int x,int tp)
{
top[x] = tp;//定义顶端
id[x] = ++sum;
rk[sum] = x;
if (son[x])
dfs2(son[x],tp);//重儿子优先成重链
for (int v,i = head[x]; i; i = edge[i].next)
if ((v = edge[i].to) != fa[x] && v != son[x]) //把轻儿子割出来
dfs2(v,v);//轻儿子的top是自己
}
void build(int x,int l,int r)
{
a[x].l = l;
a[x].r = r;
if (l == r)
{
a[x].w = w[rk[l]];
if (a[x].w > Mod)
a[x].w %= Mod;
return;
}
int mid = (l + r) / 2;
build(x * 2,l,mid);
build(x * 2 + 1,mid + 1,r);
a[x].w = (a[x * 2].w + a[x * 2 + 1].w) % Mod;
}
void down(int x)
{
a[x * 2].f += a[x].f;
a[x * 2 + 1].f += a[x].f;
a[x * 2].w += a[x].f * (a[x * 2].r - a[x * 2].l + 1) % Mod;
a[x * 2 + 1].w += a[x].f * (a[x * 2 + 1].r - a[x * 2 + 1].l + 1) % Mod;
a[x].f = 0;
}
void change_interval(int k)
{
if (a[k].l >= as && a[k].r <= bs)
{
a[k].w += g * (a[k].r - a[k].l + 1) % Mod;
a[k].f += g;
return;
}
if (a[k].f) down(k);
int mid = (a[k].l + a[k].r) / 2;
if (as <= mid)
change_interval(k * 2);
if (mid < bs)
change_interval(k * 2 + 1);
a[k].w = (a[k * 2].w + a[k * 2 + 1].w) % Mod;
}
void ask_interval(int k)
{
if (a[k].l >= as && a[k].r <= bs)
{
ans = (ans + a[k].w) % Mod;
return;
}
if (a[k].f) down(k);
int mid = (a[k].l + a[k].r) / 2;
if (as <= mid)
ask_interval(k * 2);
if (mid < bs)
ask_interval(k * 2 + 1);
}
//不解释。
int Si(int x,int y)//输入从$x$到$y$的简单路径上的值
{
ans = 0;
while (top[x] != top[y])//没有在一起
{
if (d[top[x]] < d[top[y]])//我们保证x深度深
swap(x,y);
as = id[top[x]];
bs = id[x];
ask_interval(1);//左区间为链顶,右区间为链尾
x = fa[top[x]];//继续向上
}
if (id[x] > id[y])
swap(x,y);
as = id[x];
bs = id[y];//同理
ask_interval(1);
return ans % Mod;
}//链询问
int ts(int x,int y)
{
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x,y);
as = id[top[x]];
bs = id[x];
change_interval(1);
x = fa[top[x]];
}
if (id[x] > id[y])
swap(x,y);
as = id[x];
bs = id[y];
//这里同查询
}//链修改
只要将以上代码合并,
我们可以在\(O(nlog_n)\)的复杂度内A掉在引子2出现的例题了。
总代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,fa[200005],son[200005],edgenum,head[200005],size[200005];
int d[200005],rk[200005],g,ans,as,bs,r,Mod,sum;
int top[200005],id[200005],w[200005];
struct E{
int next,to;
} edge[200005];
struct T{
int l,r,w,f;
} a[500000];
void ins(int x,int y)
{
edge[++edgenum].to = y;
edge[edgenum].next = head[x];
head[x] = edgenum;
}
void dfs1(int x)
{
size[x] = 1;
d[x] = d[fa[x]] + 1;
for (int v,i = head[x]; i; i = edge[i].next)
if ((v = edge[i].to) != fa[x])
{
fa[v] = x;
dfs1(v);
size[x] += size[v];
if (size[son[x]] < size[v])
son[x] = v;
}
}
void dfs2(int x,int tp)
{
top[x] = tp;
id[x] = ++sum;
rk[sum] = x;
if (son[x])
dfs2(son[x],tp);
for (int v,i = head[x]; i; i = edge[i].next)
if ((v = edge[i].to) != fa[x] && v != son[x])
dfs2(v,v);
}
void build(int x,int l,int r)
{
a[x].l = l;
a[x].r = r;
if (l == r)
{
a[x].w = w[rk[l]];
if (a[x].w > Mod)
a[x].w %= Mod;
return;
}
int mid = (l + r) / 2;
build(x * 2,l,mid);
build(x * 2 + 1,mid + 1,r);
a[x].w = (a[x * 2].w + a[x * 2 + 1].w) % Mod;
}
void down(int x)
{
a[x * 2].f += a[x].f;
a[x * 2 + 1].f += a[x].f;
a[x * 2].w += a[x].f * (a[x * 2].r - a[x * 2].l + 1) % Mod;
a[x * 2 + 1].w += a[x].f * (a[x * 2 + 1].r - a[x * 2 + 1].l + 1) % Mod;
a[x].f = 0;
}
void change_interval(int k)
{
if (a[k].l >= as && a[k].r <= bs)
{
//cout<<g<<endl;
a[k].w += g * (a[k].r - a[k].l + 1) % Mod;
a[k].f += g;
return;
}
if (a[k].f) down(k);
int mid = (a[k].l + a[k].r) / 2;
if (as <= mid)
change_interval(k * 2);
if (mid < bs)
change_interval(k * 2 + 1);
a[k].w = (a[k * 2].w + a[k * 2 + 1].w) % Mod;
}
void ask_interval(int k)
{
if (a[k].l >= as && a[k].r <= bs)
{
ans = (ans + a[k].w) % Mod;
//cout<<"Test:"<<ans<<endl;
return;
}
if (a[k].f) down(k);
int mid = (a[k].l + a[k].r) / 2;
if (as <= mid)
ask_interval(k * 2);
if (mid < bs)
ask_interval(k * 2 + 1);
}
int Si(int x,int y)
{
ans = 0;
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x,y);
as = id[top[x]];
bs = id[x];
ask_interval(1);
//cout<<ans<<endl;
//ans = ans % Mod;
x = fa[top[x]];
}
if (id[x] > id[y])
swap(x,y);
as = id[x];
bs = id[y];
ask_interval(1);
//cout<<ans<<endl;
return ans % Mod;
}
int ts(int x,int y)
{
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x,y);
as = id[top[x]];
bs = id[x];
change_interval(1);
x = fa[top[x]];
}
if (id[x] > id[y])
swap(x,y);
as = id[x];
bs = id[y];
change_interval(1);
//cout<<g<<endl;
}
signed main(){
scanf("%lld%lld",&n,&m);
Mod = 1000000008;
for (int i = 1; i <= n; i++)
scanf("%lld",&w[i]);
for (int i = 1; i < n; i++)
{
int x,y;
scanf("%lld%lld",&x,&y);
ins(x,y);
ins(y,x);
}
dfs1(1);
dfs2(1,1);
build(1,1,n);
for (int i = 1; i <= m; i++)
{
int op,x,y;
scanf("%lld",&op);
if (op == 1)
{
scanf("%lld%lld%lld",&x,&y,&g);
ts(x,y);
} else
if (op == 2)
{
scanf("%lld%lld",&x,&y);
printf("%lld\n",Si(x,y) % Mod);
}
}
return 0;
}
对于子树操作其实很好求:
将左区间定义为x,右区间定义为x+size[x]-1就ok了
因为这一段在线段树上已经是一段完整的区间了。
上代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,fa[200005],son[200005],edgenum,head[200005],size[200005];
int d[200005],rk[200005],g,ans,as,bs,r,Mod,sum;
int top[200005],id[200005],w[200005];
struct E{
int next,to;
} edge[200005];
struct T{
int l,r,w,f;
} a[500000];
void ins(int x,int y)
{
edge[++edgenum].to = y;
edge[edgenum].next = head[x];
head[x] = edgenum;
}
void dfs1(int x)
{
size[x] = 1;
d[x] = d[fa[x]] + 1;
for (int v,i = head[x]; i; i = edge[i].next)
if ((v = edge[i].to) != fa[x])
{
fa[v] = x;
dfs1(v);
size[x] += size[v];
if (size[son[x]] < size[v])
son[x] = v;
}
}
void dfs2(int x,int tp)
{
top[x] = tp;
id[x] = ++sum;
rk[sum] = x;
if (son[x])
dfs2(son[x],tp);
for (int v,i = head[x]; i; i = edge[i].next)
if ((v = edge[i].to) != fa[x] && v != son[x])
dfs2(v,v);
}
void build(int x,int l,int r)
{
a[x].l = l;
a[x].r = r;
if (l == r)
{
a[x].w = w[rk[l]];
if (a[x].w > Mod)
a[x].w %= Mod;
return;
}
int mid = (l + r) / 2;
build(x * 2,l,mid);
build(x * 2 + 1,mid + 1,r);
a[x].w = (a[x * 2].w + a[x * 2 + 1].w) % Mod;
}
void down(int x)
{
a[x * 2].f += a[x].f;
a[x * 2 + 1].f += a[x].f;
a[x * 2].w += a[x].f * (a[x * 2].r - a[x * 2].l + 1) % Mod;
a[x * 2 + 1].w += a[x].f * (a[x * 2 + 1].r - a[x * 2 + 1].l + 1) % Mod;
a[x].f = 0;
}
void change_interval(int k)
{
if (a[k].l >= as && a[k].r <= bs)
{
a[k].w += g * (a[k].r - a[k].l + 1) % Mod;
a[k].f += g;
return;
}
if (a[k].f) down(k);
int mid = (a[k].l + a[k].r) / 2;
if (as <= mid)
change_interval(k * 2);
if (mid < bs)
change_interval(k * 2 + 1);
a[k].w = (a[k * 2].w + a[k * 2 + 1].w) % Mod;
}
void ask_interval(int k)
{
if (a[k].l >= as && a[k].r <= bs)
{
ans = (ans + a[k].w) % Mod;
return;
}
if (a[k].f) down(k);
int mid = (a[k].l + a[k].r) / 2;
if (as <= mid)
ask_interval(k * 2);
if (mid < bs)
ask_interval(k * 2 + 1);
}
int Si(int x,int y)
{
ans = 0;
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x,y);
as = id[top[x]];
bs = id[x];
ask_interval(1);
x = fa[top[x]];
}
if (id[x] > id[y])
swap(x,y);
as = id[x];
bs = id[y];
ask_interval(1);
return ans % Mod;
}
int ts(int x,int y)
{
while (top[x] != top[y])
{
if (d[top[x]] < d[top[y]])
swap(x,y);
as = id[top[x]];
bs = id[x];
change_interval(1);
x = fa[top[x]];
}
if (id[x] > id[y])
swap(x,y);
as = id[x];
bs = id[y];
change_interval(1);
}
signed main(){
scanf("%lld%lld%lld%lld",&n,&m,&r,&Mod);
for (int i = 1; i <= n; i++)
scanf("%lld",&w[i]);
for (int i = 1; i < n; i++)
{
int x,y;
scanf("%lld%lld",&x,&y);
ins(x,y);
ins(y,x);
}
dfs1(r);
dfs2(r,r);
build(1,1,n);
for (int i = 1; i <= m; i++)
{
int op,x,y;
scanf("%lld",&op);
if (op == 1)
{
scanf("%lld%lld%lld",&x,&y,&g);
ts(x,y);
} else
if (op == 2)
{
scanf("%lld%lld",&x,&y);
printf("%lld\n",Si(x,y) % Mod);
} else
if (op == 3)
{
scanf("%lld%lld",&x,&y);
as = id[x];
bs = id[x] + size[x] - 1;
g = y;
change_interval(1);
} else
if (op == 4)
{
ans = 0;
scanf("%lld",&x);
as = id[x];
bs = id[x] + size[x] - 1;
ask_interval(1);
printf("%lld\n",ans % Mod);
}
}
return 0;
}
链查询,子树修改。
在查询完毕后别忘了修改链。
代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
int head[100005],rk[100005],top[100005],id[100005],size[100005];
int son[100005],n,m,edgenum,d[100005],fa[100005],sum,ans;
int as,bs,insert[100005],cnt,anss,g;
struct E{
int next,to;
} edge[300005];
struct T{
int l,r,w,f;
} a[400005];
void ins(int x,int y)
{
edge[++edgenum].to = y;
edge[edgenum].next = head[x];
head[x] = edgenum;
}
void dfs1(int x)
{
size[x] = 1;
d[x] = d[fa[x]] + 1;
for (int v,i = head[x]; i; i = edge[i].next)
if ((v = edge[i].to) != fa[x])
{
fa[v] = x;
dfs1(v);
size[x] += size[v];
if (size[son[x]] < size[v] || !son[x])
son[x] = v;
}
}
void dfs2(int x,int tp)
{
top[x] = tp;
id[x] = ++sum;
rk[sum] = x;
if (son[x])
dfs2(son[x],tp);
for (int v,i = head[x]; i; i = edge[i].next)
if ((v = edge[i].to) != fa[x] && v != son[x])
dfs2(v,v);
}
void build(int x,int l,int r)
{
a[x].l = l;
a[x].r = r;
if (l == r)
{
a[x].f = -1;
return;
}
int mid = (l + r) / 2;
build(x * 2,l,mid);
build(x * 2 + 1,mid + 1,r);
}
void down(int x)
{
a[x * 2].f = a[x].f;
a[x * 2 + 1].f = a[x].f;
a[x * 2].w = a[x].f * (a[x * 2].r - a[x * 2].l + 1);
a[x * 2 + 1].w = a[x].f * (a[x * 2 + 1].r - a[x * 2 + 1].l + 1);
a[x].f = -1;
}
void change_interval(int k)
{
if (a[k].l >= as && a[k].r <= bs)
{
a[k].f = g;
a[k].w = (a[k].r - a[k].l + 1) * g;
return;
}
int mid = (a[k].l + a[k].r) / 2;
if (a[k].f != -1) down(k);
if (as <= mid)
change_interval(k * 2);
if (bs > mid)
change_interval(k * 2 + 1);
a[k].w = a[k * 2].w + a[k * 2 + 1].w;
}
void ask_interval(int k)
{
if (a[k].l >= as && a[k].r <= bs)
{
anss += a[k].w;
return;
}
int mid = (a[k].l + a[k].r) / 2;
if (a[k].f != -1) down(k);
if (as <= mid)
ask_interval(k * 2);
if (bs > mid)
ask_interval(k * 2 + 1);
}
int Si(int x)
{
int ans = 0;
int fs = top[x];
while (fs)
{
anss = 0;
as = id[fs];
bs = id[x];
ask_interval(1);
ans += id[x] - id[fs] - anss + 1;
change_interval(1);
x = fa[fs];
fs = top[x];
}
as = id[0];
bs = id[x];
anss = 0;
ask_interval(1);
ans += id[x] - id[0] - anss + 1;
change_interval(1);
return ans;
}
signed main(){
scanf("%lld",&n);
for (int i = 1; i < n; i++)
{
int x;
scanf("%lld",&x);
ins(x,i);
ins(i,x);
}
dfs1(0);
dfs2(0,0);
build(1,1,n);
scanf("%lld",&m);
for (int i = 1; i <= m; i++)
{
char s[10];
int x;
scanf("%s%lld",s,&x);
anss = 0;
cnt = 0;
if (s[0] == ‘i‘)
{
g = 1;
printf("%lld\n",Si(x));
} else
{
as = id[x];
bs = id[x] + size[x] - 1;
ask_interval(1);
g = 0;
change_interval(1);
printf("%lld\n",anss);
}
}
return 0;
}
[SDOI2011]染色 DP+树剖
原文:https://www.cnblogs.com/taoyc/p/10158612.html