Median in a row-wise sorted Matrix
Description
You are given a 2D matrix of size n × m, where each row is sorted in non-decreasing order. The total number of elements in the matrix (n × m) is guaranteed to be odd.
Your task is to find the median of all the elements in the matrix.
The median is defined as the middle element when all n × m elements are arranged in sorted order. Since the total count is always odd, the median is unique and always exists.
For example, if the sorted list of all elements is [1, 2, 3, 3, 5, 6, 6, 9, 9], the median is the 5th element (index 4 in 0-based), which is 5.
Examples
Example 1
Input: mat = [[1, 3, 5], [2, 6, 9], [3, 6, 9]]
Output: 5
Explanation: If we collect all elements and sort them, we get [1, 2, 3, 3, 5, 6, 6, 9, 9]. There are 9 elements total, so the median is the element at position ⌊9/2⌋ = 4 (0-indexed), which is 5.
Example 2
Input: mat = [[2, 4, 9], [3, 6, 7], [4, 7, 10]]
Output: 6
Explanation: All elements sorted: [2, 3, 4, 4, 6, 7, 7, 9, 10]. The middle element (index 4) is 6.
Example 3
Input: mat = [[3], [4], [8]]
Output: 4
Explanation: All elements sorted: [3, 4, 8]. The middle element (index 1) is 4.
Constraints
- 1 ≤ n, m ≤ 400
- 1 ≤ mat[i][j] ≤ 2000
- n × m is always odd
- Each row of the matrix is sorted in non-decreasing order
Editorial
Brute Force
Intuition
The most straightforward way to find the median is to treat the matrix as if it were a single list. We collect every element from every row into one big array, sort that array, and then simply pick the middle element.
Think of it like having multiple sorted piles of numbered cards spread across a table. The easiest way to find the middle card is to gather all piles together, sort the entire collection, and pick the card right in the center.
Step-by-Step Explanation
Let's trace through with mat = [[1, 3, 5], [2, 6, 9], [3, 6, 9]]:
Step 1: Start with an empty list to collect all elements.
- flat = []
Step 2: Traverse row 0: [1, 3, 5]. Append each element.
- flat = [1, 3, 5]
Step 3: Traverse row 1: [2, 6, 9]. Append each element.
- flat = [1, 3, 5, 2, 6, 9]
Step 4: Traverse row 2: [3, 6, 9]. Append each element.
- flat = [1, 3, 5, 2, 6, 9, 3, 6, 9]
Step 5: Sort the flattened array.
- flat = [1, 2, 3, 3, 5, 6, 6, 9, 9]
Step 6: Compute median index = total_elements / 2 = 9 / 2 = 4.
Step 7: Return flat[4] = 5.
Brute Force — Flatten, Sort, Pick Middle — Watch how we gather all matrix elements into a single array, sort it, and locate the median at the center position.
Algorithm
- Create an empty array
flat - For each row in the matrix, append all elements to
flat - Sort
flatin non-decreasing order - Return
flat[total_elements / 2]as the median
Code
#include <vector>
#include <algorithm>
using namespace std;
class Solution {
public:
int median(vector<vector<int>>& mat) {
vector<int> flat;
int n = mat.size();
int m = mat[0].size();
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
flat.push_back(mat[i][j]);
}
}
sort(flat.begin(), flat.end());
return flat[(n * m) / 2];
}
};class Solution:
def median(self, mat: list[list[int]]) -> int:
flat = []
for row in mat:
flat.extend(row)
flat.sort()
return flat[len(flat) // 2]import java.util.*;
class Solution {
public int median(int[][] mat) {
int n = mat.length;
int m = mat[0].length;
int[] flat = new int[n * m];
int idx = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
flat[idx++] = mat[i][j];
}
}
Arrays.sort(flat);
return flat[(n * m) / 2];
}
}Complexity Analysis
Time Complexity: O(n × m × log(n × m))
We first traverse all n × m elements to flatten the matrix (O(n × m)), and then sort the resulting array which takes O(n × m × log(n × m)). The sorting step dominates.
Space Complexity: O(n × m)
We create an auxiliary array holding all n × m elements. For large matrices (e.g., 400 × 400 = 160,000 elements), this is manageable but far from optimal.
Why This Approach Is Not Efficient
The brute force approach completely ignores the fact that each row is already sorted. By flattening and re-sorting, we throw away this valuable structural information and do redundant work.
With n and m up to 400, we have up to 160,000 elements. Sorting them costs roughly 160,000 × 17 ≈ 2.7 million operations — acceptable for this constraint range, but the O(n × m) extra space is wasteful.
More importantly, the approach doesn't scale well if constraints were larger. The key insight is: we don't need to know the exact sorted order of all elements — we just need to find the one element that sits at the median position. This is a classic scenario where binary search on the answer can drastically reduce work by leveraging the sorted rows.
Better Approach - Merge Using Min-Heap
Intuition
Since each row is sorted, we can think of the problem like merging multiple sorted lists — similar to the "merge k sorted lists" pattern. We use a min-heap (priority queue) to always extract the globally smallest element across all row fronts.
Imagine you have n conveyor belts, each carrying items in increasing order of size. You want to find the item that would be in the exact middle if you combined everything. Instead of dumping all items into one pile, you peek at the front item on each belt, pick the smallest one, advance that belt, and repeat. After pulling out exactly (n×m)/2 + 1 items, the last one you pulled is the median.
This avoids sorting the entire collection — we only process up to half the elements.
Step-by-Step Explanation
Let's trace through with mat = [[1, 3, 5], [2, 6, 9], [3, 6, 9]]:
We need the element at position (9+1)/2 = 5th smallest (1-indexed), i.e., we extract 5 elements.
Step 1: Initialize min-heap with first element of each row.
- Heap: [(1, row0, col0), (2, row1, col0), (3, row2, col0)]
- Count = 0
Step 2: Extract min = 1 (from row 0, col 0). Push next element from row 0: mat[0][1] = 3.
- Heap: [(2, row1, col0), (3, row0, col1), (3, row2, col0)]
- Count = 1
Step 3: Extract min = 2 (from row 1, col 0). Push next element from row 1: mat[1][1] = 6.
- Heap: [(3, row0, col1), (3, row2, col0), (6, row1, col1)]
- Count = 2
Step 4: Extract min = 3 (from row 0, col 1). Push next from row 0: mat[0][2] = 5.
- Heap: [(3, row2, col0), (5, row0, col2), (6, row1, col1)]
- Count = 3
Step 5: Extract min = 3 (from row 2, col 0). Push next from row 2: mat[2][1] = 6.
- Heap: [(5, row0, col2), (6, row1, col1), (6, row2, col1)]
- Count = 4
Step 6: Extract min = 5 (from row 0, col 2). Count = 5 = desired position.
- Median found: 5
Min-Heap Merge — Extracting Elements Until Median — Watch how a min-heap merges sorted rows by repeatedly extracting the smallest front element, advancing that row's pointer, until we reach the median position.
Algorithm
- Compute
desired = (n * m + 1) / 2— the position of the median (1-indexed) - Create a min-heap and push
(mat[i][0], i, 0)for each row i - Initialize
count = 0andresult = -1 - While
count < desired:- Extract the minimum element
(val, row, col)from the heap - Set
result = valand incrementcount - If
col + 1 < m, push(mat[row][col+1], row, col+1)into the heap
- Extract the minimum element
- Return
result
Code
#include <vector>
#include <queue>
using namespace std;
class Solution {
public:
int median(vector<vector<int>>& mat) {
int n = mat.size();
int m = mat[0].size();
int desired = (n * m + 1) / 2;
// min-heap: {value, row, col}
priority_queue<vector<int>, vector<vector<int>>, greater<vector<int>>> minHeap;
// Push first element of each row
for (int i = 0; i < n; i++) {
minHeap.push({mat[i][0], i, 0});
}
int count = 0, result = -1;
while (count < desired) {
auto top = minHeap.top();
minHeap.pop();
int val = top[0], row = top[1], col = top[2];
result = val;
count++;
if (col + 1 < m) {
minHeap.push({mat[row][col + 1], row, col + 1});
}
}
return result;
}
};import heapq
class Solution:
def median(self, mat: list[list[int]]) -> int:
n = len(mat)
m = len(mat[0])
desired = (n * m + 1) // 2
# Min-heap: (value, row, col)
min_heap = []
for i in range(n):
heapq.heappush(min_heap, (mat[i][0], i, 0))
count = 0
result = -1
while count < desired:
val, row, col = heapq.heappop(min_heap)
result = val
count += 1
if col + 1 < m:
heapq.heappush(min_heap, (mat[row][col + 1], row, col + 1))
return resultimport java.util.*;
class Solution {
public int median(int[][] mat) {
int n = mat.length;
int m = mat[0].length;
int desired = (n * m + 1) / 2;
// min-heap: {value, row, col}
PriorityQueue<int[]> minHeap = new PriorityQueue<>(
(a, b) -> a[0] - b[0]
);
for (int i = 0; i < n; i++) {
minHeap.offer(new int[]{mat[i][0], i, 0});
}
int count = 0, result = -1;
while (count < desired) {
int[] top = minHeap.poll();
int val = top[0], row = top[1], col = top[2];
result = val;
count++;
if (col + 1 < m) {
minHeap.offer(new int[]{mat[row][col + 1], row, col + 1});
}
}
return result;
}
}Complexity Analysis
Time Complexity: O(n × m × log(n))
We extract from the heap at most (n×m)/2 times. Each heap operation (insert or extract) takes O(log n) since the heap holds at most n elements (one per row). Therefore the total time is O(n × m × log(n)).
Space Complexity: O(n)
The min-heap stores at most one element per row at any time, so it uses O(n) space — a significant improvement over the O(n × m) of the brute force.
Why This Approach Is Not Efficient
While the heap approach improves space from O(n×m) to O(n), its time complexity O(n × m × log n) is not fundamentally better than sorting for this problem's constraints. We still process roughly half of all elements.
The crucial observation we haven't exploited yet: we don't need to identify the exact sorted order up to the median position. We only need to find which value sits at the median rank. Since each row is sorted, we can use binary search on the value domain — for any candidate value x, we can quickly count how many elements across all rows are ≤ x using binary search within each row. This approach doesn't touch individual elements at all; instead, it narrows down the value space logarithmically.
This leads to an O(n × log(m) × log(max−min)) solution with O(1) extra space — dramatically better for large matrices.
Optimal Approach - Binary Search on Value
Intuition
Instead of finding the actual sorted order of elements, we flip the question: for a given value x, how many elements in the matrix are less than or equal to x?
If we can answer this question efficiently, we can binary search over all possible values to find the smallest x where the count of elements ≤ x is at least (n×m+1)/2. That x is the median.
Here's the key insight that makes this fast: since each row is sorted, counting elements ≤ x within a single row is just a matter of finding where x would be inserted — which is a binary search taking O(log m). Doing this for all n rows costs O(n × log m).
Think of it like a guessing game. Someone hides the median among 160,000 numbers. Instead of looking at each number, you pick a guess and ask "how many numbers are ≤ my guess?" If the count is less than half, your guess is too small — go higher. If the count is at least half, your guess might be the answer — try going lower to tighten the bound. Each guess costs only O(n × log m) to verify, and you make at most O(log(max−min)) guesses.
Step-by-Step Explanation
Let's trace with mat = [[1, 3, 5], [2, 6, 9], [3, 6, 9]]:
Setup: n=3, m=3, total=9, desired = (9+1)/2 = 5. We need the smallest value x such that at least 5 elements are ≤ x.
Value range: min=1 (smallest first-column element), max=9 (largest last-column element).
Step 1: lo=1, hi=9. Compute mid = (1+9)/2 = 5.
- Count elements ≤ 5 in each row:
- Row [1,3,5]: upper_bound(5) = 3 elements ≤ 5
- Row [2,6,9]: upper_bound(5) = 1 element ≤ 5
- Row [3,6,9]: upper_bound(5) = 1 element ≤ 5
- Total count = 3+1+1 = 5
- count (5) ≥ desired (5) → mid could be the answer. Set hi = 5.
Step 2: lo=1, hi=5. Compute mid = (1+5)/2 = 3.
- Count elements ≤ 3 in each row:
- Row [1,3,5]: upper_bound(3) = 2 elements ≤ 3
- Row [2,6,9]: upper_bound(3) = 1 element ≤ 3
- Row [3,6,9]: upper_bound(3) = 1 element ≤ 3
- Total count = 2+1+1 = 4
- count (4) < desired (5) → 3 is too small. Set lo = 4.
Step 3: lo=4, hi=5. Compute mid = (4+5)/2 = 4.
- Count elements ≤ 4 in each row:
- Row [1,3,5]: upper_bound(4) = 2
- Row [2,6,9]: upper_bound(4) = 1
- Row [3,6,9]: upper_bound(4) = 1
- Total count = 2+1+1 = 4
- count (4) < desired (5) → 4 is too small. Set lo = 5.
Step 4: lo=5, hi=5. Loop ends (lo == hi).
- Result: lo = 5. The median is 5.
Binary Search on Value — Narrowing the Median Candidate — Watch how binary search over the value range [min, max] narrows down to the exact median by counting elements ≤ mid in each sorted row.
Algorithm
- Find
minVal= minimum of all first-column elements,maxVal= maximum of all last-column elements - Set
desired = (n * m + 1) / 2 - Binary search:
lo = minVal,hi = maxVal - While
lo < hi:- Compute
mid = lo + (hi - lo) / 2 - For each row, use
upper_boundto count elements ≤mid - Sum counts across all rows to get
totalCount - If
totalCount < desired:lo = mid + 1(need a larger value) - Else:
hi = mid(mid could be the answer, try smaller)
- Compute
- Return
lo(which equalshi)
Code
#include <vector>
#include <algorithm>
#include <climits>
using namespace std;
class Solution {
public:
int median(vector<vector<int>>& mat) {
int n = mat.size();
int m = mat[0].size();
int minVal = INT_MAX, maxVal = INT_MIN;
for (int i = 0; i < n; i++) {
minVal = min(minVal, mat[i][0]);
maxVal = max(maxVal, mat[i][m - 1]);
}
int desired = (n * m + 1) / 2;
int lo = minVal, hi = maxVal;
while (lo < hi) {
int mid = lo + (hi - lo) / 2;
int count = 0;
for (int i = 0; i < n; i++) {
count += upper_bound(mat[i].begin(), mat[i].end(), mid) - mat[i].begin();
}
if (count < desired) {
lo = mid + 1;
} else {
hi = mid;
}
}
return lo;
}
};from bisect import bisect_right
class Solution:
def median(self, mat: list[list[int]]) -> int:
n = len(mat)
m = len(mat[0])
min_val = min(row[0] for row in mat)
max_val = max(row[-1] for row in mat)
desired = (n * m + 1) // 2
lo, hi = min_val, max_val
while lo < hi:
mid = lo + (hi - lo) // 2
count = 0
for row in mat:
count += bisect_right(row, mid)
if count < desired:
lo = mid + 1
else:
hi = mid
return loimport java.util.*;
class Solution {
public int median(int[][] mat) {
int n = mat.length;
int m = mat[0].length;
int minVal = Integer.MAX_VALUE;
int maxVal = Integer.MIN_VALUE;
for (int i = 0; i < n; i++) {
minVal = Math.min(minVal, mat[i][0]);
maxVal = Math.max(maxVal, mat[i][m - 1]);
}
int desired = (n * m + 1) / 2;
int lo = minVal, hi = maxVal;
while (lo < hi) {
int mid = lo + (hi - lo) / 2;
int count = 0;
for (int i = 0; i < n; i++) {
count += upperBound(mat[i], mid);
}
if (count < desired) {
lo = mid + 1;
} else {
hi = mid;
}
}
return lo;
}
private int upperBound(int[] row, int target) {
int lo = 0, hi = row.length;
while (lo < hi) {
int mid = lo + (hi - lo) / 2;
if (row[mid] <= target) {
lo = mid + 1;
} else {
hi = mid;
}
}
return lo;
}
}Complexity Analysis
Time Complexity: O(n × log(m) × log(maxVal − minVal))
The outer binary search runs O(log(maxVal − minVal)) iterations. For the given constraints (values up to 2000), this is at most log(2000) ≈ 11 iterations. Within each iteration, we perform a binary search (O(log m)) on each of the n rows, costing O(n × log m). Total: O(n × log m × log(maxVal − minVal)).
With n=400, m=400, maxVal−minVal=2000: approximately 400 × 9 × 11 ≈ 39,600 operations — extremely fast.
Space Complexity: O(1)
We only use a constant number of variables (lo, hi, mid, count). No auxiliary data structures are needed — a massive improvement over both previous approaches.