A key insight is that applying XOR with k on an edge affects both endpoints. If we apply the operation on the same edge twice, the effects cancel out. This means we can effectively choose any pair of nodes to XOR (not just adjacent ones) by applying operations along the path between them.
For each node, we have two choices: keep its original value or XOR it with k. However, since each operation affects two nodes simultaneously, we must XOR an even number of nodes in total. We use DFS to track two states for each subtree: the maximum sum when an even number of nodes are XORed, and when an odd number are XORed.
res[0] (even XOR count) and res[1] (odd XOR count):res[0] = nums[node] and res[1] = nums[node] ^ k.cur, update:res[0] = max(res[0] + cur[0], res[1] + cur[1]) (combining even+even or odd+odd gives even)res[1] = max(res[1] + cur[0], res[0] + cur[1]) (combining odd+even or even+odd gives odd)res[0] from the root, representing the maximum sum with an even number of XOR operations.class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
adj = [[] for _ in range(len(nums))]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
def dfs(node, par):
res = [nums[node], nums[node] ^ k]
for child in adj[node]:
if child == par:
continue
cur = dfs(child, node)
tmp = []
tmp.append(max(res[0] + cur[0], res[1] + cur[1]))
tmp.append(max(res[1] + cur[0], res[0] + cur[1]))
res = tmp
return res
return dfs(0, -1)[0]Since we must XOR an even number of nodes, we can ignore the tree structure entirely and treat this as a selection problem. For each node, we decide whether to XOR it or not, while tracking whether we have selected an odd or even count so far. This naturally leads to a DP formulation where the state is the current index and the parity of nodes XORed.
dp[i][xorCnt] where i is the node index and xorCnt is 0 (even) or 1 (odd).dp[n][0] = 0 (valid: even count at the end) and dp[n][1] = -infinity (invalid: odd count).i, recursively compute:nums[i] as is, add to dfs(i + 1, xorCnt).nums[i] ^ k, add to dfs(i + 1, xorCnt ^ 1) (toggle parity).dfs(0, 0) to get the maximum sum starting with even parity.class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
dp = [[None] * 2 for _ in range(len(nums))] + [[0, float("-inf")]]
def dfs(i, xorCnt):
if dp[i][xorCnt] is not None:
return dp[i][xorCnt]
res = nums[i] + dfs(i + 1, xorCnt)
res = max(res, (nums[i] ^ k) + dfs(i + 1, xorCnt ^ 1))
dp[i][xorCnt] = res
return res
return dfs(0, 0)This is the iterative version of the top-down approach. Instead of using recursion with memoization, we fill the DP table from the end to the beginning. At each position, we compute the best sum for both even and odd XOR counts based on the values already computed for subsequent positions.
dp[i][0] and dp[i][1] for each index.dp[n][0] = 0 and dp[n][1] = -infinity.i = n - 1 down to 0:dp[i][0] = max(nums[i] + dp[i+1][0], (nums[i] ^ k) + dp[i+1][1])dp[i][1] = max(nums[i] + dp[i+1][1], (nums[i] ^ k) + dp[i+1][0])dp[0][0].class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
n = len(nums)
dp = [[0, 0] for _ in range(n + 1)]
dp[n][1] = float("-inf")
for i in range(n - 1, -1, -1):
dp[i][0] = max(nums[i] + dp[i + 1][0], (nums[i] ^ k) + dp[i + 1][1])
dp[i][1] = max(nums[i] + dp[i + 1][1], (nums[i] ^ k) + dp[i + 1][0])
return dp[0][0]In the bottom-up approach, each state only depends on the immediately next state. This means we do not need to store the entire DP table. We can reduce space by keeping only two variables: one for even parity and one for odd parity, updating them as we process each element.
dp = [0, -infinity] representing even and odd XOR counts.i = n - 1 down to 0:next_dp[0] = max(nums[i] + dp[0], (nums[i] ^ k) + dp[1])next_dp[1] = max(nums[i] + dp[1], (nums[i] ^ k) + dp[0])dp = next_dp.dp[0].class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
dp = [0, float("-inf")]
for i in range(len(nums) - 1, -1, -1):
next_dp = [0, 0]
next_dp[0] = max(nums[i] + dp[0], (nums[i] ^ k) + dp[1])
next_dp[1] = max(nums[i] + dp[1], (nums[i] ^ k) + dp[0])
dp = next_dp
return dp[0]For each node, compute the delta: (nums[i] ^ k) - nums[i]. A positive delta means XORing that node increases the sum. Since we must XOR an even number of nodes, we greedily pick pairs of nodes with the highest combined deltas. We sort deltas in descending order and take pairs as long as their sum is positive.
delta[i] = (nums[i] ^ k) - nums[i] for each node.res = sum(nums).0-1, 2-3, etc.):delta[i] + delta[i+1] > 0, add this sum to res.res.class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
delta = [(num ^ k) - num for num in nums]
delta.sort(reverse=True)
res = sum(nums)
for i in range(0, len(nums), 2):
if i == len(nums) - 1:
break
path_delta = delta[i] + delta[i + 1]
if path_delta <= 0:
break
res += path_delta
return resWe can optimize by avoiding sorting. For each node, greedily add whichever is larger: nums[i] or nums[i] ^ k. Track whether we have XORed an odd or even number of nodes. If odd at the end, we need to undo one operation. The minimum cost to fix parity is the smallest absolute difference |nums[i] ^ k - nums[i]| across all nodes.
res = 0, xorCnt = 0, and minDiff = infinity.nums[i] ^ k > nums[i], add nums[i] ^ k to res and toggle xorCnt.nums[i] to res.minDiff = min(minDiff, |nums[i] ^ k - nums[i]|).xorCnt is odd, subtract minDiff from res to fix parity.res.class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
xorCnt = res = 0
minDiff = 1 << 30
for num in nums:
xorNum = num ^ k
if xorNum > num:
res += xorNum
xorCnt ^= 1
else:
res += num
minDiff = min(minDiff, abs(xorNum - num))
return res - xorCnt * minDiffThe XOR operation on an edge affects both endpoints simultaneously. This means you must XOR an even number of nodes in total. Forgetting this constraint and greedily XORing any node that benefits individually will produce incorrect results when an odd number of nodes would improve.
A key insight is that applying XOR operations along a path allows you to effectively XOR any pair of nodes, not just adjacent ones. Many solutions incorrectly try to only consider adjacent node pairs, missing that the tree structure enables reaching any two nodes through a sequence of edge operations.
When summing node values (especially after XOR operations), the total can exceed 32-bit integer limits. Using int instead of long in languages like Java, C++, or C# will cause overflow errors on large inputs.
In the optimal greedy approach, if an odd number of nodes are XORed, you need to undo one operation. The cost to fix parity is the minimum absolute difference |nums[i] ^ k - nums[i]| across all nodes. A common mistake is taking the minimum gain rather than the minimum absolute difference, or forgetting to track this value during iteration.
Applying XOR with k twice on the same value returns the original value. Some solutions fail to recognize that repeated operations on the same edge cancel out, leading to unnecessarily complex state tracking or incorrect assumptions about which operations are possible.