【二分难题】力扣4-寻找两个正序数组的中位数

  |  

摘要: 各种二分

【对算法,数学,计算机感兴趣的同学,欢迎关注我哈,阅读更多原创文章】
我的网站:潮汐朝夕的生活实验室
我的公众号:算法题刷刷
我的知乎:潮汐朝夕
我的github:FennelDumplings
我的leetcode:FennelDumplings


本文看一个比较难的二分的问题。在两个有序数组中找中位数。可以基于两个数组有序的性质直接在整个区间范围二分。也可以对中位数的值进行值域二分。此外还可以参考 TopK 问题的做法,用减治的思想解决(二分也是一种减治)。

(1) 二分中位数的数值本身,这属于值域二分。思路上类似于 TopK 问题的值域二分算法。
(2) 减治法,思路上类似于 TopK 问题的减治算法。
(3) 二分中位数所在的位置,这是基于两个数组均有序的做法。

$1 题目

给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。

算法的时间复杂度应该为 $O(\log (m+n))$ 。

提示:

1
2
3
4
5
6
nums1.length == m
nums2.length == n
0 <= m <= 1000
0 <= n <= 1000
1 <= m + n <= 2000
-1e6 <= nums1[i], nums2[i] <= 1e6

示例 1:
输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2

示例 2:
输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

$2 题解

对于单串的中位数,如果想避免奇偶性的讨论,可以返回:

1
(a[n/2] + a[(n - 1)/2]) / 2;

算法1: 值域二分

值域二分的思路类似于在无序数组中求第 K 大元素中的值域二分算法,参考 topK问题汇总

两个数组的长度之和为 len = m + n

将两个数组视为一个长度为 len 的数组,通过值域二分的方式找出 a[len / 2]a[(len - 1) / 2],对于 len 为奇数,只需要求出 a[len / 2] 即可,对于 len 为偶数,相当于做两次值域二分。

当前正在求 a[len / 2] 猜的答案为 mid,在 nums1 中大于等于 mid 的最小下标为 i,在 nums2 中大于等于 mid 的最小下标为 j,比 mid 小的元素个数为 i + j 个。

如果 a[len / 2] = mid ,则比 mid 小的元素个数最多 len / 2 个,因此判定后删掉一半值域的逻辑如下:

1
2
3
4
if(i + j > len / 2)
right = mid;
else
left = mid;

时间复杂度 $O(2(\log U)(\log N+\log M))$,$U$ 为数组的取值范围。

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
const double EPS = 1e-9;
int n = nums1.size(), m = nums2.size();
int len = n + m;
double left = INT_MIN, right = INT_MAX;
while(left + EPS < right)
{
double mid = (left + right) / 2;
int i = upper_bound(nums1.begin(), nums1.end(), mid) - nums1.begin();
int j = upper_bound(nums2.begin(), nums2.end(), mid) - nums2.begin();
if(i + j > len / 2)
right = mid;
else
left = mid;
}
// a[len / 2] = mid
if(len & 1)
return left;
double tmp = left;
left = INT_MIN, right = INT_MAX;
while(left + EPS < right)
{
double mid = (left + right) / 2;
int i = upper_bound(nums1.begin(), nums1.end(), mid) - nums1.begin();
int j = upper_bound(nums2.begin(), nums2.end(), mid) - nums2.begin();
if(i + j > (len - 1) / 2)
right = mid;
else
left = mid;
}
return (left + tmp) / 2;
}
};

算法2:减治

转换为用减治法找数组中的第 k 大的问题, 其中 k = (N + 1) / 2, N = n + m,思路上与 TopK 问题的快速选择算法类似,参考 topK问题汇总

定义以下子问题:

1
2
3
int solve(nums1, i1, nums2, i2, k)

求 nums1[i1..n1-1] 和 nums2[i2..n2-1] 的第 k 大

两个数组有序且要找第 k 大,如果找到了两个数组的第 K / 2 大,即 nums1[K/2], nums2[K/2], 不妨设 nums1[K/2] < nums2[K/2],则 nums1 中前 K/2 个数一定是排在合并后的数组中的第 K 位之前,而 nums2 就不一定有这个性质了,此时 nums1 的前 K / 2 个数一定不可能是答案,此时问题变成了在剩下的范围里找第 K-K/2 大的数

