线段树和树状数组:单点修改、区间和查询

  |  

摘要: 线段树和树状数组,单点修改与区间查询,原理与实现

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文通以 307. 区域和检索 - 数组可修改 作为模板题来看一下线段树和树状数组的原理与实现。

本文的树状数组和线段树是支持单点修改和区间查询这一套需求的。如果需要区间修改,参考文章 线段树:区间修改、区间查询,区间最值问题

模板题:307. 区域和检索 - 数组可修改

给你一个数组 nums ,请你完成两类查询。

  1. 其中一类查询要求 更新 数组 nums 下标对应的值
  2. 另一类查询要求返回数组 nums 中索引 left 和索引 right 之间( 包含 )的nums元素的 和 ,其中 left <= right

实现 NumArray 类:

  • NumArray(int[] nums) 用整数数组 nums 初始化对象
  • void update(int index, int val) 将 nums[index] 的值 更新 为 val

提示:

1
2
3
4
5
6
1 <= nums.length <= 3e4
-100 <= nums[i] <= 100
0 <= index < nums.length
-100 <= val <= 100
0 <= left <= right < nums.length
调用 update 和 sumRange 方法次数不大于 3e4

  • int sumRange(int left, int right) 返回数组 nums 中索引 left 和索引 right 之间( 包含 )的nums元素的 和 (即,nums[left] + nums[left + 1], …, nums[right])

示例 1:
输入:
[“NumArray”, “sumRange”, “update”, “sumRange”]
[[[1, 3, 5]], [0, 2], [1, 2], [0, 2]]
输出:
[null, 9, null, 8]
解释:
NumArray numArray = new NumArray([1, 3, 5]);
numArray.sumRange(0, 2); // 返回 1 + 3 + 5 = 9
numArray.update(1, 2); // nums = [1,2,5]
numArray.sumRange(0, 2); // 返回 1 + 2 + 5 = 8


算法1 : 线段树

基于链表

线段树是一个查询和修改复杂度都为$\log n$的数据结构。主要用于数组的单点修改、单点查询、区间修改,区间查询。

本题需要支持的操作是单点修改和区间查询。

  • 建树 build(int start, int end, vector<int> vals);
  • 单点更新 update(int index, int val);
  • 区间查询 query(int start, int end);

线段树是一种平衡二叉树(树高为$\log N$),的每个节点代表一个区间,保存区间的两个端点和区间的和这三个值。

线段树节点的定义:

1
2
3
4
5
6
7
8
9
10
11
struct STNode {
int start;
int end;
int sum;
STNode *left;
STNode *right;
STNode(){}
STNode(int start, int end, int sum, STNode *left=nullptr, STNode *right=nullptr)
:start(start),end(end),sum(sum),left(left),right(right){}
~STNode(){}
};

start 和 end 是区间的端点,sum 是区间的和(也可以是最大,最小,异或,模质数下乘法)。

叶子节点表示的就是数组的元素本身,sum 是数组元素值,start = end 为元素在数组中的下标。

一个非叶子节点表示区间 [left, right],它的两个子节点表示区间 [left, mid], [mid+1, right], mid = (left + right) / 2,且节点的 sum 值为左右两个子节点的 sum 值的和。

根节点表示整个数组的范围。

例如数组 nums = [2, 1, 5, 3, 4],建的线段树如下图所示:根节点的区间就是整个区间 [0-4], 值是区间所有元素的和 15,各个子节点表示各自的子区间以及子区间上的和,直到叶子节点表示数组元素本身。

<1> 由给定数组建树

与一般的平衡树差不多,递归建树,不同的是叶子节点的判断是根据 start == end。两个子节点建好后,再建当前节点,当前节点的 sum 值需要用到子节点的 sum 值。

1
2
3
4
5
6
7
8
9
STNode* _build(int start, int end, const vector<int>& vals)
{
if(start == end)
return new STNode(start, end, vals[start]);
int mid = start + (end - start) / 2;
STNode *left = _build(start, mid, vals);
STNode *right = _build(mid + 1, end, vals);
return new STNode(start, end, left -> sum + right -> sum, left, right);
}

<2> 单点更新

通过 index 与 mid = (left + right) / 2 的关系找到 index 在哪个子树:然后沿着子树链条往下走直到叶子,更新叶子的 sum 值。回溯之后更新当前节点的 sum 值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void _point_update(STNode *root, int index, int val)
{
if(root -> start == root -> end && root -> end == index)
{
root -> sum = val;
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(index <= mid) // 更新位置在左子树
_point_update(root -> left, index, val);
else
_point_update(root -> right, index, val);
root -> sum = root -> left -> sum + root -> right -> sum;
}

<3> 区间查询

1
2
3
4
5
6
7
8
9
10
11
12
int _range_query(STNode *root, int i, int j)
{
if(root -> start == i && root -> end == j)
return root -> sum;
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid) // 查询的区间在左子树
return _range_query(root -> left, i, j);
else if(i > mid) // 查询的区间在右子树
return _range_query(root -> right, i, j);
else
return _range_query(root -> left, i, mid) + _range_query(root -> right, mid + 1, j);
}

