线段树维护区间最值RMQ

  |  
  • 线段树的区间修改模板
    • 数组写法
    • 链式写法
  • 线段树的单点更新和区间查询参考 : 线段树,树状数组

问题

问题:给你一个数组 ,其中有 N 个数字,现在给你一次查询,给你区间[l ,r],问你在这个区间内的最值为多少

数组长度为 N,查询次数为 Q

主流解法


1、搜索,$O(n)$ 预处理 $O(qn)$ 在线查询。

2、稀疏表(ST),$O(n\log n)$ 预处理 $O(q)$ 在线查询。

3、线段树/树状数组,$O(n)$ 预处理 $O(q\log n)$ 在线查询。

4、RMQ标准算法:先规约成LCA,再规约成约束RMQ,$O(n)$ 预处理 $O(q)$ 在线查询。


这里关注线段树这种做法

树状数组的做法参考 树状数组维护区间最值RMQ

线段树

适用于数组中的数据会变化的情况

参考题目:699. 掉落的方块

力扣699-平衡树区间模块,线段树区间修改,RMQ

当数组中的值会有修改,需要边修改,变查询。此时稀疏表的效率就不高了,这与前缀和在求可变数组的区间和时效率不高的情况类似,解决方案也类似,用线段树。

本题遇到的就是这种情况。

线段树区间修改

关于线段树的单点更新和区间查询参考 : 力扣307-线段树,树状数组

有一个长为 n 的数据数组 nums,有一个表示 [0,n-1] 内各个区间的线段树,线段树的区间更新是这个意思:对于内部的某个区间 $[i, j], 0 \leq i, j \leq n-1$,
将区间内的值改为 v。

这一类更新,最直观的是对区间内每个点做单点更新:首先更新点 i,point_update(i, v),然后更新 i+1, point_update(i + 1,v),…,直到更新 j, point_update(j, v)。这样的话,操作次数就爆了。

解法是对每个区间引入一个标记 lazy,称为懒标记。
当需要把区间 [i, j] 内的值改为 v 时,没有直接进入线段树的对应区间的子区间取修改,而是给该区间做标记 v,当查询时若需要读取该区间数据,再利用标记算出应当返回的结果。
如果是链式写法,lazy 直接作为节点的一个单独子段,表示对区间 [nodeLeft, nodeRight] 标记 v。如果是数组写法,则用一个与线段树数组 st_vec 同长度的数组维护,即 lazy 数组。

例如有线段树,根表示 [0, 15] ,对 [5, 13] 上的所有值更新为 v,则仅仅给区间[5, 13]做标记 v,即在 [5], [6, 7], [8, 11], [12, 13] 这几个区间上进行标记。其中 [6, 7], [8, 11], [12, 13],这几个区间还有子区间,在做标记的时候不进入子区间。

1
2
3
4
5
[0                                                 15]
[0 7][8 15]
[0 3][4 7][8 11][12 15]
[0 1][2 3][4 5][6 7][8 9][10 11][12 13][14 15]
[0][1][2][3][4][5][6][7][8][9][10][11][12][13][14][15]

这种维护 lazy 标记的核心思想是把向下的修改先存起来,对于每个查询,在向上传递答案是再利用 lazy 标记修正传递的答案。
例如题目 1476. 子矩形查询,用到了这种保存修改,在查询时再计算答案的思想。

以上思路涉及到两个操作 push_uppush_down
push_up 表示向上的更新,push_down 维护的是 lazy 标记的值。
以下是的写法节点值表示区间和的两个操作的写法

