Skip to main content

Partition to K Equal Sum Subsets

MEDIUMProblemSolveExternal Links

Description

Given an integer array nums and a positive integer k, determine whether it is possible to divide the array into exactly k non-empty subsets such that the sum of elements in every subset is the same.

Each element must appear in exactly one subset — you cannot leave any element out, and you cannot use any element more than once. All k subsets must have equal total sum.

Return true if such a partition exists, and false otherwise.

Examples

Example 1

Input: nums = [4, 3, 2, 3, 5, 2, 1], k = 4

Output: true

Explanation: The total sum is 4 + 3 + 2 + 3 + 5 + 2 + 1 = 20. Each subset must sum to 20 / 4 = 5. One valid partition is:

  • Subset 1: {5} → sum = 5
  • Subset 2: {4, 1} → sum = 5
  • Subset 3: {3, 2} → sum = 5
  • Subset 4: {3, 2} → sum = 5

All four subsets have equal sum 5, so the answer is true.

Example 2

Input: nums = [1, 2, 3, 4], k = 3

Output: false

Explanation: The total sum is 1 + 2 + 3 + 4 = 10. For k = 3 subsets with equal sum, each would need to sum to 10 / 3 ≈ 3.33. Since 10 is not divisible by 3, it is impossible to create three subsets with equal integer sums.

Example 3

Input: nums = [2, 2, 2, 2, 3, 4, 5], k = 4

Output: false

Explanation: The total sum is 2 + 2 + 2 + 2 + 3 + 4 + 5 = 20. Each subset would need sum 20 / 4 = 5. However, the element 5 alone fills an entire subset, leaving {2, 2, 2, 2, 3, 4} for three subsets of sum 5. One possible attempt: {4, ?} needs a 1 but none exists, {3, 2} = 5 works, {2, 2, ?} = 4 needs 1. No valid partition exists.

Constraints

  • 1 ≤ k ≤ nums.length ≤ 16
  • 1 ≤ nums[i] ≤ 10^4
  • The frequency of each element is in the range [1, 4]

Editorial

Brute Force

Intuition

Think of this problem as having k empty buckets and a pile of numbered balls. You need to distribute all the balls into the buckets so that every bucket holds the same total weight.

The most straightforward approach is to try every possible assignment: pick up the first ball and try dropping it into bucket 1, then bucket 2, and so on up to bucket k. For each choice, pick up the next ball and try all k buckets again. Continue until all balls are placed or you realize the current assignment cannot work.

When a bucket's total would exceed the target sum, you skip that bucket. If no bucket can accept the current ball, you backtrack — remove the previous ball from its bucket and try a different bucket for it.

Before starting, we perform a quick feasibility check: the total sum of all elements must be divisible by k. If it is not, no equal partition is possible and we immediately return false.

Step-by-Step Explanation

Let's trace through with nums = [4, 3, 2, 3, 5, 2, 1], k = 4, target = 20 / 4 = 5.

We maintain a buckets array of size k = 4, initialized to [0, 0, 0, 0].

Step 1: Try nums[0] = 4 in bucket 0. buckets = [4, 0, 0, 0]. 4 ≤ 5 ✓. Proceed.

Step 2: Try nums[1] = 3 in bucket 0. 4 + 3 = 7 > 5 ✗. Cannot fit!

Step 3: Try nums[1] = 3 in bucket 1. buckets = [4, 3, 0, 0]. 3 ≤ 5 ✓. Proceed.

Step 4: Try nums[2] = 2 in bucket 0. 4 + 2 = 6 > 5 ✗. Try bucket 1: 3 + 2 = 5 ≤ 5 ✓. buckets = [4, 5, 0, 0]. Bucket 1 complete!

Step 5: Try nums[3] = 3 in bucket 0. 4 + 3 = 7 > 5 ✗. Bucket 1 full. Try bucket 2: 0 + 3 = 3 ≤ 5 ✓. buckets = [4, 5, 3, 0]. Proceed.

