归并树:区间大于x的元素个数

  |  

摘要: 归并树,区间大于 x 的元素个数

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们从 315. 计算右侧小于当前元素的个数 出发,引出一个维护区间大于 x 的元素个数的数据结构:归并树。

本题有三总主流解法,此前我们都解决过:

  1. 平衡树,参考文章 手撕平衡树-大小平衡树SBT手撕平衡树-Treap
  2. 归并排序,参考文章 线性索引表
  3. 权值线段树,参考文章 权值线段树、权值树状数组:元素排名区间的权值(个数)和

归并树的思考过程就是归并排序解决问题的算法出发,将其中的分治过程放到树种维护,所得的树形结构就是归并树。

归并排序过程

归并排序的过程。divide部分形成了一颗类似线段树的树;combine部分如果左区间[left, mid]对[mid+1, right]的结果有影响,就是CDQ分治的思想

普通分治在合并两个子问题的过程中(图中的 combine 部分),$[left, mid]$ 内的问题不会对 $[mid + 1, right]$ 内的问题产生影响,比如排序,线段树的求和、求极值。

而类似于归并排序的这种分治,合并两个子问题时,还考虑 $[left, mid]$ 内的修改对 $[mid + 1, right]$ 的结果产生的影响。这是一种基于时间的离线分治,也称为 CDQ 分治,其算法原理,以及解决问题的思路参考文章 离线分治:基于时间 (CDQ分治)

归并树

归并排序的过程构成了一颗类似线段树的树(图中的divede部分)。那它就可以支持区间查询。

归并树的核心思想,就是利用线段树的建树过程,将归并排序的过程保存。

归并树节点定义

与线段树的唯一区别是线段树节点存的是区间的某指标,例如和,最大值;归并树节点存的是区间内所有的值

1
2
3
4
5
6
7
8
9
struct MTNode
{
int start, end;
vector<int> data;
MTNode *left, *right;
MTNode(int s, int e, const vector<int>& nums, MTNode* l=nullptr, MTNode* r=nullptr)
:start(s),end(e),data(nums),left(l),right(r) {}
~MTNode(){}
};

建树

自底向上地建立节点:叶子节点 start=end, 直接存 nums[start];如果是非叶子节点,在回溯阶段把左右子节点的数据做归并,作为当前节点存的数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
void build(int start, int end, const vector<int>& nums)
{
root = _build(start, end, nums);
}

MTNode* _build(int start, int end, const vector<int>& nums)
{
if(start == end)
{
return new MTNode(start, end, vector<int>({nums[start]}));
}
int mid = start + (end - start) / 2;
MTNode *left = _build(start, mid, nums);
MTNode *right = _build(mid + 1, end, nums);
vector<int> merged((left -> data).size() + (right -> data).size());
merge((left -> data).begin(), (left -> data).end(), (right -> data).begin(), (right -> data).end(), merged.begin());
MTNode *cur = new MTNode(start, end, merged, left, right);
return cur;
}

区间查询

归并树的经典问题:查询的是区间 $[a, b]$ 中值大于 x 的个数

在归并树中找到$[a, b]$包含的所有不相交的区间对应的节点,分别进行一次二分查找。一次二分查找需要时间$O(\log N)$,而区间$[a, b]$至多被分为 $\log N$ 个不重叠的小区间,这样 $O(\log N)$ 可以得到答案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
int query(int i, int j, int k)
{
if(i > j) return 0;
int result = 0;
_query(root, i, j, k, result);
return result;
}
void _query(MTNode* root, int i, int j, int k, int& result)
{
if(root -> start == i && root -> end == j)
{
auto pos = upper_bound((root -> data).begin(), (root -> data).end(), k);
result += (root -> data).end() - pos;
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid)
{
_query(root -> left, i, j, k, result);
return;
}
if(i > mid)
{
_query(root -> right, i, j, k, result);
return;
}
_query(root -> left, i, mid, k, result);
_query(root -> right, mid + 1, j, k, result);
}

完整代码 (C++,模板)

query(i, j, k) 返回 ${i, j}$ 中大于 $k$ 的个数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// 区间查询 > x 的数
struct MTNode
{
int start, end;
vector<int> data;
MTNode *left, *right;
MTNode(int s, int e, const vector<int>& nums, MTNode* l=nullptr, MTNode* r=nullptr)
:start(s),end(e),data(nums),left(l),right(r) {}
~MTNode(){}
};

