【二分难题】力扣4-寻找两个正序数组的中位数

  |  

$1 题目

题目链接

4. 寻找两个正序数组的中位数

题目描述

给定两个大小为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。

请你找出这两个正序数组的中位数,并且要求算法的时间复杂度为 $O(\log (m + n))$。

你可以假设 nums1 和 nums2 不会同时为空。

样例

示例 1:
nums1 = [1, 3]
nums2 = [2]
则中位数是 2.0

示例 2:
nums1 = [1, 2]
nums2 = [3, 4]
则中位数是 (2 + 3)/2 = 2.5

$2 题解

对于单串的中位数,如果想避免奇偶性的讨论,可以返回:

1
(a[n/2] + a[(n - 1)/2]) / 2;

算法1: 值域二分

两个数组的长度之和为 len = m + n

将两个数组视为一个长度为 len 的数组,通过值域二分的方式找出 a[len / 2]a[(len - 1) / 2],对于 len 为奇数,只需要求出 a[len / 2] 即可,对于 len 为偶数,相当于做两次值域二分。

当前正在求 a[len / 2] 猜的答案为 mid,在 nums1 中大于等于 mid 的最小下标为 i,在 nums2 中大于等于 mid 的最小下标为 j,比 mid 小的元素个数为 i + j 个。

如果 a[len / 2] = mid ,则比 mid 小的元素个数最多 len / 2 个,因此判定后删掉一半值域的逻辑如下:

1
2
3
4
if(i + j > len / 2)
right = mid;
else
left = mid;

时间复杂度 $O(2(\log INT)(\log N+\log M))$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
const double EPS = 1e-9;
int n = nums1.size(), m = nums2.size();
int len = n + m;
double left = INT_MIN, right = INT_MAX;
while(left + EPS < right)
{
double mid = (left + right) / 2;
int i = upper_bound(nums1.begin(), nums1.end(), mid) - nums1.begin();
int j = upper_bound(nums2.begin(), nums2.end(), mid) - nums2.begin();
if(i + j > len / 2)
right = mid;
else
left = mid;
}
// a[len / 2] = mid
if(len & 1)
return left;
double tmp = left;
left = INT_MIN, right = INT_MAX;
while(left + EPS < right)
{
double mid = (left + right) / 2;
int i = upper_bound(nums1.begin(), nums1.end(), mid) - nums1.begin();
int j = upper_bound(nums2.begin(), nums2.end(), mid) - nums2.begin();
if(i + j > (len - 1) / 2)
right = mid;
else
left = mid;
}
return (left + tmp) / 2;
}
};

算法2:二分 k

转换为二分地找数组中的第 k 大, 其中 k = (N + 1) / 2, N = n + m

有两种二分方法,一种是对 K 二分,每次排除掉 K / 2,另一种是对整个索引范围二分(对切分位置二分),每次排除掉 $n/2$。算法2关注二分 K 的方法,算法3关注二分切分位置的方法。

二分 K 的接口:

1
2
3
int solve(nums1, i1, nums2, i2, k)

求 nums1[i1..n1-1] 和 nums2[i2..n2-1] 的第 k 大

两个数组有序且要找第 k 大,如果找到了两个数组的第 K / 2 大,即 nums1[K/2], nums2[K/2], 不妨设 nums1[K/2] < nums2[K/2],则 nums1 中前 K/2 个数一定是排在合并后的数组中的第 K 位之前,而 nums2 就不一定有这个性质了,此时 nums1 的前 K / 2 个数一定不可能是答案,此时问题变成了在剩下的范围里找第 K-K/2 大的数

1
solve(nums1, 0, nums2, 0, k)  ->  solve(nums1, k/2, nums2, 0, k - k/2)

例子

1
2
3
4
5
nums1 = [1,3,4,9]
nums2 = [1,2,3,4,5,6,7,8,9,10]

n = 4, m = 10, N = n + m = 14
k = (N + 1) / 2 = 7;

首先找到两个数组中的第 K/2 = 3 大的数,其中 nums2 的较小,因此 nums2 的三个数排除

此时变成在剩下的范围里面找第 k-k/2 = 7 - 3 = 4 大的数,在两个数组中找 4/2 = 2 大的数,其中 nums1 的较小,因此 nums1 的 2 个数排除

此时变成在剩下的范围里面找第 4-4/2 = 2 大的数,在两个数组中找 2/2 = 1 大的数,nums1 的和 nums2 的相等,排除哪一半都可以

  • 在减治的过程中,可能出现数组长 < K/2的情况,这时就取数组末尾比较即可。
  • 可以先求第 (N+1)/2 大,再求 (N+2)/2 大,再取平均,避免奇偶性判断
  • 因为是尾递归,额外空间 $O(1)$,时间复杂度 $O(\log (n+m))$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n1 = nums1.size();
int n2 = nums2.size();
if(n1 > n2) return findMedianSortedArrays(nums2, nums1);
int n = n1 + n2;
int k = (n + 1) / 2; // 找有序数组的第 k 大, 对应索引 k - 1;
if(n1 == 0)
{
if(n % 2 == 1)
return (double)nums2[k - 1];
else
return ((double)nums2[k - 1] + nums2[k]) / 2;
}
int l1 = 0, l2 = 0;
return _findKth(nums1, nums2, l1, l2, n1, n2, k);
}