Step 6: Try nums[4] = 5 in bucket 0. 4 + 5 = 9 > 5 ✗. Bucket 1 full. Bucket 2: 3 + 5 = 8 > 5 ✗. Try bucket 3: 0 + 5 = 5 ≤ 5 ✓. buckets = [4, 5, 3, 5]. Bucket 3 complete!

Step 7: Try nums[5] = 2 in bucket 0. 4 + 2 = 6 > 5 ✗. Buckets 1, 3 full. Bucket 2: 3 + 2 = 5 ≤ 5 ✓. buckets = [4, 5, 5, 5]. Bucket 2 complete!

Step 8: Try nums[6] = 1 in bucket 0. 4 + 1 = 5 ≤ 5 ✓. buckets = [5, 5, 5, 5]. All buckets = target!

Step 9: All elements placed. Return true!

Brute Force — Placing Elements Into K Buckets — Watch how we try placing each array element into one of k=4 buckets, backtracking when a bucket would exceed the target sum of 5.

Algorithm

  1. Compute total sum of nums. If total % k ≠ 0, return false.
  2. Set target = total / k.
  3. Create an array buckets of size k, initialized to 0.
  4. Define recursive function dfs(index):
    • Base case: if index == n, return true (all elements placed).
    • For each bucket j from 0 to k-1:
      • If buckets[j] + nums[index] ≤ target:
        • Place element: buckets[j] += nums[index]
        • Recurse: if dfs(index + 1) returns true, return true.
        • Backtrack: buckets[j] -= nums[index]
    • Return false.
  5. Return dfs(0).

Code

class Solution {
public:
    bool canPartitionKSubsets(vector<int>& nums, int k) {
        int total = 0;
        for (int n : nums) total += n;
        if (total % k != 0) return false;
        int target = total / k;
        vector<int> buckets(k, 0);
        return dfs(nums, buckets, 0, target, k);
    }

    bool dfs(vector<int>& nums, vector<int>& buckets, int index, int target, int k) {
        if (index == (int)nums.size()) {
            for (int b : buckets) {
                if (b != target) return false;
            }
            return true;
        }
        for (int j = 0; j < k; j++) {
            if (buckets[j] + nums[index] <= target) {
                buckets[j] += nums[index];
                if (dfs(nums, buckets, index + 1, target, k)) return true;
                buckets[j] -= nums[index];
            }
        }
        return false;
    }
};
class Solution:
    def canPartitionKSubsets(self, nums: list[int], k: int) -> bool:
        total = sum(nums)
        if total % k != 0:
            return False
        target = total // k
        buckets = [0] * k

        def dfs(index: int) -> bool:
            if index == len(nums):
                return all(b == target for b in buckets)
            for j in range(k):
                if buckets[j] + nums[index] <= target:
                    buckets[j] += nums[index]
                    if dfs(index + 1):
                        return True
                    buckets[j] -= nums[index]
            return False

        return dfs(0)
class Solution {
    public boolean canPartitionKSubsets(int[] nums, int k) {
        int total = 0;
        for (int n : nums) total += n;
        if (total % k != 0) return false;
        int target = total / k;
        int[] buckets = new int[k];
        return dfs(nums, buckets, 0, target, k);
    }

    private boolean dfs(int[] nums, int[] buckets, int index, int target, int k) {
        if (index == nums.length) {
            for (int b : buckets) {
                if (b != target) return false;
            }
            return true;
        }
        for (int j = 0; j < k; j++) {
            if (buckets[j] + nums[index] <= target) {
                buckets[j] += nums[index];
                if (dfs(nums, buckets, index + 1, target, k)) return true;
                buckets[j] -= nums[index];
            }
        }
        return false;
    }
}

Complexity Analysis

Time Complexity: O(k^n)

