力扣1157-分块查找表,无修改区间众数

  |  
  • 先找可能的候选答案,再验证答案,以下为选出可能候选的一些方法和时间复杂度

    • 随机化 $O(K)$,K 为尝试次数
    • 转换为区间第 k 大,k = j + 1 - (j - i + 1) / 2 - i,可以用划分树 $O(\log N)$
    • 分块查找表 $O(\sqrt{N})$
    • 摩尔投票 $O(N)$
    • 用线段树维护的摩尔投票 $O(\log N)$
  • 扩展:分块查找表做带修改的 RMQ


$1 题目

题目链接

1157. 子数组中占绝大多数的元素

题目描述

实现一个 MajorityChecker 的类,它应该具有下述几个 API:

MajorityChecker(int[] arr) 会用给定的数组 arr 来构造一个 MajorityChecker 的实例。
int query(int left, int right, int threshold) 有这么几个参数:
0 <= left <= right < arr.length 表示数组 arr 的子数组的长度。
2 * threshold > right - left + 1,也就是说阈值 threshold 始终比子序列长度的一半还要大。
每次查询 query(…) 会返回在 arr[left], arr[left+1], …, arr[right] 中至少出现阈值次数 threshold 的元素,如果不存在这样的元素,就返回 -1。

数据范围

提示:

1 <= arr.length <= 20000
1 <= arr[i] <= 20000
对于每次查询,0 <= left <= right < len(arr)
对于每次查询,2 * threshold > right - left + 1
查询次数最多为 10000

样例

示例:

MajorityChecker majorityChecker = new MajorityChecker([1,1,2,2,1,1]);
majorityChecker.query(0,5,4); // 返回 1
majorityChecker.query(0,3,3); // 返回 -1
majorityChecker.query(2,3,2); // 返回 2

$2 题解

算法1: 随机化

若数量大于区间长度一半的众数存在,则随机去一个元素,取到该众数的概率是 1/2,如果取 K 次,则 K 次中至少一次取到该众数的概率为 $p=1-(\frac{1}{2})^{K}$ ,当 K=10 时,至少一次取到众数的概率为 p=0.99902。

因此可以随机取 10 次,每次取出一个数 cand 之后,都查一次该数字在数组中出现了多少次,若次数 >= threshold,则返回 cand;如果 10 次都没有取到众数,则返回 -1,此时有两种可能,一种是次数大于区间长度一半的众数不存在(返回-1正确),另一种是存在但是恰好 10 次都没取到(返回-1错误)。

这里面查询 cand 在查询的区间 [left, right] 中出现了多少次,可以 $O(\log N)$ 完成:用一个 HashMap 保存各个出现过的数字的下标,并且将这些下标排序

1
mapping = unordered_map<int, vector<int>>(); // num -> 有序 idxs

在区间 [left, right] 中随机取得 cand 后,可以二分地得到 cand 在原数组的所有出现下标中,大于等于 left 的第一个位置,和大于 right 的第一个位置
1
2
auto it_left = lower_bound(mapping[cand].begin(), mapping[cand].end(), left);
auto it_right = upper_bound(mapping[cand].begin(), mapping[cand].end(), right);

然后因为 vector 可以随机访问,两个迭代器可以直接相减,得到两个迭代器之间的元素个数。
1
2
if(it_right - it_left >= threshold)
return cand;

这里排序的 vector 不能改成平衡树 set,因为在 set 上得到上述的两个迭代器之后,set 的迭代器不支持直接相减得到之间的元素个数,此时只能用 distance(it_left, it_right),但是 distance 获取中间的元素个数就是 $O(N)$ 的了。

1
2
3
first 和 last 的迭代器类型,直接决定了 distance(first, last) 函数底层的实现机制:
当 first 和 last 为随机访问迭代器时,distance() 底层直接采用 last - first 求得 [first, last) 范围内包含元素的个数,其时间复杂度为O(1)常数阶;
当 first 和 last 为非随机访问迭代器时,distance() 底层通过不断执行 ++first(或者 first++)直到 first==last,由此来获取 [first, last) 范围内包含元素的个数,其时间复杂度为O(n)线性阶。