1
solve(nums1, 0, nums2, 0, k)  ->  solve(nums1, k/2, nums2, 0, k - k/2)

例子

1
2
3
4
5
nums1 = [1,3,4,9]
nums2 = [1,2,3,4,5,6,7,8,9,10]

n = 4, m = 10, N = n + m = 14
k = (N + 1) / 2 = 7;

首先找到两个数组中的第 K/2 = 3 大的数,其中 nums2 的较小,因此 nums2 的三个数排除

此时变成在剩下的范围里面找第 k-k/2 = 7 - 3 = 4 大的数,在两个数组中找 4/2 = 2 大的数,其中 nums1 的较小,因此 nums1 的 2 个数排除

此时变成在剩下的范围里面找第 4-4/2 = 2 大的数,在两个数组中找 2/2 = 1 大的数,nums1 的和 nums2 的相等,排除哪一半都可以

  • 在减治的过程中,可能出现数组长 < K/2的情况,这时就取数组末尾比较即可。
  • 可以先求第 (N+1)/2 大,再求 (N+2)/2 大,再取平均,避免奇偶性判断。
  • 时间复杂度 $O(k) = O(\log (n+m))$,此外这里额递归是尾递归,额外空间复杂度为 $O(1)$,

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n1 = nums1.size();
int n2 = nums2.size();
if(n1 > n2) return findMedianSortedArrays(nums2, nums1);
int n = n1 + n2;
int k = (n + 1) / 2; // 找有序数组的第 k 大, 对应索引 k - 1;
if(n1 == 0)
{
if(n % 2 == 1)#
return (double)nums2[k - 1];
else
return ((double)nums2[k - 1] + nums2[k]) / 2;
}
int l1 = 0, l2 = 0;
return _findKth(nums1, nums2, l1, l2, n1, n2, k);
}

private:
double _findKth(vector<int>& nums1, vector<int>& nums2, int l1, int l2, int n1, int n2, int k)
{
if(l1 == n1)
{
double small = (double)nums2[l2 + k - 1];
if((n1 + n2) % 2 == 1)
return small;
else
{
double large = nums2[l2 + k];
return (small + large) / 2.0;
}
}
if(l2 == n2) // 保证了 nums1 长度 <= nums2 长度之后,这组判断条件可以去掉
{
double small = (double)nums1[l1 + k - 1];
if((n1 + n2) % 2 == 1)
return small;
else
{
double large = nums1[l1 + k];
return (small + large) / 2.0;
}
}
if(k == 1)
{
if((n1 + n2) % 2 == 1)
return (double)min(nums1[l1], nums2[l2]);
else
{
double small, large;
if(nums1[l1] <= nums2[l2])
{
small = nums1[l1];
if(l1 + 1 < n1)
large = min(nums1[l1 + 1], nums2[l2]);
else large = nums2[l2];
}
else
{
small = nums2[l2];
if(l2 + 1 < n2)
large = min(nums2[l2 + 1], nums1[l1]);
else large = nums1[l1];
}
return (small + large) / 2;
}
}
int kk = k / 2;
int i = l1 + kk - 1;
if(i >= n1)
i = n1 - 1;
int j = l2 + kk - 1;
if(j >= n2)
i = n2 - 1;
if(nums1[i] <= nums2[j])
return _findKth(nums1, nums2, i + 1, l2, n1, n2, k - (i - l1 + 1));
else
return _findKth(nums1, nums2, l1, j + 1, n1, n2, k - (j - l2 + 1));
}
};

简洁写法,避免奇偶判断

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n1 = nums1.size();
int n2 = nums2.size();
if(n1 > n2)
return findMedianSortedArrays(nums2, nums1);
int n = n1 + n2;
int k1 = (n + 1) / 2;
int k2 = (n + 2) / 2; // 直接找 2 次 K, 如果 n 为奇数,返回两次相同的值,可避免奇偶判断
int l1 = 0, l2 = 0;
return (_findKth(nums1, nums2, l1, l2, n1, n2, k1)
+ _findKth(nums1, nums2, l1, l2, n1, n2, k2)) / 2.0;
}
private:
int _findKth(vector<int>& nums1, vector<int>& nums2, int l1, int l2, int n1, int n2, int k)
{
if(l1 == n1)
return nums2[l2 + k - 1];
if(l2 == n2)
return nums1[l1 + k - 1];
if(k == 1)
return min(nums1[l1], nums2[l2]);

int kk = k / 2;
int i = min(l1 + kk - 1, n1 - 1);
int j = min(l2 + kk - 1, n2 - 1);
if(nums1[i] <= nums2[j])
return _findKth(nums1, nums2, i + 1, l2, n1, n2, k - (i - l1 + 1));
else
return _findKth(nums1, nums2, l1, j + 1, n1, n2, k - (j - l2 + 1));
}
};