代码(c++,模板)

单点更新,区间查询的模板写法,没有增删。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
struct STNode {
int start;
int end;
int sum;
STNode *left;
STNode *right;
STNode(){}
STNode(int start, int end, int sum)
:start(start),end(end),sum(sum),left(nullptr),right(nullptr){}
STNode(int start, int end, int sum, STNode *left, STNode *right)
:start(start),end(end),sum(sum),left(left),right(right){}
};

class SegmentTree {
public:
SegmentTree(){
root = nullptr;
}

~SegmentTree()
{
if(root)
{
delete_sub_tree(root);
}
}

void build(int start, int end, const vector<int>& vals)
{
if(end >= start)
root = _build(start, end, vals);
}

void point_update(int index, int val)
{
_point_update(root, index, val);
}

int range_query(int i, int j)
{
return _range_query(root, i, j);
}

private:
STNode *root;

void delete_sub_tree(const STNode* node)
{
if(node)
{
if(node -> left)
delete_sub_tree(node -> left);
if(node -> right)
delete_sub_tree(node -> right);
}
delete node;
node = nullptr;
}

STNode* _build(int start, int end, const vector<int>& vals)
{
if(start == end)
return new STNode(start, end, vals[start]);
int mid = start + (end - start) / 2;
STNode *left = _build(start, mid, vals);
STNode *right = _build(mid + 1, end, vals);
return new STNode(start, end, left -> sum + right -> sum, left, right);
}

void _point_update(STNode *root, int index, int val)
{
if(root -> start == root -> end && root -> end == index)
{
root -> sum = val;
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(index <= mid) // 更新位置在左子树
_point_update(root -> left, index, val);
else
_point_update(root -> right, index, val);
root -> sum = root -> left -> sum + root -> right -> sum;
}

int _range_query(STNode *root, int i, int j)
{
if(root -> start == i && root -> end == j)
return root -> sum;
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid) // 查询的区间在左子树
return _range_query(root -> left, i, j);
else if(i > mid) // 查询的区间在右子树
return _range_query(root -> right, i, j);
else
return _range_query(root -> left, i, mid) + _range_query(root -> right, mid + 1, j);
}
};

class NumArray {
public:
NumArray(vector<int>& nums) {
int n = nums.size();
int start = 0, end = n - 1;
sttree = SegmentTree();
sttree.build(start, end, nums);
}

void update(int i, int val) {
sttree.point_update(i, val);
}

int sumRange(int i, int j) {
return sttree.range_query(i, j);
}

private:
SegmentTree sttree;
};

基于数组

数据数组 nums 的范围是 [0..n],则线段树的根节点表示的区间总是 [0, n - 1]

因为线段树是几乎被填满的二叉树,因此数组的实现更省空间。线段树的数据存在数组 st_vec 中。st_vec[i] 表示节点 i 的值,这里是和,也可以是最值等。

节点 i 的两个子节点分别为左子节点 i*2 右子节点 i*2+1。线段树数据从下标 1 开始,1 为根节点。

以 nums 建树时,st_vec 开 4*n 的空间是够的。

建树,查询,更新的接口如下:

1
2
3
void _build(int node, int nodeLeft, int nodeRight, const vector<int>& arr)
int _range_query(int node, int nodeLeft, int nodeRight, int start, int end)
void _point_update(int node, int nodeLeft, int nodeRight, int idx, int val)

数组形式的线段树,节点 node 表示的区间 [nodeLeft, nodeRight] 是在递归中作为参数带着的。而链式写法中,这个信息是保存在节点上的,递归中不带该参数。

代码(c++, 模板)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
class SeqSegmentTree
{
public:
SeqSegmentTree()
{
st_vec = vector<int>();
n = -1;
}

void build(const vector<int>& arr)
{
n = arr.size();
int start = 0, end = n - 1;
if(start <= end)
{
st_vec.resize(n * 4);
_build(1, 0, n - 1, arr);
}
}

int range_query(int start, int end)
{
// 这里 [0, n - 1] 跟着参数走
// 链式写法中保存在节点
return _range_query(1, 0, n - 1, start, end);
}

void point_update(int idx, int val)
{
_point_update(1, 0, n - 1, idx, val);
}

private:
vector<int> st_vec;
int n; // 保存根节点的范围 [0, n - 1]

void _point_update(int node, int nodeLeft, int nodeRight, int idx, int val)
{
// [nodeLeft, nodeRight] 与 idx 的关系
if(nodeLeft == nodeRight && nodeLeft == idx)
{
st_vec[node] = val;
return;
}

int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2, right_son = node * 2 + 1;
if(idx <= nodeMid)
_point_update(left_son, nodeLeft, nodeMid, idx, val);
else
_point_update(right_son, nodeMid + 1, nodeRight, idx, val);
st_vec[node] = st_vec[left_son] + st_vec[right_son];
}

void _build(int node, int nodeLeft, int nodeRight, const vector<int>& arr)
{
// node 表示 [nodeLeft, nodeRight]
if(nodeLeft == nodeRight)
{
st_vec[node] = arr[nodeLeft];
return;
}

int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2, right_son = node * 2 + 1;
_build(left_son, nodeLeft, nodeMid, arr);
_build(right_son, nodeMid + 1, nodeRight, arr);
st_vec[node] = st_vec[left_son] + st_vec[right_son];
}

int _range_query(int node, int nodeLeft, int nodeRight, int start, int end)
{
// [nodeLeft, nodeRight] 与 [start, end] 的关系
// 得出交集,返回和
if(nodeLeft == start && nodeRight == end)
return st_vec[node];

int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2, right_son = node * 2 + 1;

if(end <= nodeMid)
return _range_query(left_son, nodeLeft, nodeMid, start, end);
else if(start > nodeMid)
return _range_query(right_son, nodeMid + 1, nodeRight, start, end);
else // [nodeLeft, start, nodeMid, end, nodeRight]
return _range_query(left_son, nodeLeft, nodeMid, start, nodeMid)
+ _range_query(right_son, nodeMid + 1, nodeRight, nodeMid + 1, end);
}
};

