力扣699-平衡树区间模块,线段树区间修改,RMQ

  |  
  • 多次查询区间的最值,也称 RMQ 问题。是一个常用组件。稀疏表和线段树是 RMQ 的两种主流做法。
  • 其中稀疏表适用与数据不变的情况,线段树适用于数据会改变的情况(可能是单点修改或区间修改)
  • 线段树的区间修改模板
    • 数组写法
    • 链式写法
  • 用平衡树维护区间动态插入删除的区间模块,是一个常用组件

  • 实现区间模块 715. Range 模块
  • 扫描线+区间模块处理矩形问题:横轴区间的左右端点作为两种事件,合适地排序,然后用扫描线算法对事件扫描。对依次到来的事件,用区间模块(类似lc715实现的数据结构)对竖轴上的区间做相应操作。类似思路的题目:

$1 题目

题目链接

699. 掉落的方块

题目描述

在无限长的数轴(即 x 轴)上,我们根据给定的顺序放置对应的正方形方块。

第 i 个掉落的方块(positions[i] = (left, side_length))是正方形,其中 left 表示该方块最左边的点位置(positions[i][0]),side_length 表示该方块的边长(positions[i][1])。

每个方块的底部边缘平行于数轴(即 x 轴),并且从一个比目前所有的落地方块更高的高度掉落而下。在上一个方块结束掉落,并保持静止后,才开始掉落新方块。

方块的底边具有非常大的粘性,并将保持固定在它们所接触的任何长度表面上(无论是数轴还是其他方块)。邻接掉落的边不会过早地粘合在一起,因为只有底边才具有粘性。

返回一个堆叠高度列表 ans 。每一个堆叠高度 ans[i] 表示在通过 positions[0], positions[1], …, positions[i] 表示的方块掉落结束后,目前所有已经落稳的方块堆叠的最高高度。

数据范围

1 <= positions.length <= 1000.
1 <= positions[i][0] <= 10^8.
1 <= positions[i][1] <= 10^6.

样例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
>示例 1:
>
>输入: [[1, 2], [2, 3], [6, 1]]
>输出: [2, 5, 5]
>解释:
>
>第一个方块 positions[0] = [1, 2] 掉落:
>_aa
>_aa
>-------
>方块最大高度为 2 。
>
>第二个方块 positions[1] = [2, 3] 掉落:
>__aaa
>__aaa
>__aaa
>_aa__
>_aa__
>--------------
>方块最大高度为5。
>大的方块保持在较小的方块的顶部,不论它的重心在哪里,因为方块的底部边缘有非常大的粘性。
>
>第三个方块 positions[1] = [6, 1] 掉落:
>__aaa
>__aaa
>__aaa
>_aa
>_aa___a
>--------------
>方块最大高度为5。
>
>因此,我们返回结果[2, 5, 5]。

>示例 2:
>
>输入: [[100, 100], [200, 100]]
>输出: [100, 100]
>解释: 相邻的方块不会过早地卡住,只有它们的底部边缘才能粘在表面上。

$2 题解

算法1 - RMQ

RMQ(Range Minimum/Maximum Query),即区间最值查询。

问题:给你一个数组 ,其中有 N 个数字,现在给你一次查询,给你区间[l ,r],问你在这个区间内的最值为多少

数组长度为 N,查询次数为 Q

数组中的数据保持不变的情况 — 稀疏表

参考 : RMQ

当数组中数据不变时,RMQ 算法一般用较长时间$O(N\log N)$做预处理,然后可以在 O(1) 的时间内处理每次查询。该方法称为稀疏表
这种先预处理,在查询的做法与用前缀和处理大量区间和查询的思想相同。

<1> 预处理部分

1
dp[i][j] := 从第 i 位开始,连续 2^j 个数的最小值

例如: arr = [1, 2, 6, 8, 4, 3, 7]

dp[2][1] 表示从第2位数开始连续2个数的最小值,即2,6中的最小值,所以 dp[2][1] = 2;
dp[3][2] 表示从第3位数开始连续4个数的最小值,即6,8,4,3中的最小值,所以 dp[3][2] = 3
dp[i][0] 表示第i个数字本身

