Skip to main content

K Closest Points to Origin

MEDIUMProblemSolveExternal Links

Description

You are given a 2D array points where points[i] = [x_i, y_i] represents a point on the X-Y coordinate plane, and an integer k.

Return the k points that are closest to the origin (0, 0).

The distance between a point (x, y) and the origin is the Euclidean distance: √(x² + y²). Since we only need to compare distances (not compute exact values), comparing x² + y² directly is equivalent and avoids floating-point issues.

You may return the answer in any order. The answer is guaranteed to be unique (except for the order of the points).

Examples

Example 1

Input: points = [[1,3], [-2,2], [5,8], [0,1]], k = 2

Output: [[0,1], [-2,2]]

Explanation: Compute squared distances from origin for each point:

  • [1,3]: 1² + 3² = 10
  • [-2,2]: (-2)² + 2² = 8
  • [5,8]: 5² + 8² = 89
  • [0,1]: 0² + 1² = 1

Sorted by distance: [0,1] (1), [-2,2] (8), [1,3] (10), [5,8] (89). The 2 closest are [0,1] and [-2,2]. Output can be in any order.

Example 2

Input: points = [[3,3], [5,-1], [-2,4]], k = 2

Output: [[3,3], [-2,4]]

Explanation: Squared distances:

  • [3,3]: 9 + 9 = 18
  • [5,-1]: 25 + 1 = 26
  • [-2,4]: 4 + 16 = 20

The 2 closest points are [3,3] (distance² = 18) and [-2,4] (distance² = 20). [5,-1] is the farthest at distance² = 26.

Example 3

Input: points = [[0,2], [2,2]], k = 1

Output: [[0,2]]

Explanation: [0,2] has distance² = 4, [2,2] has distance² = 8. The single closest point is [0,2].

Constraints

  • 1 ≤ k ≤ points.length ≤ 10^4
  • -10^4 ≤ x_i, y_i ≤ 10^4

Editorial

Brute Force - Sorting

Intuition

The most natural approach: if you need the k closest points, just sort ALL points by distance and pick the first k.

Imagine you have a stack of student exam papers and need to find the top 2 scores. The simplest method is to sort the entire stack by score, then take the top 2. It works, but you did more work than necessary — you sorted ALL papers when you only needed the top 2.

A key optimization: since √(a) < √(b) whenever a < b for non-negative values, we can compare squared distances (x² + y²) instead of actual Euclidean distances. This avoids expensive square root calculations and floating-point precision issues.

Step-by-Step Explanation

Let's trace with points = [[1,3], [-2,2], [5,8], [0,1]], k = 2.

Step 1: Compute squared distance for each point:

  • [1,3]: 1 + 9 = 10
  • [-2,2]: 4 + 4 = 8
  • [5,8]: 25 + 64 = 89
  • [0,1]: 0 + 1 = 1

Step 2: Sort points by squared distance:

  • [0,1] (dist²=1), [-2,2] (dist²=8), [1,3] (dist²=10), [5,8] (dist²=89)

Step 3: Take the first k=2 points: [[0,1], [-2,2]].

Step 4: Return [[0,1], [-2,2]].

Sort by Distance — Full Array Sort Then Slice — Watch as we compute squared distances, sort the array by distance, then pick the first k elements as our answer.

Algorithm

  1. For each point [x, y], compute the squared distance x² + y² (no square root needed).
  2. Sort the points array using squared distance as the sorting key.
  3. Return the first k points from the sorted array.

Code

class Solution {
public:
    vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
        sort(points.begin(), points.end(), [](const vector<int>& a, const vector<int>& b) {
            return (a[0] * a[0] + a[1] * a[1]) < (b[0] * b[0] + b[1] * b[1]);
        });
        return vector<vector<int>>(points.begin(), points.begin() + k);
    }
};
class Solution:
    def kClosest(self, points: list[list[int]], k: int) -> list[list[int]]:
        points.sort(key=lambda p: p[0] * p[0] + p[1] * p[1])
        return points[:k]
class Solution {
    public int[][] kClosest(int[][] points, int k) {
        Arrays.sort(points, (a, b) ->
            (a[0] * a[0] + a[1] * a[1]) - (b[0] * b[0] + b[1] * b[1])
        );
        return Arrays.copyOfRange(points, 0, k);
    }
}

