给你一棵TREE,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K
输入格式:
N(n<=40000) 接下来n-1行边描述管道,按照题目中写的输入 接下来是k
输出格式:
一行,有多少对点之间的距离小于等于k
k≤20000 对于任意一条管道边权wi?≤1000
与点分治1树上直接统计不同的是这题用的指针扫描数组
根据排序数组具有单调性,用了l,r两个指针来维护距离,同时利用容斥剪掉不合法的方案(如图)
即对于每个根节点的子树ans-=calc(to,e[i].val)(两个不合法节点之间的距离相对于根节点为e[i].val)
sort
#include<bits/stdc++.h> #define inf 0x3f3f3f3f using namespace std; const int maxn=40000+100; int head[maxn]; int s[maxn],ms[maxn]; int vis[maxn]; int p[maxn]; int d[maxn],dis[maxn]; struct edge { int to,next,val; }e[maxn<<2]; int size=0; int n,m; int sum,rt; int ans=0; inline int read() { int x=0,f=1;char ch=getchar(); while(ch<‘0‘||ch>‘9‘){if(ch==‘-‘)f=-1;ch=getchar();} while(ch>=‘0‘&&ch<=‘9‘){x=(x<<3)+(x<<1)+ch-‘0‘;ch=getchar();} return x*f; } void addedge(int u,int v,int w) { e[++size].to=v;e[size].val=w;e[size].next=head[u];head[u]=size; } void tc(int u,int fa) { s[u]=1;ms[u]=0; for(int i=head[u];i;i=e[i].next) { int to=e[i].to; if(to==fa||vis[to])continue; tc(to,u); s[u]+=s[to]; ms[u]=max(ms[u],s[to]); } ms[u]=max(ms[u],sum-ms[u]); if(ms[u]<ms[rt])rt=u; } void dfs(int u,int fa) { d[++d[0]]=dis[u]; for(int i=head[u];i;i=e[i].next) { int to=e[i].to; if(to==fa||vis[to])continue; dis[to]=dis[u]+e[i].val; dfs(to,u); } } int calc(int u,int x) { int l=1,r=0,res=0; d[0]=0,dis[u]=x; dfs(u,0); for(int i=1;i<=d[0];i++)if(d[i]<=m)p[++r]=d[i]; sort(p+1,p+1+r); while(l<=r) { if(p[l]+p[r]<=m) res+=r-l,++l; else --r; } return res; } void solve(int u) { vis[u]=1;ans+=calc(u,0); for(int i=head[u];i;i=e[i].next) { int to=e[i].to; if(vis[to])continue; ans-=calc(to,e[i].val); sum=s[to],ms[rt=0]=inf; tc(to,u),solve(rt); } } int main() { n=read(); for(int i=1;i<n;i++) { int u=read(),v=read(),w=read(); addedge(u,v,w); addedge(v,u,w); } m=read(); sum=n,ms[rt]=inf; tc(1,0); solve(rt); printf("%d",ans); return 0; }
桶
#include<bits/stdc++.h> #define inf 0x3f3f3f3f using namespace std; const int maxn=40000+100; int head[maxn]; int s[maxn],ms[maxn]; int vis[maxn]; int p[maxn],t[maxn]; int d[maxn],dis[maxn]; struct edge { int to,next,val; }e[maxn<<2]; int size=0; int n,m; int sum,rt; int ans=0; inline int read() { int x=0,f=1;char ch=getchar(); while(ch<‘0‘||ch>‘9‘){if(ch==‘-‘)f=-1;ch=getchar();} while(ch>=‘0‘&&ch<=‘9‘){x=(x<<3)+(x<<1)+ch-‘0‘;ch=getchar();} return x*f; } void addedge(int u,int v,int w) { e[++size].to=v;e[size].val=w;e[size].next=head[u];head[u]=size; } void tc(int u,int fa) { s[u]=1;ms[u]=0; for(int i=head[u];i;i=e[i].next) { int to=e[i].to; if(to==fa||vis[to])continue; tc(to,u); s[u]+=s[to]; ms[u]=max(ms[u],s[to]); } ms[u]=max(ms[u],sum-ms[u]); if(ms[u]<ms[rt])rt=u; } void dfs(int u,int fa) { d[++d[0]]=dis[u]; for(int i=head[u];i;i=e[i].next) { int to=e[i].to; if(to==fa||vis[to])continue; dis[to]=dis[u]+e[i].val; dfs(to,u); } } int calc(int u,int x) { int l=1,r=0,res=0; d[0]=0,dis[u]=x; dfs(u,0); for(int i=1;i<=d[0];i++)if(d[i]<=m)++p[d[i]]; for(int i=0;i<=m;i++) for(int j=1;j<=p[i];j++)t[++r]=i; while(l<=r) { if(t[l]+t[r]<=m) res+=r-l,++l; else --r; } for(int i=1;i<=d[0];i++) p[d[i]]=0; return res; } void solve(int u) { vis[u]=1;ans+=calc(u,0); for(int i=head[u];i;i=e[i].next) { int to=e[i].to; if(vis[to])continue; ans-=calc(to,e[i].val); sum=s[to],ms[rt=0]=inf; tc(to,u),solve(rt); } } int main() { n=read(); for(int i=1;i<n;i++) { int u=read(),v=read(),w=read(); addedge(u,v,w); addedge(v,u,w); } m=read(); sum=n,ms[rt]=inf; tc(1,0); solve(rt); printf("%d",ans); return 0; }
原文:https://www.cnblogs.com/DriverBen/p/10999295.html