平衡树白痴,学了忘忘了学,平均比人晚两年
主要是双旋操作。
上图所说的
void rotate(int x){
int y=tr[x].fa,z=tr[y].fa,k=(tr[y].ch[1]==x);
tr[z].ch[(tr[z].ch[1]==y)]=x; tr[x].fa=z;
tr[y].ch[k]=tr[x].ch[k^1]; tr[tr[x].ch[k^1]].fa=y;
tr[x].ch[k^1]=y; tr[y].fa=x;
update(y); update(x);
return;
}
void splay(int x,int goal){
// cout<<x<<" "<<goal<<endl;
while(tr[x].fa!=goal){
int y=tr[x].fa,z=tr[y].fa;
if(z!=goal)
((tr[z].ch[0]==y)^(tr[y].ch[0]==x))?rotate(x):rotate(y);
rotate(x);
}
if(goal==0) rt=x;
return;
}
接下来:
insert:按二叉检索树找找到那个位置然后更新,再splay保持平衡。
delete:把x的前继splay到根,后继splay到右儿子,更新右儿子的左儿子,再splay保持平衡。
find:找到val=x的节点并把它splay到根。
rk:find(x) 然后左节点的点的个数,注意因为之前加入了 -inf 所以不用+1
kth:之前加入了-inf,所以查找的是第 x+1 个。特判一下不可能找到的情况其他往下找就是了。
前继后继:find(x) 之后,前继:在左儿子里找永远右的;后继:在右儿子里找永远左的。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10;
int inf=1e9;
int n,tot,rt;
struct node{
int val,fa,sz,cnt,ch[2];
}tr[N];
void update(int x){
tr[x].sz=tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+tr[x].cnt; return;
}
void newnode(int &nw,int fa,int x){
nw=++tot;
tr[fa].ch[x>tr[fa].val]=nw; tr[nw].fa=fa;
tr[nw].val=x; tr[nw].sz=tr[nw].cnt=1;
tr[nw].ch[0]=tr[nw].ch[1]=0;
return;
}
void rotate(int x){
int y=tr[x].fa,z=tr[y].fa,k=(tr[y].ch[1]==x);
tr[z].ch[(tr[z].ch[1]==y)]=x; tr[x].fa=z;
tr[y].ch[k]=tr[x].ch[k^1]; tr[tr[x].ch[k^1]].fa=y;
tr[x].ch[k^1]=y; tr[y].fa=x;
update(y); update(x);
return;
}
void splay(int x,int goal){
// cout<<x<<" "<<goal<<endl;
while(tr[x].fa!=goal){
int y=tr[x].fa,z=tr[y].fa;
if(z!=goal)
((tr[z].ch[0]==y)^(tr[y].ch[0]==x))?rotate(x):rotate(y);
rotate(x);
}
if(goal==0) rt=x;
return;
}
void find(int x){
int nw=rt;
while(tr[nw].ch[x>tr[nw].val]&&tr[nw].val!=x)
nw=tr[nw].ch[x>tr[nw].val];
splay(nw,0);
return;
}
int nxt(int x,int tp){
find(x);
// cout<<"*";
int nw=rt;
if(tr[nw].val<x&&!tp) return nw;
if(tr[nw].val>x&&tp) return nw;
nw=tr[nw].ch[tp];
while(tr[nw].ch[tp^1]) nw=tr[nw].ch[tp^1];
return nw;
}
void ins(int x){
int nw=rt,fa=0;
while(nw&&tr[nw].val!=x)
fa=nw,nw=tr[nw].ch[x>tr[nw].val];
if(nw) tr[nw].cnt++,tr[nw].sz++;
else newnode(nw,fa,x);
splay(nw,0);
}
void del(int x){
int las=nxt(x,0),nex=nxt(x,1);
splay(las,0); splay(nex,las);
int nw=tr[nex].ch[0];
if(tr[nw].cnt>1) tr[nw].cnt--,tr[nw].sz--,splay(nw,0);
else tr[nex].ch[0]=0;
return;
}
int rk(int x){
find(x); return tr[tr[rt].ch[0]].sz;
}
int kth(int x){
if(tr[rt].sz<x) return 0;
int nw=rt;
while(x){
int l=tr[nw].ch[0],r=tr[nw].ch[1];
if(x<=tr[l].sz){ nw=l; continue; }
x-=tr[l].sz;
if(x<=tr[nw].cnt) return tr[nw].val;
x-=tr[nw].cnt;
nw=r;
}
}
int main(){
// freopen("ex.in","r",stdin);
// freopen("ex.out","w",stdout);
scanf("%d",&n);
ins(-inf); ins(inf);
for(int i=1,opt,x;i<=n;i++){
scanf("%d%d",&opt,&x);
if(opt==1) ins(x);
else if(opt==2) del(x);
else if(opt==3) printf("%d\n",rk(x));
else if(opt==4) printf("%d\n",kth(x+1));
else if(opt==5) printf("%d\n",tr[nxt(x,0)].val);
else if(opt==6) printf("%d\n",tr[nxt(x,1)].val);
}
return 0;
}
就这?就这?就这?就这破玩意折腾了我两年??
原文:https://www.cnblogs.com/zdsrs060330/p/14502082.html