1
2
3
4
5
6
void push_up(int node)
{
int left_son = node * 2;
int right_son = node * 2 + 1;
st_vec[node] = st_vec[left_son] + st_vec[right_son];
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void push_down(int node, int m) // m 为 node 表示区间的长度
{
if(lazy[node])
{
// 节点 node 有 lazy 标记
int left_son = node * 2;
int right_son = node * 2 + 1;

// 如果 lazy = v 的含义是区间内的值都加 v,则是如下写法
lazy[left_son] += lazy[node]; // 向左子节点传递
lazy[right_son] += lazy[node]; // 向右子节点传递
st_vec[left_son] += lazy[node] * (m - m / 2);
st_vec[right_son] += lazy[node] * (m / 2);
// 如果 lazy = v 的含义是区间内的值都改为 v,则将以上的 += 改为 =

lazy[node] = 0;
}
}

线段树RMQ 数组写法

以下代码为区间修改的线段树的数组写法的模板,节点值为 Max。Min, Sum 写法类似。

重点关注子区间结果上传 push_up 和懒标记下传 push_dowm,以及它们在 range_updaterange_query 中的发动时机。

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class SeqSegmentTree
{
public:
SeqSegmentTree()
{
st_vec = vector<int>();
lazy = vector<int>();
n = -1;
}

void build(const vector<int>& nums)
{
if(nums.empty()) return;
n = nums.size();
st_vec.resize(n * 4);
lazy.resize(n * 4);
_build(1, 0, n - 1, nums);
}

void range_update(int i, int j, int v)
{
// [i, j] 范围内改为 v
_range_update(1, 0, n - 1, i, j, v);
}

int range_query(int i, int j)
{
return _range_query(1, 0, n - 1, i, j);
}

private:
int _range_query(int node, int nodeLeft, int nodeRight, int start, int end)
{
if(nodeLeft == start && nodeRight == end)
return st_vec[node];
int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2;
int right_son = node * 2 + 1;
push_down(node);
if(end <= nodeMid)
return _range_query(left_son, nodeLeft, nodeMid, start, end);
else if(nodeMid < start)
return _range_query(right_son, nodeMid + 1, nodeRight, start, end);
else
{
return max(_range_query(left_son, nodeLeft, nodeMid, start, nodeMid),
_range_query(right_son, nodeMid + 1, nodeRight, nodeMid + 1, end));
}
}

void _range_update(int node, int nodeLeft, int nodeRight, int start, int end, int v)
{
if(nodeLeft == start && nodeRight == end)
{
lazy[node] = v;
st_vec[node] = v;
return;
}
if(nodeLeft == nodeRight) return;
int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2;
int right_son = node * 2 + 1;
push_down(node);
if(end <= nodeMid)
_range_update(left_son, nodeLeft, nodeMid, start, end, v);
else if(nodeMid < start)
_range_update(right_son, nodeMid + 1, nodeRight, start, end, v);
else
{
_range_update(left_son, nodeLeft, nodeMid, start, nodeMid, v);
_range_update(right_son, nodeMid + 1, nodeRight, nodeMid + 1, end, v);
}
push_up(node);
}

// 懒标记下传
void push_down(int node)
{
if(lazy[node])
{
// 节点 node 有 lazy 标记
int left_son = node * 2;
int right_son = node * 2 + 1;
// 如果 lazy = v 的含义是区间内的值都加 v,则是如下写法
lazy[left_son] = lazy[node]; // 向左子节点传递
lazy[right_son] = lazy[node]; // 向右子节点传递
st_vec[left_son] = lazy[node];
st_vec[right_son] = lazy[node];
// 如果 lazy = v 的含义是区间内的值都改为 v,则将以上的 += 改为 =
lazy[node] = 0;
}
}

// 子区间结果上传
void push_up(int node)
{
int left_son = node * 2;
int right_son = node * 2 + 1;
st_vec[node] = max(st_vec[left_son], st_vec[right_son]);
}

void _build(int node, int nodeLeft, int nodeRight, const vector<int>& nums)
{
if(nodeLeft == nodeRight)
{
st_vec[node] = nums[nodeLeft];
return;
}
int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2;
int right_son = node * 2 + 1;
_build(left_son, nodeLeft, nodeMid, nums);
_build(right_son, nodeMid + 1, nodeRight, nums);
st_vec[node] = max(st_vec[left_son], st_vec[right_son]);
}

vector<int> st_vec; // 节点值表示区间最大值
vector<int> lazy;
int n;
};

class Solution {
public:
vector<int> fallingSquares(vector<vector<int>>& positions) {
// 离散化
vector<int> x;
for(const vector<int>& pos: positions)
{
x.push_back(pos[0]);
x.push_back(pos[0] + pos[1] - 1);
}
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());

int n = x.size();
SeqSegmentTree seqsttree;
vector<int> arr(n);
seqsttree.build(arr);
vector<int> result;
for(const vector<int>& pos: positions)
{
int start = _find(pos[0], x);
int end = _find(pos[0] + pos[1] - 1, x);
int v = pos[1];
int maxx = seqsttree.range_query(start, end);
seqsttree.range_update(start, end, maxx + v);
result.push_back(seqsttree.range_query(0, n - 1));
}
return result;
}

private:
int _find(int v, const vector<int>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

线段树RMQ 链式写法

以下代码为区间修改的线段树的链式写法的模板,节点值为 Max。Min, Sum 写法类似。

重点关注子区间结果上传 push_up 和懒标记下传 push_dowm,以及它们在 range_updaterange_query 中的发动时机。

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
struct STNode
{
int nodeLeft, nodeRight;
int maxx;
STNode *left, *right;
int lazy;
STNode(int l, int r, int x, STNode* left=nullptr, STNode* right=nullptr)
:nodeLeft(l),nodeRight(r),maxx(x),left(left),right(right),lazy(0){}
~STNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

class SegmentTree
{
public:
SegmentTree()
{
root = nullptr;
}

~SegmentTree()
{
if(root)
{
delete root;
root = nullptr;
}
}

void range_update(int i, int j, int v)
{
_range_update(root, i, j, v);
}

int range_query(int i, int j)
{
return _range_query(root, i, j);
}

void build(const vector<int>&arr)
{
if(arr.empty()) return;
int n = arr.size();
root = _build(0, n - 1, arr);
}

void traverse()
{
cout << "==================" << endl;
_traverse(root);
cout << "==================" << endl;
}

private:
STNode *root;

void _traverse(STNode* node)
{
cout << "Range: [";
cout << node -> nodeLeft << " , " << node -> nodeRight << "]" << endl;
cout << "Max: " << node -> maxx << endl;
if(node -> nodeLeft != node -> nodeRight)
{
_traverse(node -> left);
_traverse(node -> right);
}
}

// 懒标记下传
void push_down(STNode* node)
{
if(node -> lazy)
{
node -> left -> lazy = node -> lazy;
node -> right -> lazy = node -> lazy;
node -> left -> maxx = node -> lazy;
node -> right -> maxx = node -> lazy;
node -> lazy = 0;
}
}

// 子区间结果上传
void push_up(STNode* node)
{
node -> maxx = max(node -> left -> maxx, node -> right -> maxx);
}

int _range_query(STNode* node, int start, int end)
{
int nodeLeft = node -> nodeLeft;
int nodeRight = node -> nodeRight;
if(nodeLeft == start && nodeRight == end)
return node -> maxx;
int nodeMid = (nodeLeft + nodeRight) / 2;
// 要根据子树结果计算当前节点结果时,懒标记下传
push_down(node);
if(end <= nodeMid)
return _range_query(node -> left, start, end);
else if(nodeMid < start)
return _range_query(node -> right, start, end);
else
{
return max(_range_query(node -> left, start, nodeMid),
_range_query(node -> right, nodeMid + 1, end));
}
}

void _range_update(STNode* node, int start, int end, int v)
{
int nodeLeft = node -> nodeLeft;
int nodeRight = node -> nodeRight;
if(nodeLeft == start && nodeRight == end)
{
node -> lazy = v;
node -> maxx = v;
return;
}
int nodeMid = (nodeLeft + nodeRight) / 2;
if(nodeLeft == nodeRight) return;
// 下传懒标记
push_down(node);
if(end <= nodeMid)
{
_range_update(node -> left, start, end, v);
}
else if(nodeMid < start)
{
_range_update(node -> right, start, end, v);
}
else
{
_range_update(node -> left, start, nodeMid, v);
_range_update(node -> right, nodeMid + 1, end, v);
}
push_up(node);
}

STNode* _build(int nodeLeft, int nodeRight, const vector<int>& arr)
{
if(nodeLeft == nodeRight)
return new STNode(nodeLeft, nodeRight, arr[nodeLeft]);
int nodeMid = (nodeLeft + nodeRight) / 2;
STNode *left_son = _build(nodeLeft, nodeMid, arr);
STNode *right_son = _build(nodeMid + 1, nodeRight, arr);
int maxx = max(left_son -> maxx, right_son -> maxx);
return new STNode(nodeLeft, nodeRight, maxx, left_son, right_son);
}
};

class Solution {
public:
vector<int> fallingSquares(vector<vector<int>>& positions) {
// 离散化
vector<int> x;
for(const vector<int>& pos: positions)
{
x.push_back(pos[0]);
x.push_back(pos[0] + pos[1] - 1);
}
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());

int n = x.size();
SegmentTree sttree;
vector<int> arr(n);
sttree.build(arr);
vector<int> result;
// sttree.traverse();
for(const vector<int>& pos: positions)
{
int start = _find(pos[0], x);
int end = _find(pos[0] + pos[1] - 1, x);
int v = pos[1];
int maxx = sttree.range_query(start, end);
sttree.range_update(start, end, maxx + v);
result.push_back(sttree.range_query(0, n - 1));
}
return result;
}

private:
int _find(int v, const vector<int>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

Share