力扣215-快速选择算法,划分树

  |  
  • topK 问题有 3 种主流解法
  1. 把原始数据视为数据流,然后用堆始终维护前 K 个元素。困难的问题,一般要从堆中保存的元素,以及堆中元素的排序规则入手。
  2. 快速选择算法,是一种减治算法,它的划分 partition 进而转化为子问题的思路与快排中的划分方式完全相同。按照该划分 partition 的思路把子问题的结果组织成树形结果保存下来共后续查询,得到划分树(与归并树类似,归并树保存的是归并排序的子问题结果)
  3. 值域二分,长 n 的数组,其元素只可能是第 1 大到第 n 大之间。最大值 maxx, 最小值 minx,则答案范围在 [minx, maxx]。在这个范围内始终二分地猜一个数,计算它的排名 x; 根据 x 与 k 的关系决定二分范围减小的方向。
  • STL 的使用 — 有很多问题,topK 是其中的一个组件
    • 对于比较简单的 topK 问题,可以直接用 nth_element(nums.begin(), nums.begin() + k, nums.end());
    • 对于复杂的问题,用结构体 Item 自定义堆中保存的元素,并自定义比较规则,然后定义 priority<Item, vector<Item>, Cmp> pq,其中 Cmp 的写法
1
2
3
4
5
6
7
8
struct Cmp
{
// 返回 true -> item1 在堆中排前面
bool operator()(const Item& item1, const Item& item2) const
{
...
}
};
  • 区间 [i, j] 第 k 大问题 query(*i, j, k)
  1. 归并树的查询 query(i, j, x) 可以回答区间 [i, j] 中大于 x 的元素有多少个,因此可以二分地猜答案,然后用归并树的查询快的得到猜大了还是猜小了。
  2. 划分树将快速选择算法中,每一步分治的子问题结果保存下来。查询 query(i, j, k),时可以直接查找结果返回。
  3. 此外还有主席树整体二分(CDQ分治)等做法

$1 题目

题目链接

215. 数组中的第K个最大元素

题目描述

在未排序的数组中找到第 k 个最大的元素。请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

样例

示例 1:

输入: [3,2,1,5,6,4] 和 k = 2
输出: 5
示例 2:

输入: [3,2,3,1,2,4,5,5,6] 和 k = 4
输出: 4

$2 题解

算法1:堆

开一个小顶堆,用于维护枚举到 nums[i] 时的前 k 个最大元素

1
2
// 直接写 priority_queue<int> pq; 是大顶堆,小顶堆是下面的写法
priority_queue<int, vector<int>, greater<int> > pq;

先把前 k 个数压进去。然后枚举剩余的数,只要堆顶小于当前枚举的数,就先弹出堆顶在压入当前值,保持堆里是前 K 个最大元素。枚举结束后,堆顶就是答案。

时间复杂度稳定 $O(N\log K)$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
priority_queue<int, vector<int>, greater<int> > pq;
for(int i = 0; i < k; ++i)
pq.push(nums[i]);
for(int i = k; i < n; ++i)
if(pq.top() < nums[i])
{
pq.pop();
pq.push(nums[i]);
}
return pq.top();
}
};

算法2: 快速选择算法

在子区间 [left, right] 中选择第 k 大的数时,完全照搬快排的划分算法:

  1. 选择一个枢轴(pivot),然后交换到 left 位置
1
2
int randIdx = rand() % (right - left + 1) + left; // 随机选择 pivot
swap(nums[randIdx], nums[left]);
  1. 使用划分算法确定 pivot 的位置。大于 pivot 的元素移到左边,小于等于 pivot 的元素移到右边。
    1
    2
    3
    4
    5
    6
    7
    int pivot = nums[left];
    int l = left, r = right;
    while(l < r)
    {
    //...
    }
    // 一轮 partition 完成

一轮 partition 完成后,pivot 的位置为 l,此时考察 l 和 left + K 的关系,[left, right] 中第 K 个位置的下标是 left + K - 1:

1
2
3
l = left + K - 1: pivot 刚好在 [left, right] 第 k 个位置,找到答案了
l > left + K - 1: [left, l] 中有 l - left + 1 个数字,还要在 [l + 1, right] 中找 K - (l - left + 1) 个
l < left + K - 1: 在 [left, l - 1] 中继续找第 k 大

