Skip to main content

Count Partitions with Target Difference

Description

You are given an integer array arr[] and an integer diff. Your task is to count the number of ways you can partition the array into two subsets, S1 and S2, such that the absolute difference between their sums equals diff.

A partition splits the array into two groups where every element belongs to exactly one group. The union of S1 and S2 must equal the original array, and no element can appear in both subsets.

Return the total number of such valid partitions.

Examples

Example 1

Input: arr = [5, 2, 6, 4], diff = 3

Output: 1

Explanation: The only valid partition is S1 = {6, 4} with sum 10, and S2 = {5, 2} with sum 7. The difference is 10 - 7 = 3, which matches the required difference. No other partition of this array produces a difference of 3.

Example 2

Input: arr = [1, 1, 1, 1], diff = 0

Output: 6

Explanation: Since diff = 0, we need both subsets to have equal sums. The total sum is 4, so each subset must sum to 2. We need to choose exactly two 1's for S1 (the rest go to S2). The number of ways to pick 2 items from 4 identical values at distinct indices is C(4,2) = 6. For instance: {arr[0], arr[1]} and {arr[2], arr[3]}, or {arr[0], arr[2]} and {arr[1], arr[3]}, and so on.

Example 3

Input: arr = [3, 2, 7, 1], diff = 4

Output: 0

Explanation: The total sum is 13. For a valid partition with difference 4, we would need S1 = (13 + 4) / 2 = 8.5. Since 8.5 is not an integer, no valid partition exists. Whenever (totalSum + diff) is odd, the answer is always 0.

Constraints

  • 1 ≤ arr.size() ≤ 50
  • 0 ≤ diff ≤ 50
  • 0 ≤ arr[i] ≤ 6

Editorial

Brute Force

Intuition

The most natural way to solve this problem is to realize that every element has exactly two choices: it either goes into subset S1, or it goes into subset S2. If we try every possible assignment, we get all possible partitions of the array.

Before we start generating partitions, there is a key mathematical reduction that simplifies the problem enormously. Let totalSum be the sum of all elements. If S1 and S2 are two subsets with S1 + S2 = totalSum and S1 - S2 = diff, then adding these two equations gives 2 * S1 = totalSum + diff, so S1 = (totalSum + diff) / 2. This means we do not need to track both subsets — we just need to count how many subsets have sum equal to target = (totalSum + diff) / 2.

Two quick validity checks before we start:

  • If (totalSum + diff) is odd, no integer subset sum can equal target, so the answer is 0.
  • If diff > totalSum, the answer is also 0 because even putting all elements into one subset cannot create a large enough gap.

With the target computed, the brute force approach uses recursion. For each element, we branch into two choices: include it in our current subset (reducing the remaining target) or exclude it. When we reach the end of the array, if the remaining target is exactly 0, we found one valid subset.

Step-by-Step Explanation

Let's trace through with arr = [1, 1, 1, 1], diff = 0.

totalSum = 4. target = (4 + 0) / 2 = 2. We need to count subsets that sum to 2.

We call solve(index=0, remaining=2):

Step 1: At index 0 (value 1), remaining = 2. We branch:

  • Include arr[0]: solve(1, 2-1) = solve(1, 1)
  • Exclude arr[0]: solve(1, 2)

Step 2: In the include branch, at index 1 (value 1), remaining = 1. We branch:

  • Include arr[1]: solve(2, 0)
  • Exclude arr[1]: solve(2, 1)

Step 3: solve(2, 0) — remaining is 0! This is a valid subset {arr[0], arr[1]}. Return 1.

Step 4: solve(2, 1) — at index 2 (value 1), remaining = 1. Branch:

  • Include arr[2]: solve(3, 0) → remaining = 0, valid! Return 1. This is subset {arr[0], arr[2]}.
  • Exclude arr[2]: solve(3, 1) → at index 3 (value 1), remaining = 1.
    • Include arr[3]: solve(4, 0) → base case, valid! Return 1. Subset {arr[0], arr[3]}.
    • Exclude arr[3]: solve(4, 1) → index out of bounds, remaining ≠ 0. Return 0.

