线段树:区间修改(更改为定值)、区间最值查询

  |  

摘要: 线段树,区间修改与区间查询,原理与实现,RMQ 问题

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


在文章 线段树和树状数组:单点修改、区间查询,区间求和问题 中,我们学习了带单点修改、区间查询这一套需求的线段树,解决了区间求和问题。

本文我们看一个类似的问题:区间最值问题。在增删改查方面的一套需求是区间修改、区间查询。我们首先阐述一下用线段树实现这套需求的算法原理和代码模板,然后解决 699. 掉落的方块

区间最值问题

RMQ 问题:给定一个数组,其中有 $N$ 个数字,对于一次查询,给定区间 [l ,r],返回在这个区间内的最值为多少,一共有 $Q$ 次查询。

算法:线段树

在文章 稀疏表(倍增优化DP):区间最值问题 中,我们用倍增优化 DP 解决了元素不更改的区间最值问题。

当数组中的值会有修改,需要边修改,变查询。此时稀疏表的效率就不高了,这与前缀和在求可变数组的区间和时效率不高的情况类似,解决方案也类似,用线段树。

对于线段树,我们之前处理过单点修改,区间查询的一套需求,线段树和树状数组:单点修改、区间查询,区间求和问题。本文我们处理区间修改、区间查询的一套需求。

线段树区间修改

有一个长为 n 的数据数组 nums,有一个表示 [0,n-1] 内各个区间的线段树,线段树的区间更新是这个意思:对于内部的某个区间 $[i, j], 0 \leq i, j \leq n-1$,将区间内的值改为 $v$。

这一类更新,最直观的是对区间内每个点做单点更新:首先更新点 $i$,point_update(i, v),然后更新 $i+1$, point_update(i + 1,v),直到更新 $j$, point_update(j, v)。这样的话,操作次数就爆了。

解法是对每个区间引入一个标记 lazy,称为懒标记。

懒标记

当需要把区间 [i, j] 内的值改为 v 时,没有直接进入线段树的对应区间的子区间取修改,而是给该区间做标记 v,当查询时若需要读取该区间数据,再利用标记算出应当返回的结果。

如果是基于链表实现,lazy 直接作为节点的一个单独子段,表示对区间 [nodeLeft, nodeRight] 标记 v。如果是基于数组实现,则用一个与线段树数组 st_vec 同长度的数组维护,即 lazy 数组。

例如有线段树,根表示 [0, 15] ,对 [5, 13] 上的所有值更新为 v,则仅仅给区间 [5, 13] 做标记 v,即在 [5], [6, 7], [8, 11], [12, 13] 这几个区间上进行标记。其中 [6, 7], [8, 11], [12, 13],这几个区间还有子区间,在做标记的时候不进入子区间。

1
2
3
4
5
[0                                                 15]
[0 7][8 15]
[0 3][4 7][8 11][12 15]
[0 1][2 3][4 5][6 7][8 9][10 11][12 13][14 15]
[0][1][2][3][4][5][6][7][8][9][10][11][12][13][14][15]

这种维护 lazy 标记的核心思想是把向下的修改先存起来,对于每个查询,在向上传递答案时再利用 lazy 标记修正传递的答案。

例如题目 1476. 子矩形查询,用到了这种保存修改,在查询时再计算答案的思想。

以上思路涉及到两个操作 push_uppush_down

push_up 表示向上的更新,push_down 维护的是 lazy 标记的值。

以下是的写法节点值表示区间和的两个操作的写法

1
2
3
4
5
6
void push_up(int node)
{
int left_son = node * 2;
int right_son = node * 2 + 1;
st_vec[node] = st_vec[left_son] + st_vec[right_son];
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void push_down(int node, int m) // m 为 node 表示区间的长度
{
if(lazy[node])
{
// 节点 node 有 lazy 标记
int left_son = node * 2;
int right_son = node * 2 + 1;

// 如果 lazy = v 的含义是区间内的值都加 v,则是如下写法
lazy[left_son] += lazy[node]; // 向左子节点传递
lazy[right_son] += lazy[node]; // 向右子节点传递
st_vec[left_son] += lazy[node] * (m - m / 2);
st_vec[right_son] += lazy[node] * (m / 2);
// 如果 lazy = v 的含义是区间内的值都改为 v,则将以上的 += 改为 =

lazy[node] = 0;
}
}

基于数组的线段树

以下代码为区间修改的线段树的数组写法的模板,节点值为 Max。Min, Sum 写法类似。

重点关注子区间结果上传 push_up 和懒标记下传 push_down,以及它们在 range_updaterange_query 中的发动时机。

代码(c++,模板)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class SeqSegmentTree
{
public:
SeqSegmentTree()
{
st_vec = vector<int>();
lazy = vector<int>();
n = -1;
}

void build(const vector<int>& nums)
{
if(nums.empty()) return;
n = nums.size();
st_vec.resize(n * 4);
lazy.resize(n * 4);
_build(1, 0, n - 1, nums);
}

void range_update(int i, int j, int v)
{
// [i, j] 范围内改为 v
_range_update(1, 0, n - 1, i, j, v);
}

int range_query(int i, int j)
{
return _range_query(1, 0, n - 1, i, j);
}

private:
int _range_query(int node, int nodeLeft, int nodeRight, int start, int end)
{
if(nodeLeft == start && nodeRight == end)
return st_vec[node];
int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2;
int right_son = node * 2 + 1;
push_down(node);
if(end <= nodeMid)
return _range_query(left_son, nodeLeft, nodeMid, start, end);
else if(nodeMid < start)
return _range_query(right_son, nodeMid + 1, nodeRight, start, end);
else
{
return max(_range_query(left_son, nodeLeft, nodeMid, start, nodeMid),
_range_query(right_son, nodeMid + 1, nodeRight, nodeMid + 1, end));
}
}

void _range_update(int node, int nodeLeft, int nodeRight, int start, int end, int v)
{
if(nodeLeft == start && nodeRight == end)
{
lazy[node] = v;
st_vec[node] = v;
return;
}
if(nodeLeft == nodeRight) return;
int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2;
int right_son = node * 2 + 1;
push_down(node);
if(end <= nodeMid)
_range_update(left_son, nodeLeft, nodeMid, start, end, v);
else if(nodeMid < start)
_range_update(right_son, nodeMid + 1, nodeRight, start, end, v);
else
{
_range_update(left_son, nodeLeft, nodeMid, start, nodeMid, v);
_range_update(right_son, nodeMid + 1, nodeRight, nodeMid + 1, end, v);
}
push_up(node);
}

// 懒标记下传
void push_down(int node)
{
if(lazy[node])
{
// 节点 node 有 lazy 标记
int left_son = node * 2;
int right_son = node * 2 + 1;
// 如果 lazy = v 的含义是区间内的值都加 v,则是如下写法
lazy[left_son] = lazy[node]; // 向左子节点传递
lazy[right_son] = lazy[node]; // 向右子节点传递
st_vec[left_son] = lazy[node];
st_vec[right_son] = lazy[node];
// 如果 lazy = v 的含义是区间内的值都改为 v,则将以上的 += 改为 =
lazy[node] = 0;
}
}

// 子区间结果上传
void push_up(int node)
{
int left_son = node * 2;
int right_son = node * 2 + 1;
st_vec[node] = max(st_vec[left_son], st_vec[right_son]);
}

void _build(int node, int nodeLeft, int nodeRight, const vector<int>& nums)
{
if(nodeLeft == nodeRight)
{
st_vec[node] = nums[nodeLeft];
return;
}
int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2;
int right_son = node * 2 + 1;
_build(left_son, nodeLeft, nodeMid, nums);
_build(right_son, nodeMid + 1, nodeRight, nums);
st_vec[node] = max(st_vec[left_son], st_vec[right_son]);
}

vector<int> st_vec; // 节点值表示区间最大值
vector<int> lazy;
int n;
};

基于链表的线段树

以下代码为区间修改的线段树的链式写法的模板,节点值为 Max。Min, Sum 写法类似。

重点关注子区间结果上传 push_up 和懒标记下传 push_down,以及它们在 range_updaterange_query 中的发动时机。

代码(c++,模板)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
struct STNode
{
int nodeLeft, nodeRight;
int maxx;
STNode *left, *right;
int lazy;
STNode(int l, int r, int x, STNode* left=nullptr, STNode* right=nullptr)
:nodeLeft(l),nodeRight(r),maxx(x),left(left),right(right),lazy(0){}
~STNode(){}
};

class SegmentTree
{
public:
SegmentTree()
{
root = nullptr;
}

~SegmentTree()
{
if(root)
{
delete_sub_tree(root);
}
}

void range_update(int i, int j, int v)
{
_range_update(root, i, j, v);
}

int range_query(int i, int j)
{
return _range_query(root, i, j);
}

void build(const vector<int>&arr)
{
if(arr.empty()) return;
int n = arr.size();
root = _build(0, n - 1, arr);
}

void traverse()
{
cout << "==================" << endl;
_traverse(root);
cout << "==================" << endl;
}

private:
STNode *root;

void delete_sub_tree(STNode* node)
{
if(node -> left)
delete_sub_tree(node -> left);
if(node -> right)
delete_sub_tree(node -> right);
delete node;
node = nullptr;
}

void _traverse(STNode* node)
{
cout << "Range: [";
cout << node -> nodeLeft << " , " << node -> nodeRight << "]" << endl;
cout << "Max: " << node -> maxx << endl;
if(node -> nodeLeft != node -> nodeRight)
{
_traverse(node -> left);
_traverse(node -> right);
}
}

// 懒标记下传
void push_down(STNode* node)
{
if(node -> lazy)
{
node -> left -> lazy = node -> lazy;
node -> right -> lazy = node -> lazy;
node -> left -> maxx = node -> lazy;
node -> right -> maxx = node -> lazy;
node -> lazy = 0;
}
}

// 子区间结果上传
void push_up(STNode* node)
{
node -> maxx = max(node -> left -> maxx, node -> right -> maxx);
}

int _range_query(STNode* node, int start, int end)
{
int nodeLeft = node -> nodeLeft;
int nodeRight = node -> nodeRight;
if(nodeLeft == start && nodeRight == end)
return node -> maxx;
int nodeMid = (nodeLeft + nodeRight) / 2;
// 要根据子树结果计算当前节点结果时,懒标记下传
push_down(node);
if(end <= nodeMid)
return _range_query(node -> left, start, end);
else if(nodeMid < start)
return _range_query(node -> right, start, end);
else
{
return max(_range_query(node -> left, start, nodeMid),
_range_query(node -> right, nodeMid + 1, end));
}
}

void _range_update(STNode* node, int start, int end, int v)
{
int nodeLeft = node -> nodeLeft;
int nodeRight = node -> nodeRight;
if(nodeLeft == start && nodeRight == end)
{
node -> lazy = v;
node -> maxx = v;
return;
}
int nodeMid = (nodeLeft + nodeRight) / 2;
if(nodeLeft == nodeRight) return;
// 下传懒标记
push_down(node);
if(end <= nodeMid)
{
_range_update(node -> left, start, end, v);
}
else if(nodeMid < start)
{
_range_update(node -> right, start, end, v);
}
else
{
_range_update(node -> left, start, nodeMid, v);
_range_update(node -> right, nodeMid + 1, end, v);
}
push_up(node);
}

STNode* _build(int nodeLeft, int nodeRight, const vector<int>& arr)
{
if(nodeLeft == nodeRight)
return new STNode(nodeLeft, nodeRight, arr[nodeLeft]);
int nodeMid = (nodeLeft + nodeRight) / 2;
STNode *left_son = _build(nodeLeft, nodeMid, arr);
STNode *right_son = _build(nodeMid + 1, nodeRight, arr);
int maxx = max(left_son -> maxx, right_son -> maxx);
return new STNode(nodeLeft, nodeRight, maxx, left_son, right_son);
}
};

题目:699. 掉落的方块

在二维平面上的 x 轴上,放置着一些方块。

给你一个二维整数数组 positions ,其中 positions[i] = [lefti, sideLengthi] 表示:第 i 个方块边长为 sideLengthi ,其左侧边与 x 轴上坐标点 lefti 对齐。

每个方块都从一个比目前所有的落地方块更高的高度掉落而下。方块沿 y 轴负方向下落,直到着陆到 另一个正方形的顶边 或者是 x 轴上 。一个方块仅仅是擦过另一个方块的左侧边或右侧边不算着陆。一旦着陆,它就会固定在原地,无法移动。

在每个方块掉落后,你必须记录目前所有已经落稳的 方块堆叠的最高高度 。

返回一个整数数组 ans ,其中 ans[i] 表示在第 i 块方块掉落后堆叠的最高高度。

提示:

1
2
3
1 <= positions.length <= 1000
1 <= lefti <= 1e8
1 <= sideLengthi <= 1e6

示例 1:

输入:positions = [[1,2],[2,3],[6,1]]
输出:[2,5,5]
解释:
第 1 个方块掉落后,最高的堆叠由方块 1 组成,堆叠的最高高度为 2 。
第 2 个方块掉落后,最高的堆叠由方块 1 和 2 组成,堆叠的最高高度为 5 。
第 3 个方块掉落后,最高的堆叠仍然由方块 1 和 2 组成,堆叠的最高高度为 5 。
因此,返回 [2, 5, 5] 作为答案。

示例 2:
输入:positions = [[100,100],[200,100]]
输出:[100,100]
解释:
第 1 个方块掉落后,最高的堆叠由方块 1 组成,堆叠的最高高度为 100 。
第 2 个方块掉落后,最高的堆叠可以由方块 1 组成也可以由方块 2 组成,堆叠的最高高度为 100 。
因此,返回 [100, 100] 作为答案。
注意,方块 2 擦过方块 1 的右侧边,但不会算作在方块 1 上着陆。

算法:线段树

共有 $N$ 个方块落下,当第 $i$ 个方块落下时,下面已经堆叠了一些方块。

第 $i$ 个方块的左端点为 pos[i][0] 边长为 pos[i][1],因此区间为 [start, end] = [pos[i][0], pos[i][0] + pos[i][1] - 1],我们要知道的是在已经落稳的方块中,[start, end] 这个范围的最大值是多少,这是一步区间最值查询,结果为 maxx

然后第 $i$ 个方块落稳后,[start, end] 范围的高度更新为 maxx + pos[i][1]。更新后,查询 [0, n-1] 范围内的最值,写入 result[i] 中。

代码 (C++)

以下两份代码,一个是使用基于链表的线段树,一个是使用基于数组的线段树。区别仅在于 sttreeseqsttree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution {
public:
vector<int> fallingSquares(vector<vector<int>>& positions) {
// 离散化
vector<int> x;
for(const vector<int>& pos: positions)
{
x.push_back(pos[0]);
x.push_back(pos[0] + pos[1] - 1);
}
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());

int n = x.size();
SegmentTree sttree;
vector<int> arr(n);
sttree.build(arr);
vector<int> result;
// sttree.traverse();
for(const vector<int>& pos: positions)
{
int start = _find(pos[0], x);
int end = _find(pos[0] + pos[1] - 1, x);
int v = pos[1];
int maxx = sttree.range_query(start, end);
sttree.range_update(start, end, maxx + v);
result.push_back(sttree.range_query(0, n - 1));
}
return result;
}

private:
int _find(int v, const vector<int>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
public:
vector<int> fallingSquares(vector<vector<int>>& positions) {
// 离散化
vector<int> x;
for(const vector<int>& pos: positions)
{
x.push_back(pos[0]);
x.push_back(pos[0] + pos[1] - 1);
}
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());

int n = x.size();
SeqSegmentTree seqsttree;
vector<int> arr(n);
seqsttree.build(arr);
vector<int> result;
for(const vector<int>& pos: positions)
{
int start = _find(pos[0], x);
int end = _find(pos[0] + pos[1] - 1, x);
int v = pos[1];
int maxx = seqsttree.range_query(start, end);
seqsttree.range_update(start, end, maxx + v);
result.push_back(seqsttree.range_query(0, n - 1));
}
return result;
}

private:
int _find(int v, const vector<int>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

Share