You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.
We will ask you to perform the following operation:
In the first line there are two integers N and M. (N <= 40000, M <= 100000)
In the second line there are N integers. The i-th integer denotes the weight of the i-th node.
In the next N-1 lines, each line contains two integers u v, which describes an edge (u, v).
In the next M lines, each line contains two integers u v, which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.
For each operation, print its result.
Input: 8 2 105 2 9 3 8 5 7 7 1 2 1 3 1 4 3 5 3 6 3 7 4 8 2 5 7 8
Output: 4 4
分析:求树上任意两点路径上权值的不同的个数;
树上莫队的一般步骤为:
预处理出树的LCA,以及树上每个顶点进来的时间st和退出的时间ed,以及时间戳对应的顶点序号seq;
询问u→v(设st[u]≤st[v])的路径,无非是以下三种情况:
当u==v时,路径上包含一个点;
1.当u==lca(u,v)时,那么链上的点就是时间戳从st[u]~st[v]中的所有出现一次的点,即seq中下标从st[u]~st[v]且只出现一次的点。
2.当u!=lca(u,v)时,那么链上的点就是时间戳从ed[u]~st[v]中所有出现一次的点+lca(u,v)
代码:
#include <iostream> #include <cstdio> #include <cstdlib> #include <cmath> #include <algorithm> #include <climits> #include <cstring> #include <string> #include <set> #include <bitset> #include <map> #include <queue> #include <stack> #include <vector> #include <cassert> #include <ctime> #define rep(i,m,n) for(i=m;i<=(int)n;i++) #define inf 0x3f3f3f3f #define mod 1000000007 #define vi vector<int> #define pb push_back #define mp make_pair #define fi first #define se second #define ll long long #define pi acos(-1.0) #define pii pair<int,int> #define sys system("pause") #define ls (rt<<1) #define rs (rt<<1|1) #define all(x) x.begin(),x.end() const int maxn=1e5+10; const int N=1e5; using namespace std; ll gcd(ll p,ll q){return q==0?p:gcd(q,p%q);} ll qmul(ll p,ll q,ll mo){ll f=0;while(q){if(q&1)f=(f+p)%mo;p=(p+p)%mo;q>>=1;}return f;} ll qpow(ll p,ll q){ll f=1;while(q){if(q&1)f=f*p%mod;p=p*p%mod;q>>=1;}return f;} int n,m,k,t,dfn[maxn],st[maxn],ed[maxn],bl[maxn],a[maxn],d[maxn],tot,cnt[maxn],fa[20][maxn],ans[maxn],dep[maxn],cl,ret; bool vis[maxn]; struct node { int l,r,fa,id; bool operator<(const node&p)const { return bl[l]==bl[p.l]?r<p.r:bl[l]<bl[p.l]; } }qu[maxn]; vi e[maxn]; void dfs(int x,int y) { int i; dfn[++cl]=x; st[x]=cl; dep[x]=dep[y]+1; fa[0][x]=y; for(i=1;fa[i-1][fa[i-1][x]];i++)fa[i][x]=fa[i-1][fa[i-1][x]]; rep(i,0,e[x].size()-1) { int z=e[x][i]; if(z==y)continue; dfs(z,x); } dfn[++cl]=x; ed[x]=cl; } int lca(int x,int y) { int i; if(dep[x]<dep[y])swap(x,y); for(i=19;i>=0;i--)if(dep[fa[i][x]]>=dep[y])x=fa[i][x]; if(x==y)return x; for(i=19;i>=0;i--) { if(fa[i][x]!=fa[i][y]) { x=fa[i][x], y=fa[i][y]; } } return fa[0][x]; } void modify(int x) { if(vis[dfn[x]]^=1) { if(++cnt[a[dfn[x]]]==1)++ret; } else { if(--cnt[a[dfn[x]]]==0)--ret; } } int main(){ int i,j; scanf("%d%d",&n,&m); rep(i,1,n)scanf("%d",&a[i]),d[i]=a[i]; rep(i,1,n-1) { int x,y; scanf("%d%d",&x,&y); e[x].pb(y),e[y].pb(x); } sort(d+1,d+n+1); tot=unique(d+1,d+n+1)-d-1; rep(i,1,n)a[i]=lower_bound(d+1,d+tot+1,a[i])-d; dfs(1,0); int sz=round(sqrt(2*n)+0.5); rep(i,1,2*n)bl[i]=(i-1)/sz+1; rep(i,1,m) { int x,y; scanf("%d%d",&x,&y); if(st[x]>st[y])swap(x,y); int fa=lca(x,y); if(fa==x) { qu[i]=node{st[x],st[y],fa,i}; }else qu[i]=node{ed[x],st[y],fa,i}; } sort(qu+1,qu+m+1); int l=1,r=0; rep(i,1,m) { while(r<qu[i].r)modify(++r); while(l>qu[i].l)modify(--l); while(r>qu[i].r)modify(r--); while(l<qu[i].l)modify(l++); if(qu[i].fa!=dfn[qu[i].l])modify(st[qu[i].fa]); ans[qu[i].id]=ret; if(qu[i].fa!=dfn[qu[i].l])modify(st[qu[i].fa]); } rep(i,1,m)printf("%d\n",ans[i]); return 0; }
原文:http://www.cnblogs.com/dyzll/p/7896957.html