力扣315-索引数组,归并树

  |  
  • 索引数组以及在索引数组上做CDQ分治
  • 归并树的思考过程:归并排序 -> CDQ分治 -> 归并树
  • 归并树应用:无修改区间大于 x 的元素个数问题

$1 题目

题目链接

315. 计算右侧小于当前元素的个数

题目描述

给定一个整数数组 nums,按要求返回一个新数组 counts。数组 counts 有该性质: counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。

样例

示例:

输入: [5,2,6,1]
输出: [2,1,1,0]
解释:
5 的右侧有 2 个更小的元素 (2 和 1).
2 的右侧仅有 1 个更小的元素 (1).
6 的右侧有 1 个更小的元素 (1).
1 的右侧有 0 个更小的元素.

$2 题解

算法1: 离散化+权值线段树,权值树状数组

493. 翻转对 思路一样。

算法内容参考 力扣493-离散化,权值线段树,权值树状数组

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
struct STNode
{
int start, end;
int sum;
STNode *left, *right;
STNode(int s, int e, int v, STNode* l=nullptr, STNode* r=nullptr)
:start(s), end(e), sum(v), left(l), right(r) {}
~STNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

class SegmentTree
{
public:
SegmentTree()
{
root = nullptr;
}

~SegmentTree()
{
if(root)
{
delete root;
root = nullptr;
}
}

void build(int start, int end)
{
if(start <= end)
root = _build(start, end);
}

void add(int index)
{
_add(root, index);
}

int query(int i, int j)
{
if(i > j)
return 0;
return _query(root, i, j);
}

private:
STNode *root;

int _query(STNode* root, int i, int j)
{
if(!root) return 0;
if(root -> start == i && root -> end == j)
return root -> sum;
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid)
return _query(root -> left, i, j);
if(mid < i)
return _query(root -> right, i, j);
return _query(root -> left, i, mid) + _query(root -> right, mid + 1, j);
}

void _add(STNode* root, int index)
{
if(!root) return;
if(root -> start == index && root -> end == index)
{
++(root -> sum);
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(mid >= index)
_add(root -> left, index);
else
_add(root -> right, index);
++(root -> sum);
}

STNode* _build(int start, int end)
{
if(start == end)
return new STNode(start, end, 0);
int mid = start + (end - start) / 2;
STNode *left = _build(start, mid);
STNode *right = _build(mid + 1, end);
return new STNode(start, end, 0, left, right);
}
};

class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
if(nums.empty()) return vector<int>();
int n = nums.size();
vector<int> x(nums.begin(), nums.end()); // 离散化后的值
sort(x.begin(), x.end());
x.erase(unique(x.begin(), x.end()), x.end());
int m = x.size();

SegmentTree segmenttree;
segmenttree.build(0, m - 1);
vector<int> result(n);
result[n - 1] = 0;
segmenttree.add(_find(nums[n - 1], x));
for(int i = n - 2; i >= 0; --i)
{
result[i] = segmenttree.query(0, _find(nums[i], x) - 1);
cout << _find(nums[i], x) - 1 << endl;
segmenttree.add(_find(nums[i], x));
}
return result;
}

private:
int _find(int v, const vector<int>& x)
{
return lower_bound(x.begin(), x.end(), v) - x.begin();
}
};

算法2: 索引数组 + CDQ 分治

线性索引表

思路仍然与 493 相同,在归并排序的归并阶段统计信息。但是本题有所不同:归并阶段统计信息时要追溯元素在原始数组中的位置。正常的归并排序过程中,元素的位置已经被打乱,此时要追溯元素在原数组的位置,需要引入索引数组的思想。

考察归并排序过程中的其中一次归并,现在两个子数组 p[left .. mid], q[mid + 1 .. right] 已经有序,现在正在归并。当 p 前进到 i 位置,q 前进到 j 位置时,p[left .. i - 1], q[mid + 1, .. j - 1] 都已经进入了归并排序的辅助数组。此时若 p[i] <= q[j], p[i] 准备进入辅助数组,在进入之前,先要求此时 q 中已近有多少个数字进入了辅助数组,答案是 j - mid。这个是本轮归并对 p[i] 这个数的答案的贡献,因为这些是在 p[i] 的右边且小于 p[i] 的。把历次归并对该数的贡献相加,就得到 p[i] 对应的答案。

索引数组

