973. K Closest Points to Origin - Explanation

Problem Link

Description

You are given an 2-D array points where points[i] = [xi, yi] represents the coordinates of a point on an X-Y axis plane. You are also given an integer k.

Return the k closest points to the origin (0, 0).

The distance between two points is defined as the Euclidean distance (sqrt((x1 - x2)^2 + (y1 - y2)^2)).

You may return the answer in any order.

Example 1:

Input: points = [[0,2],[2,2]], k = 1

Output: [[0,2]]

Explanation : The distance between (0, 2) and the origin (0, 0) is 2. The distance between (2, 2) and the origin is sqrt(2^2 + 2^2) = 2.82842. So the closest point to the origin is (0, 2).

Example 2:

Input: points = [[0,2],[2,0],[2,2]], k = 2

Output: [[0,2],[2,0]]

Explanation: The output [2,0],[0,2] would also be accepted.

Constraints:

  • 1 <= k <= points.length <= 1000
  • -100 <= points[i][0], points[i][1] <= 100


Topics

Recommended Time & Space Complexity

You should aim for a solution as good or better than O(nlogk) time and O(k) space, where n is the size of the input array, and k is the number of points to be returned.


Hint 1

A naive solution would be to sort the array in ascending order based on the distances of the points from the origin (0, 0) and return the first k points. This would take O(nlogn) time. Can you think of a better way? Perhaps you could use a data structure that maintains only k points and allows efficient insertion and removal.


Hint 2

We can use a Max-Heap that keeps the maximum element at its top and allows retrieval in O(1) time. This data structure is ideal because we need to return the k closest points to the origin. By maintaining only k points in the heap, we can efficiently remove the farthest point when the size exceeds k. How would you implement this?


Hint 3

We initialize a Max-Heap that orders points based on their distances from the origin. Starting with an empty heap, we iterate through the array of points, inserting each point into the heap. If the size of the heap exceeds k, we remove the farthest point (the maximum element in the heap). After completing the iteration, the heap will contain the k closest points to the origin. Finally, we convert the heap into an array and return it.


Company Tags

Please upgrade to NeetCode Pro to view company tags.



Prerequisites

Before attempting this problem, you should be comfortable with:

  • Sorting with Custom Comparators - Sorting elements based on a computed value (distance) rather than the elements themselves
  • Euclidean Distance - Understanding that distance from origin is sqrt(x^2 + y^2), and that squared distance preserves ordering
  • Heap Data Structure - Using min-heaps and max-heaps to efficiently track smallest/largest elements
  • QuickSelect Algorithm - Partitioning arrays to find the k-th element in average O(n) time without full sorting

1. Sorting

Intuition

To find the k closest points to the origin (0, 0), we compare points by their distance from the origin.
Since the actual distance uses a square root, and square root preserves ordering, we can instead compare using squared distance:

[
d^2 = x^2 + y^2
]

This avoids unnecessary computation and is sufficient for sorting.

If we sort all points by this squared distance, then the first k points in sorted order must be the k closest ones.

Algorithm

  1. For every point (x, y), compute its squared distance: dist = x^2 + y^2
  2. Sort all points based on this dist value.
  3. Return the first k points from the sorted list.
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        points.sort(key=lambda p: p[0]**2 + p[1]**2)
        return points[:k]

Time & Space Complexity

  • Time complexity: O(nlogn)O(n \log n)
  • Space complexity: O(1)O(1) or O(n)O(n) depending on the sorting algorithm.

2. Min-Heap

Intuition

A min-heap always gives you the smallest element first.
If we insert every point into a min-heap, using its squared distance from the origin as the priority, then:

  • The closest point will be at the top.
  • The next closest will be removed next, and so on.

So if we remove from the heap k times, we get exactly the k closest points.

This works because the heap always keeps the smallest distances at the front.

Algorithm

  1. Create an empty min-heap.
  2. For each point (x, y):
    • Compute squared distance x^2 + y^2.
    • Push (distance, x, y) into the heap.
  3. Repeat k times:
    • Remove the smallest element from the heap.
    • Add its (x, y) coordinates to the result.
  4. Return the result list of k closest points.
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        minHeap = []
        for x, y in points:
            dist = (x ** 2) + (y ** 2)
            minHeap.append([dist, x, y])

        heapq.heapify(minHeap)
        res = []
        while k > 0:
            dist, x, y = heapq.heappop(minHeap)
            res.append([x, y])
            k -= 1

        return res

