线段树(Segment Tree)也叫区间树,其本质上是一种二分搜索树,不同点在于线段树中每个节点不再是存放单纯的元素,而是存放了一个可以表示区间的值,通常是该区间合并后的值。并且每个区间会被平均分为2个子区间,作为它的左右子节点。比如说根节点存放了区间 [1,10]
,那么就会被分为区间 [1,5]
作为左子节点,区间 [6,10]
作为右子节点。
例如,我们可以将这样一个数组所表示的区间构造成线段树:
并且指定区间合并规则为区间内的元素求和,那么构造出来的线段树表示如下:
关于线段树的一个经典问题就是:区间染色。假设有一面墙,长度为 n,每次选择一段儿墙进行染色。在 m 次操作后,我们可以在 [i, j]
区间内看见多少中颜色?
对于这个问题,我们可以使用一个数组来实现:
对于染色操作(更新区间)我们可以遍历数组找到目标区间进行染色,时间复杂度是 $O(n)$。对于查询操作(查询区间)也是遍历数组即可,同样时间复杂度为 $O(n)$。显然用线性结构来解决这类问题的时间复杂度要更高一些,此时线段树就派上用场了,因为树形结构的时间复杂度通常在 $O(logn)$。
除此之外,线段树的另一个经典问题就是:区间查询。查询一个区间 [i, j]
的最大值和最小值,或者区间数字之和。例如,在实际业务中很常见的基于区间的统计查询:2017年注册用户中消费最高的用户?消费最少的用户?学习时间最长的用户?某个太空区间中天体总量?
对于静态区间数据(区间内的数据不会发生变化)来说,是比较好解决的,但以上所提到的问题都是动态的区间数据(区间内的数据在不断的变化),此时线段树就是一个比较好的选择。
通过以上的介绍,我们能总结出线段树的两个核心操作:
[i, j]
的最大值、最小值,或者区间数字之和线段树虽然不像堆那样是一棵完全二叉树,但线段树由于其特性满足平衡二叉树(左右子树高度相差不超过1),所以依然可以使用数组进行表示。我们可以将其看做是一颗满二叉树,空节点就当做叶子节点即可。如下示例:
既然可以用数组来表示一棵线段树,那么如果区间有 n 个元素,此时应该创建多大容量的数组来构建一颗线段树呢?对于这个问题,我们先来看如何求一棵满二叉树的节点:假设这棵树有 h 层,那么这棵树就一共有 $2^h-1$ 个节点(大约是 $2^h$)。对于最后一层($h - 1$ 层)来说,就有 $2^{(h-1)}$ 个节点。因此,最后一层的节点数大致等于前面所有层节点之和。
了解了如何求满二叉树的节点数量后,回到之前的问题,如果区间有 n 个元素,此时应该开多大空间的数组?我们可以分成两种情况:
通常来说,我们的线段树不考虑添加元素,即区间固定(区间内的数据可以是不固定的),那么使用 $4n$ 的静态空间即可。这也是普遍构造线段树时,使用的一个通用值。除非对内存有严格要求,否则一般开辟 $4n$ 的数组空间即可。而且对于内存有要求的情况下,一般也不会采用数组来表示,此时链式结会是更优的选择。
接下来,我们就实现一下线段树的基础结构代码:
package tree;
/**
* 线段树 - 基于数组的表示实现
*
* @author 01
* @date 2021-01-27
**/
public class SegmentTree<E> {
/**
* 保存原始数组,即需要被构造成线段树的区间
*/
private E[] data;
/**
* 线段树的数组表示
*/
private E[] tree;
public SegmentTree(E[] arr) {
this.data = (E[]) new Object[arr.length];
System.arraycopy(arr, 0, this.data, 0, arr.length);
// 开辟 4n 的数组空间用于构造线段树
this.tree = (E[]) new Object[4 * arr.length];
}
public int getSize() {
return data.length;
}
public E get(int index) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal");
}
return data[index];
}
/**
* 返回完全二叉树的数组表示中,一个索引所表示的元素的左子节点的索引
*/
private int leftChild(int index) {
return 2 * index + 1;
}
/**
* 返回完全二叉树的数组表示中,一个索引所表示的元素的右子节点的索引
*/
private int rightChild(int index) {
return 2 * index + 2;
}
}
在本小节中,我们来根据之前实现的基础代码,完成创建线段树逻辑的编写。需要说明一下的是,在本例中,线段树每个节点所存储的元素是区间合并后的值。具体的实现代码如下:
/**
* 用户自定义的区间合并逻辑
*/
private final Merger<E> merger;
public SegmentTree(E[] arr, Merger<E> merger) {
this.merger = merger;
this.data = (E[]) new Object[arr.length];
System.arraycopy(arr, 0, this.data, 0, arr.length);
// 开辟 4n 的数组空间用于构建线段树
this.tree = (E[]) new Object[4 * arr.length];
// 构建线段树,传入根节点索引,以及区间的左右端点
buildSegmentTree(0, 0, data.length - 1);
}
/**
* 在treeIndex的位置创建表示区间[left...right]的线段树
*/
private void buildSegmentTree(int treeIndex, int left, int right) {
// 区间中只有一个元素,代表递归到底了
if (left == right) {
tree[treeIndex] = data[left];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 计算中间点,需要避免整型溢出
int mid = left + (right - left) / 2;
// 构建左子树
buildSegmentTree(leftTreeIndex, left, mid);
// 构建右子树
buildSegmentTree(rightTreeIndex, mid + 1, right);
// 对于两个区间的合并规则是与业务相关的,所以要调用用户自定义的逻辑来完成
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
/**
* 遍历打印树中节点中值信息。
*
* @return String
*/
@Override
public String toString() {
StringBuilder res = new StringBuilder();
res.append(‘[‘);
for (int i = 0; i < tree.length; i++) {
if (tree[i] != null) {
res.append(tree[i]);
} else {
res.append("null");
}
if (i != tree.length - 1) {
res.append(", ");
}
}
res.append(‘]‘);
return res.toString();
}
用户传入的 Merger
是一个接口,其定义如下:
package tree;
/**
* 合并器接口
*
* @author 01
* @date 2021-01-27
**/
public interface Merger<E> {
/**
* 用户自定义的区间合并逻辑
*
* @param a 区间a
* @param b 区间b
* @return 合并后的结果
*/
E merge(E a, E b);
}
最后,我们来编写一个简单的测试用例进行一下测试:
package tree;
/**
* 测试SegmentTree
*
* @author 01
*/
public class SegmentTreeTests {
public static void main(String[] args) {
Integer[] nums = {-2, 0, 3, -5, 2, -1};
SegmentTree<Integer> segTree = new SegmentTree<>(
nums, Integer::sum // 对两个区间中的值进行求和
);
System.out.println(segTree);
}
}
输出结果如下:
[-3, 1, -4, -2, 3, -3, -1, -2, 0, null, null, -5, 2, null, null, null, null, null, null, null, null, null, null, null]
-3
,因为对整个数组的求和结果就是 -3
。左子节点为 1
,因为 -2 + 0 + 3 = 1
。右子节点为 -4
,同理,因为 -5 + 2 + -1 = -4
,其余以此类推。结果符合预期,证明我们实现的线段树没有问题。例如,我们要对如下这棵线段树查询 [2, 5]
这个区间:
由于我们之前传入的 Merger
实现的是求和逻辑,那么这相当于查询2 ~ 5区间所有元素的和。从根节点开始往下,我们知道分割位置,左节点查询 [2, 3]
,右节点查询 [4, 5]
,找到两个节点之后合并就可以了。
具体的实现代码如下:
/**
* 查询区间[queryLeft, queryRight]的值,如[2, 5]
*/
public E query(int queryLeft, int queryRight) {
if (queryLeft < 0 || queryLeft >= data.length ||
queryRight < 0 || queryRight >= data.length ||
queryLeft > queryRight) {
throw new IllegalArgumentException("Index is illegal");
}
return query(0, 0,
data.length - 1, queryLeft, queryRight);
}
/**
* 在以treeIndex为根的线段树中[left...right]的范围里,搜索区间[queryLeft...queryRight]的值
*/
private E query(int treeIndex, int left, int right,
int queryLeft, int queryRight) {
// 找到了目标区间
if (left == queryLeft && right == queryRight) {
return tree[treeIndex];
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 计算中间点,需要避免整型溢出
int mid = left + (right - left) / 2;
if (queryLeft >= mid + 1) {
// 目标区间不在左子树中,查找右子树
return query(rightTreeIndex, mid + 1, right, queryLeft, queryRight);
} else if (queryRight <= mid) {
// 目标区间不在右子树中,查找左子树
return query(leftTreeIndex, left, mid, queryLeft, queryRight);
}
// 目标区间一部分在右子树中,一部分在左子树中,则两个子树都需要找
E leftResult = query(leftTreeIndex, left, mid, queryLeft, mid);
E rightResult = query(rightTreeIndex, mid + 1, right, mid + 1, queryRight);
// 找到目标区间的值,将其合并后返回
return merger.merge(leftResult, rightResult);
}
进行一个简单的测试:
public static void main(String[] args) {
Integer[] nums = {-2, 0, 3, -5, 2, -1};
SegmentTree<Integer> segTree = new SegmentTree<>(
nums, Integer::sum // 对两个区间中的值进行求和
);
System.out.println(segTree.query(0,2));
System.out.println(segTree.query(2,5));
System.out.println(segTree.query(0,5));
}
输出结果如下:
1
-1
-3
我们使用线段树来解决区间相关的问题,主要是针对区间内的数据是动态变化的情况,如果是静态区间一般不需要用到线段树。所以在本小节,我们就来实现线段树中的更新操作。
实际上线段树中的更新操作,本质上是在二分查找。因为根据线段树的特性,待更新的目标节点肯定是一个叶子节点,我们只需要找到这个叶子节点并进行更新即可。我们查找待更新节点的依据是数组的索引,而数组的索引是从 0 ~ n 有序的,所以在一个有序的区间中查找某个特定的值,妥妥的就是二分查找了。
知道了我们在更新线段树中某个节点时,要找的这个待更新节点是一个叶子节点,并且找到这个叶子节点的过程本质上是一个二分查找,那么这个思路就很清晰了。
首先,将找到叶子节点的条件作为递归的退出条件。然后计算中间点,并将线段树数组划分为 [left...mid]
和 [mid+1...right]
两个区间。接着判断要找的数组索引落在哪个区间,就继续往哪个区间递归查找。最后,将区间的值进行合并。如此一来,就完成了目标节点的更新操作。
具体的实现代码如下:
/**
* 将index位置的值,更新为e
*/
public void set(int index, E e) {
if (index < 0 || index >= data.length) {
throw new IllegalArgumentException("Index is illegal");
}
data[index] = e;
set(0, 0, data.length - 1, index, e);
}
/**
* 在以treeIndex为根的线段树中更新index的值为e
*/
private void set(int treeIndex, int left, int right, int index, E e) {
// 找到了叶子节点
if (left == right) {
// 进行更新
tree[treeIndex] = e;
return;
}
int mid = left + (right - left) / 2;
// 将线段树数组划分为[left...mid]和[mid+1...right]两个区间
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (index >= mid + 1) {
// index在右子树
set(rightTreeIndex, mid + 1, right, index, e);
} else {
// index在左子树
set(leftTreeIndex, left, mid, index, e);
}
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
在本文的最后,我们来使用自己实现的线段树解决一个Leetcode上的307号问题:
该问题的主要需求是更新数组下标对应的值,以及查询数组中某个区间内的元素总和。像这种对区间内数据有更新需求的,会使得区间内数据动态变化的,就很适合使用线段树来解决。具体的实现代码如下:
package tree.solution;
import tree.SegmentTree;
/**
* Leetcode 307. Range Sum Query - Mutable
* https://leetcode.com/problems/range-sum-query-mutable/description/
*/
class NumArray {
private SegmentTree<Integer> segTree;
public NumArray(int[] nums) {
if (nums.length != 0) {
Integer[] data = new Integer[nums.length];
for (int i = 0; i < nums.length; i++) {
data[i] = nums[i];
}
segTree = new SegmentTree<>(data, Integer::sum);
}
}
public void update(int i, int val) {
if (segTree == null) {
throw new IllegalArgumentException("Error");
}
segTree.set(i, val);
}
public int sumRange(int i, int j) {
if (segTree == null) {
throw new IllegalArgumentException("Error");
}
return segTree.query(i, j);
}
}
原文:https://blog.51cto.com/zero01/2608654