Step 5: Back to the exclude branch of Step 1: solve(1, 2). At index 1 (value 1), remaining = 2:

  • Include arr[1]: solve(2, 1) — this produces subsets {arr[1], arr[2]} and {arr[1], arr[3]}, returning 2.
  • Exclude arr[1]: solve(2, 2) — at index 2 (value 1), remaining = 2:
    • Include arr[2]: solve(3, 1) — include arr[3] → solve(4, 0) = 1. Subset {arr[2], arr[3]}. Exclude arr[3] → 0. Returns 1.
    • Exclude arr[2]: solve(3, 2) — remaining is 2 but only arr[3]=1 left. Cannot reach 0. Returns 0.

Step 6: Total = 3 (from include arr[0]) + 3 (from exclude arr[0]) = 6. The answer is 6.

The recursion tree expands exponentially, exploring all 2^4 = 16 leaf paths.

Brute Force Recursion — Counting Subsets with Sum 2 — Watch how the recursion explores every include/exclude decision for each element, building a full binary tree of choices.

Algorithm

  1. Compute totalSum = sum of all elements in arr
  2. Check validity: if (totalSum + diff) is odd or diff > totalSum, return 0
  3. Compute target = (totalSum + diff) / 2
  4. Define a recursive function solve(index, remaining):
    • Base case: if remaining == 0, return 1 (found a valid subset)
    • Base case: if index == n or remaining < 0, return 0
    • Recursive case: return solve(index+1, remaining - arr[index]) + solve(index+1, remaining)
  5. Return solve(0, target)

Code

#include <vector>
using namespace std;

class Solution {
public:
    int solve(vector<int>& arr, int index, int remaining) {
        if (remaining == 0) return 1;
        if (index == arr.size() || remaining < 0) return 0;
        
        // Include current element or exclude it
        int include = solve(arr, index + 1, remaining - arr[index]);
        int exclude = solve(arr, index + 1, remaining);
        
        return include + exclude;
    }
    
    int countPartitions(vector<int>& arr, int diff) {
        int totalSum = 0;
        for (int x : arr) totalSum += x;
        
        if ((totalSum + diff) % 2 != 0) return 0;
        if (diff > totalSum) return 0;
        
        int target = (totalSum + diff) / 2;
        return solve(arr, 0, target);
    }
};
class Solution:
    def countPartitions(self, arr, diff):
        total_sum = sum(arr)
        
        if (total_sum + diff) % 2 != 0:
            return 0
        if diff > total_sum:
            return 0
        
        target = (total_sum + diff) // 2
        
        def solve(index, remaining):
            if remaining == 0:
                return 1
            if index == len(arr) or remaining < 0:
                return 0
            
            include = solve(index + 1, remaining - arr[index])
            exclude = solve(index + 1, remaining)
            
            return include + exclude
        
        return solve(0, target)
class Solution {
    private int solve(int[] arr, int index, int remaining) {
        if (remaining == 0) return 1;
        if (index == arr.length || remaining < 0) return 0;
        
        int include = solve(arr, index + 1, remaining - arr[index]);
        int exclude = solve(arr, index + 1, remaining);
        
        return include + exclude;
    }
    
    public int countPartitions(int[] arr, int diff) {
        int totalSum = 0;
        for (int x : arr) totalSum += x;
        
        if ((totalSum + diff) % 2 != 0) return 0;
        if (diff > totalSum) return 0;
        
        int target = (totalSum + diff) / 2;
        return solve(arr, 0, target);
    }
}

Complexity Analysis

Time Complexity: O(2^n)

At each index, the recursion branches into two calls (include or exclude). With n elements, this creates a binary tree of depth n, giving 2^n leaf nodes. Every path from root to leaf is explored without any pruning or caching.

Space Complexity: O(n)

The recursion stack goes at most n levels deep (one level per array element). No additional data structures are used besides the call stack.

Why This Approach Is Not Efficient

The brute force recursion has O(2^n) time complexity. With n up to 50, that means up to 2^50 ≈ 10^15 operations — far beyond any reasonable time limit.

The root cause of the inefficiency is overlapping subproblems. Many different paths through the recursion tree arrive at the same (index, remaining) pair. For example, in our trace of [1, 1, 1, 1], the subproblem solve(2, 1) was computed twice — once when we included arr[0] and excluded arr[1], and again when we excluded arr[0] and included arr[1]. Each duplicate recomputation triggers its own subtree of redundant work.

If we could store the result of each unique (index, remaining) pair the first time we compute it, we would avoid all this repeated work. This is exactly what memoization provides.

