698. Partition to K Equal Sum Subsets - Explanation

Problem Link

Description

You are given an integer array nums and an integer k, return true if it is possible to divide this array into k non-empty subsets whose sums are all equal.

Example 1:

Input: nums = [2,4,1,3,5], k = 3

Output: true

Explanation: Given array can be divided into three subsets [5], [2,3], [4,1].


Example 2:

Input: nums = [1,2,3,4], k = 3

Output: false

Constraints:

  • 1 <= k <= nums.length <= 16
  • 1 <= nums[i] <= 10,000
  • The frequency of each element is in the range [1, 4].


Topics

Company Tags

Please upgrade to NeetCode Pro to view company tags.



Prerequisites

Before attempting this problem, you should be comfortable with:

  • Backtracking - Exploring all possible ways to assign elements to subsets and undoing choices when they don't lead to a solution
  • Recursion - Building subsets recursively and understanding the call stack
  • Bit Manipulation - Using bitmasks to efficiently represent which elements have been used (for optimized solutions)
  • Dynamic Programming - Memoizing states to avoid redundant computation in the bitmask DP approach

1. Backtracking

Intuition

The problem asks whether we can divide the array into exactly k subsets, each with the same sum. First, we check if the total sum is divisible by k. If not, it's impossible. Otherwise, each subset must sum to target = total / k.

We use backtracking to try placing each number into one of the k subsets. Sorting the array in descending order helps us fail faster when a large number cannot fit. Once a subset reaches the target sum, we start building the next one. If we successfully build all k subsets, we return true.

Algorithm

  1. Calculate the total sum. If it's not divisible by k, return false.
  2. Sort the array in descending order for early pruning.
  3. Compute target = total / k.
  4. Use a boolean array used to track which elements have been assigned.
  5. Define a recursive function backtrack(i, k, subsetSum):
    • If k == 0, all subsets are formed, return true.
    • If subsetSum == target, start building the next subset with backtrack(0, k - 1, 0).
    • For each unused element from index i, try adding it if it doesn't exceed the target.
    • Mark the element as used, recurse, then backtrack by unmarking it.
  6. Return the result of backtrack(0, k, 0).
class Solution:
    def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
        if sum(nums) % k != 0:
            return False

        nums.sort(reverse=True)
        target = sum(nums) // k
        used = [False] * len(nums)

        def backtrack(i, k, subsetSum):
            if k == 0:
                return True
            if subsetSum == target:
                return backtrack(0, k - 1, 0)
            for j in range(i, len(nums)):
                if used[j] or subsetSum + nums[j] > target:
                    continue
                used[j] = True
                if backtrack(j + 1, k, subsetSum + nums[j]):
                    return True
                used[j] = False
            return False

        return backtrack(0, k, 0)

Time & Space Complexity

  • Time complexity: O(k2n)O(k * 2 ^ n)
  • Space complexity: O(n)O(n)

Where nn is the size of the array numsnums and kk is the number of subsets.


2. Backtracking (Pruning)

Intuition

This approach extends the basic backtracking by adding a key pruning optimization. If we fail to place any element into an empty subset (when subsetSum == 0), we know that element cannot be placed anywhere, so the entire configuration is invalid. This avoids redundant exploration of branches that will never succeed.

Algorithm

  1. Calculate the total sum. If it's not divisible by k, return false.
  2. Sort the array in descending order.
  3. Compute target = total / k.
  4. Track used elements with a boolean array.
  5. Define backtrack(i, k, subsetSum):
    • If k == 0, return true.
    • If subsetSum == target, recurse to build the next subset.
    • For each unused element from index i:
      • Skip if adding it exceeds the target.
      • Mark as used and recurse.
      • Backtrack by unmarking.
      • Pruning: If subsetSum == 0 and we failed, return false immediately.
  6. Return the result of backtrack(0, k, 0).
class Solution:
    def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
        total = sum(nums)
        if total % k != 0:
            return False

        nums.sort(reverse=True)
        target = total // k
        used = [False] * len(nums)

        def backtrack(i, k, subsetSum):
            if k == 0:
                return True
            if subsetSum == target:
                return backtrack(0, k - 1, 0)
            for j in range(i, len(nums)):
                if used[j] or subsetSum + nums[j] > target:
                    continue
                used[j] = True
                if backtrack(j + 1, k, subsetSum + nums[j]):
                    return True
                used[j] = False

                if subsetSum == 0: # Pruning
                    return False

            return False

        return backtrack(0, k, 0)

Time & Space Complexity

  • Time complexity: O(k2n)O(k * 2 ^ n)
  • Space complexity: O(n)O(n)

Where nn is the size of the array numsnums and kk is the number of subsets.


3. Backtracking (Bit Mask + Pruning)

Intuition

Instead of using a boolean array to track used elements, we can represent the state as a bitmask. Each bit indicates whether the corresponding element has been used. This representation is more compact and prepares us for memoization in later approaches.