Time & Space Complexity

  • Time complexity: O(n+klogn)O(n + k * \log n)
  • Space complexity: O(n)O(n)

Where nn is the length of the array pointspoints.


3. Max Heap

Intuition

We want the k closest points, not all points sorted.

Use a max-heap of size k:

  • The heap always keeps the k closest points found so far.
  • The point with the largest distance among these k sits at the top.
  • When a new point is closer than the farthest in the heap, we remove the farthest and insert the new one.

This way, the heap never grows beyond size k, and it always contains the k best candidates.

Algorithm

  1. Create an empty max-heap.
  2. For each point:
    • Compute its squared distance from origin: d = x^2 + y^2.
    • Insert (d, point) into the heap.
    • If heap size exceeds k:
      • Remove the element with the maximum distance.
  3. After processing all points:
    • The heap contains exactly the k closest points.
  4. Return all points stored in the heap.
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        maxHeap = []
        for x, y in points:
            dist = -(x ** 2 + y ** 2)
            heapq.heappush(maxHeap, [dist, x, y])
            if len(maxHeap) > k:
                heapq.heappop(maxHeap)

        res = []
        while maxHeap:
            dist, x, y = heapq.heappop(maxHeap)
            res.append([x, y])
        return res

Time & Space Complexity

  • Time complexity: O(nlogk)O(n * \log k)
  • Space complexity: O(k)O(k)

Where nn is the length of the array pointspoints.


4. Quick Select

Intuition

We want the k closest points, but we do NOT need them sorted.

This is a perfect use-case for QuickSelect, the same idea used in QuickSort's partition step:

  • Pick a pivot point.
  • Partition all points into:
    • points closer than the pivot
    • points farther than the pivot
  • After partitioning, the pivot ends at its correct position in the final sorted order.
  • If the pivot ends up at index p:
    • If p == k, then the left side already contains the k closest points.
    • If p < k, search the right half.
    • If p > k, search the left half.

This avoids fully sorting the array and runs in average O(N) time.

Algorithm

  1. Define a function to compute squared distance: dist = x^2 + y^2.
  2. Use a partition function:
    • Choose a pivot distance.
    • Rearrange points so all smaller distances go left, larger go right.
    • Return the pivot's final index.
  3. Maintain two pointers: L = 0, R = n - 1.
  4. Repeatedly partition:
    • If pivot index p == k, stop.
    • If p < k, move L = p + 1.
    • If p > k, move R = p - 1.
  5. After partitioning ends, the first k points in the array are the k closest.
  6. Return those k points.
class Solution:
    def kClosest(self, points, k):
        euclidean = lambda x: x[0] ** 2 + x[1] ** 2
        def partition(l, r):
            pivotIdx = r
            pivotDist = euclidean(points[pivotIdx])
            i = l
            for j in range(l, r):
                if euclidean(points[j]) <= pivotDist:
                    points[i], points[j] = points[j], points[i]
                    i += 1
            points[i], points[r] = points[r], points[i]
            return i

        L, R = 0, len(points) - 1
        pivot = len(points)

        while pivot != k:
            pivot = partition(L, R)
            if pivot < k:
                L = pivot + 1
            else:
                R = pivot - 1
        return points[:k]

Time & Space Complexity

  • Time complexity: O(n)O(n) in average case, O(n2)O(n ^ 2) in worst case.
  • Space complexity: O(1)O(1)

Common Pitfalls

Computing Actual Distance with Square Root

Using sqrt(x^2 + y^2) is unnecessary and introduces floating-point precision issues. Since we only compare relative distances, squared distance x^2 + y^2 preserves ordering and avoids costly square root operations.

Using the Wrong Heap Type

For the heap approach, using a min-heap requires extracting k elements at the end, while a max-heap of size k naturally keeps the k closest. Mixing these up leads to incorrect results or inefficient solutions that maintain more elements than needed.

Integer Overflow in Distance Calculation

When coordinates can be large (up to 10^4), squaring them produces values up to 10^8. While this fits in a 32-bit integer, summing two such values approaches the limit. In languages with overflow concerns, ensure your distance calculation uses appropriate integer types.