论文第二题,看了论文的思路,自己敲出来的
g++ 4.3.2交AC
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 200005;
const int inf = 2*1e9+7;
struct Edge {
int v, w, next;
Edge(int v, int w, int next):
v(v), w(w), next(next){}
Edge(){}
}g[maxn<<1];
int head[maxn], E;
int n, k, m;
bool vis[maxn];
int col[maxn], mx[maxn], sz[maxn];
void init() {
memset(vis, false, sizeof(bool)*(n+1));
memset(col, 0, sizeof(int)*(n+1));
memset(head, -1, sizeof(int)*(n+1));
E = 0;
}
void add(int s, int t, int w) {
g[E] = Edge(t, w, head[s]);
head[s] = E++;
}
void dfsSz(int u, int fa) {
sz[u] = 1, mx[u] = 0;
for(int i = head[u]; ~i; i = g[i].next) {
int v = g[i].v;
if(v == fa || vis[v]) continue;
dfsSz(v, u);
sz[u] += sz[v];
mx[u] = max(mx[u], sz[v]);
}
}
int mi, root;
void dfsRt(int rt, int u, int fa) {
mx[u] = max(mx[u], sz[rt]-sz[u]);
if(mi > mx[u]) mi = mx[u], root = u;
for(int i = head[u]; ~i; i = g[i].next) {
int v = g[i].v;
if(v == fa || vis[v]) continue;
dfsRt(rt, v, u);
}
}
int pre[maxn], tmp[maxn], pt;
int tt[maxn];
int black[maxn];
int lim;
void cal(int u, int fa, int c) {
lim = max(lim, c);
for(int i = head[u]; ~i; i = g[i].next) {
int v = g[i].v;
if(v == fa || vis[v]) continue;
cal(v, u, c+col[v]);
}
}
int ans;
bool cmp(int a, int b) {
return black[g[a].v] < black[g[b].v];
}
void cal(int u, int fa, int c, int w) {
tmp[c] = max(tmp[c], w);
for(int i = head[u]; ~i; i = g[i].next) {
int v = g[i].v;
if(v == fa || vis[v]) continue;
cal(v, u, c+col[v], w + g[i].w);
}
}
void out(int v) {
for(int i = 0; i <= black[v]; i++)
printf("%d ", tmp[i]);
puts("~~~");
}
void show() {
for(int i = 0; i <= pt; i++)
printf("%d ", pre[i]);
puts("!!!!!!");
}
void dfs(int u) {
mi = n;
dfsSz(u, -1);
dfsRt(u, u, -1);
//printf("u = %d root = %d\n", u, root);
pt = 0;
vector <int> vex;
int i, j;
u = root;
for(i = head[u]; ~i; i = g[i].next) {
int v = g[i].v;
if(vis[v]) continue;
vex.push_back(i);
lim = 0;
cal(v, u, col[v]);
black[v] = lim;
}
sort(vex.begin(), vex.end(), cmp);
for(i = 0; i <= sz[u]; i++)
pre[i] = -inf;
pt = 0;
for(i = 0; i < vex.size(); i++) {
int v = g[vex[i]].v, w = g[vex[i]].w;
if(vis[v]) continue;
for(j = 0; j <= black[v]; j++)
tmp[j] = 0;
cal(v, u, col[v], w);
tt[0] = tmp[0];
for(j = 1; j <= black[v]; j++)
tt[j] = max(tt[j-1], tmp[j]);
for(j = 0; j <= pt; j++) {
int res = k-j-col[u];
if(res > black[v]) res = black[v];
if(res < 0) break;
ans = max(ans, pre[j] + tt[res]);
}
for(j = 0; j <= black[v]; j++)
pre[j] = max(pre[j], tmp[j]);
pt = max(pt, black[v]);
}
vis[u] = 1;
for(int i = head[u]; ~i; i = g[i].next) {
int v = g[i].v;
if(vis[v]) continue;
dfs(v);
}
}
int main() {
int i, j;
while(~scanf("%d%d%d", &n, &k, &m)) {
init();
while(m--) {
int x;
scanf("%d", &x);
col[x] = 1;
}
for(i = 1; i < n; i++) {
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
add(x, y, z);
add(y, x, z);
}
ans = 0;
dfs(1);
printf("%d\n", ans);
}
return 0;
}
原文:http://blog.csdn.net/auto_ac/article/details/19759995