Better Approach - Memoization (Top-Down DP)

Intuition

The memoization approach uses the same recursive structure as the brute force but adds a cache. Before computing solve(index, remaining), we check: have we already solved this exact subproblem? If yes, return the stored result immediately. If no, compute it recursively and store the result before returning.

The state of each subproblem is uniquely identified by two parameters: the current index (which element we are deciding about) and the remaining sum we still need to reach. So our cache is a 2D table dp[index][remaining].

Think of it like a student solving a math worksheet. Without memoization, every time they encounter the same sub-calculation, they redo it from scratch. With memoization, they write each result on a sticky note. The next time the same calculation comes up, they just read the note instead of recalculating.

Step-by-Step Explanation

Let's trace with arr = [1, 1, 1, 1], diff = 0, target = 2. We use a memo table dp[index][remaining].

Step 1: Call solve(0, 2). Not in memo. Branch: include arr[0] → solve(1, 1), exclude arr[0] → solve(1, 2).

Step 2: solve(1, 1). Not in memo. Branch: include arr[1] → solve(2, 0), exclude arr[1] → solve(2, 1).

Step 3: solve(2, 0). remaining = 0, return 1. No memo needed for base case.

Step 4: solve(2, 1). Not in memo. Branch: include arr[2] → solve(3, 0) = 1, exclude arr[2] → solve(3, 1).

Step 5: solve(3, 1). Not in memo. Include arr[3] → solve(4, 0) = 1, exclude arr[3] → solve(4, 1) = 0. Result = 1. Store dp[3][1] = 1.

Step 6: solve(2, 1) = 1 + 1 = 2. Store dp[2][1] = 2.

Step 7: solve(1, 1) = 1 + 2 = 3. Store dp[1][1] = 3.

Step 8: Now process solve(1, 2). Branch: include arr[1] → solve(2, 1). Check memo: dp[2][1] = 2! Return immediately without recomputing.

Step 9: Exclude arr[1] → solve(2, 2). Not in memo. Include arr[2] → solve(3, 1). Check memo: dp[3][1] = 1! Returned instantly.

Step 10: Exclude arr[2] → solve(3, 2). Include arr[3]=1 → solve(4, 1) = 0. Exclude → solve(4, 2) = 0. dp[3][2] = 0.

Step 11: solve(2, 2) = 1 + 0 = 1. Store dp[2][2] = 1.

Step 12: solve(1, 2) = 2 + 1 = 3. Store dp[1][2] = 3.

Step 13: solve(0, 2) = 3 + 3 = 6. The memo prevented recomputation of solve(2,1) and solve(3,1), saving significant work.

Memoization — Pruning Repeated Subproblems — Watch how memoization stores results of solved subproblems and instantly returns cached values when the same subproblem is encountered again, avoiding redundant recursion.

Algorithm

  1. Compute totalSum and target = (totalSum + diff) / 2 (with validity checks)
  2. Create a 2D memo table dp[n][target+1] initialized to -1
  3. Define solve(index, remaining):
    • Base case: if remaining == 0, return 1
    • Base case: if index == n or remaining < 0, return 0
    • If dp[index][remaining] ≠ -1, return dp[index][remaining] (cache hit)
    • Compute result = solve(index+1, remaining - arr[index]) + solve(index+1, remaining)
    • Store dp[index][remaining] = result and return it
  4. Return solve(0, target)

Code

#include <vector>
using namespace std;

class Solution {
public:
    int solve(vector<int>& arr, int index, int remaining, vector<vector<int>>& dp) {
        if (remaining == 0) return 1;
        if (index == arr.size() || remaining < 0) return 0;
        
        if (dp[index][remaining] != -1) return dp[index][remaining];
        
        int include = solve(arr, index + 1, remaining - arr[index], dp);
        int exclude = solve(arr, index + 1, remaining, dp);
        
        return dp[index][remaining] = include + exclude;
    }
    
