splay模板题
#include <cstdio> typedef long long LL; #define N 300500 LL data[N]; int siz[N],cnt[N],ch[N][2],fa[N],root,cn,n; inline int son(int x) {return ch[fa[x]][1]==x;} inline void pushup(int rt) { int l=ch[rt][0],r=ch[rt][1]; siz[rt]=cnt[rt]+siz[l]+siz[r]; } inline void rotate(int x) { int y=fa[x],z=fa[y],b=son(x),c=son(y),a=ch[x][!b]; if(z) ch[z][c]=x; else root=x; fa[x]=z; if(a) fa[a]=y; ch[y][b]=a; ch[x][!b]=y; fa[y]=x; pushup(y); pushup(x); } void splay(int x,int i) { for(;fa[x]!=i;) { int y=fa[x],z=fa[y]; if(z==i) rotate(x); else { if(son(x)==son(y)) { rotate(y); rotate(x); } else { rotate(x); rotate(x); } } } } void ins(int &rt,LL x) { if(!rt) { rt=++cn; data[cn]=x; siz[cn]=cnt[cn]=1; splay(cn,0); return; } if(data[rt]==x) { cnt[rt]++; siz[rt]++; splay(rt,0); return; } if(x<data[rt]) { ins(ch[rt][0],x); fa[ch[rt][0]]=rt; pushup(rt); } else { ins(ch[rt][1],x); fa[ch[rt][1]]=rt; pushup(rt); } } int getmn(int rt) { int p=rt,ans=-1; for(;p;p=ch[p][0]) ans=p; return ans; } void del(int rt,LL x) { if(data[rt]==x) { if(cnt[rt]>1) { cnt[rt]--; siz[rt]--; } else { splay(rt,0); int p=getmn(ch[rt][1]); if(p!=-1) { splay(p,rt); root=p; fa[p]=0; ch[p][0]=ch[rt][0]; fa[ch[rt][0]]=p; } else { root=ch[rt][0]; fa[ch[rt][0]]=0; } } return; } if(x<data[rt]) { del(ch[rt][0],x); pushup(rt); } else { del(ch[rt][1],x); pushup(rt); } } int getkth(int rt,int k) { int l=ch[rt][0]; if(siz[l]+1<=k&&k<=siz[l]+cnt[rt]) return rt; if(siz[l]+1>k) return getkth(ch[rt][0],k); else if(k>siz[l]+cnt[rt]) return getkth(ch[rt][1],k-(siz[l]+cnt[rt])); } int get_suc(int rt,LL x) { int p=rt,ret=-1; for(;p;) { if(x>=data[p]) p=ch[p][1]; else { ret=p; p=ch[p][0]; } } return ret; } int get_pre(int rt,LL x) { int p=rt,ret=-1; for(;p;) { if(x<=data[p]) p=ch[p][0]; else { ret=p; p=ch[p][1]; } } return ret; } int get_pos(int rt,LL x) { if(data[rt]==x) return rt; if(x<data[rt]) return get_pos(ch[rt][0],x); else return get_pos(ch[rt][1],x); } int Main() { scanf("%d",&n); for(LL opt,x,pos,flag;n--;) { scanf("%lld%lld",&opt,&x); if(!opt) ins(root,x); else if(opt==1) del(root,x); else if(opt==2) printf("%lld\n",data[getkth(root,x)]); else if(opt==3) ins(root,x),pos=get_pos(root,x),splay(pos,0),printf("%d\n",siz[ch[root][0]]),del(root,x); else if(opt==4) pos=get_pre(root,x),pos==-1?printf("-1\n"):printf("%lld\n",data[pos]); else pos=get_suc(root,x),pos==-1?printf("-1\n"):printf("%lld\n",data[pos]); } return 0; } int sb=Main(); int main(int argc,char *argv[]) {;}
原文:http://www.cnblogs.com/ruojisun/p/7517893.html