At each recursive level, we try placing the current element in one of k buckets. The recursion tree has depth n and branching factor k, yielding up to k^n nodes in the worst case. For n = 16 and k = 16, this is 16^16 ≈ 1.8 × 10^19 — astronomically slow.

Space Complexity: O(n)

The recursion stack has depth n. The buckets array uses O(k) space, and k ≤ n. Total auxiliary space is O(n).

Why This Approach Is Not Efficient

The brute force explores up to k^n possible assignments. Even for modest inputs like n = 16 and k = 4, that is 4^16 ≈ 4.3 × 10^9 — well beyond acceptable time limits.

The inefficiency comes from two sources:

  1. Symmetrical exploration: If two buckets currently hold the same sum, placing the next element in either leads to equivalent subproblems. The brute force tries both, doubling (or k-plying) work unnecessarily.
  2. Poor element ordering: Processing elements in their original (often arbitrary) order means small elements create a massive branching tree of possibilities before discovering failures caused by large elements that cannot fit.

We need pruning strategies: sorting elements to fail fast, and skipping duplicate bucket states to eliminate symmetrical exploration.

Better Approach - Backtracking with Sorting and Pruning

Intuition

The structure is identical to brute force — we still try placing each element into one of k buckets and backtrack on failure. However, we add three key optimizations that dramatically prune the search tree:

1. Sort elements in descending order. Larger elements have fewer valid bucket placements. By placing them first, we discover dead-end branches much sooner. If the largest element exceeds the target, we fail immediately. If a large element fits in only one bucket, we avoid branching entirely at that step.

2. Skip duplicate bucket states. Before placing an element in bucket j, check if bucket j has the same accumulated sum as bucket j-1. If so, placing the element in either bucket leads to an equivalent subproblem — skip bucket j to avoid redundant computation.

3. Early termination. If placing an element in a bucket would cause its sum to exceed the target, skip that bucket immediately without recursing. Combined with sorting, this prunes huge branches early.

Additionally, if any single element exceeds the target sum, we return false immediately — that element can never fit in any subset.

These optimizations do not change the worst-case complexity of O(k^n), but they reduce the practical running time by orders of magnitude. For the given constraints (n ≤ 16), optimized backtracking typically runs in milliseconds.

Step-by-Step Explanation

Let's trace through with nums = [4, 3, 2, 3, 5, 2, 1], k = 4, target = 5.

After sorting descending: nums = [5, 4, 3, 3, 2, 2, 1]. Buckets = [0, 0, 0, 0].

Step 1: Place nums[0]=5 in bucket 0. buckets = [5, 0, 0, 0]. Bucket 0 = target! Complete.

  • Buckets 1,2,3 all have sum 0. By duplicate-skip, we only try bucket 0 (the others are equivalent).

Step 2: Try nums[1]=4 in bucket 0: 5+4=9 > 5 ✗. Pruned!

Step 3: Place nums[1]=4 in bucket 1. buckets = [5, 4, 0, 0]. Bucket 1 needs 1 more.

  • Duplicate-skip: buckets[2]=0=buckets[1]=4? No. buckets[3]=0=buckets[2]=0? Yes → skip bucket 3.

Step 4: Try nums[2]=3 in bucket 0: 5+3>5 ✗. Try bucket 1: 4+3=7>5 ✗. Both pruned!

Step 5: Place nums[2]=3 in bucket 2. buckets = [5, 4, 3, 0]. Bucket 2 needs 2 more.

Step 6: Try nums[3]=3 in buckets 0,1: both overflow. Try bucket 2: 3+3=6>5 ✗.

Step 7: Place nums[3]=3 in bucket 3. buckets = [5, 4, 3, 3]. Bucket 3 needs 2 more.

Step 8: Try nums[4]=2 in buckets 0: overflow. Bucket 1: 4+2=6>5 ✗. Bucket 2: 3+2=5 ≤ 5 ✓.
buckets = [5, 4, 5, 3]. Bucket 2 complete!

