[传送门]
很明显,可以转化成求每个点在两棵树中对应的子树中有多少个相同的节点,对答案的贡献就是$C(x, 2)$。关键就是怎么求这个东西。
一是,对第一棵树求出dfs序,然后dfs第二棵树,用树状数组维护节点是否遍历到。对应下标就是第一棵树的dfs序,求每个节点递归其子树前后对应子树的区间和,作个差就是对在两棵树同时出现的节点数了。
#include <bits/stdc++.h> #define ll long long using namespace std; const int N = 1e5 + 7; int n, degree[N], dfn[N], last[N], tol; vector<int> G[N]; ll ans; struct BIT { int tree[N]; inline int lowbit(int x) { return x & -x; } inline void add(int x) { for (int i = x; i <= n; i += lowbit(i)) tree[i]++; } inline int query(int x) { int ans = 0; for (int i = x; i; i -= lowbit(i)) ans += tree[i]; return ans; } } bit; void dfs1(int u) { dfn[u] = ++tol; for (auto v: G[u]) dfs1(v); last[u] = tol; } void dfs2(int u) { bit.add(dfn[u]); int now = bit.query(last[u]) - bit.query(dfn[u] - 1); for (auto v: G[u]) dfs2(v); now = bit.query(last[u]) - bit.query(dfn[u] - 1) - now; ans += 1LL * now * (now - 1) / 2; } int main() { scanf("%d", &n); for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); degree[v]++; } for (int i = 1; i <= n; i++) if (!degree[i]) dfs1(i); for (int i = 1; i <= n; i++) { G[i].clear(); degree[i] = 0; } for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); degree[v]++; } for (int i = 1; i <= n; i++) if (!degree[i]) dfs2(i); printf("%lld\n", ans); return 0; }
第二种做法是主席树,把两棵树的dfs序都求出来,一个节点$u$在第一棵树的子树的区间为$\left[x, y\right]$,在第二棵树的子树的区间为$\left[l, r\right]$,那么就相当于求$\left[l, r\right]$之间有多少个数在$\left[x,y\right]$中。然后主席树维护就行了。那部分感觉有点绕。主席树维护第二颗树dfs序每个位置对应的节点在第一棵树里面的dfs序,然后查询就是查第二棵树dfs序对应的区间中第一棵树的dfs序的区间,相当于就是把$A$用$B$的顺序插入主席树,查询就是查$A$,啊说不清楚,看代码吧。
#include <bits/stdc++.h> #define ll long long using namespace std; const int N = 1e5 + 7; struct Seg { struct Tree { int lp, rp, sum; } tree[N * 30]; int tol; void update(int &p, int q, int l, int r, int pos) { tree[p = ++tol] = tree[q]; tree[p].sum++; if (l == r) return; int mid = l + r >> 1; if (pos <= mid) update(tree[p].lp, tree[q].lp, l, mid, pos); else update(tree[p].rp, tree[q].rp, mid + 1, r, pos); } int query(int p, int q, int l, int r, int x, int y) { if (x <= l && y >= r) return tree[p].sum - tree[q].sum; int mid = l + r >> 1; int ans = 0; if (x <= mid) ans += query(tree[p].lp, tree[q].lp, l, mid, x, y); if (y > mid) ans += query(tree[p].rp, tree[q].rp, mid + 1, r, x, y); return ans; } } seg; int n, degree[N], dfn1[N], last1[N], tol, dfn2[N], last2[N], id[N]; vector<int> G[N]; ll ans; int root[N]; void dfs(int u, int dfn[], int last[]) { dfn[u] = ++tol; for (auto v: G[u]) dfs(v, dfn, last); last[u] = tol; } void build() { for (int i = 1; i <= n; i++) id[dfn2[i]] = i; for (int i = 1; i <= n; i++) seg.update(root[i], root[i - 1], 1, n, dfn1[id[i]]); } int main() { scanf("%d", &n); for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); degree[v]++; } for (int i = 1; i <= n; i++) if (!degree[i]) dfs(i, dfn1, last1); for (int i = 1; i <= n; i++) { G[i].clear(); degree[i] = 0; } tol = 0; for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); degree[v]++; } for (int i = 1; i <= n; i++) if (!degree[i]) dfs(i, dfn2, last2); build(); for (int i = 1; i <= n; i++) { int temp = seg.query(root[last2[i]], root[dfn2[i] - 1], 1, n, dfn1[i], last1[i]) - 1; ans += 1LL * temp * (temp - 1) / 2; } printf("%lld\n", ans); return 0; }
原文:https://www.cnblogs.com/Mrzdtz220/p/11664636.html