对于一个经典问题《区间k小数问题》:给定一个长度为n的整数序列{a},求某个区间从小到大排序第k小数是多少?
这类问题的特点是:这是一类静态问题:询问过程中原序列是不变的
静态问题的常见解法:
数据结构 | 时间复杂度 | 空间复杂度 |
---|---|---|
归并树 | ||
划分树 | O(NlogN) | O(NlogN) |
树套树 | O(Nlog^2N)(支持修改) | O(NlogN) |
可持续化线段树(主席树) | O(NlogN) | O(NlogN) |
这里我们主要介绍主席树:
构造:
主席树维护值域(如果数值较大需要离散化,主席树维护的值域最常见是1e5,1e6容易爆空间)
在数值上建立线段树,维护每个数值区间中一共有多少数
时/空限制:1s / 64MB
输入样例:
7 3
1 5 2 6 3 7 4
2 5 3
4 4 1
1 7 3
输出样例:
5
6
3
求k小数的思路:
思考一个简单的问题:(在一棵线段树中)整体求第k小数(无区间限制):
(二分的思想)
首先计算[1,mid]区间中有多少个数,如果cnt >= k 则第k小数在左边,我们递归到左子树,否则在右边,递归到右子树
求区间[L,R]中第k小数
(前缀和 + 二分的思想)
用主席树维护值域区间中的个数cnt
主席树有一个特点,对于每一个版本的线段树,其结构都是相同的但是某些节点的信息是不同的
第i个版本的线段树维护了前i个数的信息,所以我们可以用前缀和的思想
同时递归第R个版本以及第L - 1个版本的线段树
首先求出两个版本的线段树维护的做区间的数量cnt1_l,cnt2_l
cnt1_l - cnt2_l 就是L~R个数,在[L,mid]中出现的个数
比较cnt1_l - cnt2_l 与k的大小
cnt1_l - cnt2_l >= k同时递归左子树
cnt1_l - cnt2_l < k同时递归右子树
空间复杂度计算:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 100010;
struct Node
{
int l,r;//记录的不是左右边界而是左右儿子
int cnt;
}tr[N * 4 + N * 17];//N * 4初始线段的的大小,N * 17,总共会修改N次每次修改logN个节点
int root[N],idx;//每个版本的根节点,节点编号
int a[N];
int alls[N],cnt;
int n,m;
int find(int x)
{
return lower_bound(alls + 1,alls + 1 + n,x) - alls;
}
int build(int l,int r)
{
int p = ++ idx;//建立新节点
if(l == r) return p;
int mid = l + r >> 1;
tr[p].l = build(l,mid);
tr[p].r = build(mid + 1,r);
return p;
}
int insert(int p,int l,int r,int x)
{
int q = ++ idx;//创建新节点
tr[q] = tr[p];//在上一个版本的基础上修改
if(l == r)
{
tr[q].cnt ++;
return q;
}
int mid = l + r >> 1;
if(x <= mid) tr[q].l = insert(tr[p].l,l,mid,x);//修改左子树
else tr[q].r = insert(tr[p].r,mid + 1,r,x);
tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt;//子节点更新父节点
return q;
}
int query(int p,int q,int l,int r,int k)
{
if(l == r) return r;
int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt;//看左区间的个数
int mid = l + r >> 1;
if(k <= cnt) return query(tr[p].l,tr[q].l,l,mid,k);
else return query(tr[p].r,tr[q].r,mid + 1,r,k - cnt);//二分
}
int main()
{
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;i ++)
{
scanf("%d",&a[i]);
alls[++ cnt] = a[i];
}
sort(alls + 1,alls + 1 + n);
cnt = unique(alls + 1,alls + 1 + n) - alls - 1;
root[0] = build(1,cnt);
for(int i = 1;i <= n;i ++)
root[i] = insert(root[i - 1],1,cnt,find(a[i]));
while(m --)
{
int l,r,k;
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",alls[query(root[l - 1],root[r],1,cnt,k)]);//同时跳两个版本二分
}
return 0;
}
总结:
原文:https://www.cnblogs.com/jzdx/p/14652134.html