Solution
对于正在处理的以 u 为根的子树,将子树内路径分成两种
对于第一类,遍历整棵树,将子树内每个点(包括根)放入数组 a ,
计算出它到根的路径长度以及属于根的哪一个儿子,
双向遍历 a 数组,计算对答案的贡献。
对于第二类,处理去掉 u 后分成的每棵子树
复杂度O(n2),每次处理时以子树的重心为根可使层数最小,复杂度O(nlog(n))
Code
#include <cstdio> #include <cstdlib> #include <algorithm> using namespace std; const int N=1e4+1,M=101; int q[M],ans[M],b[N],n,m,dis[N],d[N],si[N],rt,sum; int a[N],u,v,w,head[N],nxt[N*2],ver[N*2],edge[N*2],tot; bool vis[N]; inline char get() { static char buf[1024]; static int pos=0,size=0; if(pos==size) { size=fread(buf,1,1024,stdin); pos=0; if(!size) return EOF; else return buf[pos++]; } else return buf[pos++]; } int read() { int sum=0,fh=1; char ch=get(); while(!(ch>=‘0‘ && ch<=‘9‘)) { if(ch==‘-‘) fh=-1; ch=get(); } while(ch>=‘0‘ && ch<=‘9‘ && ch!=EOF) sum=sum*10+ch-48,ch=get(); return sum*fh; } void add(int u,int v,int w) { ver[++tot]=v,nxt[tot]=head[u],edge[tot]=w,head[u]=tot; ver[++tot]=u,nxt[tot]=head[v],edge[tot]=w,head[v]=tot; } void getrt(int u,int fa) { si[u]=1,d[u]=0; for(int i=head[u];i;i=nxt[i]) { int v=ver[i]; if(v==fa || vis[v]) continue; getrt(v,u); si[u]+=si[v]; d[u]=max(d[u],si[v]); } d[u]=max(d[u],sum-si[u]); rt=d[rt]>d[u]?u:rt; } void getdis(int u,int fa,int di,int from){ a[++a[0]]=u; dis[u]=di; b[u]=from; for(int i=head[u];i;i=nxt[i]){ int v=ver[i]; if(v==fa||vis[v])continue; getdis(v,u,di+edge[i],from); } } bool cmp(int x,int y) { return dis[x]<dis[y]; } void calc(int u){ a[0]=0,a[++a[0]]=u; dis[u]=0,b[u]=u; for(int i=head[u];i;i=nxt[i]) { int v=ver[i]; if(vis[v]) continue; dis[v]=edge[i]; getdis(v,u,edge[i],v); } sort(a+1,a+a[0]+1,cmp); for(int i=0;i<m;i++) { int l=1,r=a[0]; if(ans[i]) continue; while(l<r) { if(dis[a[l]]+dis[a[r]]>q[i]) r--; else if(dis[a[l]]+dis[a[r]]<q[i]) l++; else if(b[a[l]]==b[a[r]]) { if(dis[a[r]]==dis[a[r-1]]) r--; else l++; } else { ans[i]=1; break; } } } } void dfs(int u) { vis[u]=true,calc(u); for(int i=head[u];i;i=nxt[i]) { int v=ver[i]; if(vis[v]) continue; sum=d[rt=0]=si[v]; getrt(v,0),dfs(rt); } } int main() { n=read(),m=read(); for(int i=1;i<n;i++) { u=read(),v=read(),w=read(); add(u,v,w); } for(int i=0;i<m;i++) { q[i]=read(); if(q[i]==0) ans[i]=1; } d[rt=0]=sum=n; getrt(1,0); dfs(rt); for(int i=0;i<m;i++) if(ans[i]) puts("AYE"); else puts("NAY"); return 0; }
原文:https://www.cnblogs.com/hsez-cyx/p/12400351.html