N=10 时,不应当返回 -1 的 case 有 0.001 的概率错误地返回 -1。在实际跑的时候,很难错误返回-1,实测 N=6 时,连测几次一次也没有通过,N=7 时,连跑几次通过和不通过的情况都有,N=10 的时候就很难不通过了,实测的几次均通过,时间 500ms。

代码(c++)

代码中的函数模板是用于后续支持随机地在 [iter1, iter2) 范围内获取一个迭代器的功能

1
auto it = select_randomly(arr.begin() + left, arr.begin() + right + 1)

力扣149:自定义哈希函数&RANSAC算法
的第三种解法 RANSAC 中也有使用到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
template<typename Iter, typename RandomGenerator>
Iter select_randomly(Iter start, Iter end, RandomGenerator *g) {
std::uniform_int_distribution<> dis(0, std::distance(start, end) - 1);
std::advance(start, dis(*g));
return start;
}

template<typename Iter>
Iter select_randomly(Iter start, Iter end) {
static std::random_device rd;
static std::mt19937 gen(rd());
return select_randomly(start, end, &gen);
}
class MajorityChecker {
public:
MajorityChecker(vector<int>& arr):arr(arr) {
mapping = unordered_map<int, vector<int>>(); // num -> 有序 idxs
int n = arr.size();
for(int i = 0; i < n; ++i)
mapping[arr[i]].push_back(i);
}

int query(int left, int right, int threshold) {
int cnt = 0;
for(int i = 0; i < N; ++i)
{
auto it = select_randomly(arr.begin() + left, arr.begin() + right + 1);
int cand = *it;
auto it_left = lower_bound(mapping[cand].begin(), mapping[cand].end(), left);
auto it_right = upper_bound(mapping[cand].begin(), mapping[cand].end(), right);
if(it_right - it_left >= threshold)
return cand;
}
return -1;
}

private:
unordered_map<int, vector<int>> mapping; // num -> 有序 idxs
vector<int> arr;
int N = 10;
};

算法2: 划分树找区间第 k 大

因为保证2 * threshold > right - left + 1,众数元素的个数在区间长度一半以,因此众数元素一定在 [left, right] 上的所有元素排序后的中位数位置。

查询区间 [i, j] 上排序后的最中间的数,可以套用求区间第 k 大的解法,例如划分树,每次查询的时间复杂度为 $O(\log N)$。

查找得到 [i, j] 上第 int k = j + 1 - (j - i + 1) / 2 - i; 大的元素 cand 后,还不知道它的个数是不是大于等于 threshold,需要再加一步验证。
这里的进一步验证用的方法与算法一中一样,也是用 unordered_map<int, vector<int>> 预处理好各个值的有序索引集合,然后二分地查询 cand 有多少索引在 [left, right] 中

划分树的内容以及代码模板参考 力扣215-快速选择算法,划分树

算法1 与算法 2 都是先快速选出可能的候选值,然后验证的流程。随机化的速度快,但是有概率返回错误结果。

代码(c++)