Step 9: Try nums[5]=2 in buckets 0,2: overflow. Bucket 1: 4+2=6>5 ✗. Bucket 3: 3+2=5 ≤ 5 ✓.
buckets = [5, 4, 5, 5]. Bucket 3 complete!

Step 10: Place nums[6]=1 in bucket 1. 4+1=5 = target. buckets = [5, 5, 5, 5]. All complete!

Step 11: All elements placed. Return true!

Optimized Backtracking — Sorted Elements with Pruning — Watch how sorting elements descending and skipping duplicate bucket states efficiently partitions the array into k=4 equal-sum subsets.

Algorithm

  1. Compute total sum. If total % k ≠ 0, return false.
  2. Set target = total / k.
  3. Sort nums in descending order. If nums[0] > target, return false.
  4. Initialize buckets = [0] * k.
  5. Define recursive function dfs(index):
    • Base case: if index == n, return true.
    • For each bucket j from 0 to k-1:
      • Duplicate skip: if j > 0 and buckets[j] == buckets[j-1], skip.
      • Prune: if buckets[j] + nums[index] > target, skip.
      • Place: buckets[j] += nums[index]. Recurse. If true, return true. Backtrack.
    • Return false.
  6. Return dfs(0).

Code

class Solution {
public:
    bool canPartitionKSubsets(vector<int>& nums, int k) {
        int total = 0;
        for (int n : nums) total += n;
        if (total % k != 0) return false;
        int target = total / k;
        sort(nums.rbegin(), nums.rend());
        if (nums[0] > target) return false;
        vector<int> buckets(k, 0);
        return dfs(nums, buckets, 0, target, k);
    }

    bool dfs(vector<int>& nums, vector<int>& buckets, int index, int target, int k) {
        if (index == (int)nums.size()) return true;
        for (int j = 0; j < k; j++) {
            if (j > 0 && buckets[j] == buckets[j - 1]) continue;
            if (buckets[j] + nums[index] <= target) {
                buckets[j] += nums[index];
                if (dfs(nums, buckets, index + 1, target, k)) return true;
                buckets[j] -= nums[index];
            }
        }
        return false;
    }
};
class Solution:
    def canPartitionKSubsets(self, nums: list[int], k: int) -> bool:
        total = sum(nums)
        if total % k != 0:
            return False
        target = total // k
        nums.sort(reverse=True)
        if nums[0] > target:
            return False
        buckets = [0] * k

        def dfs(index: int) -> bool:
            if index == len(nums):
                return True
            for j in range(k):
                if j > 0 and buckets[j] == buckets[j - 1]:
                    continue
                if buckets[j] + nums[index] <= target:
                    buckets[j] += nums[index]
                    if dfs(index + 1):
                        return True
                    buckets[j] -= nums[index]
            return False

        return dfs(0)
class Solution {
    public boolean canPartitionKSubsets(int[] nums, int k) {
        int total = 0;
        for (int n : nums) total += n;
        if (total % k != 0) return false;
        int target = total / k;
        Arrays.sort(nums);
        // Reverse to descending
        for (int i = 0, j = nums.length - 1; i < j; i++, j--) {
            int temp = nums[i];
            nums[i] = nums[j];
            nums[j] = temp;
        }
        if (nums[0] > target) return false;
        int[] buckets = new int[k];
        return dfs(nums, buckets, 0, target, k);
    }

    private boolean dfs(int[] nums, int[] buckets, int index, int target, int k) {
        if (index == nums.length) return true;
        for (int j = 0; j < k; j++) {
            if (j > 0 && buckets[j] == buckets[j - 1]) continue;
            if (buckets[j] + nums[index] <= target) {
                buckets[j] += nums[index];
                if (dfs(nums, buckets, index + 1, target, k)) return true;
                buckets[j] -= nums[index];
            }
        }
        return false;
    }
}

Complexity Analysis

Time Complexity: O(k^n) worst case