Complexity Analysis

Time Complexity: O(n log n)

Sorting n points dominates the runtime. Computing each squared distance is O(1), and we do it n times during sorting comparisons. The sort itself is O(n log n) in the average and worst cases.

Space Complexity: O(log n)

The sort uses O(log n) space for its internal recursion stack (e.g., Timsort in Python, introsort in C++). We don't allocate additional data structures proportional to n.

Why This Approach Is Not Efficient

Sorting works, but it does more work than necessary. We sort all n points into perfect order, even though we only care about the top k. If n = 10,000 and k = 5, we're sorting 10,000 elements just to pick 5.

The O(n log n) time doesn't depend on k at all — whether k = 1 or k = n, we pay the same cost. Can we do better?

Key insight: We don't need the full sorted order. We only need to identify the k smallest distances. A max-heap of size k lets us maintain a running set of the k closest points seen so far, processing each point in O(log k) time instead of sorting everything.

Better Approach - Max-Heap

Intuition

Think of it this way: you're a talent scout watching auditions. You have k seats in the finals. As each performer auditions, you compare them against the weakest finalist currently seated. If the new performer is better, they replace that weakest finalist.

Translated to our problem: maintain a max-heap of size k. The heap's root is always the farthest point among our current k closest. For each new point:

  • If the heap has fewer than k elements, add the point.
  • If the new point is closer than the heap's root (the farthest of our k closest), remove the root and insert the new point.
  • Otherwise, skip — this point is farther than all k current candidates.

Why a max-heap and not min-heap? Because we need quick access to the farthest point in our k-set. The max-heap keeps the largest distance at the root, making it O(1) to check and O(log k) to replace.

Step-by-Step Explanation

Let's trace with points = [[1,3], [-2,2], [5,8], [0,1]], k = 2.

Step 1: Initialize empty max-heap. Heap: []. Size: 0.

Step 2: Process [1,3] (dist²=10). Heap size (0) < k (2), so push. Heap: [(10, [1,3])]. Size: 1.

Step 3: Process [-2,2] (dist²=8). Heap size (1) < k (2), so push. Heap: [(10, [1,3]), (8, [-2,2])]. Root = 10. Size: 2.

Step 4: Process [5,8] (dist²=89). Heap is full. Compare 89 vs root 10. Since 89 > 10, this point is farther than our farthest finalist. Skip it.

Step 5: Process [0,1] (dist²=1). Heap is full. Compare 1 vs root 10. Since 1 < 10, this point is closer! Pop root [1,3] (dist²=10), push [0,1] (dist²=1). Heap: [(8, [-2,2]), (1, [0,1])]. Root = 8.

Step 6: All points processed. Return heap contents: [[-2,2], [0,1]].

Max-Heap of Size k — Maintaining the k Closest Points — Watch how a max-heap of size k efficiently tracks the k closest points. The root always holds the farthest of our current candidates, making replacement decisions O(1) to check.

Algorithm

  1. Create an empty max-heap.
  2. For each point [x, y]:
    a. Compute dist = x² + y².
    b. If heap size < k, push (dist, point) onto the heap.
    c. Else if dist < heap root, pop the root and push (dist, point).
  3. Return all points remaining in the heap.

Code

class Solution {
public:
    vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
        // max-heap: (distance², point)
        priority_queue<pair<int, vector<int>>> maxHeap;

        for (auto& p : points) {
            int dist = p[0] * p[0] + p[1] * p[1];
            if ((int)maxHeap.size() < k) {
                maxHeap.push({dist, p});
            } else if (dist < maxHeap.top().first) {
                maxHeap.pop();
                maxHeap.push({dist, p});
            }
        }

        vector<vector<int>> result;
        while (!maxHeap.empty()) {
            result.push_back(maxHeap.top().second);
            maxHeap.pop();
        }
        return result;
    }
};
import heapq

class Solution:
    def kClosest(self, points: list[list[int]], k: int) -> list[list[int]]:
        # Python has min-heap, so negate distances for max-heap behavior
        max_heap = []

        for x, y in points:
            dist = -(x * x + y * y)  # negate for max-heap
            if len(max_heap) < k:
                heapq.heappush(max_heap, (dist, [x, y]))
            elif dist > max_heap[0][0]:
                heapq.heapreplace(max_heap, (dist, [x, y]))

        return [point for _, point in max_heap]
