This problem asks to find the median of a matrix where each row is sorted in non-decreasing order. A naive approach of sorting the entire matrix would take O(mn log(mn)) time, exceeding the required time complexity. The efficient solution utilizes binary search.
Approach:
The core idea is to perform a binary search on the possible range of values in the matrix (1 to 106). For each potential median x
, we count the number of elements in the matrix that are less than or equal to x
. If this count is greater than or equal to the target position (the middle element index), then the median must be less than or equal to x
. Otherwise, the median is greater than x
.
Algorithm:
Find the Target Index: Calculate the target index target
which represents the index of the median element in a sorted array of all matrix elements. For an odd number of elements, the median index is (m*n + 1) // 2
.
Binary Search: Perform a binary search on the range [1, 106]. In each iteration:
mid
as the potential median.mid
. This is efficiently done using binary search on each row.target
, it means the median is less than or equal to mid
. So, update the right boundary of the search space to mid
.mid + 1
.Return Result: After the binary search converges, the left
boundary will hold the median value.
Time Complexity Analysis:
Therefore, the overall time complexity is O(m log n log M). This is significantly better than the naive O(mn log(mn)) approach.
Space Complexity Analysis:
The algorithm uses a constant amount of extra space for variables, so the space complexity is O(1).
Code Examples:
The code examples provided in Python, Java, C++, and Go all implement the described algorithm. They differ slightly in syntax and library functions used for binary search (e.g., bisect_right
in Python, upper_bound
in C++, etc.), but the underlying logic is the same.
Example in Python (using bisect
):
import bisect
class Solution:
def matrixMedian(self, grid: List[List[int]]) -> int:
m, n = len(grid), len(grid[0])
target = (m * n + 1) // 2
def count_less_equal(x):
count = 0
for row in grid:
count += bisect.bisect_right(row, x)
return count
left, right = 1, 10**6 # Range of possible values
while left < right:
mid = (left + right) // 2
if count_less_equal(mid) >= target:
right = mid
else:
left = mid + 1
return left
The other languages follow a similar structure, adapting the binary search and counting mechanisms to their respective standard libraries.