1400ms

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
struct PTNode
{
int start, end;
vector<int> nums, toleft;
PTNode *left, *right;
PTNode(int s, int e, PTNode* l=nullptr, PTNode* r=nullptr)
:start(s),end(e),nums(vector<int>(end - start + 1)),toleft(vector<int>(end - start + 2)),left(l),right(r) {}
~PTNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

class PartitionTree
{
public:
PartitionTree()
{
root = nullptr;
sorted = vector<int>();
}

~PartitionTree()
{
if(root)
{
delete root;
root = nullptr;
}
}

void build(int start, int end, const vector<int>& nums)
{
sorted = nums;
sort(sorted.begin(), sorted.end());
root = new PTNode(start, end);
root -> nums = nums;
_build(root);
}

int query(int i, int j, int k)
{
if(i > j) return 0;
return _query(root, i, j, k);
}

private:
PTNode *root;
vector<int> sorted;

int _query(PTNode* root, int i, int j, int k)
{
if(root -> start == root -> end) return (root -> nums)[0];

int tli = 0;
if(root -> start != i)
tli = (root -> toleft)[i - root -> start];
int tlj = (root -> toleft)[j - root -> start + 1];
int tl = tlj - tli;
int new_i, new_j;
if(tl >= k)
{
// 第 k 大在左子
new_i = root -> start + tli;
new_j = new_i + tl - 1;
return _query(root -> left, new_i, new_j, k);
}
else
{
// 第 k 大在右子
int mid = root -> start + (root -> end - root -> start) / 2;
new_i = mid + 1 + i - (root -> start) - tli;
new_j = new_i + j - i - tl;
return _query(root -> right, new_i, new_j, k - tl);
}
}

void _build(PTNode* root)
{
if(root -> start == root -> end)
return;
int mid = root -> start + (root -> end - root -> start) / 2;
int median = sorted[mid];
PTNode *left = new PTNode(root -> start, mid);
PTNode *right = new PTNode(mid + 1, root -> end);
int n = (root -> nums).size();
int median_to_left = mid - root -> start + 1;
for(int i = 0; i < n; ++i)
{
if((root -> nums)[i] < median)
--median_to_left;
}
// 出循环后 median_to_left 为去往左子树中等于中位数的个数
int to_left = 0; // 去往左子树的个数
int idx_left = 0, idx_right = 0;
for(int i = 0; i < n; ++i)
{
int cur = (root -> nums)[i];
if(cur < median || ((cur == median) && median_to_left > 0))
{
(left -> nums)[idx_left] = cur;
++idx_left;
++to_left;
if(cur == median)
--median_to_left;
}
else
{
(right -> nums)[idx_right] = cur;
++idx_right;
}
(root -> toleft)[i + 1] = to_left;
}
_build(left);
_build(right);
root -> left = left;
root -> right = right;
}
};


class MajorityChecker {
public:
MajorityChecker(vector<int>& arr):arr(arr) {
if(arr.empty()) return;
mapping = unordered_map<int, vector<int>>(); // num -> 有序 idxs
int n = arr.size();
for(int i = 0; i < n; ++i)
mapping[arr[i]].push_back(i);
partitiontree = PartitionTree();
partitiontree.build(0, n - 1, arr);
}

int query(int left, int right, int threshold) {
// right + 1 - (right - left + 1) / 2 - left
int k = right + 1 - (right - left + 1) / 2 - left;
int cand = partitiontree.query(left, right, k);
auto it_left = lower_bound(mapping[cand].begin(), mapping[cand].end(), left);
auto it_right = upper_bound(mapping[cand].begin(), mapping[cand].end(), right);
if(it_right - it_left >= threshold)
return cand;
return -1;
}

private:
unordered_map<int, vector<int>> mapping; // num -> 有序 idxs
PartitionTree partitiontree;
vector<int> arr;
};

算法3:分块查找表

将 [0, n-1] 分成若干块,每块 bs 个,总共有 bn = n / bs + 1 块,其中最后一块的元素个数可能为 0 ~ bs-1。
对应查询 query(l, r),区间 [l, r] 会覆盖到若干个块,其中两侧覆盖到的块,可能只覆盖到一部分,而中间的块都是完整被覆盖的。

此时可能的众数(候选数)有两种情况

  1. 在中间的若干完整块中的众数,记最左边的块为 L, 最右边的块为 R,此时要问 [L..R] 这若干个块的众数是多少,这可以在初始化阶段预处理到一个表里,记为 dp[L][R]。
  2. 两边的不完整块中的某个元素,最多有 $O(\sqrt(N))$ 个。
    对每个候选数,确认它在区间 [l, r] 中的个数,这一步与算法1 和算法2 相同。

预处理 dp[L][R] 的过程:枚举 L = [0,..,bn-1],然后枚举 i = [L * bs, .. n-1],用计数数组记录每个出现过的元素的次数,当某次更新后,当前元素 nums[i] 取到了最大次数,则更新 dp[L][i/bs] = nums[i]

预处理阶段时间复杂度为 $O(bs \times N)$,查询阶段时间复杂度为 $O(bs \times Q\log N)$

一般的分块会取 $bs = \sqrt{N}$,这样总时间复杂度就是 $O(N\sqrt{N} + Q\sqrt{N}\log N)$,可以调整 bs 使得前后两项尽量相等,速度会更快。例如下面代码中两组 bs 设置的对比。

代码(c++)

bs = floor(sqrt(n)) 1800ms
bs = floor(sqrt(n * 2)) 1200ms

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class Block
{
public:
void build(const vector<int>& nums)
{
this -> nums = nums;
n = nums.size();
bs = floor(sqrt(n));
nb = n / bs + 1;
dp = vector<vector<int>>(nb, vector<int>(nb, 0));
mapping = vector<vector<int>>(20001);

for(int i = 0; i < n; ++i)
mapping[nums[i]].push_back(i);

vector<int> cnts(20001); // num -> cnt;
for(int block_l = 0; block_l < nb; ++block_l)
{
cnts.clear();
int max_cnt = 0;
int max_num = 0; // dp[block_l][block_r];
for(int i = block_l * bs; i < n; ++i)
{
int num = nums[i];
++cnts[num];
if(cnts[num] > max_cnt)
{
max_cnt = cnts[num];
max_num = num;
dp[block_l][i / bs] = max_num;
}
}
}
}

int query(int l, int r, int threshold)
{
int block_l = l / bs, block_r = r / bs;
int L = block_l, R = block_r;
if(block_l * bs < l || r < ((block_l + 1) * bs - 1))
{
++L;
for(int i = l; i < min((block_l + 1) * bs, n); ++i)
{
int cand = nums[i];
auto it_left = lower_bound(mapping[cand].begin(), mapping[cand].end(), l);
auto it_right = upper_bound(mapping[cand].begin(), mapping[cand].end(), r);
if(it_right - it_left >= threshold)
return cand;
}
}
if(block_r != block_l)
{
if(r < (block_r + 1) * bs - 1)
{
--R;
for(int i = block_r * bs; i <= min((block_r + 1) * bs - 1, r); ++i)
{
int cand = nums[i];
auto it_left = lower_bound(mapping[cand].begin(), mapping[cand].end(), l);
auto it_right = upper_bound(mapping[cand].begin(), mapping[cand].end(), r);
if(it_right - it_left >= threshold)
return cand;
}
}
}
if(L <= R)
{
int cand = dp[L][R];
auto it_left = lower_bound(mapping[cand].begin(), mapping[cand].end(), l);
auto it_right = upper_bound(mapping[cand].begin(), mapping[cand].end(), r);
if(it_right - it_left >= threshold)
return cand;
}
return -1;
}

private:
vector<int> nums;
vector<vector<int>> dp; // dp[block_l][block_r] := 块 block_l ~ block_r 的众数
vector<vector<int>> mapping; // num -> 有序 idxs
int nb, bs, n;
};

class MajorityChecker {
public:
MajorityChecker(vector<int>& arr) {
block = Block();
block.build(arr);
}

int query(int left, int right, int threshold) {
return block.query(left, right, threshold);
}

private:
Block block;
};

算法4:摩尔投票+线段树+二分;线段树的区间合并

基于摩尔投票的暴力做法:对每个查询 query(l, r),都在 [l, r] 内做一次摩尔投票,需要时间复杂度 $O(N)$, 空间复杂度 $O(1)$。关于摩尔投票算法的思路,参考 力扣169,299-摩尔投票
总时间复杂度为 $O(QN)$,单纯用摩尔投票无法通过。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
int query(int l, int r)
{
int cand = 0;
int cnt = 0;
for(int i = l; i <= r; ++i)
{
if(cnt == 0)
{
cand = nums[i];
++cnt;
continue;
}
if(cand != nums[i])
--cnt;
else
++cnt;
}
cnt = 0;
for(int i = l; i <= r; ++i)
if(nums[i] == cand)
++cnt;
if(cnt >= threshold)
return cand;
else
return -1;
}

本题最精彩的算法是把摩尔投票和线段树结合起来,先说结论:时间复杂度 $O(Q\log N)$,空间复杂度 $O(N)$。

首先对于区间问题,无论区间有没有修改,都可以考虑用线段树,但是用线段树有一个条件,即区间可加:打散区间的过程是本着二分分治的思路,分到叶子的时候区间长度是 1.打散区间的过程是本着二分分治的思路,分到叶子的时候区间长度是 1。
一个区间 [l, r] 总可以拆成不多于 $2\log N$ 段区间,因为左右子树的深度均为 $\log N$。所以查询 query(l, r) 的时候,线段树的做法就是先将 [l, r] 拆成各个有用的区间(维护在树节点中),然后将各个区间的答案(节点维护的值)合并到一起,作为 [l, r] 的查询结果。

对于摩尔投票,线段树节点维护的是各个区间摩尔投票的结果:cand 和 cnt。

因为摩尔投票不要求顺序,只要众数元素的个数大于 N/2,摩尔投票总可以选到正确的值,因此对于区间 [l, r] 可以先做左区间 [l, mid] 的摩尔投票,结果为 cand1, cnt1;
再做右区间 [mid+1, r] 的摩尔投票,结果为 cand2, cnt2。

如果 cand1 和 cand2 相同,该区间 [l, r] 的结果直接就是 cand1, cnt1+cnt2;如果 cand1 和 cand2 不相等,则拿多的减小的(摩尔投票的逻辑),得到新的 cnt 和 cand。所得新 cand 是可能的众数(后续需要验证),如果[left, right] 中有超过一半的数,合并后一定是 cand。

一次查询中:查询可能的候选是 $O(\log N)$,得到候选之后验证候选依然用算法1~3用的方法,也是 $O(\log N)$。

关于线段树维护区间信息的思路和写法,参考 力扣307-线段树,树状数组

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
struct STNode
{
int nodeLeft, nodeRight;
int cand, cnt;
STNode *left, *right;
STNode(int l, int r, int cand, int cnt, STNode* left=nullptr, STNode* right=nullptr)
:nodeLeft(l),nodeRight(r),cand(cand),cnt(cnt),left(left),right(right){}
~STNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

class SegmentTree
{
public:
SegmentTree()
{
root = nullptr;
}
~SegmentTree()
{
if(root)
{
delete root;
root = nullptr;
}
}

void build(const vector<int>& nums)
{
int n = nums.size();
root = _build(0, n - 1, nums);
}

int query(int l, int r)
{
PII item = _query(root, l, r);
return item.first;
}

private:
STNode *root;
using PII = pair<int, int>;

PII _query(STNode* node, int l, int r)
{
int nodeLeft = node -> nodeLeft, nodeRight = node -> nodeRight;
// [l, r], [nodeLeft, nodeRight]
if(l == nodeLeft && r == nodeRight)
return PII(node -> cand, node -> cnt);
int nodeMid = (nodeLeft + nodeRight) / 2;
if(r <= nodeMid)
return _query(node -> left, l, r);
else if(nodeMid < l)
return _query(node -> right, l, r);
else
{
PII item1 = _query(node -> left, l, nodeMid);
PII item2 = _query(node -> right, nodeMid + 1, r);
int cand1 = item1.first, cnt1 = item1.second;
int cand2 = item2.first, cnt2 = item2.second;
if(cand1 == cand2)
return PII(cand1, cnt1 + cnt2);
int cand, cnt;
if(cnt1 >= cnt2)
{
cand = cand1;
cnt = cnt1 - cnt2;
}
else
{
cand = cand2;
cnt = cnt2 - cnt1;
}
return PII(cand, cnt);
}
}


STNode* _build(int nodeLeft, int nodeRight, const vector<int>& nums)
{
if(nodeLeft == nodeRight)
{
return new STNode(nodeLeft, nodeRight, nums[nodeLeft], 1);
}
int nodeMid = (nodeLeft + nodeRight) / 2;
STNode* left = _build(nodeLeft, nodeMid, nums);
STNode* right = _build(nodeMid + 1, nodeRight, nums);
int cand1 = left -> cand, cnt1 = left -> cnt;
int cand2 = right -> cand, cnt2 = right -> cnt;
if(cand1 == cand2)
{
return new STNode(nodeLeft, nodeRight, cand1, cnt1 + cnt2, left, right);
}
else
{
int cand, cnt;
if(cnt1 >= cnt2)
{
cand = cand1;
cnt = cnt1 - cnt2;
}
else
{
cand = cand2;
cnt = cnt2 - cnt1;
}
return new STNode(nodeLeft, nodeRight, cand, cnt, left, right);
}
}
};

class MajorityChecker {
public:
MajorityChecker(vector<int>& arr) {
int n = arr.size();
mapping = vector<vector<int>>(20001);
for(int i = 0; i < n; ++i)
mapping[arr[i]].push_back(i);
sttree = SegmentTree();
sttree.build(arr);
}

int query(int left, int right, int threshold) {
int cand = sttree.query(left, right);
auto it_left = lower_bound(mapping[cand].begin(), mapping[cand].end(), left);
auto it_right = upper_bound(mapping[cand].begin(), mapping[cand].end(), right);
if(it_right - it_left >= threshold)
return cand;
else
return -1;
}

private:
SegmentTree sttree;
vector<vector<int>> mapping; // num -> 有序 idxs
};

$3 扩展:分块查找表做带修改 RMQ

分块查找表的主要思路是对于数据数组 nums[0..n-1], 共 n 个元素,每 $\sqrt{n}$ 个元素分在一个桶内进行维护。
对区间的操作可以从 $O(N)$ 降到 $O(\sqrt{N})$。
桶里面的数据如何维护需要看需要实现的功能来定。
分块算法可以维护一些线段树维护不了的东西,例如单调队列等,线段树能维护的东西必须能够进行信息合并,而分块则不需要。

以 RMQ 为例,原始数据为 nums, 长度为 n,令 block_size = floor(sqrt{n}),将 nums 中的元素每 block_size 个一桶,并维护桶的最小值。
共有 bn = n / block_size + 1 个桶,最后一个桶的元素个数可能为 [0, block_size - 1]。原始数据与需要维护的信息组织成数据结构如下:

1
2
3
4
5
vector<int> nums;
vector<int> maxx; // maxx[block_id] := 桶的最大值
vector<int> lazy; // lazy[block_id] := 桶的懒标记
vector<bool> has_lazy; // has_lazy[block_l] := 桶的懒标记是否有效
int block_size, bn, n;

其中 maxx[0..bn], lazy[0..bn], has_lazy[0..bn] 均为块的信息。nums 中的数据直接按照下标分成各个桶,nums[i] 所属的桶 id 为 block_id = i / block_size
然后 maxx[block_id], lazy[block_id], has_lazy[block_id] 即可访问到块的信息。这中数据组织的方式称为块状数组

区间查询 query(i, j)

[i, j] 范围会跨若干个桶,其中:

  1. 中间有一些桶完全把区间覆盖住了,对这些桶,直接返回桶维护的最值,共 $\sqrt{N}$ 个桶。
  2. 左右两边各自可能会有一个桶与 [i, j] 重叠但是只包含了一部分区间,成其为半桶,对左右这两个半桶,如果懒标记有效,可以直接将桶维护的最值返回;如果懒标记无效,需要逐个遍历取得最值返回,一个桶共 $\sqrt{N}$ 个元素。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
int range_query(int l, int r)
{
int bucket_l = l / block_size, bucket_r = r / block_size;
int result = nums[l];

// 左半桶
if(has_lazy[bucket_l])
result = maxx[bucket_l];
else
for(int i = l; i <= min((bucket_l + 1) * block_size - 1, r); ++i)
result = max(result, nums[i]);
// 右半桶
if(bucket_l != bucket_r)
{
if(has_lazy[bucket_r])
result = max(result, maxx[bucket_r]);
else
for(int i = bucket_r * block_size; i <= r; ++i)
result = max(result, nums[i]);
}
// 中间的桶
for(int bucket_id = bucket_l + 1; bucket_id < bucket_r; ++bucket_id)
result = max(result, maxx[bucket_id]);
return result;
}

区间更新 update(i, j, x)

与区间查询同样,会在 [i, j] 范围内面对若干个桶,其中:

  1. 对于中间若干个把区间完全覆盖住的桶,将 x 更新到桶的懒标记 lazy, 以及桶的最大值 maxx, 并将 has_lazy 置为 true,而不对原数组修改,共 $\sqrt{N}$ 个桶。
  2. 对于左右两边各自可能会出现的半桶,如果懒标记有效,则先逐个遍历将懒标记下传,下传完成后将 has_lazy 置为 false,然后逐个遍历并直接在原数组上将值改为 x,一个桶共 $\sqrt{N}$ 个元素。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
void range_update(int l, int r, int x)
{
int block_l = l / block_size, block_r = r / block_size;
// 左半桶
if(has_lazy[block_l])
{
for(int i = block_l * block_size; i < min((block_l + 1) * block_size, n); ++i)
nums[i] = lazy[block_l];
has_lazy[block_l] = false;
}
for(int i = l; i <= min((block_l + 1) * block_size - 1, r); ++i)
nums[i] = x;
maxx[block_l] = max(maxx[block_l], x);
// 右半桶
if(block_l != block_r)
{
if(has_lazy[block_r])
{
for(int i = block_r * block_size; i < min((block_r + 1) * block_size, n); ++i)
nums[i] = lazy[block_r];
has_lazy[block_r] = false;
}
for(int i = block_r * block_size; i <= r; ++i)
nums[i] = x;
maxx[block_r] = max(maxx[block_r], x);
}
// 中间的桶
for(int block_id = block_l + 1; block_id < block_r; ++block_id)
{
lazy[block_id] = x;
has_lazy[block_id] = true;
maxx[block_id] = x;
}
}

单点更新

先将对应块的标记 lazy[i] 下传,再暴力更新被修改块的状态。时间复杂度 O(\sqrt(N))

1
2
3
4
5
6
7
8
9
10
11
12
13
void point_update(int idx, int x)
{
// 先将对应块的标记 `lazy[i]` 下传,再暴力更新被修改块的状态。时间复杂度 O(\sqrt(N))
int block_id = idx / block_size;
if(has_lazy[block_id])
{
for(int i = block_id * block_size; i < min((block_id + 1) * block_size, n); ++i)
nums[i] = lazy[block_id];
has_lazy[block_id] = false;
}
nums[idx] = x;
maxx[block_id] = max(maxx[block_id], x);
}

分块查找表RMQ 解 699. 掉落的方块

本题可以用平衡树维护区间模块,以及带区间修改的 RMQ 两条路线。在文章 力扣699-平衡树区间模块,线段树区间修改,RMQ 中,平衡树区间模块,RMQ 都用了,其中 RMQ 使用带区间修改的线段树实现的。以下为带区间修改的分块查找表实现 RMQ 的做法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class Block
{
public:
Block(){}

void build(const vector<int>& arr)
{
nums = arr;
n = nums.size();
block_size = floor(sqrt(n)); // 8 -> 2, 9 -> 3, 10 -> 3, 11 -> 3, ..., 15 -> 3, 16 -> 4
bn = n / block_size + 1; // 最后一个桶的元素个数可能为 [0, block_size - 1]
lazy = vector<int>(bn, 0);
maxx = vector<int>(bn, 0);
has_lazy = vector<bool>(bn, false);
for(int i = 0; i < n; ++i)
{
int block_id = i / block_size;
maxx[block_id] = max(maxx[block_id], nums[i]);
}
}

void point_update(int idx, int x)
{
// 先将对应块的标记 `lazy[i]` 下传,再暴力更新被修改块的状态。时间复杂度 O(\sqrt(N))
int block_id = idx / block_size;
if(has_lazy[block_id])
{
for(int i = block_id * block_size; i < min((block_id + 1) * block_size, n); ++i)
nums[i] = lazy[block_id];
has_lazy[block_id] = false;
}
nums[idx] = x;
maxx[block_id] = max(maxx[block_id], x);
}

void range_update(int l, int r, int x)
{
// 对左右半桶,若懒标记有效,则先将懒标记 `lazy[block_l], lazy[block_r]` 下传,再暴力更新被修改块的状态。时间复杂度 O(\sqrt(N))
int block_l = l / block_size, block_r = r / block_size;
// 左半桶
if(has_lazy[block_l])
{
for(int i = block_l * block_size; i < min((block_l + 1) * block_size, n); ++i)
nums[i] = lazy[block_l];
has_lazy[block_l] = false;
}
for(int i = l; i <= min((block_l + 1) * block_size - 1, r); ++i)
nums[i] = x;
maxx[block_l] = max(maxx[block_l], x);
// 右半桶
if(block_l != block_r)
{
if(has_lazy[block_r])
{
for(int i = block_r * block_size; i < min((block_r + 1) * block_size, n); ++i)
nums[i] = lazy[block_r];
has_lazy[block_r] = false;
}
for(int i = block_r * block_size; i <= r; ++i)
nums[i] = x;
maxx[block_r] = max(maxx[block_r], x);
}
// 中间的桶
for(int block_id = block_l + 1; block_id < block_r; ++block_id)
{
lazy[block_id] = x;
has_lazy[block_id] = true;
maxx[block_id] = x;
}
}

int range_query(int l, int r)
{
// 对于中间跨过的整块,直接利用块保存的信息统计答案,两端剩余部分任然可以暴力扫描统计。
int block_l = l / block_size, block_r = r / block_size;
int result = nums[l];

// 左半桶
if(has_lazy[block_l])
result = maxx[block_l];
else
for(int i = l; i <= min((block_l + 1) * block_size - 1, r); ++i)
result = max(result, nums[i]);
// 右半桶
if(block_l != block_r)
{
if(has_lazy[block_r])
result = max(result, maxx[block_r]);
else
for(int i = block_r * block_size; i <= r; ++i)
result = max(result, nums[i]);
}
// 中间的桶
for(int block_id = block_l + 1; block_id < block_r; ++block_id)
result = max(result, maxx[block_id]);
return result;
}

vector<int> nums;
vector<int> maxx;
vector<int> lazy;
vector<bool> has_lazy;
int block_size, bn, n;
};

class Solution {
public:
vector<int> fallingSquares(vector<vector<int>>& positions) {
// 离散化
vector<int> x;
for(const vector<int>& pos: positions)
{
x.push_back(pos[0]);
x.push_back(pos[0] + pos[1] - 1);
}
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());

int n = x.size();
Block block;
vector<int> arr(n);
block.build(arr);
vector<int> result;
// sttree.traverse();
for(const vector<int>& pos: positions)
{
int start = _find(pos[0], x);
int end = _find(pos[0] + pos[1] - 1, x);
int v = pos[1];
int maxx = block.range_query(start, end);
block.range_update(start, end, maxx + v);
result.push_back(block.range_query(0, n - 1));
}
return result;
}

private:
int _find(int v, const vector<int>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

Share