求 dp[i][j] 的时候可以把它分成两部分,第一部分是从 $i$ 到 $i+2^{j-1}-1$ ,第二部分从 $i+2^{j-1}$ 到 $i+2^{j}-1$。

因为二进制数前一个数是后一个的两倍,所以可以把 $i$ 到 $i+2^{j}-1$ (即 dp[i][j] 表示的区间范围) 这个区间通过 $2^{j-1}$ 分
成相等的两部分,进而写出转移方程:

1
dp[i][j] = min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1])

在外层枚举 j : 1..(1<<j)<=N,含义:
先更新每两个($2^{1}$)元素中的最小值,然后通过每两个元素的最小值获得每4个元素中的最小值,依次类推更新所有长度的最小值。

<2> 查询部分

查询区间 [l, r],令 $k = \log_{2}(r - l + 1)$,则区间 [l, r] 的最小值计算如下

1
query(l, r) = min(dp[l][k], dp[r - (1 << k) + 1][k])

因为 dp[l][k], dp[r - (1 << k) + 1][k] 分别维护区间 [l, $l + 2^{k} - 1$] 和 [$r - 2^{k} + 1$, r]。而 $r - 2^{k} + 1 \leq r - 2^{k} + 1$

代码(C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <iostream>
#include <vector>
#include <cmath>

using namespace std;

class RMQ
{
public:
RMQ(){}

void init(const vector<int>& arr)
{
int n = arr.size();
// 2 ^ m <= n
// log2(2^m) <= log2(n)
// m <= log2(n)
int m = log2(n);
dp.assign(n + 1, vector<int>(m + 1, 0));
for(int i = 1; i <= n; ++i)
dp[i][0] = arr[i]; //初始化
for(int j = 1; (1 << j) <= n; ++j)
for(int i = 1; i + (1 << j) - 1 <= n; ++i)
dp[i][j] = min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);
}

int query(int l, int r)
{
int k = log2(r - l + 1);
return min(dp[l][k], dp[r - (1 << k) + 1][k]);
}
private:
vector<vector<int>> dp;
};

int main()
{
int n;
cin >> n;
vector<int> arr(n);
for(int i = 0; i < n; ++i)
cin >> arr[i];
cout << "数据: " << endl;
for(int i = 0; i < n; ++i)
cout << i << " " << arr[i] << endl;
RMQ rmq;
rmq.init(arr);
while(true)
{
int start, end;
cin >> start >> end;
cout << "查询区间: [" << start << ", " << end << "], 最小值:";
cout << rmq.query(start, end) << endl;
}
};

稀疏表隐含了倍增法的思想。这与倍增法处理 LCA 问题的思路是一直的。 力扣1483-树节点的第K个祖先

数组中的数据会变化的情况 — 线段树

参考 : 线段树维护区间最值RMQ

当数组中的值会有修改,需要边修改,变查询。此时稀疏表的效率就不高了,这与前缀和在求可变数组的区间和时效率不高的情况类似,解决方案也类似,用线段树。

本题遇到的就是这种情况。

线段树区间修改

关于线段树的单点更新和区间查询参考 : 力扣307-线段树,树状数组

有一个长为 n 的数据数组 nums,有一个表示 [0,n-1] 内各个区间的线段树,线段树的区间更新是这个意思:对于内部的某个区间 $[i, j], 0 \leq i, j \leq n-1$,
将区间内的值改为 v。

这一类更新,最直观的是对区间内每个点做单点更新:首先更新点 i,point_update(i, v),然后更新 i+1, point_update(i + 1,v),…,直到更新 j, point_update(j, v)。这样的话,操作次数就爆了。

解法是对每个区间引入一个标记 lazy,称为懒标记。
当需要把区间 [i, j] 内的值改为 v 时,没有直接进入线段树的对应区间的子区间取修改,而是给该区间做标记 v,当查询时若需要读取该区间数据,再利用标记算出应当返回的结果。
如果是链式写法,lazy 直接作为节点的一个单独子段,表示对区间 [nodeLeft, nodeRight] 标记 v。如果是数组写法,则用一个与线段树数组 st_vec 同长度的数组维护,即 lazy 数组。

