topK问题:第K个最大元素

  |  

摘要: TopK 问题的几种主流解法

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文我们看一个大家都研究的很透的问题,数组中的第 K 个最大元素。这种 TopK 问题有 3 种主流解法:

  1. 用堆始终维护前 K 个元素
  2. 快速选择算法
  3. 值域二分

有很多问题,topK 是其中的一个组件,如果数据比较简单,可以直接用 nth_element(nums.begin(), nums.begin() + k, nums.end());

对于复杂的问题,用结构体 Item 自定义堆中保存的元素,并自定义比较规则,然后定义 priority<Item, vector<Item>, Cmp> pq


$1 题目

给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。

请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。

提示:

1
2
1 <= k <= nums.length <= 1e5
-1e4 <= nums[i] <= 1e4

示例 1:

输入: [3,2,1,5,6,4] 和 k = 2
输出: 5
示例 2:

输入: [3,2,3,1,2,4,5,5,6] 和 k = 4
输出: 4

$2 题解

算法1:堆

把原始数据视为数据流,然后用堆始终维护前 K 个元素。对于数据很复杂的情况,一般要从堆中保存的元素,以及堆中元素的排序规则入手。

开一个小顶堆,用于维护枚举到 nums[i] 时的前 k 个最大元素。

1
2
// 直接写 priority_queue<int> pq; 是大顶堆,小顶堆是下面的写法
priority_queue<int, vector<int>, greater<int> > pq;

先把前 k 个数压进去。然后枚举剩余的数,只要堆顶小于当前枚举的数,就先弹出堆顶在压入当前值,保持堆里是前 K 个最大元素。枚举结束后,堆顶就是答案。

时间复杂度稳定 $O(N\log K)$

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
priority_queue<int, vector<int>, greater<int> > pq;
for(int i = 0; i < k; ++i)
pq.push(nums[i]);
for(int i = k; i < n; ++i)
if(pq.top() < nums[i])
{
pq.pop();
pq.push(nums[i]);
}
return pq.top();
}
};

算法2: 快速选择算法

快速选择算法,是一种减治算法,它的划分 partition 进而转化为子问题的思路与快排中的划分方式完全相同。按照该划分 partition 的思路把子问题的结果组织成树形结果保存下来共后续查询,得到划分树(与归并树类似,归并树保存的是归并排序的子问题结果)。

在子区间 [left, right] 中选择第 k 大的数时,完全照搬快排的划分算法:

(1) 选择一个枢轴(pivot),然后交换到 left 位置:

1
2
int randIdx = rand() % (right - left + 1) + left; // 随机选择 pivot
swap(nums[randIdx], nums[left]);

(2) 使用划分算法确定 pivot 的位置。大于 pivot 的元素移到左边,小于等于 pivot 的元素移到右边。

1
2
3
4
5
6
7
int pivot = nums[left];
int l = left, r = right;
while(l < r)
{
//...
}
// 一轮 partition 完成

一轮 partition 完成后,pivot 的位置为 l,此时考察 l 和 left + K 的关系,[left, right] 中第 K 个位置的下标是 left + K - 1:

1
2
3
l = left + K - 1: pivot 刚好在 [left, right] 第 k 个位置,找到答案了
l > left + K - 1: [left, l] 中有 l - left + 1 个数字,还要在 [l + 1, right] 中找 K - (l - left + 1) 个
l < left + K - 1: 在 [left, l - 1] 中继续找第 k 大

时间复杂度平均 $O(N)$,最坏 $O(N^{2})$,关于平均情况时间复杂度为 $O(N)$,涉及到一些推导,之后再算法分析的文章中再来讨论。

快速选择算法有一个优化:BFPRT算法,也叫中位数的中位数算法,它进一步优化了 pivot 的选取方法,使得最坏时间复杂度也变为 $O(N)$,它由Blum、Floyd、Pratt、Rivest、Tarjan提出。其推导过程也在算法分析中讨论。

快速选择算法每完成一轮 partition 将数据分成不均匀的两份后,只有其中一份对结果有影响,可以去掉一部分,数据规模就减少一部分,所以也叫减治算法。插入排序,DFS, BFS, 拓扑排序也隐含了减治的思想:每完成一轮,下一轮就可以少考虑一个数据。

代码(c++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
return partition(nums, k, 0, n - 1);
}

private:
int partition(vector<int>& nums, int k, int left, int right)
{
// 在 nums 的 [left .. right] 中找第 k 大
int randIdx = rand() % (right - left + 1) + left; // 随机选择 pivot
swap(nums[randIdx], nums[left]);
int pivot = nums[left];
int l = left, r = right;
while(l < r)
{
while(l < r && nums[r] <= pivot)
--r;
if(l < r)
{
nums[l] = nums[r];
++l;
}
while(l < r && nums[l] > pivot)
++l;
if(l < r)
{
nums[r] = nums[l];
--r;
}
}
nums[l] = pivot;
if(l == left + k - 1)
return nums[l];
else if(l < left + k - 1)
return partition(nums, k - (l - left + 1), l + 1, right);
else
return partition(nums, k, left, l - 1);
}
};

算法3: 值域二分

长 n 的数组,其元素只可能是第 1 大到第 n 大之间。最大值 maxx, 最小值 minx,则答案范围在 [minx, maxx]。在这个范围内始终二分地猜一个数,计算它的排名 x; 根据 x 与 k 的关系决定二分范围减小的方向。

这样时间复杂度为 $O(N\log U)$,$U$ 为数组取值范围。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
return topk(nums, n + 1 - k);
}

private:
int topk(const vector<int>& nums, int k)
{
// 从小到大排第 k
int n = nums.size();
int maxx = nums[0], minx = nums[0];
for(int i = 1; i < n; ++i)
{
maxx = max(maxx, nums[i]);
minx = min(minx, nums[i]);
}
int left = minx, right = maxx;
while(left < right)
{
int mid = (left + right + 1) / 2;
int x = check(nums, mid);
if(x >= k)
right = mid - 1;
else
left = mid;
}
return left;
}

int check(const vector<int>& nums, int mid)
{
int ans = 0;
for(int i: nums)
if(i < mid)
++ans;
return ans;
}
};

Share