Partition Array for Maximum Sum
Description
Given an integer array arr and an integer k, partition the array into contiguous subarrays where each subarray has length at most k.
After partitioning, every element in each subarray is replaced by the maximum value of that subarray.
Return the largest possible sum of the modified array after performing the optimal partition.
For example, if you partition [1, 15, 7] into one subarray of length 3, all elements become 15 (the maximum), giving a sum of 15 + 15 + 15 = 45. But if you partition it as [1] and [15, 7], the first becomes [1] and the second becomes [15, 15], giving sum 1 + 15 + 15 = 31 — clearly worse.
Your goal is to choose the partition that maximizes this total sum.
Examples
Example 1
Input: arr = [1, 15, 7, 9, 2, 5, 10], k = 3
Output: 84
Explanation: The optimal partition is [1, 15, 7] | [9] | [2, 5, 10].
- [1, 15, 7]: max = 15, all become 15. Contribution = 15 × 3 = 45.
- [9]: max = 9, stays 9. Contribution = 9 × 1 = 9.
- [2, 5, 10]: max = 10, all become 10. Contribution = 10 × 3 = 30.
Total = 45 + 9 + 30 = 84. The modified array looks like [15, 15, 15, 9, 10, 10, 10].
Example 2
Input: arr = [1, 4, 1, 5, 7, 3, 6, 1, 9, 9, 3], k = 4
Output: 83
Explanation: One optimal partition is [1, 4, 1, 5] | [7, 3, 6] | [1, 9, 9, 3].
- [1, 4, 1, 5]: max = 5, contribution = 5 × 4 = 20.
- [7, 3, 6]: max = 7, contribution = 7 × 3 = 21.
- [1, 9, 9, 3]: max = 9, contribution = 9 × 4 = 36.
However, there may be other equally optimal partitions. The key is the total sum of 83 cannot be exceeded.
Example 3
Input: arr = [1], k = 1
Output: 1
Explanation: With only one element and k = 1, the only possible partition is [1]. The maximum of this subarray is 1, so the sum is 1.
Constraints
- 1 ≤ arr.length ≤ 500
- 0 ≤ arr[i] ≤ 10^9
- 1 ≤ k ≤ arr.length
- The answer fits in a 32-bit integer.
Editorial
Brute Force
Intuition
The most natural way to think about this problem is to try every possible way to partition the array and pick the one that gives the largest sum.
Imagine you're reading the array from left to right. At each position, you must decide: how long should the current partition be? It can be 1, 2, ..., up to k elements (as long as you don't go past the end of the array). Once you decide the length, all elements in that partition are replaced by the partition's maximum, and you move on to solve the remainder of the array.
This is a classic choices at each step problem. At the first position, you have up to k choices. For each choice, you recurse on the remaining array, which again has up to k choices, and so on. You try all combinations and return the maximum total sum.
Step-by-Step Explanation
Let's trace through with arr = [1, 15, 7, 9, 2, 5, 10], k = 3:
Step 1: At position 0, we have 3 choices:
- Length 1: Take [1]. max=1, cost=1×1=1. Recurse on remaining [15,7,9,2,5,10].
- Length 2: Take [1,15]. max=15, cost=15×2=30. Recurse on remaining [7,9,2,5,10].
- Length 3: Take [1,15,7]. max=15, cost=15×3=45. Recurse on remaining [9,2,5,10].
Step 2: Explore Choice 3 (length 3, cost 45). Now at position 3, solving [9,2,5,10]:
- Length 1: Take [9]. cost=9. Recurse on [2,5,10].
- Length 2: Take [9,2]. max=9, cost=18. Recurse on [5,10].
- Length 3: Take [9,2,5]. max=9, cost=27. Recurse on [10].
Step 3: Explore Choice 3→1 (length 1, cost 9). Now at position 4, solving [2,5,10]:
- Length 1: Take [2]. cost=2. Recurse on [5,10].
- Length 2: Take [2,5]. max=5, cost=10. Recurse on [10].
- Length 3: Take [2,5,10]. max=10, cost=30. No remainder.
Step 4: Explore 3→1→3 (length 3, cost 30): Total = 45 + 9 + 30 = 84.
Step 5: We continue exploring other paths. For example:
- Path 3→2 (cost 45+18): At [5,10], best is length 2 → max=10, cost=20. Total = 45+18+20 = 83.
- Path 2→... explores [7,9,2,5,10] from position 2. Many branches.
- Path 1→... explores [15,7,9,2,5,10] from position 1. Even more branches.
Step 6: After exhausting all paths, the maximum found is 84.
Result: 84.
Algorithm
- Define a recursive function
solve(start)that returns the maximum sum achievable for arr[start..n-1]. - Base case: If
start >= n, return 0 (no elements left). - Recursive case: Try every partition length
lenfrom 1 to min(k, n - start):- Track the running maximum of arr[start..start+len-1].
- Compute
candidate = max_val × len + solve(start + len). - Keep the maximum candidate.
- Return the best candidate.
- The answer is
solve(0).
Code
class Solution {
public:
int maxSumAfterPartitioning(vector<int>& arr, int k) {
int n = arr.size();
return solve(arr, k, 0, n);
}
private:
int solve(vector<int>& arr, int k, int start, int n) {
if (start >= n) return 0;
int maxVal = 0, best = 0;
for (int len = 1; len <= k && start + len <= n; len++) {
maxVal = max(maxVal, arr[start + len - 1]);
int candidate = maxVal * len + solve(arr, k, start + len, n);
best = max(best, candidate);
}
return best;
}
};class Solution:
def maxSumAfterPartitioning(self, arr: list[int], k: int) -> int:
n = len(arr)
def solve(start):
if start >= n:
return 0
max_val = 0
best = 0
for length in range(1, k + 1):
if start + length > n:
break
max_val = max(max_val, arr[start + length - 1])
candidate = max_val * length + solve(start + length)
best = max(best, candidate)
return best
return solve(0)class Solution {
public int maxSumAfterPartitioning(int[] arr, int k) {
return solve(arr, k, 0);
}
private int solve(int[] arr, int k, int start) {
if (start >= arr.length) return 0;
int maxVal = 0, best = 0;
for (int len = 1; len <= k && start + len <= arr.length; len++) {
maxVal = Math.max(maxVal, arr[start + len - 1]);
int candidate = maxVal * len + solve(arr, k, start + len);
best = Math.max(best, candidate);
}
return best;
}
}Complexity Analysis
Time Complexity: O(k^n)
At each position, we branch into up to k recursive calls. The recursion depth is at most n (if every partition has length 1). This creates a recursion tree with up to k^n leaves. For n = 500 and k = 500, this is astronomically large — the universe would end before the computation finishes.
More precisely, the recurrence T(n) = k × T(n-1) in the worst case, which gives O(k^n).
Space Complexity: O(n)
The recursion stack depth is at most n (one call per element in the worst case). No additional data structures are used.
Why This Approach Is Not Efficient
The brute force has exponential time O(k^n). With n up to 500 and k up to 500, this is completely infeasible.
The core problem is overlapping subproblems. The recursive function solve(start) depends only on start — not on how we arrived there. But we call solve(start) multiple times from different earlier partitions. For example:
- Partition [1] then [15, 7] reaches
solve(3). - Partition [1, 15] then [7] also reaches
solve(3). - Partition [1, 15, 7] directly reaches
solve(3).
All three paths need the answer for solve(3), but the brute force recomputes it each time. There are only n unique states (start ranges from 0 to n), but the brute force may visit each state exponentially many times.
The fix: Dynamic Programming. If we store solve(start) after computing it once, we avoid all redundant work. With n unique states, each requiring O(k) work to compute, the total becomes O(n × k) — polynomial instead of exponential. We can build this table bottom-up (tabulation) for maximum efficiency.
Optimal Approach - Dynamic Programming (Tabulation)
Intuition
We define dp[i] as the maximum sum achievable for the first i elements of the array (arr[0..i-1]).
To compute dp[i], we consider all possible last partitions — the subarray that ends at position i-1. This last partition can have length 1, 2, ..., up to k (as long as it stays within bounds). For a last partition of length j:
- It covers elements arr[i-j], arr[i-j+1], ..., arr[i-1].
- The maximum value in this partition replaces all its elements.
- The contribution of this partition is
max(arr[i-j..i-1]) × j. - The rest of the array (first
i-jelements) contributesdp[i-j].
So the recurrence is:
dp[i] = max over j = 1..min(k, i) of { dp[i-j] + max(arr[i-j..i-1]) × j }
Think of it like building a wall of bricks. You're placing bricks from left to right. At each position, you decide the width of the current brick (1 to k). A wider brick earns more (its value is the maximum in its span, multiplied by its width), but it also pushes the boundary further. You want to choose widths that maximize the total wall value.
We build the table from left to right, ensuring that when we compute dp[i], all dp[i-j] values we need are already computed.
Step-by-Step Explanation
Let's trace through with arr = [1, 15, 7, 9, 2, 5, 10], k = 3:
Step 1: Initialize dp[0] = 0 (empty prefix has zero sum).
Step 2: Compute dp[1] — first 1 element [1].
- j=1: last partition is [1]. max=1. dp[0] + 1×1 = 0 + 1 = 1.
- dp[1] = 1.
Step 3: Compute dp[2] — first 2 elements [1, 15].
- j=1: last partition is [15]. max=15. dp[1] + 15×1 = 1 + 15 = 16.
- j=2: last partition is [1, 15]. max=15. dp[0] + 15×2 = 0 + 30 = 30.
- dp[2] = max(16, 30) = 30. Taking both elements as one partition is better!
Step 4: Compute dp[3] — first 3 elements [1, 15, 7].
- j=1: [7], max=7. dp[2] + 7 = 37.
- j=2: [15, 7], max=15. dp[1] + 30 = 31.
- j=3: [1, 15, 7], max=15. dp[0] + 45 = 45.
- dp[3] = max(37, 31, 45) = 45. One big partition of 3 wins — all elements become 15.
Step 5: Compute dp[4] — first 4 elements [1, 15, 7, 9].
- j=1: [9], max=9. dp[3] + 9 = 45 + 9 = 54.
- j=2: [7, 9], max=9. dp[2] + 18 = 30 + 18 = 48.
- j=3: [15, 7, 9], max=15. dp[1] + 45 = 1 + 45 = 46.
- dp[4] = max(54, 48, 46) = 54. Surprise: j=1 wins! The short partition [9] alone (54) beats the longer partitions (48, 46), because dp[3]=45 already captured the 15s optimally.
Step 6: Compute dp[5] — first 5 elements.
- j=1: [2], dp[4]+2 = 56. j=2: [9,2], max=9, dp[3]+18 = 63. j=3: [7,9,2], max=9, dp[2]+27 = 57.
- dp[5] = 63. Best is j=2.
Step 7: Compute dp[6] — first 6 elements.
- j=1: [5], dp[5]+5 = 68. j=2: [2,5], max=5, dp[4]+10 = 64. j=3: [9,2,5], max=9, dp[3]+27 = 72.
- dp[6] = 72. Best is j=3.
Step 8: Compute dp[7] — all 7 elements.
- j=1: [10], max=10. dp[6] + 10 = 72 + 10 = 82.
- j=2: [5, 10], max=10. dp[5] + 20 = 63 + 20 = 83.
- j=3: [2, 5, 10], max=10. dp[4] + 30 = 54 + 30 = 84.
- dp[7] = max(82, 83, 84) = 84. The longest partition wins here because the maximum (10) spreads across 3 elements.
Result: dp[7] = 84.
Partition Array — DP Table Filling with Partition Width Choices — Watch how we fill the DP table left-to-right. For each position, we try all partition widths (1 to k) and pick the one giving the maximum sum. The dependency always looks back by 1 to k positions.
Algorithm
- Create an array
dpof size n+1, initialized to 0.dp[i]= max sum for firstielements. - For each position
ifrom 1 to n:- Initialize
max_val = 0. - For each partition width
jfrom 1 to min(k, i):- Update
max_val = max(max_val, arr[i - j])(track the running max as we extend the partition leftward). - Compute
dp[i] = max(dp[i], dp[i - j] + max_val × j).
- Update
- Initialize
- Return
dp[n].
Key optimization: We compute max_val incrementally as we extend j. This avoids rescanning the partition for its maximum at each width, keeping the inner loop at O(1) per iteration.
Code
class Solution {
public:
int maxSumAfterPartitioning(vector<int>& arr, int k) {
int n = arr.size();
// dp[i] = max sum for the first i elements
vector<int> dp(n + 1, 0);
for (int i = 1; i <= n; i++) {
int maxVal = 0;
for (int j = 1; j <= min(k, i); j++) {
// Extend partition leftward by one element
maxVal = max(maxVal, arr[i - j]);
dp[i] = max(dp[i], dp[i - j] + maxVal * j);
}
}
return dp[n];
}
};class Solution:
def maxSumAfterPartitioning(self, arr: list[int], k: int) -> int:
n = len(arr)
# dp[i] = max sum for the first i elements
dp = [0] * (n + 1)
for i in range(1, n + 1):
max_val = 0
for j in range(1, min(k, i) + 1):
# Extend partition leftward by one element
max_val = max(max_val, arr[i - j])
dp[i] = max(dp[i], dp[i - j] + max_val * j)
return dp[n]class Solution {
public int maxSumAfterPartitioning(int[] arr, int k) {
int n = arr.length;
// dp[i] = max sum for the first i elements
int[] dp = new int[n + 1];
for (int i = 1; i <= n; i++) {
int maxVal = 0;
for (int j = 1; j <= Math.min(k, i); j++) {
// Extend partition leftward by one element
maxVal = Math.max(maxVal, arr[i - j]);
dp[i] = Math.max(dp[i], dp[i - j] + maxVal * j);
}
}
return dp[n];
}
}Complexity Analysis
Time Complexity: O(n × k)
The outer loop runs n times (once per element). For each element, the inner loop tries at most k partition widths, performing O(1) work per width (one max comparison, one multiplication, one addition). Total: n × k operations. For n = 500 and k = 500, this is 250,000 operations — extremely fast.
Space Complexity: O(n)
We use a single array dp of size n+1. No other data structures grow with input size.
Note: We could further optimize space to O(k) by observing that dp[i] only depends on dp[i-1] through dp[i-k]. However, O(n) space is already very efficient for n ≤ 500, so this optimization is rarely needed.