j - mid 这个贡献要追加到 p[i] 在原数组中的位置上,假设 p[i] 在原数组中的位置是 idx, 则更新贡献值的操作是 result[idx] += j - mid; 但是在归并排序过程中,p[i] 在原数组中的位置已经改变,此时要获取 p[i] 在原数组中的位置 idx,需要一个索引数组。

索引数组的思想:元素在算法的流程中位置变化了,但是后续还需要定位元素在原数组中的位置。此时建立原数组的索引数组,例如 nums = [5, 2, 6, 1], 索引数组为 indexes = [0, 1, 2, 3],在算法流程中,nums 始终保持不变,改变位置的是 indexes 中的元素,以前是算法完成后需要 nums[i] 有序的性质,现在是要 nums[indexes[i]] 有序的性质。

有类似的思想的结构还有索引堆:堆在插入,更新数据之后,需要做一步向下更新或向上更新,这一步堆数据在数组中的位置就变了。
但是有时在堆的不断更新过程中,需要定位到元素在堆数组中的位置,例如 239. 滑动窗口最大值
此时使用索引堆,保存堆数据的数组始终保持不变,不断更新位置的是索引数组。原来在更新算法流程中始终对 nums[i] 保持堆的性质,现在是对 nums[index[i]] 保持堆的性质。

代码(c++)

代码与朴素归并排序基本不变,以前 nums[i] 改变位置的地方,改为 indexes[i] 改变位置,以前给定 i 获取 nums[i] 的地方,改为获取 nums[indexes[i]]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Solution_2 {
public:
vector<int> countSmaller(vector<int>& nums) {
if(nums.empty()) return vector<int>();
int n = nums.size();
if(n == 1) return vector<int>({0});
vector<int> indexes(n, 0);
for(int i = 1; i < n; ++i)
indexes[i] = i;
vector<int> result(n, 0);
_mergesort(nums, indexes, result, 0, n - 1);
return result;
}

private:
void _mergesort(const vector<int>& nums, vector<int>& indexes, vector<int>& result, int left, int right)
{
if(left == right) return;
int mid = left + (right - left) / 2;
_mergesort(nums, indexes, result, left, mid);
_mergesort(nums, indexes, result, mid + 1, right);
_merge(nums, indexes, result, left, mid, right);
}

void _merge(const vector<int>& nums, vector<int>& indexes, vector<int>& result, int left, int mid, int right)
{
vector<int> tmp(right - left + 1, 0);
int i = left, j = mid + 1, k = 0;
while(i <= mid && j <= right)
{
int pi = nums[indexes[i]], qj = nums[indexes[j]];
if(pi <= qj)
{
result[indexes[i]] += (j - mid - 1);
tmp[k++] = indexes[i++];
}
else
tmp[k++] = indexes[j++];
}
while(i <= mid)
{
result[indexes[i]] += (j - mid - 1);
tmp[k++] = indexes[i++];
}
while(j <= right) tmp[k++] = indexes[j++];
for(i = left, k = 0; i <= right; ++i, ++k)
indexes[i] = tmp[k];
}
};

$3 扩展: 归并树与无修改区间大于x的元素个数

1) 归并排序过程

归并排序的过程。divide部分形成了一颗类似线段树的树;combine部分如果左区间[left, mid]对[mid+1, right]的结果有影响,就是CDQ分治的思想

2) CDQ分治

普通分治在合并两个子问题的过程中(图中的combine部分),[left, mid]内的问题不会对[mid + 1, right]内的问题产生影响,比如排序,线段树的求和、求极值。而CDQ分治合并两个子问题时,还考虑[left, mid]内的修改对[mid + 1, right] 的结果产生的影响,例如求逆序对个数。归并排序CDQ分治的基础。

3) 归并树

归并排序的过程构成了一颗类似线段树的树(图中的divede部分)。那它就可以支持区间查询。

归并树的核心思想,就是利用线段树的建树过程,将归并排序的过程保存。

归并树节点定义

