题意还是很好懂的,问题也很容易转化为求每个点能到的点的个数之和,最后除以$2$即可
考虑任意一点i能到的点的个数。这些点所组成的点集等于所有包含节点$i$的链的点集的并集。
需要哪些信息才能维护出这个点集?
由于每条链都包含了节点$i$,因此这个点集会组成一个连通块(暂且这么叫吧),这个连通块显然可以通过确定其边界上的点确定下来,那么问题变为如何维护出连通块的大小
这就是在考察$dfs$序的应用了,先以$dfs$序建立线段树。借助于虚树的思想,我们维护出每一个区间内选择了点到$1$号节点的根缀的并集大小,记为$siz$。假设现在我们已知线段树中一个区间的左区间信息与右区间信息,如何得出这个区间的信息?直接把左右区间的$siz$相加显然不对,有重复的部分,重复部分的大小是多少?结合$dfs$序,可以发现重合部分就是左区间选择的点中$dfs$序最大的点$mx$与右区间选择的点中$dfs$序最小的点$mn$的$lca$的深度,即$dep[lca(mx, mn)]$,左右区间的$siz$相加再减去这个就是这个就是这个区间的$siz$,而最后的连通块的大小就是整个序列的$siz$减去$dep[lca(mx, mn)]$,$mx$是整个点集中$dfs$序最大的点,$mn$是$dfs$序最小的点。
为了快速求$lca$可以预处理欧拉序与$RMQ$,在$O(1)$内查询$lca$,信息更新就是$log$的时间复杂度
我们需要一个连通块的边界,这个边界显然是由所有链的边界组成的,因此在树上差分就可维护出边界
这是对于一个点的情况。对于所有的点,因为使用了树上差分,父亲节点需要从儿子节点继承信息,于是就需要线段树合并
总时间复杂度是$O(NlogN)$的
代码:
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 100005; inline int read() { int ret, f=1; char c; while((c=getchar())&&(c<‘0‘||c>‘9‘))if(c==‘-‘)f=-1; ret = c-‘0‘; while((c=getchar())&&(c>=‘0‘&&c<=‘9‘))ret=(ret<<3)+(ret<<1)+c-‘0‘; return ret*f; } int n, m, root[maxn]; ll ans; int head[maxn], tot; struct edge{ int nxt, to; }e[maxn<<1]; void Addedge(int x, int y) { e[++tot] = (edge){head[x], y}; head[x] = tot; } int f[maxn]; int cur, lg[maxn<<1], st[20][maxn<<1], arr[maxn<<1], dfn[maxn], dep[maxn]; void dfs1(int x, int fa) { arr[++cur] = x; dfn[x] = cur; for(int i = head[x]; i; i = e[i].nxt) { int id = e[i].to; if(id == fa) continue; dep[id] = dep[x] + 1; f[id] = x; dfs1(id, x); arr[++cur] = x; } } void RMQ() { lg[0] = -1; for(int i = 1; i <= cur; ++i) { lg[i] = lg[i>>1] + 1; st[0][i] = arr[i]; } for(int j = 1; j <= lg[cur]; ++j) for(int i = 1; i + (1 << j) - 1 <= cur; ++i) { int u = st[j-1][i], v = st[j-1][i + (1 << (j - 1))]; st[j][i] = (dep[u] < dep[v]? u: v); } } int Get_lca(int x, int y) { if(!x || !y) return 0; x = dfn[x]; y = dfn[y]; if(x > y) swap(x, y); int u = st[lg[y-x+1]][x], v = st[lg[y-x+1]][y-(1 << lg[y-x+1])+1]; return dep[u] < dep[v]? u: v; } int ndnum, tim; struct seg_tree{ int ls, rs, siz, mx, mn, num; }tr[maxn*64]; void update(int x) { int lson = tr[x].ls, rson = tr[x].rs; tr[x].siz = tr[lson].siz + tr[rson].siz - dep[Get_lca(tr[lson].mx, tr[rson].mn)]; tr[x].mx = (tr[rson].mx? tr[rson].mx: tr[lson].mx); tr[x].mn = (tr[lson].mn? tr[lson].mn: tr[rson].mn); } void Modify(int x, int L, int R, int p, int w) { if(L == R) { tr[x].num += w; tr[x].siz = (tr[x].num? dep[p]: 0); tr[x].mx = tr[x].mn = (tr[x].num? p: 0); return ; } int mid = (L + R) >> 1; if(dfn[p] <= mid) { if(!tr[x].ls) tr[x].ls = ++ ndnum; Modify(tr[x].ls, L, mid, p, w); } else { if(!tr[x].rs) tr[x].rs = ++ ndnum; Modify(tr[x].rs, mid + 1, R, p, w); } update(x); } int Merg(int x, int y, int L, int R) { if(!x || !y) return x + y; if(L == R) { tr[x].num += tr[y].num; tr[x].siz = (tr[x].num? dep[arr[L]]: 0); tr[x].mx = tr[x].mn = (tr[x].num? arr[L]: 0); return x; } int mid = (L + R) >> 1; tr[x].ls = Merg(tr[x].ls, tr[y].ls, L, mid); tr[x].rs = Merg(tr[x].rs, tr[y].rs, mid + 1, R); update(x); return x; } void dfs2(int x) { for(int i = head[x]; i; i = e[i].nxt) { int id = e[i].to; if(id == f[x]) continue; dfs2(id); root[x] = Merg(root[x], root[id], 1, cur); } ans += (ll)tr[root[x]].siz - dep[Get_lca(tr[root[x]].mx, tr[root[x]].mn)]; } int main() { n = read(); m = read(); for(int i = 1; i < n; ++i) { int u = read(), v = read(); Addedge(u, v); Addedge(v, u); } dfs1(1, 0); RMQ(); for(int i = 1; i <= n; ++i) root[i] = ++ ndnum; while(m --) { int u = read(), v = read(), lca = Get_lca(u, v); Modify(root[u], 1, cur, u, 1); Modify(root[u], 1, cur, v, 1); Modify(root[v], 1, cur, u, 1); Modify(root[v], 1, cur, v, 1); Modify(root[lca], 1, cur, u, -1); Modify(root[lca], 1, cur, v, -1); if(lca != 1) { Modify(root[f[lca]], 1, cur, u, -1); Modify(root[f[lca]], 1, cur, v, -1); } } dfs2(1); printf("%lld\n", ans >> 1); return 0; }
原文:https://www.cnblogs.com/Joker-Yza/p/11664505.html