class NumArray {
public:
NumArray(vector<int>& nums) {
seqsttree = SeqSegmentTree();
seqsttree.build(nums);
}

void update(int i, int val) {
seqsttree.point_update(i, val);
}

int sumRange(int i, int j) {
return seqsttree.range_query(i, j);
}

private:
SeqSegmentTree seqsttree;
};

算法2: 树状数组

树状数组只能维护前缀信息(本题是前缀和),区间和可以用两个前缀和的差得到,所以区间和的问题用树状数组可做。

例如一个长度为 8 的数组 a ,用树状数组 b 表示的前缀和结构如下图:

树状数组中的值与原数组值的关系如下:

b[1] = a[1]
b[2] = a[1] + a[2]
b[3] = a[3]
b[4] = a[1] + a[2] + a[3] + a[4]
b[5] = a[5]
b[6] = a[5] + a[6]
b[7] = a[7]
b[8] = a[1] + a[2] + a[3] + a[4] + a[5] + a[6] + a[7] + a[8]

b[index] 更新(增加delta)后,会有比 index 大的某些位置也会受影响,例如更新 b[5] 之后,会依次影响到 b[6]b[8],参考图里画的树。下一个受影响的的位置是 index 加上它最低位的 1 表示的数。最低位的1表示的数,实现方法是 x & (-x)

1
2
3
4
int _lowbit(int x)
{
return x & (-x);
}

<1> 单点更新

index 上的值增加 delta 后,依次把后续受影响的位置也加上 delta。

1
2
3
4
5
6
7
8
9
void update(int index, int delta)
{
int n = sums.size();
while(index < n)
{
sums[index] += delta;
index += _lowbit(index);
}
}

<2> 单点查询

查询前缀和。例如:

b[4]=a[1]+a[2]+a[3]+a[4];
b[5]=a[5];
可以推出: query(5) = b[4]+b[5];
序号写为二进制: query(101) = b[(100)] + b[(101)];
第一次 101,减去最低位的 1 就是 100;

也就是单点更新的逆操作:

1
2
3
4
5
6
7
8
9
10
int query(int i)
{
int sum = 0;
while(i > 0)
{
sum += sums[i];
i -= _lowbit(i);
}
return sum;
}

区间查询可以通过两次单点查询得到:

1
2
// 求 [i..j] 的区间和
query(j + 1) - query(i);

代码(c++,模板)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class BIT {
public:
BIT():sums(1, 0){}
BIT(int n):sums(n + 1, 0){}

void update(int index, int delta)
{
int n = sums.size();
while(index < n)
{
sums[index] += delta;
index += _lowbit(index);
}
}

int query(int i)
{
int sum = 0;
while(i > 0)
{
sum += sums[i];
i -= _lowbit(i);
}
return sum;
}

private:
vector<int> sums;

int _lowbit(int x)
{
return x & (-x);
}
};

class NumArray_2 {
public:
// vec初始化完后 nums.size() 就变为0了
NumArray_2(vector<int> nums):vec(move(nums)),bit(vec.size()) {
int n = vec.size();
for(int i = 0; i < n; ++i)
bit.update(i + 1, vec[i]);
}

void update(int i, int val) {
bit.update(i + 1, val - vec[i]);
vec[i] = val;
}

int sumRange(int i, int j) {
return bit.query(j + 1) - bit.query(i);
}

private:
vector<int> vec;
BIT bit;
};

线段树和树状数组的区别和联系

具体区别和联系如下:

  1. 时间复杂度相同, 但是树状数组的常数优于线段树。

  2. 树状数组的作用被线段树完全涵盖, 凡是可以使用树状数组解决的问题, 使用线段树一定可以解决, 但是线段树能够解决的问题树状数组未必能够解决。

  3. 树状数组的代码量比线段树小很多。

  4. 线段树和树状数组都可以推广到高维。


Share