时间复杂度平均 $O(N)$,最坏 $O(N^{2})$,但是最坏情况太难达到了,一般还是认为快速选择算法是 O(N) 的。

快速选择算法有一个优化:BFPRT算法,也叫中位数的中位数算法,它进一步优化了 pivot 的选取方法,使得最坏时间复杂度也变为 $O(N)$,它由Blum、Floyd、Pratt、Rivest、Tarjan提出。

快速选择算法每完成一轮 partition 将数据分成不均匀的两份后,只有其中一份对结果有影响,可以去掉一部分,数据规模就减少一部分,所以也叫减治算法。插入排序,DFS, BFS, 拓扑排序也隐含了减治的思想:每完成一轮,下一轮就可以少考虑一个数据。

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
return partition(nums, k, 0, n - 1);
}

private:
int partition(vector<int>& nums, int k, int left, int right)
{
// 在 nums 的 [left .. right] 中找第 k 大
int randIdx = rand() % (right - left + 1) + left; // 随机选择 pivot
swap(nums[randIdx], nums[left]);
int pivot = nums[left];
int l = left, r = right;
while(l < r)
{
while(l < r && nums[r] <= pivot)
--r;
if(l < r)
{
nums[l] = nums[r];
++l;
}
while(l < r && nums[l] > pivot)
++l;
if(l < r)
{
nums[r] = nums[l];
--r;
}
}
nums[l] = pivot;
if(l == left + k - 1)
return nums[l];
else if(l < left + k - 1)
return partition(nums, k - (l - left + 1), l + 1, right);
else
return partition(nums, k, left, l - 1);
}
};

算法3: 值域二分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
return topk(nums, n + 1 - k);
}

private:
int topk(const vector<int>& nums, int k)
{
// 从小到大排第 k
int n = nums.size();
int maxx = nums[0], minx = nums[0];
for(int i = 1; i < n; ++i)
{
maxx = max(maxx, nums[i]);
minx = min(minx, nums[i]);
}
int left = minx, right = maxx;
while(left < right)
{
int mid = (left + right + 1) / 2;
int x = check(nums, mid);
if(x >= k)
right = mid - 1;
else
left = mid;
}
return left;
}

int check(const vector<int>& nums, int mid)
{
int ans = 0;
for(int i: nums)
if(i < mid)
++ans;
return ans;
}
};

$3 扩展: 区间第 k 大

1
int query(i, j, k) 返回区间 [i, j] 的第 k 大

共 M 次查询。

1. 归并树

归并树的思考过程:力扣315-索引数组,归并树

已经根据原始数组建好归并树之后,可以 $O(\log^{2}N)$ 得到 [i, j] 中大于 x 的元素个数

在归并树建树过程中可以顺便得到数组的最大值 right 和最小值 left,答案肯定在 [left, right]。

值域二分:每次从答案范围 [left, right] 中猜一个答案 mid = (left + right) / 2; 然后查询区间 [left, right] 中大于 mid 的个数 cnt

1
2
3
cnt = k - 1: mid 是答案,返回 mid
cnt > k - 1: mid 猜小了,left = mid + 1 继续猜
cnt < k - 1: mid + 1 肯定大了, right = mid 继续猜

时间复杂度 $O(N\log N+\log^{3}N)$ ,明显弱于快速选择算法的 $O(N)$。但如果是 M 次查询不同的区间,则为 $O(N\log N+M\log^{3}N)$ ,其中前半部分是1次建树,后半部分是 M 次查询。这比快速选择算法的 $O(MN)$ 好。