class Solution {
    public int[][] kClosest(int[][] points, int k) {
        // max-heap by distance
        PriorityQueue<int[]> maxHeap = new PriorityQueue<>(
            (a, b) -> (b[0] * b[0] + b[1] * b[1]) - (a[0] * a[0] + a[1] * a[1])
        );

        for (int[] p : points) {
            maxHeap.offer(p);
            if (maxHeap.size() > k) {
                maxHeap.poll(); // remove the farthest
            }
        }

        int[][] result = new int[k][2];
        int i = 0;
        for (int[] p : maxHeap) {
            result[i++] = p;
        }
        return result;
    }
}

Complexity Analysis

Time Complexity: O(n log k)

We iterate through all n points. For each point, we perform at most one heap operation (push or replace), each taking O(log k) since the heap never exceeds size k. Total: n × O(log k) = O(n log k).

When k is much smaller than n, this is a significant improvement over O(n log n). For example, with n = 10,000 and k = 10: sorting does ~133,000 comparisons, while the heap does ~33,000.

Space Complexity: O(k)

The max-heap stores exactly k elements at any time. The output array also has k elements. Total extra space: O(k).

Why This Approach Is Not Efficient

The max-heap approach is better than sorting — O(n log k) vs O(n log n) — but it still does O(log k) work per element. Can we do even better?

The heap maintains a fully ordered set of k elements, but we don't actually need order within the k closest points. We just need to partition the array: k closest points on one side, the rest on the other.

This is exactly what the Quickselect algorithm does. Based on the partitioning step from Quicksort, it can find the k-th smallest element (and place all smaller elements before it) in O(n) average time — no heap, no sorting, just clever partitioning.

Optimal Approach - Quickselect

Intuition

Quickselect borrows the partition step from Quicksort. The idea:

  1. Pick a pivot point.
  2. Partition the array: all points closer than the pivot go to the left, all farther go to the right. The pivot lands at its correct position.
  3. Check where the pivot landed:
    • If at position k, we're done — the first k elements are the k closest.
    • If at position > k, the answer lies entirely in the left half — recurse left.
    • If at position < k, the left points are all in the answer, and we need more from the right — recurse right.

The key insight: unlike sorting (which fully orders both halves), Quickselect only recurses into one half. This halving behavior gives O(n) average time — similar to how binary search is O(log n) by halving each time, but here each "half" costs O(size of subarray) to partition.

Worst case is O(n²) if we always pick the worst pivot, but randomizing the pivot choice makes this extremely unlikely in practice.

Step-by-Step Explanation

Let's trace with points = [[1,3], [-2,2], [5,8], [0,1]], k = 2.
Distances²: [10, 8, 89, 1]. We want the 2 points with smallest distances.

Step 1: Choose pivot = last element [0,1] (dist²=1). Partition range [0, 3].

Step 2: Partition: scan from left. Compare each point's dist² against pivot dist²=1.

  • i=0: [1,3] dist²=10. 10 > 1 → don't swap.
  • i=1: [-2,2] dist²=8. 8 > 1 → don't swap.
  • i=2: [5,8] dist²=89. 89 > 1 → don't swap.

Step 3: Place pivot at position 0. Swap [0,1] with points[0]. Array becomes: [[0,1], [-2,2], [5,8], [1,3]]. Pivot index = 0.

Step 4: Pivot at index 0. We need k=2 points. Since 0 < k-1=1, we need more points from the right. All elements at indices 0..0 are confirmed in the answer. Recurse on [1, 3].

Step 5: Recurse: range [1, 3], need 1 more point. Pivot = [1,3] (dist²=10, last element in range).

  • i=1: [-2,2] dist²=8. 8 < 10 → swap with position 1 (already there). Advance boundary.
  • i=2: [5,8] dist²=89. 89 > 10 → don't swap.

Step 6: Place pivot at position 2. Swap [1,3] with points[2]. Array: [[0,1], [-2,2], [1,3], [5,8]]. Pivot index = 2.

Step 7: Pivot at index 2. We need k=2 total, and index 2 ≥ k. So the first 2 elements [[0,1], [-2,2]] are our answer.

