#include<iostream> #include<cstring> #include<cstdio> #include<algorithm> #include<cmath> #include<ctime> #include<climits> using namespace std; inline int read(){ int f=1,ans=0;char c=getchar(); while(c<‘0‘||c>‘9‘){if(c==‘-‘)f=-1;c=getchar();} while(c>=‘0‘&&c<=‘9‘){ans=ans*10+c-‘0‘;c=getchar();} return f*ans; } const int N=300001; struct node{ int l,r,cnt,num,size,rnk; }tr[N]; int n,root,tot; void update(int k){ tr[k].size=tr[k].cnt; tr[k].size+=tr[tr[k].l].size; tr[k].size+=tr[tr[k].r].size;return; } void zag(int &k){ int tp=tr[k].l; tr[k].l=tr[tp].r; tr[tp].r=k; tr[tp].size=tr[k].size; update(k); k=tp; return; } void zig(int &k){ int tp=tr[k].r; tr[k].r=tr[tp].l; tr[tp].l=k; tr[tp].size=tr[k].size; update(k); k=tp; return; } void insert(int x,int &k){ if(k==0){ k=++tot; tr[k].cnt=tr[k].size=1;tr[k].num=x; tr[k].rnk=rand(); return; } tr[k].size++; if(x==tr[k].num){tr[k].cnt++;return;} if(x<tr[k].num){ insert(x,tr[k].l); if(tr[tr[k].l].rnk<tr[k].rnk)zag(k); }else{ insert(x,tr[k].r); if(tr[tr[k].r].rnk<tr[k].rnk) zig(k); } return; } void del(int x,int &k){ if(x==tr[k].num){ if(tr[k].cnt>1){tr[k].size--,tr[k].cnt--;return;} if(tr[k].l*tr[k].r==0) {k=tr[k].l+tr[k].r;return;} if(tr[tr[k].l].rnk<tr[tr[k].r].rnk){zag(k);del(x,k);return;} else{zig(k);del(x,k);return;} } tr[k].size--; if(x<tr[k].num) del(x,tr[k].l); else del(x,tr[k].r); return; } int rank_x(int x,int k){ if(x==tr[k].num) return tr[tr[k].l].size+1; if(x<tr[k].num) return rank_x(x,tr[k].l); return tr[tr[k].l].size+tr[k].cnt+rank_x(x,tr[k].r); } int rank(int x,int k){ if(k==0) return 0; if(tr[tr[k].l].size<x&&tr[tr[k].l].size+tr[k].cnt>=x) return tr[k].num; if(x<=tr[tr[k].l].size) return rank(x,tr[k].l); return rank(x-tr[tr[k].l].size-tr[k].cnt,tr[k].r); } int pre(int x,int k){ if(k==0) return INT_MIN; if(x<=tr[k].num) return pre(x,tr[k].l); return max(tr[k].num,pre(x,tr[k].r)); } int nex(int x,int k){ if(k==0) return INT_MAX; if(x>=tr[k].num) return nex(x,tr[k].r); return min(tr[k].num,nex(x,tr[k].l)); } int main(){ srand(time(0)); n=read(); while(n--){ int opt=read(),x=read(); if(opt==1) insert(x,root); if(opt==2) del(x,root); if(opt==3) printf("%d\n",rank_x(x,root)); if(opt==4) printf("%d\n",rank(x,root)); if(opt==5) printf("%d\n",pre(x,root)); if(opt==6) printf("%d\n",nex(x,root)); } return 0; }
原文:https://www.cnblogs.com/si-rui-yang/p/10190089.html