一开始想了一个错误的状压dp,水了40分。
这里先记录一下错误的做法:
设\(g[i,j,S]\)从\(i\)到\(j\),只经过集合\(S\)中的点的最短路,这个可以\(O(n^3 2 ^ n)\)处理出来。
设\(f[S]\)表示生成树的集合为\(S\)时的最小代价,每次枚举起点,以及新加入生成树的点,利用\(g\)数组可以\(O(1)\)算出来\(K\)。总复杂度是\(O(n^3 2^n)\)
那么为什么错了呢?因为\(g[i,j,S]\)是一个点集内互达的点的\(\min\),其涉及到的边集远大于生成树的边集,所以相当于利用了很多没有选的边来减少\(K\),所以算出来的\(K\)是错误的。
设\(g[S]\)表示从集合\(S\)内的点出发,只走一条边可到达的点集(包括集合\(S\)内的点)。这个可以\(O(n^2 2^n)\)处理。
设\(f[i,S]\)为生成树最大树高为\(i\),目前生成树集合为\(S\)的最小代价。
我们枚举\(S\)的子集\(S0\),设其在\(S\)为全集时的补集为\(S1\),那么当\(g[S0]|S=g[S0]\)时(即通过\(S0\)的中的点可以到达当前集合\(S\)中全部的点),\(f[i][S]=\min\{f[i-1][S0]+cost \}\)。\(cost\)即为将\(S1\)与\(S0\)合并为\(S\)的最小代价。
这里有一个很重要的结论,在最优解中,拓展的集合\(S1\)在\(S\)中树高均为\(i\)(\(S0\)中最大树高为\(i-1\))。
考虑证明,假设有一个点\(x(x\in S0)\),有边\(d[y][x]>d[z][x],dep_y<dep_z\),满足\(d[y][x]*(dep_y+1)<d[z][x]*(dep_z+1)\)。
那在当\(y\)为生成树集中的最深点时一定已经拓展到了点\(x\),所以一定有另一状态\(S_2\)可以转移到\(S\)使答案最优。
即最优解中的生成树一定可以按深度拓展。
所以转移时的\(cost\)只需枚举\(S_1\)中的点,求出\(S_1\)到\(S_0\)中的点的\(\min_{dis}\)之和,乘上\(i\),就是这次加边的\(cost\)。
初始化\(f[2^k]=0(k\in [0,n-1])\)。答案即为所有集合\(S\)为全集时的\(f\)值的\(\min\)。
不看题解还是不会写...
因为状压枚举子集的复杂度是\(O(3^n)\)。所以总复杂度是\(O(n^2 3 ^n)\)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 13;
const int inf = 0x3f3f3f3f;
int n, m, w[N][N];
int f[N][(1<<N)+10], g[(1<<N)+10];
int main() {
memset(w, 0x3f, sizeof(w));
cin >> n >> m;
for(int u, v, k, i = 1; i <= m; ++i) {
cin >> u >> v >> k;
--u; --v;
w[u][v] = w[v][u] = min(w[u][v], k);
}
for(int S = 1; S < 1 << n; ++S)
for(int i = 0; i < n; ++i)
if((S >> i) & 1) {
w[i][i] = 0; g[S] |= 1 << i;
for(int j = 0; j < n; ++j)
if(w[i][j] != inf) g[S] |= 1 << j;
}
memset(f, 0x3f, sizeof(f));
for(int i = 0; i < n; ++i) f[0][1 << i] = 0;
for(int S = 1; S < 1 << n; ++S) {
for(int S0 = S - 1; S0; S0 = (S0 - 1) & S) {
int S1 = S ^ S0;
int sum = 0;
if((g[S0] | S) == g[S0]) {
for(int i = 0; i < n; ++i) {
if((S1 >> i) & 1) {
int tmp = inf;
for(int j = 0; j < n; ++j) {
if(((S0 >> j) & 1))
tmp = min(tmp, w[i][j]);
}
sum += tmp;
}
}
for(int i = 1; i < n; ++i) { // 树高
if(f[i - 1][S0] != inf)
f[i][S] = min(f[i][S], f[i - 1][S0] + i * sum);
}
}
}
}
int ans = inf;
for(int i = 0; i < n; ++i) ans = min(ans, f[i][(1 << n) - 1]);
printf("%d\n", ans);
}
原文:https://www.cnblogs.com/henry-1202/p/11342790.html