The theoretical worst case is still O(k^n), because each element can potentially go in any of k buckets. However, the three optimizations reduce practical running time enormously:

  • Descending sort ensures large elements constrain the search early.
  • Duplicate-skip avoids exploring equivalent bucket arrangements.
  • Early pruning cuts off impossible branches before recursing.

For n ≤ 16 and k ≤ 16, this optimized backtracking passes comfortably on practical inputs.

Space Complexity: O(n)

The recursion stack has depth n, and the buckets array uses O(k) ≤ O(n) space. Sorting is done in-place.

Why This Approach Is Not Efficient

Although the pruned backtracking works excellently in practice, its worst-case time complexity remains O(k^n). Adversarial inputs — particularly those with many identical elements and a large k — can still force exponential exploration.

The underlying limitation is that backtracking makes element-by-element decisions. When many elements have similar values and many buckets have similar sums, the pruning rules become less effective, and the algorithm may explore a near-exhaustive search.

We can achieve a better worst-case guarantee by changing the problem representation. Instead of asking "which bucket does each element go to?", we ask "which subset of elements forms each complete bucket?" This leads to a bitmask dynamic programming approach with guaranteed O(n × 2^n) time — much better than O(k^n) for large k.

Optimal Approach - Bitmask Dynamic Programming

Intuition

Instead of deciding "which bucket does element i go to?", we flip the perspective and track which elements have been used so far using a bitmask.

A bitmask is a binary number where bit i is 1 if element i has been used and 0 otherwise. For n = 7 elements, the bitmask 0110101 means elements 0, 2, and 5 are used.

For each bitmask (representing a subset of used elements), we store dp[mask] — the accumulated sum within the current bucket being filled. When this sum reaches the target exactly, it wraps back to 0, signaling that one bucket is complete and we start filling the next.

The key insight: if we iterate through all 2^n possible subsets and the final state dp[all_bits_set] equals 0, it means every bucket was filled exactly to the target. Since the total sum equals k × target and each wrap-around represents one completed bucket, ending at 0 means exactly k buckets were completed.

This approach considers each subset at most once per element, giving O(n × 2^n) total work. For n = 16, that is 16 × 65536 ≈ 1,000,000 — guaranteed fast regardless of how the elements are distributed.

Step-by-Step Explanation

Let's trace through with nums = [4, 3, 2, 3, 5, 2, 1], k = 4, target = 5.

We create a dp array of size 2^7 = 128. Initialize dp[0000000] = 0 (no elements used, current bucket sum = 0). All other entries are -1 (unreachable).

Step 1: Process mask 0000000 (dp = 0). Try adding each element:

  • Element 0 (val 4): 0+4 = 4 ≤ 5. dp[0000001] = 4.
  • Element 4 (val 5): 0+5 = 5 = target. dp[0010000] = (0+5)%5 = 0. One bucket complete!
  • Element 6 (val 1): 0+1 = 1 ≤ 5. dp[1000000] = 1.

Step 2: Process mask 0000001 (dp = 4, element 0 used). Need 1 more for current bucket.

  • Element 6 (val 1): 4+1 = 5 = target. dp[1000001] = 0. Bucket complete!
  • Element 1 (val 3): 4+3 = 7 > 5. Skip!

Step 3: Process mask 0010000 (dp = 0, element 4 used). One bucket done.

  • Element 0 (val 4): 0+4 = 4 ≤ 5. dp[0010001] = 4.
  • Element 1 (val 3): 0+3 = 3 ≤ 5. dp[0010010] = 3.

Step 4: Continue processing reachable masks... Eventually:

  • dp[0010001] = 4, add element 6 (val 1): dp[1010001] = 0. Two buckets done.
  • From dp = 0 states, continue adding elements to fill more buckets.

Step 5: Through many transitions, dp[1111111] is reached with value 0.

  • All 7 elements used, accumulated sum = 0 (wraps happened 4 times = 4 complete buckets).

Step 6: dp[1111111] == 0 → Return true!

