void split(nod o,nod &a,nod &b,int val)
{
if(o==null) //递归边界,当前位置为空,分裂后当然都为空
{
a=b=null;
return;
}
if(o->val<=val) a=o,split(o->ch[1],a->ch[1],b,val);
//小于等于val的要在a里面,所以先直接让a=o,为什么呢
//既然o->val<=val,显然o的左子树所有值都小于val,因此这些点都是a的
//但是我们不能保证o右子树的所有点<=val,因此递归向下来构建a的右子树,本层对b无贡献,所以还是b
else b=o,split(o->ch[0],a,b->ch[0],val);
//同上
o->upd();
//别忘了维护性质
//为什么要维护性质呢
//其实那个if else应该这么写
/*
if(o->val<=val) a=o,split(o->ch[1],a->ch[1],b,val),a->upd();
else b=o,split(o->ch[0],a,b->ch[0],val),b->upd();
*/
//a或b树会改变,所以要维护
//但其实已经让a=o或者b=o了,所以直接维护o即可
}
void merge(nod &o,nod a,nod b)
{
if(a==null||b==null) //有一个为空,则等于另一个(如果另一个也是空其实就是空了)
{
o=a==null? b:a; //为不空的那个
return;
}
if(a->key<=b->key) o=a,merge(o->ch[1],a->ch[1],b);
//这个key就是rand,不解释
//方法跟split差不多,这样也好记qwq
//反正瞎搞总比不搞弄成一条链强。。。。。。
//这样就可以使极端情况尽量少
else o=b,merge(o->ch[0],a,b->ch[0]);
o->upd();
//别忘了维护性质
}
inline void ins(int val)
{
nod x=null,y=null;
//定义两个空节点
//作用:一会分裂的时候作为两棵树的根,起一个承接作用
nod z=newnode(val);
//定义要插入的节点
split(root,x,y,val);
//因为要保证平衡树的性质,所以插入的位置必须要合适
//我们把所有<=val的点都分给x,剩下的分给y
//这样原来以root为根的数分成了两个
//我们要把z插进去
//怎么插♂呢
//可以把z一个点看成一棵树
//直接暴力合并就行了
merge(x,x,z);
merge(root,x,y);
}
//没了?
//没了!
inline void del(int val)
{
nod x=null,y=null,z=null;
split(root,x,y,val);
split(x,x,z,val-1);
//树x的所有点权都小于val
//树y的所有点权都大于val
//综上,树z的点权等于val
//所以。。。。。。
merge(z,z->ch[0],z->ch[1]);
//我们只删除一个val,所以剩下的要合并,别忘了
merge(x,x,z);
merge(root,x,y);
//把分崩离析(<----瞎用成语)的树恢复原状
}
inline int rnk(int val)
{
nod x=null,y=null;
split(root,x,y,val-1);
//把所有小于val的点分走
int t=x->siz+1;
//x作为所有合法点的根,他的大小不正是我们要找的比val小的数的个数吗?
//加一就是排名!
merge(root,x,y);
//不要过于兴奋,你的树还没有合并!!!!
return t;
}
inline nod kth(nod o,int rank)
{
//第k大不就是排名为k的数么
//这不就是操作3的逆操作吗
while(o->ch[0]->siz+1!=rank) //暴力找。。。
if(o->ch[0]->siz>=rank) o=o->ch[0]; //说明那个数在左子树
else rank-=o->ch[0]->siz+1,o=o->ch[1];
//那个数在右子树,注意,这里要减去左子树大小和自己,因为到了下面,上面比自己小的就统计不到了,
//反正都是比自己小的,直接减去最好
//理解一下
return o;
}
inline nod pre(int val)
{
nod x=null,y=null;
split(root,x,y,val-1);
//分离所有小于y的数
nod z=kth(x,x->siz);
//既然pre为小于val的数中最大的一个,我们就找x树中的最大的那个不就行了?
merge(root,x,y);
//别忘了合并
return z;
}
inline nod nxt(int val)
{
//跟上面只是稍稍有点不同而已
nod x=null,y=null;
split(root,x,y,val);
//把所有小于等于val的点都分走,注意这里可以取等号!
//那么y中的点都大于val
//在其中找最小的
nod z=kth(y,1);
merge(root,x,y);
//别忘合并
return z;
}
#include<cstdio>
#include<queue>
#include<vector>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cctype>
#include<cmath>
#define _ 0
#define LL long long
#define Space putchar(' ')
#define Enter putchar('\n')
#define fuu(x,y,z) for(int x=(y);x<=(z);x++)
#define fu(x,y,z) for(int x=(y);x<(z);x++)
#define fdd(x,y,z) for(int x=(y);x>=(z);x--)
#define fd(x,y,z) for(int x=(y);x>(z);x--)
#define mem(x,y) memset(x,y,sizeof(x))
const int max=1e5+5;
struct node
{
node *ch[2];
int siz,val,key;
node() {siz=val=key=0;}
inline void upd() {siz=ch[0]->siz+ch[1]->siz+1;}
}s[max];
typedef node* nod;
nod root;
nod null;
int cnt;
int n;
inline nod newnode(int k)
{
cnt++;
s[cnt].ch[0]=s[cnt].ch[1]=null;
s[cnt].key=rand(); s[cnt].siz=1; s[cnt].val=k;
return &s[cnt];
}
void split(nod o,nod &a,nod &b,int val)
{
if(o==null)
{
a=b=null;
return;
}
if(o->val<=val) a=o,split(o->ch[1],a->ch[1],b,val),a->upd();
else b=o,split(o->ch[0],a,b->ch[0],val),b->upd();
}
void merge(nod &o,nod a,nod b)
{
if(a==null||b==null)
{
o=a==null? b:a;
return;
}
if(a->key<=b->key) o=a,merge(o->ch[1],a->ch[1],b);
else o=b,merge(o->ch[0],a,b->ch[0]);
o->upd();
}
inline nod kth(nod o,int rank)
{
while(o->ch[0]->siz+1!=rank)
if(o->ch[0]->siz>=rank) o=o->ch[0];
else rank-=o->ch[0]->siz+1,o=o->ch[1];
return o;
}
inline void ins(int val)
{
nod x=null,y=null;
nod z=newnode(val);
split(root,x,y,val);
merge(x,x,z);
merge(root,x,y);
}
inline void del(int val)
{
nod x=null,y=null,z=null;
split(root,x,y,val);
split(x,x,z,val-1);
merge(z,z->ch[0],z->ch[1]);
merge(x,x,z);
merge(root,x,y);
}
inline int rnk(int val)
{
nod x=null,y=null;
split(root,x,y,val-1);
int t=x->siz+1;
merge(root,x,y);
return t;
}
inline nod pre(int val)
{
nod x=null,y=null;
split(root,x,y,val-1);
nod z=kth(x,x->siz);
merge(root,x,y);
return z;
}
inline nod nxt(int val)
{
nod x=null,y=null;
split(root,x,y,val);
nod z=kth(y,1);
merge(root,x,y);
return z;
}
int main()
{
std::ios::sync_with_stdio(false);
std::cin>>n;
null=new node(); null->ch[0]=null->ch[1]=null;
root=null;
ins(0x7fffffff);
int flag,x;
while(n--)
{
std::cin>>flag>>x;
if(flag==1) ins(x);
if(flag==2) del(x);
if(flag==3) std::cout<<rnk(x)<<"\n";
if(flag==4) std::cout<<kth(root,x)->val<<"\n";
if(flag==5) std::cout<<pre(x)->val<<"\n";
if(flag==6) std::cout<<nxt(x)->val<<"\n";
}
return ~~(0^_^0);
}
原文:https://www.cnblogs.com/olinr/p/10011918.html