#include<cstdio>
#include<algorithm>
#include<iostream>
using namespace std;
int inf=0x7fffffff;
struct node
{
int num;
int val;
int size;
node* ch[2];
node (int v) :val(v)
{
size=1;
num=1;
ch[0]=ch[1]=NULL;
}
void sum()
{
size=num;
if(ch[0]!=NULL)
size+=ch[0]->size;
if(ch[1]!=NULL)
size+=ch[1]->size;
return ;
}
int cmp(int v)
{
if(val==v)
return -1;
return (val>v ? 0 : 1);
}
int cmpkth(int k)
{
int s=( ch[0]==NULL ? 0 : ch[0]->size );
if(k>s&&k<=s+num)
return 1;
if(k<=s)
return 0;
else
return 1;
}
};
node* root;
void rotato(node* &x,int base)
{
node* k=x->ch[base^1];
x->ch[base^1]=k->ch[base];
k->ch[base]=x;
x->sum();
k->sum();
x=k;
}
void splay(node* &x,int v)
{
int d=x->cmp(v);
if(d!=-1&&x->ch[d]!=NULL)
{
int d2=x->ch[d]->cmp(v);
if(d2!=-1&&x->ch[d]->ch[d2]!=NULL)
{
splay(x->ch[d]->ch[d2],v);
if(d==d2)
rotato(x,d2^1),rotato(x,d^1);
else
rotato(x->ch[d],d2^1),rotato(x,d^1);
}
else
rotato(x,d^1);
}
}
void splaykth(node *x,int k)
{
int d=x->cmpkth(k);
if(d!=-1)
{
if(d==1)
k-=x->ch[0]->size-x->num;
int d2=x->ch[d]->cmpkth(k);
if(d2!=-1)
{
int k2=(d2==1 ? k-x->ch[d]->ch[0]->size-x->ch[d]->num : k);
splaykth(x->ch[d]->ch[d2],k2);
if(d==d2)
rotato(x,d2^1),rotato(x,d^1);
else
rotato(x->ch[d],d2^1),rotato(x,d^1);
}
else
rotato(x,d^1);
}
return ;
}
void pre(node* x,int val,int &ans)
{
if(x==NULL)
return ;
if(x->val<val)
{
if(x->val>ans)
ans=x->val;
if(x->ch[1]!=NULL)
pre(x->ch[1],val,ans);
}
else
if(x->val>val&&x->ch[0]!=NULL)
pre(x->ch[0],val,ans);
}
void nxt(node* x,int val,int ans)
{
if(x==NULL)
return ;
if(x->val>val)
{
if(x->val<ans)
ans=x->val;
if(x->ch[0]!=NULL)
pre(x->ch[0],val,ans);
}
else
if(x->val<val&&x->ch[1]!=NULL)
pre(x->ch[1],val,ans);
}
int find(node* x,int val)
{
splay(x,val);
return x->ch[0]->size;
}
int kth(node* x,int k)
{
splaykth(x,k);
return x->val;
}
node *spilt(node* &x,int val)
{
if(x==NULL)
return NULL;
splay(x,val);
node* t1;
node* t2;
if(x->val<=val)
t1=x,t2=x->ch[1],t1->ch[1]=NULL;
else
t2=x,t1=x->ch[0],t2->ch[0]=NULL;
x->sum();
x=t1;
return t2;
}
void merge(node* &t1,node* &t2)
{
if(t1==NULL)
swap(t1,t2);
splay(t1,inf);
t1->ch[1]=t2;
t2=NULL;
t1->sum();
}
void insert(node* &x,int val)
{
node* t2=spilt(x,val);
if(x->val==val)
{
x->num+=1;
x->sum();
}
else
{
node* nw=new node(val);
merge(x,nw);
}
merge(x,t2);
}
void erase(node* &x,int val)
{
node* t2=spilt(x,val);
t2->num-=1;
if(t2->num==0)
x=x->ch[0];
merge(x,t2);
}
int main()
{
return 0;
}
原文:https://www.cnblogs.com/Lance1ot/p/8947372.html