We need to partition the array into subarrays of length at most k, where each element in a subarray becomes the maximum value of that subarray. The goal is to maximize the total sum after this transformation.
At each position, we have a choice: end the current subarray at any of the next k positions. For each choice, we calculate the contribution (maximum element times subarray length) and recursively solve the remaining array. We try all valid partition lengths and take the maximum result.
i.i reaches the end of the array, return 0.i to min(i + k - 1, n - 1):j + 1.class Solution:
def maxSumAfterPartitioning(self, arr: List[int], k: int) -> int:
def dfs(i):
if i >= len(arr):
return 0
cur_max = 0
res = 0
for j in range(i, min(len(arr), i + k)):
cur_max = max(cur_max, arr[j])
window_size = j - i + 1
res = max(res, dfs(j + 1) + cur_max * window_size)
return res
return dfs(0)Where is the maximum length of the subarray and is the size of the array .
The recursive solution has overlapping subproblems. When we compute the maximum sum starting from index i, we might need this value multiple times from different partition choices. By caching results, we avoid redundant calculations.
This is the memoized version of the recursive approach. We store computed results in a cache and return them directly when we encounter the same subproblem again.
cache[n] = 0.i.i is in the cache, return the cached value immediately.1 to k:class Solution:
def maxSumAfterPartitioning(self, arr: List[int], k: int) -> int:
cache = { len(arr) : 0 }
def dfs(i):
if i in cache:
return cache[i]
cur_max = 0
res = 0
for j in range(i, min(len(arr), i + k)):
cur_max = max(cur_max, arr[j])
window_size = j - i + 1
res = max(res, dfs(j + 1) + cur_max * window_size)
cache[i] = res
return res
return dfs(0)Where is the maximum length of the subarray and is the size of the array .
We can flip the top-down approach to bottom-up by filling the DP table from right to left. dp[i] represents the maximum sum achievable for the subarray starting at index i.
Starting from the last element and working backwards, we build up solutions to larger subproblems using already-computed smaller ones. This eliminates recursion overhead and makes the memory access pattern more predictable.
n + 1 initialized to 0 (base case is dp[n] = 0).i = n - 1 down to 0:i.j (up to i + k - 1):dp[j + 1].dp[i] with the maximum value found.dp[0] as the answer.class Solution:
def maxSumAfterPartitioning(self, arr: List[int], k: int) -> int:
n = len(arr)
dp = [0] * (n + 1)
for i in range(n - 1, -1, -1):
cur_max = 0
for j in range(i, min(n, i + k)):
cur_max = max(cur_max, arr[j])
window_size = j - i + 1
dp[i] = max(dp[i], dp[j + 1] + cur_max * window_size)
return dp[0]Where is the maximum length of the subarray and is the size of the array .
Looking at the bottom-up solution, we notice that dp[i] only depends on dp[i+1] through dp[i+k]. We never need values more than k positions ahead. This means we can reduce our space from O(n) to O(k) using a circular buffer.
We use modulo arithmetic to wrap around and reuse array positions. This technique is common when the recurrence relation has a bounded look-ahead.
k (using circular indexing).dp[0] = arr[0] as the base case.i = 1 to n - 1:1 to k.dp[(j-1) % k] for the previous subproblem.dp[i % k].dp[(n-1) % k] as the answer.class Solution:
def maxSumAfterPartitioning(self, arr: List[int], k: int) -> int:
dp = [0] * k
dp[0] = arr[0]
for i in range(1, len(arr)):
cur_max = 0
max_at_i = 0
for j in range(i, i - k, -1):
if j < 0:
break
cur_max = max(cur_max, arr[j])
window_size = i - j + 1
cur_sum = cur_max * window_size
sub_sum = dp[(j - 1) % k] if j > 0 else 0
max_at_i = max(max_at_i, cur_sum + sub_sum)
dp[i % k] = max_at_i
return dp[(len(arr) - 1) % k]Where is the maximum length of the subarray and is the size of the array .
Each element in a partition becomes the maximum value of that partition, not stays as its original value. The sum contribution is max_element * partition_length, not the sum of original elements. This transformation is the key insight of the problem.
When iterating through possible partition endpoints, the window size is j - i + 1, not j - i. Off-by-one errors here lead to underestimating the contribution of each partition and returning a smaller sum than optimal.
For each partition starting at index i, you must track the maximum element seen so far as you extend the window. Recalculating the maximum from scratch for each window length turns an O(nk) solution into O(nk^2).