class MergeTree
{
public:
MergeTree()
{
root = nullptr;
}

~MergeTree()
{
if(root)
{
delete_sub_tree(root);
}
}

void delete_sub_tree(MTNode* node)
{
if(node -> left)
delete_sub_tree(node -> left);
if(node -> right)
delete_sub_tree(node -> right);
delete node;
node = nullptr;
}

void build(int start, int end, const vector<int>& nums)
{
root = _build(start, end, nums);
}

int query(int i, int j, int k)
{
if(i > j) return 0;
int result = 0;
_query(root, i, j, k, result);
return result;
}

int get(int i)
{
return (root -> data)[i];
}

private:
MTNode *root;

void _query(MTNode* root, int i, int j, int k, int& result)
{
if(root -> start == i && root -> end == j)
{
auto pos = upper_bound((root -> data).begin(), (root -> data).end(), k);
result += (root -> data).end() - pos;
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid)
{
_query(root -> left, i, j, k, result);
return;
}
if(i > mid)
{
_query(root -> right, i, j, k, result);
return;
}
_query(root -> left, i, mid, k, result);
_query(root -> right, mid + 1, j, k, result);
}

MTNode* _build(int start, int end, const vector<int>& nums)
{
if(start == end)
{
return new MTNode(start, end, vector<int>({nums[start]}));
}
int mid = start + (end - start) / 2;
MTNode *left = _build(start, mid, nums);
MTNode *right = _build(mid + 1, end, nums);
vector<int> merged((left -> data).size() + (right -> data).size());
merge((left -> data).begin(), (left -> data).end(), (right -> data).begin(), (right -> data).end(), merged.begin());
MTNode *cur = new MTNode(start, end, merged, left, right);
return cur;
}
};

代码测试

归并树解决静态数组上多次查询区间 [i, j] 中大于 k 的元素个数的问题的完整代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <vector>
#include <iostream>
#include <algorithm>

using namespace std;


int main()
{
vector<int> nums(20);
for(int i = 0; i < (int)nums.size(); ++i)
nums[i] = i;
for(int num: nums)
cout << num << " ";
cout << endl;
random_shuffle(nums.begin(), nums.end());
for(int num: nums)
cout << num << " ";
cout << endl;
MergeTree mergetree;
mergetree.build(0, (int)nums.size() - 1, nums);
while(true)
{
int i, j, k;
cin >> i >> j >> k;
cout << mergetree.query(i, j, k) << endl;
}
}

解决 315. 计算右侧小于当前元素的个数 的代码。

注:归并树解决这个问题效率并不高。以下代码测试用例都过了,但是耗时太长。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
int n = nums.size();
for(int &x: nums)
x = -x;
MergeTree *mergetree = new MergeTree();
mergetree -> build(0, (int)nums.size() - 1, nums);
vector<int> result(n);
for(int i = 0; i < n - 1; ++i)
result[i] = mergetree -> query(i + 1, n - 1, nums[i]);
return result;
}
};

归并树解决区间第 k 大问题

255. 第K小数

给定长度为 N 的整数序列 A,下标为 1 ∼ N。

现在要执行 M 次操作,其中第 i 次操作为给出三个整数 li,ri,ki,求 A[li],A[li+1],…,A[ri] 中第 ki 小的数是多少。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
输入格式
第一行包含两个整数 N 和 M。

第二行包含 N 个整数,表示整数序列 A。

接下来 M 行,每行包含三个整数 li,ri,ki,用以描述第 i 次操作。

输出格式
对于每次操作输出一个结果,表示在该次操作中,第 k 小的数的数值。每个结果占一行。

数据范围
N <= 1e5,
M <= 1e4,
|A[i]| <= 109

输入样例:
7 3
1 5 2 6 3 7 4
2 5 3
4 4 1
1 7 3
输出样例:
5
6
3

算法:归并树

已经根据原始数组建好归并树之后,可以 $O(\log^{2}N)$ 得到 $[i, j]$ 中大于 $x$ 的元素个数。

在归并树建树过程中可以顺便得到数组的最大值 $right$ 和最小值 $left$,答案肯定在 $[left, right]$。

值域二分:每次从答案范围 $[left, right]$ 中猜一个答案 $mid = (left + right) / 2$; 然后查询区间 $[left, right]$ 中大于 $mid$ 的个数 $cnt$,进而得到小于等于 $mid$ 的个数为 $j + 1 - i - cnt$。

我们要找的是从小到大排序后的第 k 位,因此二分的逻辑如下:

1
2
j + 1 - i - cnt < k:mid 猜小了,left = mid + 1
j + 1 - i - cnt >= k:mid + 1 肯定大了,mid 不一定,right = mid

时间复杂度 $O(N\log N+M\log^{3}N)$,其中前半部分是1次建树,后半部分是 M 次查询。

代码 (C++)

注:不是最优算法,超时。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int main()
{
int N, M;
cin >> N >> M;
vector<int> nums(N);
for(int i = 0; i < N; ++i)
cin >> nums[i];

MergeTree *mergetree = new MergeTree();
mergetree -> build(0, N - 1, nums);

for(int i = 0; i < M; ++i)
{
int left = mergetree -> get(0), right = mergetree -> get(N - 1);
int l, r, k;
cin >> l >> r >> k;
l--;
r--;
while(left < right)
{
int mid = left + (right - left) / 2;
int cnt = mergetree -> query(l, r, mid); // > mid 的元素个数
if(r + 1 - l - cnt < k)
left = mid + 1;
else
right = mid;
}
cout << left << endl;
}
}

Share