Kth Smallest Element in a Sorted Matrix

Sanjeev SharmaSanjeev Sharma
3 min read

Advertisement

Problem

Given an n×n matrix where each row and column is sorted, find the kth smallest element.

Two approaches:

  1. Min-heap: push first column, pop k times pushing right neighbor — O(k log n)
  2. Binary search on value: count elements ≤ mid — O(n log(max-min))

Solutions

// C++ — binary search O(n log(max-min))
int kthSmallest(vector<vector<int>>& matrix, int k) {
    int n = matrix.size(), lo = matrix[0][0], hi = matrix[n-1][n-1];
    while (lo < hi) {
        int mid = lo + (hi - lo) / 2;
        int count = 0, j = n - 1;
        for (int i = 0; i < n; i++) {
            while (j >= 0 && matrix[i][j] > mid) j--;
            count += j + 1;
        }
        if (count < k) lo = mid + 1; else hi = mid;
    }
    return lo;
}
// Java
public int kthSmallest(int[][] matrix, int k) {
    int n = matrix.length, lo = matrix[0][0], hi = matrix[n-1][n-1];
    while (lo < hi) {
        int mid = lo + (hi - lo) / 2;
        int count = 0, j = n - 1;
        for (int i = 0; i < n; i++) {
            while (j >= 0 && matrix[i][j] > mid) j--;
            count += j + 1;
        }
        if (count < k) lo = mid + 1; else hi = mid;
    }
    return lo;
}
// JavaScript
function kthSmallest(matrix, k) {
    const n = matrix.length;
    let lo = matrix[0][0], hi = matrix[n-1][n-1];
    while (lo < hi) {
        const mid = (lo + hi) >> 1;
        let count = 0, j = n - 1;
        for (let i = 0; i < n; i++) {
            while (j >= 0 && matrix[i][j] > mid) j--;
            count += j + 1;
        }
        if (count < k) lo = mid + 1; else hi = mid;
    }
    return lo;
}
# Python — binary search
def kthSmallest(matrix, k):
    n = len(matrix)
    lo, hi = matrix[0][0], matrix[n-1][n-1]
    while lo < hi:
        mid = (lo + hi) // 2
        count = 0
        j = n - 1
        for i in range(n):
            while j >= 0 and matrix[i][j] > mid:
                j -= 1
            count += j + 1
        if count < k:
            lo = mid + 1
        else:
            hi = mid
    return lo

Complexity

  • Binary search: O(n log(max-min))
  • Heap: O(k log n)

Key Insight

The staircase count (start from top-right) counts elements ≤ mid in O(n). Binary search on the answer value range.

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro