简化题意:
一棵 \(n\) 个节点的树,边有边权。
每个点可能是关键点,每次操作改变一个点是否是关键点。
求所有关键点形成的极小联通子树的边权和的两倍。
结论:\(DFS\) 序求出后,假设关键点按照 \(DFS\) 序排序后是 \(\{a1,a2,…,ak\}\)。
那么所有关键点形成的极小联通子树的边权和的两倍等于\(dist(a1,a2)+dist(a2,a3)+?+dist(ak?1,ak)+dist(ak,a1)\)。
画个图感性理解一下,应该是很好懂的。
那么求一下 \(DFS\) 序,每次操作相当于往集合里加入/删除一个元素。
假设插入 \(x\),它 \(DFS\) 序左右两边分别是 \(y\) 和 \(z\)。那么答案加上 \(dist(x,y)+dist(x,z)?dist(y,z)\) 即可。
删除同理。
还有,用倍增求 \(LCA\) 进而求树上两点间距离,用 \(set\) 容器维护,就不多说了。
#include<bits/stdc++.h>
typedef long long ll;
const int maxn=1e5+10;
namespace IO
{
char buf[1<<15],*fs,*ft;
inline char getc() { return (ft==fs&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),ft==fs))?0:*fs++; }
template<typename T>inline void read(T &x)
{
x=0;
T f=1, ch=getchar();
while (!isdigit(ch) && ch^'-') ch=getchar();
if (ch=='-') f=-1, ch=getchar();
while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48), ch=getchar();
x*=f;
}
char Out[1<<24],*fe=Out;
inline void flush() { fwrite(Out,1,fe-Out,stdout); fe=Out; }
template<typename T>inline void write(T x)
{
if (!x) *fe++=48;
if (x<0) *fe++='-', x=-x;
T num=0, ch[20];
while (x) ch[++num]=x%10+48, x/=10;
while (num) *fe++=ch[num--];
*fe++='\n';
}
}
using IO::read;
using IO::write;
int ver[maxn<<1],edge[maxn<<1],Next[maxn<<1],head[maxn],len;
inline void add(int x,int y,int z)
{
ver[++len]=y,edge[len]=z,Next[len]=head[x],head[x]=len;
}
int dfn[maxn],num[maxn],id;
int f[maxn][21],dep[maxn];
ll dist[maxn];
inline void dfs(int x)
{
dfn[x]=++id;
num[id]=x;
for (int i=1; i<=20; ++i) f[x][i]=f[f[x][i-1]][i-1];
for (int i=head[x]; i; i=Next[i])
{
int y=ver[i];
if (y==f[x][0]) continue;
f[y][0]=x;
dep[y]=dep[x]+1;
dist[y]=dist[x]+(ll)edge[i];
dfs(y);
}
}
inline int LCA(int x,int y)
{
if (dep[x]>dep[y]) std::swap(x,y);
for (int i=20; i>=0; --i)
if (dep[y]-(1<<i)>=dep[x]) y=f[y][i];//
if (x==y) return x;
for (int i=20; i>=0; --i)
if (f[x][i]^f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
inline ll Dis(int x,int y)
{
return dist[x]+dist[y]-(dist[LCA(x,y)]<<1);
}
bool vis[maxn];
std::set<int>s;
std::set<int>::iterator it;
ll ans;
int main()
{
int n,m;read(n);read(m);
for (int i=1,x,y,z; i<n; ++i) read(x),read(y),read(z),add(x,y,z),add(y,x,z);
dfs(1);
for (int i=1,x,y,z; i<=m; ++i)
{
read(x);
x=dfn[x];
if (!vis[num[x]]) s.insert(x);
y=num[ (it=s.lower_bound(x)) ==s.begin() ? *--s.end() : *--it];
z=num[ (it=s.upper_bound(x)) ==s.end() ? *s.begin() : *it];
if (vis[num[x]]) s.erase(x);
x=num[x];
ll d=Dis(x,y)+Dis(x,z)-Dis(y,z);
if (!vis[x]) vis[x]=1, ans+=d;
else vis[x]=0, ans-=d;
write(ans);
}
IO::flush();
return 0;
}
原文:https://www.cnblogs.com/G-hsm/p/11386319.html