看到这种路径统计问题,一般就想到要用点分治去做。
对于每个重心\(u\),统计经过\(u\)的合法的路径之中的最大值。
第一类路径是从\(u\)出发的,直接逐个子树深搜统计就可以了。第二类路径是由两棵不同子树中的两条第一类路径拼接而成的。
如果仅仅是统计长度在\([l,r]\)之间的路径有多少条,经典的统计+容斥做法就可以解决。然而现在的问题比较复杂,一来不好容斥,二来两两路径配对需要有判定条件:两条路径的接口边颜色是否相同。
我们可以采用一种不需要容斥的做法:逐一枚举子树,并逐一考虑由子树内的每个点出发去其他子树的路径。枚举到当前这棵子树内的某个节点\(x\)时,我们用\(sum[d]\)记录在之前的子树中,相对于\(u\)深度为\(d\)的点的某些信息。那么对于一个点\(x\),它可以接上的路径的信息应该有\(sum[i],\;i\in[l-dep_x,r-dep_x]\)。
所以\(sum[d]\)记录的是什么呢?按照题目意思,我们应该记录所有从深度为\(d\)的点出发到\(u\)路径中的路径最大值。可是注意这个最大值不一定相对所有询问点来说都是最大值,有可能此最大值对应的路径\(\alpha\)与当前询问点路径接口边相同,权值重复,需要减去一次,说不定就不是最大的了。所以,我们额外要记录一个接口颜色异于最大值路径\(\alpha\)的另一条权值最大路径\(\beta\)。
这样一来,对于当前枚举点对应的路径\(\gamma\),想要和先前子树中一条长度为\(d\)的路径结合时,有\(。先看sum[d]=(\alpha,\beta)。先看\)\(\alpha\)的接口边颜色是否和\(\gamma\)相同,如果是,取\(\max\{\alpha-c,\beta\}\)作为最大路径和\(\gamma\)拼接(\(c\)表示\(\alpha\)和\(\gamma\)的接口边颜色);否则,取\(\alpha\)拼接。
有意思的是,因为我们是一个一个子树进行处理,所以当前枚举的所有点的接口边颜色都是一样的,也就意味着不同点尝试配对同一个\(sum\)时的选择都是完全一样的。因此每一个\(sum\)的贡献是定值。我们相当于有一个数列,每次求\([l-dep_x,r-dep_x]\)中的最大值。如果我们把点按照深度来排序,那么这个范围就是一个滑动窗口,可以使用单调队列进行\(\mathcal O(n)\)解决。
#include <cstdio>
#include <queue>
#include <algorithm>
using namespace std;
const int N=200005,INF=2e9+5;
int n,m,l,r,cv[N],all;
int best,bestval,size[N];
int udep[N],cc[N];
int lis[N],lcnt,info[N][3];
deque<int> q;
int qv[N];
bool cut[N],vis[N];
int ans;
struct Data{
int v1,c1,v2,c2;
Data(){v1=v2=-INF; c1=c2=-1;}
inline void insert(int v,int c){
if(c!=c1){
if(v>v1)
v2=v1,c2=c1,v1=v,c1=c;
else if(v>v2)
v2=v,c2=c;
}
else if(v>v1) v1=v;
}
int get(int c){
return c==c1?max(v1-cv[c],v2):v1;
}
}s[N];
int h[N],tot;
struct Edge{int v,c,next;}e[N*2];
inline int min(int x,int y){return x<y?x:y;}
inline int max(int x,int y){return x>y?x:y;}
inline void addEdge(int u,int v,int c){
e[++tot]=(Edge){v,c,h[u]}; h[u]=tot;
e[++tot]=(Edge){u,c,h[v]}; h[v]=tot;
}
void getRt(int u,int fa,int sz){
int maxsub=-1;
size[u]=1;
for(int i=h[u],v;i;i=e[i].next)
if((v=e[i].v)!=fa&&!cut[v]){
getRt(v,u,sz);
size[u]+=size[v];
maxsub=max(maxsub,size[v]);
}
maxsub=max(maxsub,sz-size[u]);
if(maxsub<bestval)
bestval=maxsub,best=u;
}
bool cmp(const int &x,const int &y){return udep[x]<udep[y];}
void dfs(int u,int fa,int dep,int &nowx){
nowx=max(nowx,dep);
for(int i=h[u],v;i;i=e[i].next)
if((v=e[i].v)!=fa&&!cut[v])
dfs(v,u,dep+1,nowx);
}
void insert(int u,int fa,int dep,int topc,int fac,int val){
vis[u]=false;
if(l<=dep&&dep<=r) ans=max(ans,val);
s[dep].insert(val,topc);
for(int i=h[u],v;i;i=e[i].next)
if((v=e[i].v)!=fa&&!cut[v])
insert(v,u,dep+1,topc,e[i].c,fac==e[i].c?val:val+cv[e[i].c]);
}
void collect(int x,int topc){
int head=1,tail=1;
lis[1]=x;
info[x][0]=1; info[x][1]=cv[topc]; info[x][2]=topc;
vis[x]=true;
while(head<=tail){
int u=lis[head++];
for(int i=h[u],v;i;i=e[i].next)
if(!cut[v=e[i].v]&&!vis[v]){
lis[++tail]=v;
vis[v]=true;
info[v][0]=info[u][0]+1;
info[v][1]=info[u][1]+(e[i].c==info[u][2]?0:cv[e[i].c]);
info[v][2]=e[i].c;
}
}
lcnt=tail;
}
void solve(int x,int sz){
best=-1; bestval=INF;
getRt(x,0,sz);
int u=best;
cut[u]=true;
static int son[N],cnt;
cnt=0;
int curr=-1;
for(int i=h[u],v;i;i=e[i].next)
if(!cut[v=e[i].v]){
son[++cnt]=v;
cc[v]=e[i].c;
dfs(v,0,1,udep[v]);
curr=max(curr,udep[son[cnt]]);
}
curr=min(r,curr);
for(int i=1;i<=curr;i++) s[i]=Data();
vis[u]=true;
for(int i=1,v;i<=cnt;i++){
v=son[i];
collect(v,cc[v]);
while(!q.empty()) q.pop_back();
int nowl,nowr,qr=0;
bool first=true;
for(int j=lcnt;j>=1;j--)
if(info[lis[j]][0]<r){
int x=lis[j];
nowl=max(1,l-info[x][0]); nowr=r-info[x][0];
if(first){
first=false;
qr=nowl-1;
while(qr<nowr&&qr<curr&&s[qr+1].v1>-INF){
int newval=s[++qr].get(cc[v]);
if(newval>-INF){
while(!q.empty()&&qv[q.back()]<=newval) q.pop_back();
q.push_back(qr);
qv[qr]=newval;
}
}
}
while(!q.empty()&&q.front()<nowl) q.pop_front();
while(qr<nowr&&qr<curr&&s[qr+1].v1>-INF){
int newval=s[++qr].get(cc[v]);
if(newval>-INF){
while(!q.empty()&&qv[q.back()]<=newval) q.pop_back();
q.push_back(qr);
qv[qr]=newval;
}
}
if(!q.empty())
ans=max(ans,info[x][1]+qv[q.front()]);
}
insert(v,0,1,cc[v],cc[v],cv[cc[v]]);
}
for(int i=h[u],v;i;i=e[i].next)
if(!cut[v=e[i].v])
solve(v,size[v]>size[u]?(sz-size[u]):size[v]);
}
int main(){
scanf("%d%d%d%d",&n,&m,&l,&r);
for(int i=1;i<=m;i++) scanf("%d",cv+i);
for(int i=1,u,v,c;i<n;i++){
scanf("%d%d%d",&u,&v,&c);
addEdge(u,v,c);
}
ans=-INF;
solve(1,n);
printf("%d\n",ans);
return 0;
}
?
原文:https://www.cnblogs.com/RogerDTZ/p/9310131.html