    int countPartitions(vector<int>& arr, int diff) {
        int totalSum = 0;
        for (int x : arr) totalSum += x;
        
        if ((totalSum + diff) % 2 != 0) return 0;
        if (diff > totalSum) return 0;
        
        int target = (totalSum + diff) / 2;
        int n = arr.size();
        vector<vector<int>> dp(n, vector<int>(target + 1, -1));
        
        return solve(arr, 0, target, dp);
    }
};
class Solution:
    def countPartitions(self, arr, diff):
        total_sum = sum(arr)
        
        if (total_sum + diff) % 2 != 0:
            return 0
        if diff > total_sum:
            return 0
        
        target = (total_sum + diff) // 2
        n = len(arr)
        dp = [[-1] * (target + 1) for _ in range(n)]
        
        def solve(index, remaining):
            if remaining == 0:
                return 1
            if index == n or remaining < 0:
                return 0
            
            if dp[index][remaining] != -1:
                return dp[index][remaining]
            
            include = solve(index + 1, remaining - arr[index])
            exclude = solve(index + 1, remaining)
            
            dp[index][remaining] = include + exclude
            return dp[index][remaining]
        
        return solve(0, target)
import java.util.Arrays;

class Solution {
    private int solve(int[] arr, int index, int remaining, int[][] dp) {
        if (remaining == 0) return 1;
        if (index == arr.length || remaining < 0) return 0;
        
        if (dp[index][remaining] != -1) return dp[index][remaining];
        
        int include = solve(arr, index + 1, remaining - arr[index], dp);
        int exclude = solve(arr, index + 1, remaining, dp);
        
        return dp[index][remaining] = include + exclude;
    }
    
    public int countPartitions(int[] arr, int diff) {
        int totalSum = 0;
        for (int x : arr) totalSum += x;
        
        if ((totalSum + diff) % 2 != 0) return 0;
        if (diff > totalSum) return 0;
        
        int target = (totalSum + diff) / 2;
        int n = arr.length;
        int[][] dp = new int[n][target + 1];
        for (int[] row : dp) Arrays.fill(row, -1);
        
        return solve(arr, 0, target, dp);
    }
}

Complexity Analysis

Time Complexity: O(n × target)

There are n × (target + 1) unique subproblems, where n is the array length and target = (totalSum + diff) / 2. Each subproblem is computed at most once and cached. Each computation does O(1) work (two recursive calls plus addition). So total work is O(n × target).

Space Complexity: O(n × target)

The memo table has n rows and (target + 1) columns. Additionally, the recursion stack can go up to depth n. The dominant term is the memo table: O(n × target).

Why This Approach Is Not Efficient

While memoization reduces time from O(2^n) to O(n × target), it still uses O(n × target) space for the 2D memo table plus O(n) recursion stack overhead. For n = 50 and target up to 300 (since each element ≤ 6 and there are up to 50 elements, totalSum ≤ 300), this is manageable but not ideal.

The key insight for further optimization: the recursive structure processes elements one by one (index 0, 1, 2, ...), and to compute the answer for index i, we only need results from index i+1. We never look back two or more indices. This means we can replace the full 2D table with an iterative bottom-up approach, and furthermore, we can reduce the space to a single 1D array by processing elements in the right order.

This eliminates the recursion stack overhead and reduces space from O(n × target) to O(target).

Optimal Approach - Tabulation with Space Optimization (Bottom-Up DP)

Intuition

Instead of top-down recursion, we build the solution bottom-up using a 1D array. The idea is elegant: dp[j] represents the number of subsets (from the elements we have processed so far) that sum to exactly j.

Initially, before processing any element, the only achievable sum is 0 (by taking the empty subset), so dp[0] = 1 and all other entries are 0.

For each element arr[i], we update the dp array in reverse order (from target down to arr[i]). For each sum j, dp[j] gets updated to dp[j] + dp[j - arr[i]]. The first term (dp[j]) counts subsets that already sum to j without using arr[i]. The second term (dp[j - arr[i]]) counts subsets that summed to j - arr[i] before, and now with arr[i] included, they sum to j.

We iterate in reverse to avoid using the same element twice in a single pass. If we went left-to-right, the updated dp[j - arr[i]] value (which already includes arr[i]) could be used again when computing dp[j], effectively double-counting.

After processing all elements, dp[target] holds the answer.

Step-by-Step Explanation

Let's trace with arr = [1, 1, 1, 1], diff = 0. totalSum = 4, target = (4 + 0) / 2 = 2.

Step 1: Initialize dp = [1, 0, 0]. dp[0] = 1 means there is one way to make sum 0 (empty subset). dp[1] = dp[2] = 0.

