题目链接:https://www.luogu.com.cn/problem/P1352
解题思路:树形DP。
首先,我们可以把单位的层级结构抽象成一个树的问题:
每个人都是树上的一个节点,某一个人的直属领导对应该节点的父节点;某一个人的所有直属下属对应该节点的所有子节点。
现在假设我当前所处的点是点 \(u\),那么我面临两种选择:
所以我可以开一个二维数组 \(f[n][2]\),其中:
为了方便接下来做统计,我们再定义一些变量:
我们可以发现,对于点 \(u\):
当我不选 \(u\) 的时候,它的所有儿子节点可以选,也可以不选,此时
\[f[u][0] = \sum_{v} \max( f[v][0], f[v][1] )\]
其中,\(v\) 代表 \(u\) 的所有子节点。
当我选择 \(u\),它的所有儿子节点都不能选,此时
\[f[u][1] = h[u] + \sum_{v} f[v][0]\]
其中,\(v\) 代表 \(u\) 的所有子节点。
而上述两个公式合到一起就是我们的总的状态转移方程。
我们发现,要求一个节点 \(u\) 对应的 \(f[u][0]\) 和 \(f[u][1]\) ,得先求出它的所有儿子节点的 \(f\) 值,而叶子节点的 \(f\) 值可以直接得出(边界条件,对于叶子节点 \(u\) 来说,\(f[u][0]=0,f[u][1]=h[u]\)),这恰好是一种递归结构,可以用记忆化搜索来实现(事实上,对于树形结构的问题,如树形DP,一般使用记忆化搜索的方式编写代码会比较方便)。
所以我们可以使用记忆化搜索的方式来编写如下代码段:
void dfs(int u, int fa) { // u是当前节点,fa用于指代u的父节点
f[u][0] = 0; // 不选择u的情况下初始为0,不包含h[u]
f[u][1] = h[u]; // 选择u的情况下必然包含h[u],所以一开始就加上
int sz = g[u].size(); // 获得u的儿子的个数
for (int i = 0; i < sz; i ++) { // 遍历u的所有邻接点(其中sz-1个都是u的儿子)
int v = g[u][i];
if (v == fa) continue; // 和u相连的有一个是u的父节点,遇到的时候跳过
dfs(v, u); // 先递归遍历子节点计算出f[v][0]和f[v][1]
f[u][0] += max(f[v][0], f[v][1]); // 再奖答案汇总
f[u][1] += f[v][0];
}
}
上述代码实现了
\[f[u][0] = \sum_{v} \max( f[v][0], f[v][1] )\]
和
\[f[u][1] = h[u] + \sum_{v} f[v][0]\]
并且我们可以发现,上述公式对于叶子节点也成立,因为当 \(u\) 为叶子节点时:
\[f[u][0] = \sum_{v} \max( f[v][0], f[v][1] ) = 0\]
\[f[u][1] = h[u] + \sum_{v} f[v][0] = h[u]\]
完整实现代码如下:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 6060;
int f[maxn][2], h[maxn], p[maxn], n;
vector<int> g[maxn];
void dfs(int u, int fa) { // u是当前节点,fa用于指代u的父节点
f[u][0] = 0;
f[u][1] = h[u];
int sz = g[u].size();
for (int i = 0; i < sz; i ++) {
int v = g[u][i];
if (v == fa) continue;
dfs(v, u);
f[u][0] += max(f[v][0], f[v][1]);
f[u][1] += f[v][0];
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i ++) cin >> h[i];
for (int i = 1; i < n; i ++) {
int a, b;
cin >> a >> b;
p[a] = b;
g[b].push_back(a);
}
int rt;
for (int i = 1; i <= n; i ++) {
if (!p[i]) {
rt = i;
break;
}
}
dfs(rt, -1);
cout << max(f[rt][0], f[rt][1]) << endl;
return 0;
}
注意:我们主函数里的rt变量是为了求得该树的根节点,因为只有根节点是没有直属领导的,所以只有 \(p[rt]= 0\)。
然后找到 rt 之后就可以执行 \(dfs(rt, -1)\) ,这将会求得所有点的 \(f\) 值(第二个参数 \(-1\) 是用来表示 \(rt\) 的父节点,因为 \(rt\) 不存在父节点,所以我们开一个不在 \(1 \sim n\) 范围内的数就可以了,我这里出于习惯用了一个 \(-1\))。
原文:https://www.cnblogs.com/quanjun/p/12236208.html