Bitmask DP — Tracking Used Elements and Bucket Sums — Watch how the DP explores subsets of elements, filling buckets one at a time. When a bucket's accumulated sum hits the target, it resets to 0 and a new bucket begins.

Algorithm

  1. Compute total sum. If total % k ≠ 0, return false.
  2. Set target = total / k. If any element exceeds target, return false.
  3. Create a dp array of size 2^n, initialized to -1 (unreachable). Set dp[0] = 0.
  4. Iterate through all masks from 0 to 2^n - 1:
    • If dp[mask] == -1, skip (unreachable state).
    • For each element i not in the mask (bit i is 0):
      • If dp[mask] + nums[i] ≤ target:
        • Set dp[mask | (1 << i)] = (dp[mask] + nums[i]) % target
  5. Return whether dp[(1 << n) - 1] == 0.

Code

class Solution {
public:
    bool canPartitionKSubsets(vector<int>& nums, int k) {
        int total = 0;
        for (int n : nums) total += n;
        if (total % k != 0) return false;
        int target = total / k;
        int n = nums.size();
        for (int x : nums) {
            if (x > target) return false;
        }

        vector<int> dp(1 << n, -1);
        dp[0] = 0;

        for (int mask = 0; mask < (1 << n); mask++) {
            if (dp[mask] == -1) continue;
            for (int i = 0; i < n; i++) {
                if (mask & (1 << i)) continue;
                if (dp[mask] + nums[i] <= target) {
                    dp[mask | (1 << i)] = (dp[mask] + nums[i]) % target;
                }
            }
        }

        return dp[(1 << n) - 1] == 0;
    }
};
class Solution:
    def canPartitionKSubsets(self, nums: list[int], k: int) -> bool:
        total = sum(nums)
        if total % k != 0:
            return False
        target = total // k
        if max(nums) > target:
            return False

        n = len(nums)
        dp = [-1] * (1 << n)
        dp[0] = 0

        for mask in range(1 << n):
            if dp[mask] == -1:
                continue
            for i in range(n):
                if mask >> i & 1:
                    continue
                if dp[mask] + nums[i] <= target:
                    dp[mask | (1 << i)] = (dp[mask] + nums[i]) % target

        return dp[(1 << n) - 1] == 0
class Solution {
    public boolean canPartitionKSubsets(int[] nums, int k) {
        int total = 0;
        for (int n : nums) total += n;
        if (total % k != 0) return false;
        int target = total / k;
        int n = nums.length;
        for (int x : nums) {
            if (x > target) return false;
        }

        int[] dp = new int[1 << n];
        Arrays.fill(dp, -1);
        dp[0] = 0;

        for (int mask = 0; mask < (1 << n); mask++) {
            if (dp[mask] == -1) continue;
            for (int i = 0; i < n; i++) {
                if ((mask & (1 << i)) != 0) continue;
                if (dp[mask] + nums[i] <= target) {
                    dp[mask | (1 << i)] = (dp[mask] + nums[i]) % target;
                }
            }
        }

        return dp[(1 << n) - 1] == 0;
    }
}

Complexity Analysis

Time Complexity: O(n × 2^n)

We iterate over all 2^n possible masks. For each reachable mask, we try adding each of the n elements. This gives at most n × 2^n total transitions. For n = 16, that is 16 × 65536 ≈ 1,048,576 — very fast.

Compared to the other approaches:

  • Brute force: O(k^n) — up to 16^16 ≈ 1.8 × 10^19 for worst case
  • Pruned backtracking: O(k^n) worst case, fast in practice
  • Bitmask DP: O(n × 2^n) = O(16 × 2^16) ≈ 10^6 — guaranteed fast

The bitmask DP provides a strict improvement in worst-case complexity because 2^n < k^n whenever k ≥ 3, and the factor n is small.

Space Complexity: O(2^n)

The dp array has 2^n entries. For n = 16, that is 65,536 integers — about 256 KB of memory, well within limits.