Step 2: Process arr[0] = 1. Iterate j from 2 down to 1:

  • j = 2: dp[2] += dp[2 - 1] = dp[2] + dp[1] = 0 + 0 = 0
  • j = 1: dp[1] += dp[1 - 1] = dp[1] + dp[0] = 0 + 1 = 1
  • dp = [1, 1, 0]. Now there's 1 way to make sum 1: {arr[0]}.

Step 3: Process arr[1] = 1. Iterate j from 2 down to 1:

  • j = 2: dp[2] += dp[1] = 0 + 1 = 1
  • j = 1: dp[1] += dp[0] = 1 + 1 = 2
  • dp = [1, 2, 1]. Sum 2 can be made 1 way: {arr[0], arr[1]}.

Step 4: Process arr[2] = 1. Iterate j from 2 down to 1:

  • j = 2: dp[2] += dp[1] = 1 + 2 = 3
  • j = 1: dp[1] += dp[0] = 2 + 1 = 3
  • dp = [1, 3, 3]. Sum 2 can now be made 3 ways: {0,1}, {0,2}, {1,2}.

Step 5: Process arr[3] = 1. Iterate j from 2 down to 1:

  • j = 2: dp[2] += dp[1] = 3 + 3 = 6
  • j = 1: dp[1] += dp[0] = 3 + 1 = 4
  • dp = [1, 4, 6].

Step 6: Answer = dp[2] = 6. All six subsets of size 2 from indices {0,1,2,3} that sum to 2.

1D DP Table — Counting Subsets with Sum 2 — Watch how the 1D dp array evolves as each element is processed. For each element, we update right-to-left to avoid reusing the same element twice.

Algorithm

  1. Compute totalSum = sum of all elements in arr
  2. If (totalSum + diff) is odd or diff > totalSum, return 0
  3. Compute target = (totalSum + diff) / 2
  4. Create a 1D array dp of size (target + 1), initialized to 0
  5. Set dp[0] = 1 (one way to make sum 0: empty subset)
  6. For each element arr[i] in the array:
    • For j from target down to arr[i]:
      • dp[j] += dp[j - arr[i]]
  7. Return dp[target]

Code

#include <vector>
using namespace std;

class Solution {
public:
    int countPartitions(vector<int>& arr, int diff) {
        int totalSum = 0;
        for (int x : arr) totalSum += x;
        
        if ((totalSum + diff) % 2 != 0) return 0;
        if (diff > totalSum) return 0;
        
        int target = (totalSum + diff) / 2;
        
        vector<int> dp(target + 1, 0);
        dp[0] = 1;
        
        for (int i = 0; i < arr.size(); i++) {
            for (int j = target; j >= arr[i]; j--) {
                dp[j] += dp[j - arr[i]];
            }
        }
        
        return dp[target];
    }
};
class Solution:
    def countPartitions(self, arr, diff):
        total_sum = sum(arr)
        
        if (total_sum + diff) % 2 != 0:
            return 0
        if diff > total_sum:
            return 0
        
        target = (total_sum + diff) // 2
        
        dp = [0] * (target + 1)
        dp[0] = 1
        
        for num in arr:
            for j in range(target, num - 1, -1):
                dp[j] += dp[j - num]
        
        return dp[target]
class Solution {
    public int countPartitions(int[] arr, int diff) {
        int totalSum = 0;
        for (int x : arr) totalSum += x;
        
        if ((totalSum + diff) % 2 != 0) return 0;
        if (diff > totalSum) return 0;
        
        int target = (totalSum + diff) / 2;
        
        int[] dp = new int[target + 1];
        dp[0] = 1;
        
        for (int i = 0; i < arr.length; i++) {
            for (int j = target; j >= arr[i]; j--) {
                dp[j] += dp[j - arr[i]];
            }
        }
        
        return dp[target];
    }
}

Complexity Analysis

Time Complexity: O(n × target)

We have an outer loop over n elements, and for each element, an inner loop that iterates over at most (target + 1) values. Each iteration does O(1) work. Total: O(n × target), where target = (totalSum + diff) / 2. With n ≤ 50 and each element ≤ 6, target ≤ (300 + 50) / 2 = 175, so the worst case is about 50 × 175 = 8,750 operations — extremely fast.

Space Complexity: O(target)

We use a single 1D array of size (target + 1). No recursion stack needed since the approach is iterative. This is a significant improvement over the O(n × target) space of the memoization approach.