解法一:模拟法
分析:
模拟遍历两个数组,找第 k 小的元素,作为中位数,或进行下一步处理返回中位数。k就是下文中的 div
- 设置两个指针
i
、j
分别遍历两个数组- 设置一个计数器
count
表示当前遍历过的元素总数- 设置
div = (m + n + 1)/2
,表示下面两者其一:
- 中位数的位置:如【1,3】【2】,div就指向2的位置;
- 两数组合并后,中间两个元素的第一个元素位置:如【1,2】【3,4】,div就指向两个中间元素2,3中的第一个元素2的位置
- 设置
remain = (m + n +1)%2
,表示序列的类型:
- remain == 0:中位数就是div指向的元素,不需要考虑下一个元素
- 如【1,3】【2】,remain==0,中位数就是div指向的2。
- remain == 1:div指向的元素,是两个中间元素的第一个元素,还需要找到下一个元素的值,取平均作为中位数,
- 如【1,2】【3,4】,remain==1,div指向2,但还要找到2的下一个元素是哪个,最后找到3,取平均:(2+3)/2=2.5为中位数
剩下需要做的就是模拟法需要处理的一些细节问题。
代码:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size();
int n = nums2.size();
int i=0,j=0;
int count=0; //计数器
int div = (m + n + 1)/2; // div表示两个序列合并后的中间元素或中间两个元素的第一个
int remain = (m + n +1)%2;
//remain==0表示不需要考虑下一个元素 ,div就是中位数,如【1,3】 【2】 ,中位数为2
//remain==1表示需要考虑下一个元素,两者的平均作为中位数,如【1,2】【3,4】 ,中位数为(2+3)/2=2.5
int temp;
while(i<m && j<n) {
count++;
if(nums1[i]<=nums2[j]) {
if(count == div) { //访问到了中间元素,或中间两个元素的第一个,下一步根据remain判断是否要考虑下一个元素
if(remain==0) { //remain==0,直接返回nums1[i]
return (double)nums1[i];
} else { //remain==1, 找到合并后序列的下一个元素,取平均数
if(i+1 == m) temp=nums2[j]; //如果nums1[i]已经是第一个序列的最后一个元素,直接取nums2[j]
else temp = nums1[i+1]<nums2[j]?nums1[i+1]:nums2[j]; //否则,比较nums1[i+1]和nums2[j],谁小谁是下一个
return (double)(temp+nums1[i])/2;
}
}
i++;
} else { //同理
if(count == div) {
if(remain==0) {
return (double)nums2[j];
} else {
if(j+1 == n) temp = nums1[i];
else temp = nums2[j+1]<nums1[i]?nums2[j+1]:nums1[i];
return (double)(temp+nums2[j])/2;
}
}
j++;
}
}
while(i<m) { //nums2 已经遍历到尾 ,在剩下的 nums1 中继续找
count++;
if(count == div) {
if(remain==0) {
return (double)nums1[i];
} else {
return (double)(nums1[i+1]+nums1[i])/2;
}
}
i++;
}
while(j<n) { //nums1 已经遍历到尾 ,在剩下的 nums2 中继续找
count++;
if(count == div) {
if(remain==0) {
return (double)nums2[j];
} else {
return (double)(nums2[j+1]+nums2[j])/2;
}
}
j++;
}
return 0.0;
}
解法二:二分查找
分析:
解法一时间复杂度为O(m+n),题目要求达到O(log(m+n)),很容易想到使用二分查找。
为了方便叙述,我们暂且认为中位数就是第k小的数。
在解法一种,也是求第k小的数,只不过每次循环的时候,从nums1或nums2中去掉一个不可能是中位数的元素,也就是一个一个的排除。
虽然两个数组之间没有什么关系,但两个数组内部都是有序的,我们能不能一部分一部分的排除?提到一部分一部分的排除,是不是就想到了二分查找,在二分查找中每次排除掉一半不可能包含target的子序列。
而在真正找中位数的时候:
- 我们需要找到第
k
小的数作为中位数;- 或者找到 (第
k
小的数 + 第k+1
小的数)/2。那么如何利用二分的思想在两个序列中找第
k
小的数那?具体的做法如下:假如我们想要找第
k
小的数,我们可以每次循环排除掉k/2
个数。举例:
当前nums1的长度为4,nums2的长度为10,我们需要找到 第7小 和第8小的数,然后取平均。所以我们需要找到第7小的数。
我们可以比较两个数组的第
k/2
个元素,即比较第3个元素,上面的数组中的4比下面的3大,这说明了下面的数组中的前k/2
个元素都不可能是 第k
小的数,可以直接排除掉 k/2 个元素。为什么可以这样排除?
假如nums1的前k/2个元素为【a b c】,nums2 的前k/2个元素为【x y z】,因为nums2[k/2] < nums1[k/2],也就是 z<c;
而数组内部是有序的,即a<b<c,x<y<z,
那么xyz在很小的情况下可能是:x<y<z<a<b<c或x<a<y<b<z<c等等。
即使xyz取到最大的情况也只是:a<b<x<y<z<c因此,我们无法确定c是第6小的元素(因为nums2中可能还存在比c小的数)
也无法确定a和b是第几小的元素,因为不知道xyz和ab的具体大小关系,ab可能比xyz大,甚至可能比nums2后面的元素大都有可能,所以无法确定。
但是我们可以确定的是,xyz一定是前5小的元素。其实用反证法更好理解:
如果z是第6小的元素,x,y<z是已知的,且在nums2中只有xy比z小,那么如果想让z成为第6小的元素,nums1中需要有3个比z小的元素,即前三个元素abc都小于z,这与z<c相悖,所以假设不成立。z更不可能是第7小,第8小....所以z最多是第5小的元素,即xyz一定是前5小的元素,一定不是中位数,可以排除。
小结:在上面的比较中,两个长度为 k/2 的子序列比较最后一位后,小的那个子序列中的所有元素都一定是前
(k/2)*2-1
个元素,一定不会是要找的第k
个元素,可以直接排除。将nums2的前3个元素排除后,对剩下的数组,利用上面的方法,继续进行比较,但需要注意的是,在排除3个元素后,我们要找的就是所有元素中的第
k-3
小的元素。即第 4 小的元素。
比较3和5,同理可以得到,1和3最多是前3小的数,直接排除。
比较4和4,相等,去掉哪个数组都可以,因为这两个数相等的时候,两个数组中第
k/2
个元素有可能是我们要找的第k小的元素,在k/2之前的都不可能,所以去掉哪个都一样,只要两个数组中的某一个k/2留下即可。在算法中,统一将下面的去掉,即当
nums1[k/2]>=nums2[k/2]
时,将下面的序列排除。
最后,当k=1时,只需要判断两个元素中谁小即可。
当某一个数组的长度不够时,即k/2比数组剩下的长度还大:
我们只需要将其指向这个数组的末尾即可,然后进行和上面一样的判断,谁小把谁排除。
代码:
double findMedianSortedArrays_find_k(vector<int>& nums1, vector<int>& nums2) {
int n = nums1.size();
int m = nums2.size();
int left = (n + m + 1) / 2;
int right = (n + m + 2) / 2;
int remain = (n + m + 1)%2;
if(remain == 0) {
return getKth(nums1, 0, n - 1, nums2, 0, m - 1, left);
} else {
return (getKth(nums1, 0, n - 1, nums2, 0, m - 1, left) + getKth(nums1, 0, n - 1, nums2, 0, m - 1, right)) * 0.5;
}
}
int getKth(vector<int> nums1, int start1, int end1, vector<int> nums2, int start2, int end2, int k) {
int len1 = end1 - start1 + 1;
int len2 = end2 - start2 + 1;
//让 len1 的长度小于 len2,这样就能保证如果有数组空了,一定是 len1
if (len1 > len2) return getKth(nums2, start2, end2, nums1, start1, end1, k);
if (len1 == 0) return nums2[start2 + k - 1];
if (k == 1) return min(nums1[start1], nums2[start2]);
//i和j是 进行比较的两个子序列 的末尾
int i = start1 + min(len1, k / 2) - 1; //如果数组长度比k/2小,
int j = start2 + min(len2, k / 2) - 1;
if (nums1[i] > nums2[j]) {
return getKth(nums1, start1, end1, nums2, j + 1, end2, k - (j - start2 + 1));
} else {
return getKth(nums1, i + 1, end1, nums2, start2, end2, k - (i - start1 + 1));
}
}
原文:https://www.cnblogs.com/kyrieliu/p/leetcode4.html