心路历程:
这道题暴力解很简单,一看到要求O(log(m+n))的复杂度就只能是双指针,但是实测发现这道题用归并排序更快。这可能就是平均复杂度和实际复杂度的Gap吧。
二分法的思路:
要找到第 k (k>1) 小的元素,那么就取 pivot1 = nums1[k/2-1] 和 pivot2 = nums2[k/2-1] 比较;如果 pivot = pivot1,那么 nums1[0 … k/2-1] 都不可能是第 k 小的元素。把这些元素全部 “删除”,剩下的作为新的 nums1 数组。同理pivot = pivot2时也一样。
如果是一般情况,nums1 中小于等于 pivot1 的元素有 nums1[0 … k/2-2] 共计 k/2-1 个,nums2 中小于等于 pivot2 的元素有 nums2[0 … k/2-2] 共计 k/2-1 个。取 pivot = min(pivot1, pivot2),两个数组中小于等于 pivot 的元素共计不会超过 (k/2-1) + (k/2-1) <= k-2 个。由于我们 “删除” 了一些元素(这些元素都比第 k 小的元素要小),因此需要修改 k 的值,减去删除的数的个数。
解法一:归并排序法(有序数组的条件不能浪费)
class Solution:def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:# 两个数组按照顺序进行归并,但是其实复杂度最差为O(m+n)的程度nums3 = []i1, i2 = 0, 0m, n = len(nums1), len(nums2)while True:if i1 < m and i2 < n:if nums1[i1] < nums2[i2]: nums3.append(nums1[i1])i1 += 1else:nums3.append(nums2[i2])i2 += 1elif i1 == m and i2 < n:nums3 += nums2[i2:]breakelif i1 < m and i2 == n:nums3 += nums1[i1:]breakelse:breakhalf = (m+n) // 2if (m+n) % 2 == 0:return (nums3[half-1] + nums3[half]) / 2else:return nums3[half]
解法二:二分法(不断通过移动索引删除元素)
class Solution: def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float: def findkthnum(k): i1, i2 = 0, 0 while True: if i1 == m: return nums2[i2 + k - 1] if i2 == n: return nums1[i1 + k - 1] if k == 1: return min(nums1[i1], nums2[i2]) newi1 = min(i1 + k // 2 - 1, m - 1) # 防止超过数组本身长度newi2 = min(i2 + k // 2 - 1, n - 1) if nums1[newi1] <= nums2[newi2]: # 每次只删除小的那一半,因为必然不在这里面k -= newi1 + 1 - i1 # 这个就是删除ki1 = newi1 + 1 # 这个就是‘删除数组’else: k -= newi2 + 1 - i2i2 = newi2 + 1 m, n = len(nums1), len(nums2) l = m + nif l % 2 == 0: return (findkthnum(l // 2) + findkthnum(l // 2 + 1)) / 2else: return findkthnum(((l + 1) // 2))