算法3: 二分

二分在一个序列上所在的位置 mid1, 另一个序列上的位置 mid2 对应可求,每次可以排除数组长度 $N + M$ 的一半。我们假设第一个数组的长度不大于第二个数组的长度,即 $N \leq M$。

其实质是在第一个数组 nums1 上做区间二分,第二个数组 nums2 是用来判定 mid1 是取大了还是取小了的。

取到二分中点 mid1, mid2 之后,nums1[0..mid1-1]nums1[0..mid2-1] 是左半部分,共 k 个数。

首先是一些边界条件:

  • mid1 取到 0,说明前 k 大元素全都在 nums2,较小的那个中位数为 nums2[mid2 - 1]

  • mid2 取到 0,说明前 k 大元素全都在 nums1,较小的那个中位数为 nums1[mid1 - 1]。当两个序列等长的时候可能出现这种情况。

  • mid1 取到 n,说明 nums1 中的元素均属于在前 k 大元素,较小的中位数为 nums2[mid2-1]

0 < mid1 < n1时,判断 mid1 取大了、取消了、还是取到了答案:

  • 如果 nums1[mid1 - 1] > nums2[mid2],即第一个数组左半边的最大值比第二个数组右半边的最小值要大,则 mid1 取大了,置 r1 = mid1 - 1

  • 如果 nums1[mid2] < nums2[mid2 - 1],即第一个数组右半边的最小值比第二个数组左半边的最大值要大,则 mid1 取小了,置 l1 = mid1 + 1

  • 不是上面两种情况,说明已经找到目标,max(nums1[mid1-1], nums2[mid2-1]) 为较小的中位数。

若 $N + M$ 为偶数,还需要求一下较大的中位数,为 min(nums1[mid1], nums2[mid2])

时间复杂度为 $O(\log \min(N, M))$

代码 (C++)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// 二分在一个序列上所在的位置, 另一个序列上的位置对应可求,每次排除一半: 单序列上做二分
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n1 = nums1.size();
int n2 = nums2.size();
if(n1 > n2)
return findMedianSortedArrays(nums2, nums1);
// n1 <= n2
int n = n1 + n2;
int k = (n + 1) / 2; // 找有序数组的第 k 大, 对应索引 k - 1;
// n1 > 0
int l1 = 0, r1 = n1; // 注意:n1 也是一组划分, 代表 nums1 的 n1 个都选了
while(l1 <= r1) // 循环内一定会找到,这个很关键,一些边界条件不用写了
{
// n1 = 0 时,l1 = 0, r1 = 0, 也必进循环,循环内一定找到答案
int mid1 = l1 + (r1 - l1) / 2; // nums1 二分到索引 mid1, 有 mid1 个数, nums1[mid1] 本身属于右子序列
int mid2 = k - mid1; // nums2 中要取 k - mid1 个数,nums2[mid2] 本身属于右子序列
// 边界判断
// 要么 mid1 取到 0, n 两个边界值之前已经找到答案
// 要么 mid1 取到边界值 0, n 时是答案
if(mid1 != 0 && nums1[mid1 - 1] > nums2[mid2])
r1 = mid1 - 1;
else if(mid1 != n1 && nums1[mid1] < nums2[mid2 - 1])
l1 = mid1 + 1;
else
{
// 找到答案,若取到边界值,则一定是答案
double small;
if(mid1 == 0)
small = nums2[mid2 - 1];
else if(mid2 == 0)
small = nums1[mid1 - 1]; // 两个序列等长的情况
else
small = max(nums1[mid1 - 1], nums2[mid2 - 1]);
if(n % 2 == 1)
return small;
double large;
if(mid1 == n1)
large = nums2[mid2];
else if(mid2 == n2)
large = nums1[mid1];
else
large = min(nums1[mid1], nums2[mid2]);
return (small + large) / 2.0;
}
}
return 0.0;
}
};

Share