private:
double _findKth(vector<int>& nums1, vector<int>& nums2, int l1, int l2, int n1, int n2, int k)
{
if(l1 == n1)
{
double small = (double)nums2[l2 + k - 1];
if((n1 + n2) % 2 == 1)
return small;
else
{
double large = nums2[l2 + k];
return (small + large) / 2.0;
}
}
if(l2 == n2) // 保证了 nums1 长度 <= nums2 长度之后,这组判断条件可以去掉
{
double small = (double)nums1[l1 + k - 1];
if((n1 + n2) % 2 == 1)
return small;
else
{
double large = nums1[l1 + k];
return (small + large) / 2.0;
}
}
if(k == 1)
{
if((n1 + n2) % 2 == 1)
return (double)min(nums1[l1], nums2[l2]);
else
{
double small, large;
if(nums1[l1] <= nums2[l2])
{
small = nums1[l1];
if(l1 + 1 < n1) large = min(nums1[l1 + 1], nums2[l2]);
else large = nums2[l2];
}
else
{
small = nums2[l2];
if(l2 + 1 < n2) large = min(nums2[l2 + 1], nums1[l1]);
else large = nums1[l1];
}
return (small + large) / 2;
}
}
int kk = k / 2;
int i = l1 + kk - 1;
if(i >= n1) i = n1 - 1;
int j = l2 + kk - 1;
if(j >= n2) i = n2 - 1;
if(nums1[i] <= nums2[j])
return _findKth(nums1, nums2, i + 1, l2, n1, n2, k - (i - l1 + 1));
else
return _findKth(nums1, nums2, l1, j + 1, n1, n2, k - (j - l2 + 1));
}
};

优化写法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n1 = nums1.size();
int n2 = nums2.size();
if(n1 > n2) return findMedianSortedArrays(nums2, nums1);
int n = n1 + n2;
int k1 = (n + 1) / 2;
int k2 = (n + 2) / 2; // 直接找 2 次 K, 如果 n 为奇数,返回两次相同的值,可避免奇偶判断
int l1 = 0, l2 = 0;
return (_findKth(nums1, nums2, l1, l2, n1, n2, k1)
+ _findKth(nums1, nums2, l1, l2, n1, n2, k2)) / 2.0;
}
private:
int _findKth(vector<int>& nums1, vector<int>& nums2, int l1, int l2, int n1, int n2, int k)
{
if(l1 == n1) return nums2[l2 + k - 1];
if(l2 == n2) return nums1[l1 + k - 1];
if(k == 1) return min(nums1[l1], nums2[l2]);

int kk = k / 2;
int i = min(l1 + kk - 1, n1 - 1);
int j = min(l2 + kk - 1, n2 - 1);
if(nums1[i] <= nums2[j])
return _findKth(nums1, nums2, i + 1, l2, n1, n2, k - (i - l1 + 1));
else
return _findKth(nums1, nums2, l1, j + 1, n1, n2, k - (j - l2 + 1));
}
};

算法3: 二分整个索引范围

二分在一个序列上所在的位置, 另一个序列上的位置对应可求,每次排除一半: 单序列上做二分

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
// 二分在一个序列上所在的位置, 另一个序列上的位置对应可求,每次排除一半: 单序列上做二分
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n1 = nums1.size();
int n2 = nums2.size();
if(n1 > n2) return findMedianSortedArrays(nums2, nums1);
// n1 <= n2
int n = n1 + n2;
int k = (n + 1) / 2; // 找有序数组的第 k 大, 对应索引 k - 1;
// n1 > 0
int l1 = 0, r1 = n1; // 注意:n1 也是一组划分, 代表 nums1 的 n1 个都选了
while(l1 <= r1) // 循环内一定会找到,这个很关键,一些边界条件不用写了
{
// n1 = 0 时,l1 = 0, r1 = 0, 也必进循环,循环内一定找到答案
int mid1 = l1 + (r1 - l1) / 2; // nums1 二分到索引 mid1, 有 mid1 个数, nums1[mid1] 本身属于右子序列
int mid2 = k - mid1; // nums2 中要取 k - mid1 个数,nums2[mid2] 本身属于右子序列
// 边界判断
// 要么 mid1 取到 0, n 两个边界值之前已经找到答案
// 要么 mid1 取到边界值 0, n 时是答案
if(mid1 != 0 && nums1[mid1 - 1] > nums2[mid2])
r1 = mid1 - 1;
else if(mid1 != n1 && nums1[mid1] < nums2[mid2 - 1])
l1 = mid1 + 1;
else
{
// 找到答案,若取到边界值,则一定是答案
double small;
if(mid1 == 0) small = nums2[mid2 - 1];
else if(mid2 == 0) small = nums1[mid1 - 1]; // 两个序列等长的情况
else small = max(nums1[mid1 - 1], nums2[mid2 - 1]);
if(n % 2 == 1) return small;
double large;
if(mid1 == n1) large = nums2[mid2];
else if(mid2 == n2) large = nums1[mid1];
else large = min(nums1[mid1], nums2[mid2]);
return (small + large) / 2.0;
}
}
return 0.0;
}
};

Share