给定一颗边权为1的树,点权为v,给定m条从s到t的路径,对于每个点,求\(ans_i=\Sigma_{j=1}^m [dist(s_j,i)==v[i]]\)
有两种可能的做法,一种是把路径全加进去,再每一个点求\(ans\),另一种是一条一条加路径,每次求贡献。而这道题用的是第一种
对于每一条路径,将\(s->t\)的路径拆分为\(s->lca->t\),\(s->lca\)为左路径,\(lca->t\)为右路径,对于左路径上的点\(i\),左路径对它有贡献当且仅当\(dep[s]-dep[i]==v[i]\),将\(i\)项移动到一边,就有\(dep[s]==dep[i]+v[i]\),这样就可以把路径全部加进去再求\(i\)点的贡献。右路径同理有\(dep[s]-2*dep[lca]==v[i]-dep[i]\)
\(dfs\)整颗树,从下向上回溯时求\(ans\),对于回溯时遇到的点\(i\),加入以它为\(s/t\)的左右路径,求一次\(ans\),退出时减去以它为\(lca\)的左右路径即可,然后就发现过不了样例
(假设当前在以\(u\)为根的子树中\(dfs\))因为在\(u\)的某一珂子树\(v\)中的时候,\(u\)的其他子树里面的路径也被记录下来了,这时就会导致一条路径对不在它上面的点做出贡献。
解决方法:因为求的是满足上面两个约束的路径条数,那么在递归进入\(v\)子树加入里面的路径之前,先减去此时满足约束条件的路径条数,这样就保证了做出贡献的都是\(v\)子树里面出发的路径
另外,如果\(i\)是路径的\(lca\)的话可能被计算两次(左右路径各一次),所以要减1,即当满足\(v[lca]==dep[s]-dep[lca]\)的时候减1
Code:
#include<bits/stdc++.h>
#define N 900005
#define M 300005
using namespace std;
const int temp = 300000;
int n,m;
int dep[M],w[M],fa[M][18],ans[M];
int lsum[N],rsum[N];//桶
vector<int> L[M],R[M];//lca在上面,出
vector<int> l[M],r[M];//s,t在下面,入
struct Edge
{
int next,to;
}edge[M<<1];int head[M],cnt=1;
void add_edge(int from,int to)
{
edge[++cnt].next=head[from];
edge[cnt].to=to;
head[from]=cnt;
}
template <class T>
void read(T &x)
{
char c;int sign=1;
while((c=getchar())>'9'||c<'0') if(c=='-') sign=-1; x=c-48;
while((c=getchar())>='0'&&c<='9') x=x*10+c-48; x*=sign;
}
void dfs(int rt)
{
dep[rt]=dep[fa[rt][0]]+1;
for(int i=head[rt];i;i=edge[i].next)
{
int v=edge[i].to;
if(v==fa[rt][0]) continue;
fa[v][0]=rt;
for(int i=1;i<18;++i) fa[v][i]=fa[fa[v][i-1]][i-1];
dfs(v);
}
}
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int i=17;i>=0;--i) if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if(x==y) return x;
for(int i=17;i>=0;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void DFS(int rt)
{
ans[rt]-=lsum[dep[rt]+w[rt]+temp]+rsum[w[rt]-dep[rt]+temp];
for(int i=head[rt];i;i=edge[i].next)
{
int v=edge[i].to;
if(v==fa[rt][0]) continue;
DFS(v);
}
//加入rt
for(int i=0;i<(int)l[rt].size();++i)
{
int val=l[rt][i]+temp;
++lsum[val];
}
for(int i=0;i<(int)r[rt].size();++i)
{
int val=r[rt][i]+temp;
++rsum[val];
}
ans[rt]+=lsum[dep[rt]+w[rt]+temp]+rsum[w[rt]-dep[rt]+temp];
//删除以rt为lca的链
for(int i=0;i<(int)L[rt].size();++i)
{
int val=L[rt][i]+temp;
--lsum[val];
}
for(int i=0;i<(int)R[rt].size();++i)
{
int val=R[rt][i]+temp;
--rsum[val];
}
}
int main()
{
read(n);read(m);
for(int i=1;i<n;++i)
{
int x,y;
read(x);read(y);
add_edge(x,y);
add_edge(y,x);
}
dfs(1);
for(int i=1;i<=n;++i) read(w[i]);
for(int i=1;i<=m;++i)
{
int S,T;
read(S);read(T);
int lc=lca(S,T);
if(w[lc]==dep[S]-dep[lc]) ans[lc]--;
L[lc].push_back(dep[S]);
R[lc].push_back(dep[S]-2*dep[lc]);
l[S].push_back(dep[S]);
r[T].push_back(dep[S]-2*dep[lc]);
}
DFS(1);
for(int i=1;i<=n;++i) printf("%d ",ans[i]);
return 0;
}
原文:https://www.cnblogs.com/Chtholly/p/11370320.html