Algorithm

  1. Calculate the total sum. If it's not divisible by k, return false.
  2. Sort the array in descending order.
  3. Compute target = total / k.
  4. Initialize mask = (1 << n) - 1 where all bits are set (all elements available).
  5. Define backtrack(i, k, subsetSum, mask):
    • If k == 0, return true.
    • If subsetSum == target, start the next subset with backtrack(0, k - 1, 0, mask).
    • For each element j from index i:
      • If bit j is not set in mask or adding nums[j] exceeds target, skip.
      • Recurse with the bit cleared: mask ^ (1 << j).
      • If subsetSum == 0 and we fail, return false (pruning).
  6. Return backtrack(0, k, 0, (1 << n) - 1).
class Solution:
    def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
        total = sum(nums)
        if total % k != 0:
            return False

        nums.sort(reverse=True)
        target = total // k
        n = len(nums)

        def backtrack(i, k, subsetSum, mask):
            if k == 0:
                return True
            if subsetSum == target:
                return backtrack(0, k - 1, 0, mask)
            for j in range(i, n):
                if (mask & (1 << j)) == 0 or subsetSum + nums[j] > target:
                    continue
                if backtrack(j + 1, k, subsetSum + nums[j], mask ^ (1 << j)):
                    return True
                if subsetSum == 0:
                    return False
            return False

        return backtrack(0, k, 0, (1 << n) - 1)

Time & Space Complexity

  • Time complexity: O(k2n)O(k * 2 ^ n)
  • Space complexity:
    • O(1)O(1) or O(n)O(n) space depending on the sorting algorithm.
    • O(n)O(n) for the recursion stack.

Where nn is the size of the array numsnums and kk is the number of subsets.


4. Dynamic Programming (Top-Down) + Bit Mask

Intuition

The bitmask from the previous approach naturally lends itself to memoization. Different orderings of element selection can lead to the same mask, so we cache results for each mask to avoid recomputation. This transforms the exponential backtracking into a more efficient dynamic programming solution.

Algorithm

  1. Calculate the total sum. If it's not divisible by k, return false.
  2. Sort the array in descending order.
  3. Compute target = total / k.
  4. Create a memoization array dp of size 2^n, initialized to null/undefined.
  5. Define backtrack(i, k, subsetSum, mask):
    • If dp[mask] is already computed, return it.
    • If k == 0, set dp[mask] = true and return.
    • If subsetSum == target, recurse for the next subset and cache.
    • For each element from index i:
      • Skip if bit not set or would exceed target.
      • Recurse with updated mask.
      • Apply pruning when subsetSum == 0.
    • Set dp[mask] = false if no valid configuration found.
  6. Return backtrack(0, k, 0, (1 << n) - 1).
class Solution:
    def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
        total = sum(nums)
        if total % k != 0:
            return False

        nums.sort(reverse=True)
        target = total // k
        n = len(nums)
        dp = [None] * (1 << n)

        def backtrack(i, k, subsetSum, mask):
            if dp[mask] != None:
                return dp[mask]
            if k == 0:
                dp[mask] = True
                return True
            if subsetSum == target:
                dp[mask] = backtrack(0, k - 1, 0, mask)
                return dp[mask]

            for j in range(i, n):
                if (mask & (1 << j)) == 0 or subsetSum + nums[j] > target:
                    continue
                if backtrack(j + 1, k, subsetSum + nums[j], mask ^ (1 << j)):
                    dp[mask] = True
                    return True
                if subsetSum == 0:
                    dp[mask] = False
                    return dp[mask]
            dp[mask] = False
            return False

        return backtrack(0, k, 0, (1 << n) - 1)

Time & Space Complexity

  • Time complexity: O(n2n)O(n * 2 ^ n)
  • Space complexity: O(2n)O(2 ^ n)

5. Dynamic Programming (Bottom-Up) + Bit Mask

Intuition

Instead of recursion, we iterate through all possible mask values from 0 to 2^n - 1. For each valid state (reachable configuration), we try adding each unused element. The key insight is that dp[mask] stores the current sum modulo target. If we can reach the full mask with sum 0 (meaning all subsets are complete), we have a valid partition.

Algorithm

  1. Calculate the total sum. If it's not divisible by k, return false.
  2. Compute target = total / k.
  3. Initialize dp array of size 2^n with -1, except dp[0] = 0.
  4. For each mask from 0 to 2^n - 1:
    • If dp[mask] == -1, this state is unreachable; skip it.
    • For each element i:
      • If bit i is not set and adding nums[i] doesn't exceed target:
        • Set dp[mask | (1 << i)] = (dp[mask] + nums[i]) % target.
  5. Return true if dp[(1 << n) - 1] == 0, meaning all elements are used and the sum completes exactly.
class Solution:
    def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
        total = sum(nums)
        if total % k != 0:
            return False

        target = total // k
        n = len(nums)
        N = 1 << n
        dp = [0] + [-1] * (N - 1)

        for mask in range(N):
            if dp[mask] == -1:
                continue
            for i in range(n):
                if (mask & (1 << i)) == 0 and dp[mask] + nums[i] <= target:
                    dp[mask | (1 << i)] = (dp[mask] + nums[i]) % target

        return dp[N - 1] == 0

Time & Space Complexity

  • Time complexity: O(n2n)O(n * 2 ^ n)
  • Space complexity: O(2n)O(2 ^ n)