Result: Return [[0,1], [-2,2]].

Quickselect — Partition to Find k Closest — Watch how Quickselect partitions the array around pivots, recursing into only one side each time. After at most 2 partitions, the k closest points are in the first k positions.

Algorithm

  1. Define a distance function: dist(p) = p[0]² + p[1]².
  2. Call quickselect on range [0, n-1] with target k:
    a. Choose a pivot (last element, or random for better average behavior).
    b. Partition: move all points with distance ≤ pivot's distance to the left.
    c. If pivot lands at index k-1, return.
    d. If pivot lands at index > k-1, recurse on the left subarray.
    e. If pivot lands at index < k-1, recurse on the right subarray.
  3. Return points[0..k-1].

Code

class Solution {
public:
    vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
        quickSelect(points, 0, points.size() - 1, k);
        return vector<vector<int>>(points.begin(), points.begin() + k);
    }

private:
    int dist(vector<int>& p) {
        return p[0] * p[0] + p[1] * p[1];
    }

    void quickSelect(vector<vector<int>>& points, int left, int right, int k) {
        if (left >= right) return;

        int pivotDist = dist(points[right]);
        int boundary = left;

        for (int i = left; i < right; i++) {
            if (dist(points[i]) <= pivotDist) {
                swap(points[i], points[boundary]);
                boundary++;
            }
        }
        swap(points[boundary], points[right]);

        if (boundary == k - 1) return;
        else if (boundary < k - 1) quickSelect(points, boundary + 1, right, k);
        else quickSelect(points, left, boundary - 1, k);
    }
};
import random

class Solution:
    def kClosest(self, points: list[list[int]], k: int) -> list[list[int]]:
        def dist(p):
            return p[0] * p[0] + p[1] * p[1]

        def quick_select(left, right):
            # Randomize pivot for better average case
            pivot_idx = random.randint(left, right)
            points[pivot_idx], points[right] = points[right], points[pivot_idx]

            pivot_dist = dist(points[right])
            boundary = left

            for i in range(left, right):
                if dist(points[i]) <= pivot_dist:
                    points[i], points[boundary] = points[boundary], points[i]
                    boundary += 1

            points[boundary], points[right] = points[right], points[boundary]

            if boundary == k - 1:
                return
            elif boundary < k - 1:
                quick_select(boundary + 1, right)
            else:
                quick_select(left, boundary - 1)

        quick_select(0, len(points) - 1)
        return points[:k]
class Solution {
    public int[][] kClosest(int[][] points, int k) {
        quickSelect(points, 0, points.length - 1, k);
        return Arrays.copyOfRange(points, 0, k);
    }

    private int dist(int[] p) {
        return p[0] * p[0] + p[1] * p[1];
    }

    private void quickSelect(int[][] points, int left, int right, int k) {
        if (left >= right) return;

        int pivotDist = dist(points[right]);
        int boundary = left;

        for (int i = left; i < right; i++) {
            if (dist(points[i]) <= pivotDist) {
                int[] temp = points[i];
                points[i] = points[boundary];
                points[boundary] = temp;
                boundary++;
            }
        }
        int[] temp = points[boundary];
        points[boundary] = points[right];
        points[right] = temp;

        if (boundary == k - 1) return;
        else if (boundary < k - 1) quickSelect(points, boundary + 1, right, k);
        else quickSelect(points, left, boundary - 1, k);
    }
}

Complexity Analysis

Time Complexity: O(n) average, O(n²) worst case

On average, each partition halves the problem: first pass scans n elements, second scans ~n/2, third ~n/4, etc. Total: n + n/2 + n/4 + ... = 2n = O(n). The worst case occurs when the pivot is always the smallest or largest element, degenerating to O(n²). Randomizing the pivot makes this practically impossible.

Space Complexity: O(log n) average

The recursion depth is O(log n) on average (each partition roughly halves the range). Worst case recursion depth is O(n), but again, randomized pivots prevent this. The algorithm operates in-place — no additional arrays needed.

Comparison:

ApproachTimeSpaceNotes
SortingO(n log n)O(log n)Simple, always consistent
Max-HeapO(n log k)O(k)Better when k << n
QuickselectO(n) avgO(log n)Best average, but O(n²) worst case