我们知道,树上两个点的LCA要么是当前根节点,要么不是。。所以两个点间的最短路径要么经过当前根节点,要么在一棵当前根节点的子树中。。
考虑点分治,于是在原来同一子树中的两个点必然在一次分治中变为路径经过当前根节点的两个点。
处理路径经过当前根节点的两个点的情况。对于当前树,每个节点(根节点除外)记录深度\(dep_i\)(根节点深度为\(0\))和除当前根节点外的最远祖先\(fa_i\)。。
于是有:
\[\sum [fa_i\ne fa_j \land dep_i+dep_j \le K]\]
显然,式子等于:
\[\sum [dep_i+dep_j\le K]-\sum[fa_i=fa_j\land dep_i+dep_j\le K]\]
于是可以这样解决:
在当前树中,将\(dep\)排序,用\(l\)表示左指针,\(r\)表示右指针,\(l\)从左向右遍历。如果\(dep_l+dep_r\le k\),则点对\((l,t)(i<t\le r)\)都符合题意,于是将\(r-l\)加入答案中,并且\(l\)++;否则\(r\)--。
需要注意的是链的情况。。时间复杂度会退化成\(O(N^2)\)。。我们可以将中心作为根,以保证复杂度为\(O(Nlog^2N)\)
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 10005;
struct edge {
int v, l;
edge(int v_, int l_) :v(v_), l(l_) {};
};
vector<edge> g[MAXN];
vector<int> dep;
int n, k, dist[MAXN], vis[MAXN], f[MAXN], root, ans, s[MAXN], tot;
//求当前树的重心
void getroot(int now, int fa) {
/*
fa: now在当前树中的父亲
*/
int u;
s[now] = 1, f[now] = 0;
for (int i = 0; i < g[now].size(); i++) {
u = g[now][i].v;
if (u != fa && !vis[u])
getroot(u, now),
s[now] += s[u],
f[now] = max(f[now], s[u]);
}
f[now] = max(f[now], tot - s[now]);
if (f[now] < f[root]) root = now;
}
//求当前(子)树中点的深度
void getdep(int now, int fa) {
/*
fa: now在当前树中的父亲
dist[now]: now在当前子树中的深度
*/
int u;
dep.push_back(dist[now]),
s[now] = 1;
for (int i = 0; i < g[now].size(); i++) {
u = g[now][i].v;
if (u != fa && !vis[u])
dist[u] = dist[now] + g[now][i].l,
getdep(u, now),
s[now] += s[u];
}
}
//计算当前(子)树中dep[i]+dep[j]<=k的点对个数
int calc(int now, int len) {
/*
len: now在当前树中的深度
*/
dep.clear(),dist[now] = len;
getdep(now, 0),
sort(dep.begin(), dep.end());
int cnt = 0, l = 0, r = dep.size() - 1;
while (l < r)
if (dep[l] + dep[r] <= k) cnt += r - l, l++;
else r--;
return cnt;
}
//计算当前树中满足题目要求的点对个数
void work(int now) {
int u;
ans += calc(now, 0),
vis[now] = true;
for (int i = 0; i < g[now].size(); i++) {
u = g[now][i].v;
if (!vis[u])
ans -= calc(u, g[now][i].l),
f[0] = tot = s[u],
root = 0,
getroot(u, 0),
work(root);
}
}
int main() {
while (~scanf("%d%d",&n,&k)) {
if (!n && !k) break;
for (int i = 0; i <= n; i++) g[i].clear();
memset(vis, 0, sizeof(int)*(n+1));
int u, v, l;
for (int i = 1; i < n; i++)
scanf("%d%d%d",&u,&v,&l),
g[u].push_back(edge(v, l)),g[v].push_back(edge(u, l));
f[0] = n,root = 0,tot = n;
getroot(1, 0),
ans = 0,
work(root),
printf("%d\n",ans);
}
}
原文:https://www.cnblogs.com/QAQAQ/p/10989248.html