题意: https://www.lydsy.com/JudgeOnline/problem.php?id=3572
sol:
暴力dp是考虑每个点的贡献,这个题里考虑虚树对dp的优化。
看到nq同级是不是一下子想到虚树.....然而原来那个暴力dp并不行。
会发现建了虚树以后对于一条虚树边有两种情况,一种是两边都被同一个点控制,另一种是被不同点控制,
原因很简单,因为如果假设虚树上两个点中间有询问点没被建出来是不可能的。
所以中间一定不会有询问点。
不同点控制的话中间一定有个点是分界点,倍增出来。
相同点控制直接算贡献。
贡献用虚树和原树上的siz算。
代码稍微有点恶心。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define pii std::pair<int, int>
#define MP std::make_pair
#define fir first
#define sec second
inline int min(int a, int b) {return a > b ? b : a;}
inline int max(int a, int b) {return a > b ? a : b;}
typedef long long LL;
const int N = 3e5+5, INF = 1e9;
int n, Q;
struct Edge {
int v, ne, w;
} e[N << 1];
int cnt, h[N];
inline void ins(int u, int v) {
cnt++; e[cnt].v = v; e[cnt].ne = h[u]; h[u] = cnt;
cnt++; e[cnt].v = u; e[cnt].ne = h[v]; h[v] = cnt;
}
int fa[N][20], deep[N], dfn[N], dfc, size[N], All;
void dfs(int u) {
dfn[u] = ++dfc, size[u] = 1;
for (int i = 1; (1 << i) <= deep[u]; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = h[u]; i; i = e[i].ne)
if (e[i].v != fa[u][0]) {
fa[e[i].v][0] = u;
deep[e[i].v] = deep[u] + 1;
dfs(e[i].v);
size[u] += size[e[i].v];
}
}
inline int lca(int x, int y) {
if (deep[x] < deep[y])
std :: swap(x, y);
int bin = deep[x] - deep[y];
for (int i = 19; i >= 0; i--)
if ((1 << i) & bin)
x = fa[x][i];
for (int i = 19; i >= 0; i--)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return x == y ? x : fa[x][0];
}
int a[N], st[N], par[N], dis[N], t[N], m, ans[N];
int remain[N];
inline bool cmp(int x, int y) {return dfn[x] < dfn[y];}
inline void ins2(int x, int y) {
par[y] = x;
dis[y] = deep[y] - deep[x];
}
pii g[N];
void dp(int m) {
for (int i = m; i > 1; i--) {
int x = t[i], f = par[x];
g[f] = min(g[f], MP(g[x].fir + dis[x], g[x].sec));
}
for (int i = 2; i <= m; i++) {
int x = t[i], f = par[x];
g[x] = min(g[x], MP(g[f].fir + dis[x], g[f].sec));
}
}
inline int jump1(int x, int tar) {
for (int i = 19; i >= 0; i--)
if (deep[fa[x][i]] >= tar)
x = fa[x][i];
return x;
}
inline int jump(int x, int tar) {
int bin = deep[x] - tar;
for (int i = 19; i >= 0; i--)
if ((1 << i) & bin)
x = fa[x][i];
return x;
}
int ora[N];
void solve() {
int n, m = 0;
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &ora[i]);
a[i] = ora[i], t[++m] = a[i], g[a[i]] = MP(0, a[i]);
}
std :: sort(a + 1, a + 1 + n, cmp);
int top = 0;
for (int i = 1; i <= n; i++) {
if (!top) {
st[++top] = a[i];
continue;
}
int x = a[i], f = lca(x, st[top]);
while (dfn[f] < dfn[st[top]]) {
if (dfn[f] >= dfn[st[top - 1]]) {
ins2(f, st[top--]);
if (f != st[top])
st[++top] = f, t[++m] = f, g[f] = MP(INF, 0);
break;
} else
ins2(st[top - 1], st[top]), top--;
}
st[++top] = x;
}
while (top > 1) ins2(st[top - 1], st[top]), top--;
std :: sort(t + 1, t + 1 + m, cmp);
dp(m);
for (int i = 1; i <= m; i++) remain[t[i]] = size[t[i]];
ans[g[t[1]].sec] += All - size[t[1]];
for (int i = 2; i <= m; i++) {
int x = t[i], f = par[x];
par[x] = 0;
int t = jump(x, deep[f] + 1);
remain[f] -= size[t];
if (g[x].sec == g[f].sec)
ans[g[x].sec] += size[t] - size[x];
else {
int len = g[x].fir + g[f].fir + dis[x], mid = deep[x] - (len / 2 - g[x].fir);
if (!(len & 1) && g[f].sec < g[x].sec)
mid++;
int y = jump(x, mid);
ans[g[f].sec] += size[t] - size[y];
ans[g[x].sec] += size[y] - size[x];
}
}
for (int i = 1; i <= m; i++)
ans[g[t[i]].sec] += remain[t[i]];
for (int i = 1; i <= n; i++)
printf("%d%c", ans[ora[i]], i == n ? '\n' : ' '), ans[ora[i]] = 0;
}
int main() {
scanf("%d", &n);
All = n;
int ax, ay;
for (int i = 1; i < n; i++)
scanf("%d%d", &ax, &ay), ins(ax, ay);
dfs(1);
scanf("%d", &Q);
while (Q--) solve();
}
原文:https://www.cnblogs.com/cjc030205/p/11638090.html