与线段树的唯一区别是线段树节点存的是区间的某指标,例如和,最大值;归并树节点存的是区间内所有的值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
struct MTNode
{
int start, end;
vector<int> data;
MTNode *left, *right;
MTNode(int s, int e, const vector<int>& nums, MTNode* l=nullptr, MTNode* r=nullptr)
:start(s),end(e),data(nums),left(l),right(r) {}
~MTNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

建树

自底向上地建立节点:叶子节点 start=end, 直接存 nums[start];如果是非叶子节点,在回溯阶段把左右子节点的数据做归并,作为当前节点存的数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
void build(int start, int end, const vector<int>& nums)
{
root = _build(start, end, nums);
}

MTNode* _build(int start, int end, const vector<int>& nums)
{
if(start == end)
{
return new MTNode(start, end, vector<int>({nums[start]}));
}
int mid = start + (end - start) / 2;
MTNode *left = _build(start, mid, nums);
MTNode *right = _build(mid + 1, end, nums);
vector<int> merged((left -> data).size() + (right -> data).size());
merge((left -> data).begin(), (left -> data).end(), (right -> data).begin(), (right -> data).end(), merged.begin());
MTNode *cur = new MTNode(start, end, merged, left, right);
return cur;
}

区间查询

归并树的经典问题:查询的是区间 [a, b] 中值大于 x 的个数

在归并树中找到[a, b]包含的所有不相交的区间对应的节点,分别进行一次二分查找。一次二分查找需要时间O(logN),而区间[a, b]至多被分为 logN 个不重叠的小区间,这样 [公式] 可以得到答案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
int query(int i, int j, int k)
{
if(i > j) return 0;
int result = 0;
_query(root, i, j, k, result);
return result;
}
void _query(MTNode* root, int i, int j, int k, int& result)
{
if(root -> start == i && root -> end == j)
{
auto pos = upper_bound((root -> data).begin(), (root -> data).end(), k);
result += (root -> data).end() - pos;
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid)
{
_query(root -> left, i, j, k, result);
return;
}
if(i > mid)
{
_query(root -> right, i, j, k, result);
return;
}
_query(root -> left, i, mid, k, result);
_query(root -> right, mid + 1, j, k, result);
}

区间 [i, j] 中大于 k 的元素个数

归并树解决静态数组上多次查询区间 [i, j] 中大于 k 的元素个数的问题的完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <vector>
#include <iostream>
#include <algorithm>

using namespace std;

// 区间查询 > x 的数
struct MTNode
{
int start, end;
vector<int> data;
MTNode *left, *right;
MTNode(int s, int e, const vector<int>& nums, MTNode* l=nullptr, MTNode* r=nullptr)
:start(s),end(e),data(nums),left(l),right(r) {}
~MTNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

class MergeTree
{
public:
MergeTree()
{
root = nullptr;
}

~MergeTree()
{
if(root)
{
delete root;
root = nullptr;
}
}

void build(int start, int end, const vector<int>& nums)
{
root = _build(start, end, nums);
}

int query(int i, int j, int k)
{
if(i > j) return 0;
int result = 0;
_query(root, i, j, k, result);
return result;
}

private:
MTNode *root;

void _query(MTNode* root, int i, int j, int k, int& result)
{
if(root -> start == i && root -> end == j)
{
auto pos = upper_bound((root -> data).begin(), (root -> data).end(), k);
result += (root -> data).end() - pos;
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid)
{
_query(root -> left, i, j, k, result);
return;
}
if(i > mid)
{
_query(root -> right, i, j, k, result);
return;
}
_query(root -> left, i, mid, k, result);
_query(root -> right, mid + 1, j, k, result);
}

MTNode* _build(int start, int end, const vector<int>& nums)
{
if(start == end)
{
return new MTNode(start, end, vector<int>({nums[start]}));
}
int mid = start + (end - start) / 2;
MTNode *left = _build(start, mid, nums);
MTNode *right = _build(mid + 1, end, nums);
vector<int> merged((left -> data).size() + (right -> data).size());
merge((left -> data).begin(), (left -> data).end(), (right -> data).begin(), (right -> data).end(), merged.begin());
MTNode *cur = new MTNode(start, end, merged, left, right);
return cur;
}
};

int main()
{
vector<int> nums(20);
for(int i = 0; i < (int)nums.size(); ++i) nums[i] = i;
for(int num: nums) cout << num << " ";
cout << endl;
random_shuffle(nums.begin(), nums.end());
for(int num: nums) cout << num << " ";
cout << endl;
MergeTree mergetree;
mergetree.build(0, (int)nums.size() - 1, nums);
while(true)
{
int i, j, k;
cin >> i >> j >> k;
cout << mergetree.query(i, j, k) << endl;
}
}

Share