struct Splay {
static const int MAXN = 2e5 + 10;
int top, root, ch[MAXN][2], pa[MAXN];
int val[MAXN], cnt[MAXN], siz[MAXN];
void _PushUp(int x) {
siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}
bool _Get(int x) {
return x == ch[pa[x]][1];
}
void _Clear(int x) {
ch[x][0] = 0, ch[x][1] = 0, pa[x] = 0;
val[x] = 0, siz[x] = 0, cnt[x] = 0;
}
void _Rotate(int x) {
int y = pa[x], z = pa[y], d = _Get(x);
ch[y][d] = ch[x][d ^ 1], pa[ch[x][d ^ 1]] = y;
ch[x][d ^ 1] = y, pa[y] = x, pa[x] = z;
if(z)
ch[z][y == ch[z][1]] = x;
_PushUp(x), _PushUp(y);
}
void _Splay(int x) {
for(int p = pa[x]; p = pa[x]; _Rotate(x)) {
if(pa[p])
_Rotate((_Get(x) == _Get(p)) ? p : x);
}
root = x;
}
void Init() {
top = 0, root = 0;
}
void Insert(int v, int c = 1) {
if(!root) {
root = ++top;
val[top] = v, cnt[top] = c;
_PushUp(top);
return;
}
int cur = root, p = 0;
while(true) {
if(val[cur] == v) {
cnt[cur] += c;
_PushUp(cur), _PushUp(p), _Splay(cur);
break;
}
p = cur, cur = ch[cur][val[cur] < v];
if(!cur) {
++top;
val[top] = v, cnt[top] = c;
pa[top] = p, ch[p][val[p] < v] = top;
_PushUp(top), _PushUp(p), _Splay(top);
break;
}
}
}
/* "Rank of value v" means the first node with value >= v */
/* Get the rank of value v */
int GetRank(int v) {
int cur = root, res = 1;
while(cur) {
if(val[cur] > v)
cur = ch[cur][0];
else if(val[cur] == v) {
res += siz[ch[cur][0]];
_Splay(cur);
break;
} else {
res += siz[ch[cur][0]] + cnt[cur];
cur = ch[cur][1];
}
}
return res;
}
/* Get the value with rank r */
int GetValue(int r) {
int cur = root, res = INF;
while(cur) {
if(siz[ch[cur][0]] >= r)
cur = ch[cur][0];
else if(siz[ch[cur][0]] + cnt[cur] >= r) {
res = val[cur];
_Splay(cur);
break;
} else {
r -= siz[ch[cur][0]] + cnt[cur];
cur = ch[cur][1];
}
}
return res;
}
int _Prev() {
int cur = ch[root][0];
while(ch[cur][1])
cur = ch[cur][1];
_Splay(cur);
return cur;
}
int GetPrev(int v) {
Insert(v);
int res = val[_Prev()];
Remove(v);
return res;
}
int _Next() {
int cur = ch[root][1];
while(ch[cur][0])
cur = ch[cur][0];
_Splay(cur);
return cur;
}
int GetNext(int v) {
Insert(v);
int res = val[_Next()];
Remove(v);
return res;
}
void Remove(int v, int c = 1) {
GetRank(v);
if(val[root] != v)
return;
if(cnt[root] > c) {
cnt[root] -= c;
_PushUp(root);
return;
}
if(ch[root][0] && ch[root][1]) {
int cur = root, x = _Prev();
pa[ch[cur][1]] = x, ch[x][1] = ch[cur][1];
_Clear(cur), _PushUp(root);
return;
} else {
int cur = root;
root = ch[root][0] + ch[root][1], pa[root] = 0;
_Clear(cur), _PushUp(root);
return;
}
}
} splay;
原文:https://www.cnblogs.com/purinliang/p/14321732.html