Skip to content

4. Median of Two Sorted Arrays

Tutorial

First, I thought binary search the median value t, and every time see if there is correct number of items less than t altogether in both arrays. However, it involves another two binary searches to locate the position of t in both arrays, within the binary search of t, It make the complexity rise to O(log(log(n+m))).

The best way is to binary search a divide position in the longer array. Imagine, we divide the two arrays into 4 halves, all the elements in the left two halves are less than the right two halves, this division would help us locate the median. Every time we have a divide position in the longer array, we can directly calculate where to divide in the second array, since the total number of elements in the left halves should equal to that of the right halves. Then, we check whether this division is valid, by comparing the right most elements in the left halves and the left most elements in the right halves. The complexity is O(log(m+n)).

Solution

class Solution(object):
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """
        if len(nums1) < len(nums2):
            nums1, nums2 = nums2, nums1

        self.n = len(nums1) + len(nums2)

        if not nums2:
            return (nums1[(self.n + 1) / 2 - 1] + nums1[self.n / 2]) / 2.0

        l = 0
        r = len(nums1)
        while l < r:
            mid = (l + r) / 2
            if self.ok(mid, nums1, nums2):
                r = mid
            else:
                l = mid + 1
        l1, r1 = self.get_lr(nums1, l)
        l2, r2 = self.get_lr(nums2, self.n / 2 - l)
        if self.n % 2 == 1:
            return min(r1, r2)
        return (max(l1, l2) + min(r1, r2)) / 2.0


    def ok(self, a, nums1, nums2):
        b = self.n / 2 - a
        if b > len(nums2):
            return False
        if b < 0:
            return True
        l1, r1 = self.get_lr(nums1, a)
        l2, r2 = self.get_lr(nums2, b)
        if l2 > r1:
            return False
        return True


    def get_lr(self, nums, n):
        l = nums[n - 1] if n != 0 else -float('inf')
        r = nums[n] if n != len(nums) else float('inf')
        return l, r