279. Perfect Squares - Explanation

Problem Link

Description

You are given an integer n, return the least number of perfect square numbers that sum to n.

A perfect square is an integer that is the square of an integer. For example, 1, 4, 9, 16, 25... are perfect squares.

Example 1:

Input: n = 13

Output: 2

Explanation: 13 = 4 + 9.

Example 2:

Input: n = 6

Output: 3

Explanation: 6 = 4 + 1 + 1.

Constraints:

  • 1 <= n <= 10,000


Topics

Company Tags

Please upgrade to NeetCode Pro to view company tags.



Prerequisites

Before attempting this problem, you should be comfortable with:

  • Recursion - The foundation for building the brute-force and memoized solutions
  • Dynamic Programming - Understanding both top-down (memoization) and bottom-up approaches for optimization
  • Breadth-First Search (BFS) - Used to model the problem as a shortest path search
  • Math (Number Theory) - Understanding perfect squares and Lagrange's four square theorem for the optimal solution

1. Recursion

Intuition

We want to express n as a sum of the fewest perfect squares. At each step, we can subtract any perfect square that fits, then recursively solve for the remainder. By trying all possible perfect squares and taking the minimum, we find the optimal answer. This brute-force approach explores all combinations but results in repeated subproblems.

Algorithm

  1. Define a recursive function that takes a target value.
  2. Base case: if target is 0, return 0 (no squares needed).
  3. Initialize the result to target (the worst case is using all 1s).
  4. For each perfect square i*i that does not exceed target, recursively solve for (target - i*i) and update the result with 1 + the recursive result.
  5. Return the minimum count found.
class Solution:
    def numSquares(self, n: int) -> int:
        def dfs(target):
            if target == 0:
                return 0

            res = target
            for i in range(1, target):
                if i * i > target:
                    break
                res = min(res, 1 + dfs(target - i * i))
            return res

        return dfs(n)

Time & Space Complexity

  • Time complexity: O(nn)O(n ^ {\sqrt {n}})
  • Space complexity: O(n)O(n) for recursion stack.

2. Dynamic Programming (Top-Down)

Intuition

The recursive solution recomputes the same subproblems many times. By caching results in a memoization table, we avoid redundant work. Each unique target value is solved once, and subsequent calls return the cached result. This transforms the exponential time complexity into polynomial.

Algorithm

  1. Create a memoization dictionary to store results for each target value.
  2. Define a recursive function. If target is 0, return 0. If target is in memo, return the cached value.
  3. Initialize result to target.
  4. For each perfect square i*i up to target, compute 1 + dfs(target - i*i) and track the minimum.
  5. Store the result in memo and return it.
  6. Call the recursive function with n.
class Solution:
    def numSquares(self, n: int) -> int:
        memo = {}

        def dfs(target):
            if target == 0:
                return 0
            if target in memo:
                return memo[target]

            res = target
            for i in range(1, target + 1):
                if i * i > target:
                    break
                res = min(res, 1 + dfs(target - i * i))

            memo[target] = res
            return res

        return dfs(n)

Time & Space Complexity

  • Time complexity: O(nn)O(n * \sqrt {n})
  • Space complexity: O(n)O(n)

3. Dynamic Programming (Bottom-Up)

Intuition

Instead of solving top-down with recursion, we can build the solution bottom-up. We compute the minimum number of squares for every value from 1 to n, using previously computed results. For each target, we try subtracting every perfect square and take the minimum result plus one.

Algorithm

  1. Create a dp array of size n+1, initialized to n (worst case of all 1s). Set dp[0] = 0.
  2. For each target from 1 to n, iterate through all perfect squares s*s that do not exceed target.
  3. Update dp[target] = min(dp[target], 1 + dp[target - s*s]).
  4. Return dp[n].
class Solution:
    def numSquares(self, n: int) -> int:
        dp = [n] * (n + 1)
        dp[0] = 0

        for target in range(1, n + 1):
            for s in range(1, target + 1):
                square = s * s
                if target - square < 0:
                    break
                dp[target] = min(dp[target], 1 + dp[target - square])

        return dp[n]

Time & Space Complexity

  • Time complexity: O(nn)O(n * \sqrt {n})
  • Space complexity: O(n)O(n)

Intuition

We can view this as a shortest path problem. Starting from 0, each step adds a perfect square. BFS explores all sums reachable with 1 square, then 2 squares, and so on. The first time we reach n, we have found the minimum number of squares. Using a set to track visited values prevents processing the same sum multiple times.

Algorithm

  1. Initialize a queue with 0 and a set to track seen values.
  2. Process the queue level by level, incrementing the count at each level.
  3. For each current sum, try adding every perfect square s*s such that current + s*s <= n.
  4. If current + s*s equals n, return the current count.
  5. If the new sum has not been seen, add it to the set and enqueue it.
  6. Continue until n is reached.
class Solution:
    def numSquares(self, n: int) -> int:
        q = deque()
        seen = set()

        res = 0
        q.append(0)
        while q:
            res += 1
            for _ in range(len(q)):
                cur = q.popleft()
                s = 1
                while s * s + cur <= n:
                    nxt = cur + s * s
                    if nxt == n:
                        return res
                    if nxt not in seen:
                        seen.add(nxt)
                        q.append(nxt)
                    s += 1

        return res

Time & Space Complexity

  • Time complexity: O(nn)O(n * \sqrt {n})
  • Space complexity: O(n)O(n)

5. Math

Intuition

Lagrange's four square theorem states that every positive integer can be expressed as the sum of at most four perfect squares. Using additional number theory, we can determine the exact answer in constant time. If n is a perfect square, the answer is 1. If n can be written as the sum of two squares, the answer is 2. If n is of the form 4^k(8m+7), the answer is 4. Otherwise, the answer is 3.

Algorithm

  1. Remove all factors of 4 from n (divide by 4 while divisible).
  2. If the reduced n is congruent to 7 mod 8, return 4.
  3. Check if the original n is a perfect square. If so, return 1.
  4. Check if n can be expressed as the sum of two squares by testing all possible first squares. If so, return 2.
  5. Otherwise, return 3.
class Solution:
    def numSquares(self, n: int) -> int:
        while n % 4 == 0:
            n //= 4

        if n % 8 == 7:
            return 4

        def isSquareNum(num):
            s = int(math.sqrt(num))
            return s * s == num

        if isSquareNum(n):
            return 1

        i = 1
        while i * i <= n:
            if isSquareNum(n - i * i):
                return 2
            i += 1

        return 3

Time & Space Complexity

  • Time complexity: O(n)O(\sqrt {n})
  • Space complexity: O(1)O(1)

Common Pitfalls

Not Recognizing Overlapping Subproblems

A common mistake is implementing plain recursion without memoization. The recursive solution recomputes the same subproblems many times (e.g., dfs(5) might be called from multiple paths). Without caching results, the solution becomes exponentially slow and will time out on larger inputs.

Incorrect Base Case or Initialization

Forgetting to handle the base case dp[0] = 0 or initializing the DP array incorrectly leads to wrong answers. The value dp[0] = 0 is crucial because it represents that zero squares are needed to sum to zero. Similarly, initializing other values to n (worst case of all 1s) ensures the minimum is correctly computed.

Iterating Over Non-Perfect-Squares

Some implementations mistakenly iterate through all numbers from 1 to target instead of only perfect squares. This wastes computation and can lead to incorrect state transitions. Always ensure the inner loop only considers values i where i * i <= target.