我们知道,树上两个点的LCA要么是当前根节点,要么不是。。所以两个点间的最短路径要么经过当前根节点,要么在一棵当前根节点的子树中。。
考虑点分治,于是在原来同一子树中的两个点必然在一次分治中变为路径经过当前根节点的两个点。
点分治标准开头(雾
对于路径经过当前根节点的点。从当前根到点\(i\)的路径上经过的拥挤点数为\(num_i\),路径长度为\(d_i\)。求这两个值简单\(dfs\)即可。
合并我们采用启发式合并以保证复杂度。首先,我们按当前根子树的最大深度升序将子树排列。
当我们处理到当前根的第\(T\)个子树时,记录:以当前根为一端点,前\((T-1)\)棵当前根的子树中,经过\(s(0\le s\le K)\)个拥挤点的最长路径长度为\(maxv_s\)。
于是,当我们处理到第\(T\)棵子树的节点\(u\)时:
\[ans=max(ans,(maxv_s+d_u)[num_u+s\le K])\]
合并时也要处理\(maxv\)。将所有节点按\(num\)倒序处理,然后按用\(maxv_{i-1}\)更新\(maxv_i\),以找到上文中的\(maxv_s\)。然后用子树\(T\)来更新\(maxv\)。求出子树\(T\)中的每组相等的\(num_i\)中最大的\(d_i\),记为\(dist_{num_i}\),然后用\(dist\)更新\(maxv\)。
最后,显然,用树的重心作为树根为最优。
#include <bits/stdc++.h>
#define next ___________________________________________________________________
using namespace std;
typedef pair<int,int> pii;
const int MAXN = 400005;
int n, k, m, flag[MAXN], head[MAXN], to[MAXN], next[MAXN], w[MAXN], tot = 1, dist[MAXN], tmp, f[MAXN], root, size[MAXN], sum, vis[MAXN], ans, d[MAXN], num[MAXN], maxn, maxv[MAXN];
vector<pii>e;
void add(int x,int y,int z){w[tot] = z,to[tot] = y,next[tot] = head[x],head[x] = tot++;}
void getroot(int now,int fa){
size[now] = 1,f[now] = 0;
for (int i = head[now]; i; i = next[i]){
int v = to[i];
if (v == fa || vis[v]) continue;
getroot(v,now),
size[now] += size[v],
f[now] = max(f[now],size[v]);
}
f[now] = max(f[now],sum - size[now]);
if (f[now] < f[root]) root = now;
}
void getdep(int now,int fa){
maxn = max(maxn,num[now]);
for (int i = head[now]; i; i = next[i]){
int v = to[i];
if (v == fa || vis[v]) continue;
d[v] = d[now] + w[i],num[v] = num[now] + flag[v],
getdep(v,now);
}
}
void getmax(int now,int fa){
dist[num[now]] = max(dist[num[now]],d[now]);
for (int i = head[now]; i; i = next[i]){
int v = to[i];
if (v == fa || vis[v]) continue;
getmax(v,now);
}
}
void solve(int now){
vis[now] = 1,e.clear();
if (flag[now]) k--;
for (int i = head[now]; i; i = next[i]){
int v = to[i];
if (vis[v]) continue;
num[v] = flag[v],d[v] = w[i],maxn = 0,
getdep(v,now),
e.push_back(make_pair(maxn,v));
}
sort(e.begin(),e.end());
for (int i = 0; i < e.size(); i++){
getmax(e[i].second,now);
int res = 0;
if (i != 0)
for (int j = e[i].first; j >= 0; j--){
while (res + 1 + j <= k && res + 1 <= e[i - 1].first)
res++,
maxv[res] = max(maxv[res],maxv[res - 1]);
if (res + j <= k) ans = max(ans,maxv[res] + dist[j]);
}
if (i != e.size() - 1)
for (int j = 0; j <= e[i].first; j++)
maxv[j] = max(maxv[j],dist[j]),
dist[j] = 0;
else
for (int j = 0; j <= e[i].first; j++) maxv[j] = dist[j] = 0;
}
if (flag[now]) k++;
for (int i = head[now]; i; i = next[i]){
int v = to[i];
if (vis[v]) continue;
root = 0,f[0] = sum = size[v];
getroot(v,0),solve(root);
}
}
int main(){
scanf("%d%d%d",&n,&k,&m);
for (int i = 1; i <= m; i++){
int x;
scanf("%d",&x),
flag[x] = 1;
}
for (int i = 1,u , v, w; i < n; i++){
scanf("%d%d%d",&u,&v,&w);
if (i == 1) tmp = w;
add(u,v,w);
add(v,u,w);
}
if (n == 2 && k >= m) return printf("%d\n", tmp),0;
f[0] = sum = n,
getroot(1,0),solve(root),
printf("%d\n",ans);
return 0;
}
原文:https://www.cnblogs.com/QAQAQ/p/10989313.html