用归并树AC本题的代码(仅 1 次查询,i = 0, j = n - 1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
struct MTNode
{
int start, end;
vector<int> data;
MTNode *left, *right;
MTNode(int s, int e, const vector<int>& nums, MTNode* l=nullptr, MTNode* r=nullptr)
:start(s),end(e),data(nums),left(l),right(r) {}
~MTNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

class MergeTree
{
public:
MergeTree()
{
root = nullptr;
}

~MergeTree()
{
if(root)
{
delete root;
root = nullptr;
}
}

void build(int start, int end, const vector<int>& nums)
{
root = _build(start, end, nums);
}

int query(int i, int j, int k)
{
if(i > j) return 0;
int result = 0;
_query(root, i, j, k, result);
return result;
}

int get(int i)
{
return (root -> data)[i];
}

private:
MTNode *root;

void _query(MTNode* root, int i, int j, int k, int& result)
{
if(root -> start == i && root -> end == j)
{
auto pos = upper_bound((root -> data).begin(), (root -> data).end(), k);
result += (root -> data).end() - pos;
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid)
{
_query(root -> left, i, j, k, result);
return;
}
if(i > mid)
{
_query(root -> right, i, j, k, result);
return;
}
_query(root -> left, i, mid, k, result);
_query(root -> right, mid + 1, j, k, result);
}

MTNode* _build(int start, int end, const vector<int>& nums)
{
if(start == end)
{
return new MTNode(start, end, vector<int>({nums[start]}));
}
int mid = start + (end - start) / 2;
MTNode *left = _build(start, mid, nums);
MTNode *right = _build(mid + 1, end, nums);
vector<int> merged((left -> data).size() + (right -> data).size());
merge((left -> data).begin(), (left -> data).end(), (right -> data).begin(), (right -> data).end(), merged.begin());
MTNode *cur = new MTNode(start, end, merged, left, right);
return cur;
}
};

class Solution_3 {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
MergeTree mergetree;
mergetree.build(0, n - 1, nums);
int left = mergetree.get(0), right = mergetree.get(n - 1);
while(left < right)
{
int mid = left + (right - left) / 2;
int cnt = mergetree.query(0, n - 1, mid); // > mid 的元素个数
if(cnt == k - 1)
return mid;
else if(cnt > k - 1)
left = mid + 1;
else
right = mid;
}
return left;
}
};

2. 划分树

划分树和归并树都是用线段树作为辅助的,其中划分树的建树是模拟快排过程,自顶向下建树,归并树是模拟归并排序过程,自底向上建树,两种方法含有分治思想。

划分树的基本思想就是对于某个节点对应的区间 [start, end],把它划分成两个子区间,左边区间的数小于等于右边区间的数,左子区间对应划分后的 [start, mid],右子区间对应划分后的 [mid + 1, end],其中 mid = (left + right) / 2。查找的时候通过记录进入左子树的数的个数,确定下一个查找区间,直到区间长度变成1(start = end, 叶子节点),就找到了。

划分树的节点定义

1
2
3
4
5
start, end:节点对应的区间
nums[0..end-start]:节点持有的数据,一共 end - start + 1 个,是原始数组排好序之后在 [start, end] 范围的数字,但这里 nums 并没按排好序的顺序存放,而是按原数组的顺序放的。它具体含了哪些数字来自父节点划分的结果
toleft[0..end-start+2]:记录进入左子树的数字的个数
toleft[i] := [0 .. i - 1] 这 i 个数中,进入左子树的数字的个数,有前缀和的思想
toleft[0] = 0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
struct PTNode
{
int start, end;
vector<int> nums, toleft;
PTNode *left, *right;
PTNode(int s, int e, PTNode* l=nullptr, PTNode* r=nullptr)
:start(s),end(e),nums(vector<int>(end - start + 1)),toleft(vector<int>(end - start + 2)),left(l),right(r) {}
~PTNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

划分树的建树

过程基本就是模拟快排过程,对于区间 [start, end] 取一个已经排过序的区间中位数,然后把小于中值的点放左边,大于的放右边,等于中位数的需要单独统计,使得进入左子树的数字个数为 [mid - left + 1],确保树是平衡的。划分树建树之前先求取并持有原数组排序后的数组 sorted,节点的中位数直接就取 sorted[mid]。

1
int median =  sorted[mid];

只要当前节点不是叶子节点(start = end),就先建好两个子节点

1
2
PTNode *left = new PTNode(root -> start, mid);
PTNode *right = new PTNode(mid + 1, root -> end);

然后执行划分的算法流程,顺序枚举当前节点的所有数据,判断应该划分的方向,往对应的子节点塞进去。如果是塞进了左子树,toleft 加一。

[5, 2, 6, 1] 的建树过程

划分树的查询

query(i, j, k) 查询区间 [i, j] 中第 k 大的数

首先确定 [start ..i - 1], [start..j] 中去往左子树的数字个数 tli, tlj。其中 start = i 的情况需要特判。tlj - tli 是 [i..j] 中去往左子树的元素个数,记 tl。

1
2
3
4
5
int tli = 0;
if(root -> start != i)
tli = (root -> toleft)[i - root -> start];
int tlj = (root -> toleft)[j - root -> start + 1];
int tl = tlj - tli;

看 k 和 tl 的关系决定子查询,类似于快速选择算法中考察 l 和 left + K 的关系。

1
2
3
tl >= k: 第 k 大在左子树,答案是 query(root -> left, new_i, new_j, k)
tl < k: 第 k 大在右子树,答案是 query(root -> right, new_i, new_j, k - tl)
new_i, new_j 需要画图推导

用划分树AC本题的代码(仅 1 次查询,i = 0, j = n - 1)

但是跑的效果很差,因为查询只有1次,需要查询次数较多时才能体现划分树的优势

查询次数为 M 时,时间复杂度 $O(N\log N + M\log N)$,比 M 次快速选择的 $O(MN)$ 和归并树的 $O(N\log N + M\log^{3}N)$ 都好。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
struct PTNode
{
int start, end;
vector<int> nums, toleft;
PTNode *left, *right;
PTNode(int s, int e, PTNode* l=nullptr, PTNode* r=nullptr)
:start(s),end(e),nums(vector<int>(end - start + 1)),toleft(vector<int>(end - start + 2)),left(l),right(r) {}
~PTNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};

class PartitionTree
{
public:
PartitionTree()
{
root = nullptr;
sorted = vector<int>();
}

~PartitionTree()
{
if(root)
{
delete root;
root = nullptr;
}
}

void build(int start, int end, const vector<int>& nums)
{
sorted = nums;
sort(sorted.begin(), sorted.end());
root = new PTNode(start, end);
root -> nums = nums;
_build(root);
}

int query(int i, int j, int k)
{
if(i > j) return 0;
return _query(root, i, j, k);
}

private:
PTNode *root;
vector<int> sorted;

int _query(PTNode* root, int i, int j, int k)
{
if(root -> start == root -> end) return (root -> nums)[0];

int tli = 0;
if(root -> start != i)
tli = (root -> toleft)[i - root -> start];
int tlj = (root -> toleft)[j - root -> start + 1];
int tl = tlj - tli;
int new_i, new_j;
if(tl >= k)
{
// 第 k 大在左子
new_i = root -> start + tli;
new_j = new_i + tl - 1;
return _query(root -> left, new_i, new_j, k);
}
else
{
// 第 k 大在右子
int mid = root -> start + (root -> end - root -> start) / 2;
new_i = mid + 1 + i - (root -> start) - tli;
new_j = new_i + j - i - tl;
return _query(root -> right, new_i, new_j, k - tl);
}
}

void _build(PTNode* root)
{
if(root -> start == root -> end)
return;
int mid = root -> start + (root -> end - root -> start) / 2;
int median = sorted[mid];
PTNode *left = new PTNode(root -> start, mid);
PTNode *right = new PTNode(mid + 1, root -> end);
int n = (root -> nums).size();
int median_to_left = mid - root -> start + 1;
for(int i = 0; i < n; ++i)
{
if((root -> nums)[i] < median)
--median_to_left;
}
// 出循环后 median_to_left 为去往左子树中等于中位数的个数
int to_left = 0; // 去往左子树的个数
int idx_left = 0, idx_right = 0;
for(int i = 0; i < n; ++i)
{
int cur = (root -> nums)[i];
if(cur < median || ((cur == median) && median_to_left > 0))
{
(left -> nums)[idx_left] = cur;
++idx_left;
++to_left;
if(cur == median)
--median_to_left;
}
else
{
(right -> nums)[idx_right] = cur;
++idx_right;
}
(root -> toleft)[i + 1] = to_left;
}
_build(left);
_build(right);
root -> left = left;
root -> right = right;
}
};

class Solution_4 {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
PartitionTree partitiontree;
partitiontree.build(0, n - 1, nums);
return partitiontree.query(0, n - 1, n + 1 - k);
}
};


Share