例如有线段树,根表示 [0, 15] ,对 [5, 13] 上的所有值更新为 v,则仅仅给区间[5, 13]做标记 v,即在 [5], [6, 7], [8, 11], [12, 13] 这几个区间上进行标记。其中 [6, 7], [8, 11], [12, 13],这几个区间还有子区间,在做标记的时候不进入子区间。

1
2
3
4
5
[0                                                 15]
[0 7][8 15]
[0 3][4 7][8 11][12 15]
[0 1][2 3][4 5][6 7][8 9][10 11][12 13][14 15]
[0][1][2][3][4][5][6][7][8][9][10][11][12][13][14][15]

这种维护 lazy 标记的核心思想是把向下的修改先存起来,对于每个查询,在向上传递答案是再利用 lazy 标记修正传递的答案。
例如题目 1476. 子矩形查询,用到了这种保存修改,在查询时再计算答案的思想。

以上思路涉及到两个操作 push_uppush_down
push_up 表示向上的更新,push_down 维护的是 lazy 标记的值。
以下是的写法节点值表示区间和的两个操作的写法

1
2
3
4
5
6
void push_up(int node)
{
int left_son = node * 2;
int right_son = node * 2 + 1;
st_vec[node] = st_vec[left_son] + st_vec[right_son];
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void push_down(int node, int m) // m 为 node 表示区间的长度
{
if(lazy[node])
{
// 节点 node 有 lazy 标记
int left_son = node * 2;
int right_son = node * 2 + 1;

// 如果 lazy = v 的含义是区间内的值都加 v,则是如下写法
lazy[left_son] += lazy[node]; // 向左子节点传递
lazy[right_son] += lazy[node]; // 向右子节点传递
st_vec[left_son] += lazy[node] * (m - m / 2);
st_vec[right_son] += lazy[node] * (m / 2);
// 如果 lazy = v 的含义是区间内的值都改为 v,则将以上的 += 改为 =

lazy[node] = 0;
}
}

线段树RMQ 数组写法

以下代码为区间修改的线段树的数组写法的模板,节点值为 Max。Min, Sum 写法类似。

重点关注子区间结果上传 push_up 和懒标记下传 push_dowm,以及它们在 range_updaterange_query 中的发动时机。

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class SeqSegmentTree
{
public:
SeqSegmentTree()
{
st_vec = vector<int>();
lazy = vector<int>();
n = -1;
}

void build(const vector<int>& nums)
{
if(nums.empty()) return;
n = nums.size();
st_vec.resize(n * 4);
lazy.resize(n * 4);
_build(1, 0, n - 1, nums);
}

void range_update(int i, int j, int v)
{
// [i, j] 范围内改为 v
_range_update(1, 0, n - 1, i, j, v);
}

int range_query(int i, int j)
{
return _range_query(1, 0, n - 1, i, j);
}

private:
int _range_query(int node, int nodeLeft, int nodeRight, int start, int end)
{
if(nodeLeft == start && nodeRight == end)
return st_vec[node];
int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2;
int right_son = node * 2 + 1;
push_down(node);
if(end <= nodeMid)
return _range_query(left_son, nodeLeft, nodeMid, start, end);
else if(nodeMid < start)
return _range_query(right_son, nodeMid + 1, nodeRight, start, end);
else
{
return max(_range_query(left_son, nodeLeft, nodeMid, start, nodeMid),
_range_query(right_son, nodeMid + 1, nodeRight, nodeMid + 1, end));
}
}

void _range_update(int node, int nodeLeft, int nodeRight, int start, int end, int v)
{
if(nodeLeft == start && nodeRight == end)
{
lazy[node] = v;
st_vec[node] = v;
return;
}
if(nodeLeft == nodeRight) return;
int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2;
int right_son = node * 2 + 1;
push_down(node);
if(end <= nodeMid)
_range_update(left_son, nodeLeft, nodeMid, start, end, v);
else if(nodeMid < start)
_range_update(right_son, nodeMid + 1, nodeRight, start, end, v);
else
{
_range_update(left_son, nodeLeft, nodeMid, start, nodeMid, v);
_range_update(right_son, nodeMid + 1, nodeRight, nodeMid + 1, end, v);
}
push_up(node);
}

// 懒标记下传
void push_down(int node)
{
if(lazy[node])
{
// 节点 node 有 lazy 标记
int left_son = node * 2;
int right_son = node * 2 + 1;
// 如果 lazy = v 的含义是区间内的值都加 v,则是如下写法
lazy[left_son] = lazy[node]; // 向左子节点传递
lazy[right_son] = lazy[node]; // 向右子节点传递
st_vec[left_son] = lazy[node];
st_vec[right_son] = lazy[node];
// 如果 lazy = v 的含义是区间内的值都改为 v,则将以上的 += 改为 =
lazy[node] = 0;
}
}

// 子区间结果上传
void push_up(int node)
{
int left_son = node * 2;
int right_son = node * 2 + 1;
st_vec[node] = max(st_vec[left_son], st_vec[right_son]);
}

void _build(int node, int nodeLeft, int nodeRight, const vector<int>& nums)
{
if(nodeLeft == nodeRight)
{
st_vec[node] = nums[nodeLeft];
return;
}
int nodeMid = (nodeLeft + nodeRight) / 2;
int left_son = node * 2;
int right_son = node * 2 + 1;
_build(left_son, nodeLeft, nodeMid, nums);
_build(right_son, nodeMid + 1, nodeRight, nums);
st_vec[node] = max(st_vec[left_son], st_vec[right_son]);
}

vector<int> st_vec; // 节点值表示区间最大值
vector<int> lazy;
int n;
};

class Solution {
public:
vector<int> fallingSquares(vector<vector<int>>& positions) {
// 离散化
vector<int> x;
for(const vector<int>& pos: positions)
{
x.push_back(pos[0]);
x.push_back(pos[0] + pos[1] - 1);
}
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());

int n = x.size();
SeqSegmentTree seqsttree;
vector<int> arr(n);
seqsttree.build(arr);
vector<int> result;
for(const vector<int>& pos: positions)
{
int start = _find(pos[0], x);
int end = _find(pos[0] + pos[1] - 1, x);
int v = pos[1];
int maxx = seqsttree.range_query(start, end);
seqsttree.range_update(start, end, maxx + v);
result.push_back(seqsttree.range_query(0, n - 1));
}
return result;
}

private:
int _find(int v, const vector<int>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

线段树RMQ 链式写法

以下代码为区间修改的线段树的链式写法的模板,节点值为 Max。Min, Sum 写法类似。

重点关注子区间结果上传 push_up 和懒标记下传 push_dowm,以及它们在 range_updaterange_query 中的发动时机。

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
struct STNode
{
int nodeLeft, nodeRight;
int maxx;
STNode *left, *right;
int lazy;
STNode(int l, int r, int x, STNode* left=nullptr, STNode* right=nullptr)
:nodeLeft(l),nodeRight(r),maxx(x),left(left),right(right),lazy(0){}
~STNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

class SegmentTree
{
public:
SegmentTree()
{
root = nullptr;
}

~SegmentTree()
{
if(root)
{
delete root;
root = nullptr;
}
}

void range_update(int i, int j, int v)
{
_range_update(root, i, j, v);
}

int range_query(int i, int j)
{
return _range_query(root, i, j);
}

void build(const vector<int>&arr)
{
if(arr.empty()) return;
int n = arr.size();
root = _build(0, n - 1, arr);
}

void traverse()
{
cout << "==================" << endl;
_traverse(root);
cout << "==================" << endl;
}

private:
STNode *root;

void _traverse(STNode* node)
{
cout << "Range: [";
cout << node -> nodeLeft << " , " << node -> nodeRight << "]" << endl;
cout << "Max: " << node -> maxx << endl;
if(node -> nodeLeft != node -> nodeRight)
{
_traverse(node -> left);
_traverse(node -> right);
}
}

// 懒标记下传
void push_down(STNode* node)
{
if(node -> lazy)
{
node -> left -> lazy = node -> lazy;
node -> right -> lazy = node -> lazy;
node -> left -> maxx = node -> lazy;
node -> right -> maxx = node -> lazy;
node -> lazy = 0;
}
}

// 子区间结果上传
void push_up(STNode* node)
{
node -> maxx = max(node -> left -> maxx, node -> right -> maxx);
}

int _range_query(STNode* node, int start, int end)
{
int nodeLeft = node -> nodeLeft;
int nodeRight = node -> nodeRight;
if(nodeLeft == start && nodeRight == end)
return node -> maxx;
int nodeMid = (nodeLeft + nodeRight) / 2;
// 要根据子树结果计算当前节点结果时,懒标记下传
push_down(node);
if(end <= nodeMid)
return _range_query(node -> left, start, end);
else if(nodeMid < start)
return _range_query(node -> right, start, end);
else
{
return max(_range_query(node -> left, start, nodeMid),
_range_query(node -> right, nodeMid + 1, end));
}
}

void _range_update(STNode* node, int start, int end, int v)
{
int nodeLeft = node -> nodeLeft;
int nodeRight = node -> nodeRight;
if(nodeLeft == start && nodeRight == end)
{
node -> lazy = v;
node -> maxx = v;
return;
}
int nodeMid = (nodeLeft + nodeRight) / 2;
if(nodeLeft == nodeRight) return;
// 下传懒标记
push_down(node);
if(end <= nodeMid)
{
_range_update(node -> left, start, end, v);
}
else if(nodeMid < start)
{
_range_update(node -> right, start, end, v);
}
else
{
_range_update(node -> left, start, nodeMid, v);
_range_update(node -> right, nodeMid + 1, end, v);
}
push_up(node);
}

STNode* _build(int nodeLeft, int nodeRight, const vector<int>& arr)
{
if(nodeLeft == nodeRight)
return new STNode(nodeLeft, nodeRight, arr[nodeLeft]);
int nodeMid = (nodeLeft + nodeRight) / 2;
STNode *left_son = _build(nodeLeft, nodeMid, arr);
STNode *right_son = _build(nodeMid + 1, nodeRight, arr);
int maxx = max(left_son -> maxx, right_son -> maxx);
return new STNode(nodeLeft, nodeRight, maxx, left_son, right_son);
}
};

class Solution {
public:
vector<int> fallingSquares(vector<vector<int>>& positions) {
// 离散化
vector<int> x;
for(const vector<int>& pos: positions)
{
x.push_back(pos[0]);
x.push_back(pos[0] + pos[1] - 1);
}
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());

int n = x.size();
SegmentTree sttree;
vector<int> arr(n);
sttree.build(arr);
vector<int> result;
// sttree.traverse();
for(const vector<int>& pos: positions)
{
int start = _find(pos[0], x);
int end = _find(pos[0] + pos[1] - 1, x);
int v = pos[1];
int maxx = sttree.range_query(start, end);
sttree.range_update(start, end, maxx + v);
result.push_back(sttree.range_query(0, n - 1));
}
return result;
}

private:
int _find(int v, const vector<int>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

算法2 - 区间模块

用区间模块 Ranges 记录已经落下过的正方形。

对于一个正方形 [pos, len],它横跨区间 [pos, pos + len),如果这个范围此前没有落下过正方形,即区间模块中没有这个范围的记录,则区间模块中可以直接记录
区间[pos, pos+len) 的高度为 len。如果此前已经有正方形落在了该区域,则[pos, pos+len) 会与 Ranges 中的一些已有区间重叠。那么就先要处理这些重叠,重叠可能发生的情况如下

1
2
       [pos,                           pos+len)
[left1, right1) [left2, right2) [left3, right3)

受影响的有若干区间,[pos, pos+len) 最终会更新的高度为这些受影响区间中最高值加上 len。而受影响区间中可能有一部分是在重叠部分之外,
例如上面的区间 [left1, right1) , [left3, right3)。则非重叠的部分仍然维护这原来的最高值。因此更新流程如下:

  • step1: 在区间模块中找到所有与 [pos, pos+len) 重叠的区间
  • step2: 在若干重叠区间中找到两端与 [pos, pos+len) 不重叠的部分,并记录这若干区间的高度最大值
  • step3: 将这若干个区间全部动 Ranges 模块中删掉
  • step4: 插入 [pos, pos+len) 的高度记录,为 step2 得到的这一部分的原高度最大值 + len
  • step5:如果被删的若干重叠区间两侧有与 [pos, pos+len) 不重叠的部分,则把这部分的记录再插入回去
  • step6: 新插入的高度值与 Ranges 模块中维护的全局最大值比较,将较大这作为本次更新的答案输出.

以上算法中,主要用到的就是区间在动态的插入删除过程中,插入的时候动态地将重叠部分合并,删除的时候只将重叠部分删掉,已有区间中未与被删区间重叠的部分需要保留。
这个功能用一个 TreeMap 维护,键是区间左端点,值是右端点(开区间)。每次需要插入或者删除新区间时,首先利用左端点的有序性可以
快速找到所有的重叠区间,然后在重叠区间上进行操作,本题中是更新最高的高度,然后做插入删除的动作,在插入删除过程中始终维护 Ranges 模块中的区间无重叠。

题目 715. Range 模块 正是实现的以上功能的数据结构,并且有一些以此数据结构为组件的题目,但作为组件时,需要根据特定的需求对内部
微调一些逻辑。核心有两点,一是用平衡树维护区间的有序性,而是在操作区间(插入,删除)前,先利用有序性快速找到所有重叠区间,然后的操作就在这些重叠区间上完成。

区间模块经常根扫描线算法配合处理矩形问题:横轴区间的左右端点作为两种事件,合适地排序,然后用扫描线算法对事件扫描。对依次到来的事件,用区间模块(类似lc715实现的数据结构)对竖轴上的区间做相应操作。类似思路的题目:

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class Ranges
{
public:
Ranges()
{
ranges = map<int, PII>();
max_h = 0;
}

int update(int start, int end, int len)
{
// [start, end) 范围的 h 最大值, 返回 h
IT l, r;
get_overlap(start, end, l, r);
// [l, r) 为有重叠的区间
if(l == r) // [start, end) 范围无已有区间
{
ranges.insert(PIP(start, PII(end, len)));
max_h = max(max_h, len);
return max_h;
}
auto last = r;
--last;

bool has_left = false, has_right = false;
int left_l = l -> first;
int left_r = start;
int left_h = (l -> second).second;
int right_l = end;
int right_r = (last -> second).first;
int right_h = (last -> second).second;
if(left_l < left_r)
has_left = true;
if(right_l < right_r)
has_right = true;

auto iter = l;
int h = (iter -> second).second;
++iter;
while(iter != r)
{
h = max(h, (iter -> second).second);
++iter;
}

ranges.erase(l, r);
if(has_left)
ranges.insert(PIP(left_l, PII(left_r, left_h)));
if(has_right)
ranges.insert(PIP(right_l, PII(right_r, right_h)));
int new_h = h + len;
ranges.insert(PIP(start, PII(end, new_h)));
max_h = max(max_h, new_h);
return max_h;
}

private:
using PII = pair<int, int>;
using PIP = pair<int, PII>;
using IT = map<int, PII>::iterator;
map<int, PII> ranges; // left -> (right, h)
int max_h;

void get_overlap(int left, int right, IT& l, IT& r)
{
l = ranges.upper_bound(left);
r = ranges.lower_bound(right);
if(l != ranges.begin())
{
--l;
if((l -> second).first <= left)
++l;
}
}
};

class Solution {
public:
vector<int> fallingSquares(vector<vector<int>>& positions) {
int n = positions.size();
vector<int> result(n, -1);
Ranges ranges;
for(int i = 0; i < n; ++i)
{

int left = positions[i][0];
int len = positions[i][1];
int right = left + len; // right 是开区间
result[i] = ranges.update(left, right, len);
}
return result;
}
};

Share