You are given the root of a Binary Search Tree (BST), convert it to a Greater Tree such that every key of the original BST is changed to the original key plus the sum of all keys greater than the original key in BST.
As a reminder, a binary search tree is a tree that satisfies these constraints:
Example 1:
Input: root = [10,9,15,4,null,12,19,2,5,11]
Output: [67,76,34,85,null,46,19,87,81,57]Example 2:
Input: root = [2,1,3]
Output: [5,6,3]Constraints:
0 <= The number of nodes in the tree <= 10,000.-10,000 <= Node.val <= 10,000root is guaranteed to be a valid binary search tree.For each node in a BST, the Greater Tree value should be the sum of all nodes with values greater than or equal to the current node. In a BST, all greater values are found in the right subtree and ancestors that are greater. A simple two-pass approach first calculates the total sum of all nodes, then traverses in-order (left to right). As we visit each node, we update its value to the remaining sum and subtract its original value from the total.
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
def getSum(node):
if not node:
return 0
return node.val + getSum(node.left) + getSum(node.right)
totalSum = getSum(root)
def dfs(node):
nonlocal totalSum
if not node:
return
dfs(node.left)
tmp = node.val
node.val = totalSum
totalSum -= tmp
dfs(node.right)
dfs(root)
return rootWe can do this in a single pass by traversing the tree in reverse in-order (right, current, left). In a BST, this visits nodes from largest to smallest. We maintain a running sum of all nodes visited so far. When we visit a node, all previously visited nodes have greater values, so we add the current node's value to our running sum and update the node to this sum.
0.# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
curSum = 0
def dfs(node):
nonlocal curSum
if not node:
return
dfs(node.right)
tmp = node.val
node.val += curSum
curSum += tmp
dfs(node.left)
dfs(root)
return rootThe recursive reverse in-order traversal can be converted to an iterative version using a stack. We simulate the call stack explicitly, pushing nodes as we traverse right, then processing them in order. This achieves the same result as the recursive one-pass solution but avoids recursion overhead and potential stack overflow for very deep trees.
0 and an empty stack.# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
curSum = 0
stack = []
node = root
while stack or node:
while node:
stack.append(node)
node = node.right
node = stack.pop()
curSum += node.val
node.val = curSum
node = node.left
return rootMorris traversal allows us to traverse a tree without using extra space for a stack or recursion. It works by temporarily modifying the tree structure to create links back to ancestor nodes, then restoring the original structure. For this problem, we use reverse Morris in-order traversal to visit nodes from largest to smallest while maintaining only O(1) extra space.
0 and start at the root.null:null, set it to point to the current node (create a thread) and move to the right child.# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
curSum = 0
cur = root
while cur:
if cur.right:
prev = cur.right
while prev.left and prev.left != cur:
prev = prev.left
if not prev.left:
prev.left = cur
cur = cur.right
else:
prev.left = None
curSum += cur.val
cur.val = curSum
cur = cur.left
else:
curSum += cur.val
cur.val = curSum
cur = cur.left
return rootStandard in-order traversal visits nodes from smallest to largest. For the Greater Tree, we need to accumulate sums from larger values first, requiring reverse in-order (right, node, left).
# Wrong - standard in-order (left, node, right)
dfs(node.left)
node.val += curSum
dfs(node.right)
# Correct - reverse in-order (right, node, left)
dfs(node.right)
node.val += curSum
dfs(node.left)A subtle bug occurs when you add to curSum before updating the node's value, or update the node before adding the original value to curSum. The sequence must be: save original, update node with accumulated sum, then add original to running sum.
# Wrong - loses original value
curSum += node.val
node.val = curSum # curSum already includes node.val
# Correct
tmp = node.val
node.val += curSum
curSum += tmpWithout a proper base case, the recursion will crash when reaching null children. Always check for null before processing a node.
# Wrong - crashes on null nodes
def dfs(node):
dfs(node.right) # NullPointerException if node is None
# Correct
def dfs(node):
if not node:
return
dfs(node.right)