【题意】
给定一个有边权的树,m次询问,每次给定k个关键点,求所有关键点互不相连的最小代价
【分析】
和CF613D Kingdom and its Cities十分相似,建立好虚树后,设计树形dp
f[i]表示i子树内的关键点都互相断开的答案,那么对于每个节点u:
如果u是关键点,可以考虑断开u的所有存在关键点的儿子
如果u不是关键点,则有两种选择,可以断开u->son的点,也可以断开son子树内的关键点,即f[u]=f[u]+min(e[i].v,f[son])
注意清零的方式即可
【代码】
#include<bits/stdc++.h> using namespace std; #define int long long #define inf 123456789000000000 #define mod 1000000007 const int maxn=3e5+5; int read() { int x = 0, f = 1; char c = getchar(); while(c < ‘0‘ || c > ‘9‘) { if(c == ‘-‘) f = -1; c = getchar();} while(c >= ‘0‘ && c <= ‘9‘) x = x * 10 + c - 48, c = getchar(); return x * f; } struct edge { int v, w, next; }e[maxn << 1]; int n, m, head[maxn], cnt, is[maxn], mi[maxn], dfn[maxn], col, t; int size[maxn], fa[maxn], top[maxn], son[maxn], dep[maxn], s[maxn]; vector<int>v[maxn]; void add(int u, int v, int w) { e[++ cnt] = (edge){v, w, head[u]}; head[u] = cnt; } bool cmp(int a, int b){return dfn[a] < dfn[b];} void dfs1(int u, int fr) { size[u] = 1, fa[u] = fr, dep[u] = dep[fr] + 1; for(int i = head[u]; i; i = e[i].next) { int v = e[i].v; if(v == fr) continue; mi[v] = min(mi[u], e[i].w); dfs1(v, u), size[u] += size[v]; if(size[son[u]] < size[v]) son[u] = v; } } void dfs2(int u, int fr) { top[u] = fr, dfn[u] = ++ col; if(!son[u]) return; dfs2(son[u], fr); for(int i = head[u]; i; i = e[i].next) { int v = e[i].v; if(v != fa[u] && v != son[u]) dfs2(v, v); } } int lca(int a, int b) { while(top[a] != top[b]) dep[top[a]] > dep[top[b]] ? a = fa[top[a]] : b = fa[top[b]]; return dep[a] < dep[b] ? a : b; } void push(int x) { if(t == 1) {s[++ t] = x;return;} int l = lca(x, s[t]); if(l == s[t]) return; while(t > 1 && dfn[s[t - 1]] >= dfn[l]) v[s[t - 1]].push_back(s[t]), --t; if(s[t] != l) v[l].push_back(s[t]), s[t] = l; s[++ t] = x; } int dp(int u) { if(v[u].size() == 0) return mi[u]; int temp = 0; for(int i = 0; i < v[u].size(); ++ i) temp += dp(v[u][i]); v[u].clear(); return min(mi[u], temp); } signed main() { n = read(); for(int i = 1; i < n; ++ i) { int u = read(), v = read(), w = read(); add(u, v, w), add(v, u, w); } mi[1] = inf, dfs1(1, 0), dfs2(1, 1); int T = read(); while(T --) { m = read(); for(int i = 1; i <= m; ++ i) is[i] = read(); sort(is + 1, is + m + 1, cmp); s[t = 1] = 1; for(int i = 1; i <= m; ++ i) push(is[i]); while(t > 0) v[s[t - 1]].push_back(s[t]), --t; printf("%lld\n", dp(1)); } return 0; }
原文:https://www.cnblogs.com/andylnx/p/14792281.html