538. Convert Bst to Greater Tree - Explanation

Problem Link

Description

You are given the root of a Binary Search Tree (BST), convert it to a Greater Tree such that every key of the original BST is changed to the original key plus the sum of all keys greater than the original key in BST.

As a reminder, a binary search tree is a tree that satisfies these constraints:

  • The left subtree of a node contains only nodes with keys less than the node's key.
  • The right subtree of a node contains only nodes with keys greater than the node's key.
  • Both the left and right subtrees must also be binary search trees.

Example 1:

Input: root = [10,9,15,4,null,12,19,2,5,11]

Output: [67,76,34,85,null,46,19,87,81,57]

Example 2:

Input: root = [2,1,3]

Output: [5,6,3]

Constraints:

  • 0 <= The number of nodes in the tree <= 10,000.
  • -10,000 <= Node.val <= 10,000
  • All the values in the tree are unique.
  • root is guaranteed to be a valid binary search tree.


Topics

Company Tags

Please upgrade to NeetCode Pro to view company tags.



Prerequisites

Before attempting this problem, you should be comfortable with:

  • Binary Search Tree Properties - Understanding that in a BST, all nodes in the right subtree are greater than the current node
  • Tree Traversal (In-Order) - Visiting nodes in sorted order (left, node, right) and reverse in-order (right, node, left) for descending order
  • Depth First Search (DFS) - Recursively traversing tree structures to process each node
  • Stack (for Iterative DFS) - Converting recursive tree traversal to an iterative approach using an explicit stack

1. Depth First Search (Two Pass)

Intuition

For each node in a BST, the Greater Tree value should be the sum of all nodes with values greater than or equal to the current node. In a BST, all greater values are found in the right subtree and ancestors that are greater. A simple two-pass approach first calculates the total sum of all nodes, then traverses in-order (left to right). As we visit each node, we update its value to the remaining sum and subtract its original value from the total.

Algorithm

  1. First pass: Recursively calculate the sum of all node values in the tree.
  2. Second pass: Perform an in-order traversal (left, current, right).
  3. At each node during the second pass:
    • Save the original value temporarily.
    • Update the node's value to the current total sum (which represents all values >= this node).
    • Subtract the original value from the total sum for subsequent nodes.
  4. Return the modified root.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        def getSum(node):
            if not node:
                return 0
            return node.val + getSum(node.left) + getSum(node.right)

        totalSum = getSum(root)

        def dfs(node):
            nonlocal totalSum
            if not node:
                return

            dfs(node.left)
            tmp = node.val
            node.val = totalSum
            totalSum -= tmp
            dfs(node.right)

        dfs(root)
        return root

Time & Space Complexity

  • Time complexity: O(n)O(n)
  • Space complexity: O(n)O(n) for recursion stack.

2. Depth First Search (One Pass)

Intuition

We can do this in a single pass by traversing the tree in reverse in-order (right, current, left). In a BST, this visits nodes from largest to smallest. We maintain a running sum of all nodes visited so far. When we visit a node, all previously visited nodes have greater values, so we add the current node's value to our running sum and update the node to this sum.

Algorithm

  1. Initialize a running sum variable to 0.
  2. Perform a reverse in-order traversal: first visit the right subtree, then the current node, then the left subtree.
  3. At each node:
    • Save the node's original value.
    • Add the running sum to the node's value (this gives the sum of all greater nodes plus itself).
    • Add the original value to the running sum for future nodes.
  4. Continue until all nodes are processed and return the root.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        curSum = 0

        def dfs(node):
            nonlocal curSum
            if not node:
                return

            dfs(node.right)
            tmp = node.val
            node.val += curSum
            curSum += tmp
            dfs(node.left)

        dfs(root)
        return root

Time & Space Complexity

  • Time complexity: O(n)O(n)
  • Space complexity: O(n)O(n) for recursion stack.

3. Iterative DFS

Intuition

The recursive reverse in-order traversal can be converted to an iterative version using a stack. We simulate the call stack explicitly, pushing nodes as we traverse right, then processing them in order. This achieves the same result as the recursive one-pass solution but avoids recursion overhead and potential stack overflow for very deep trees.

Algorithm

  1. Initialize a running sum to 0 and an empty stack.
  2. Start with the root node and traverse as far right as possible, pushing each node onto the stack.
  3. Pop a node from the stack:
    • Add its value to the running sum.
    • Update the node's value to the running sum.
    • Move to its left child and repeat the right-traversal process.
  4. Continue until the stack is empty and there are no more nodes to process.
  5. Return the root.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        curSum = 0
        stack = []
        node = root

        while stack or node:
            while node:
                stack.append(node)
                node = node.right

            node = stack.pop()
            curSum += node.val
            node.val = curSum
            node = node.left

        return root

Time & Space Complexity

  • Time complexity: O(n)O(n)
  • Space complexity: O(n)O(n)

4. Morris Traversal

Intuition

Morris traversal allows us to traverse a tree without using extra space for a stack or recursion. It works by temporarily modifying the tree structure to create links back to ancestor nodes, then restoring the original structure. For this problem, we use reverse Morris in-order traversal to visit nodes from largest to smallest while maintaining only O(1) extra space.

Algorithm

  1. Initialize a running sum to 0 and start at the root.
  2. While the current node is not null:
    • If the current node has a right child, find its in-order predecessor in the right subtree (the leftmost node).
    • If the predecessor's left child is null, set it to point to the current node (create a thread) and move to the right child.
    • If the predecessor's left child points to the current node, remove the thread, process the current node (add to running sum and update value), and move to the left child.
    • If there's no right child, process the current node and move to the left child.
  3. Return the root after all nodes are processed.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        curSum = 0
        cur = root

        while cur:
            if cur.right:
                prev = cur.right
                while prev.left and prev.left != cur:
                    prev = prev.left

                if not prev.left:
                    prev.left = cur
                    cur = cur.right
                else:
                    prev.left = None
                    curSum += cur.val
                    cur.val = curSum
                    cur = cur.left
            else:
                curSum += cur.val
                cur.val = curSum
                cur = cur.left

        return root

Time & Space Complexity

  • Time complexity: O(n)O(n)
  • Space complexity: O(1)O(1)

Common Pitfalls

Using Standard In-Order Traversal Instead of Reverse In-Order

Standard in-order traversal visits nodes from smallest to largest. For the Greater Tree, we need to accumulate sums from larger values first, requiring reverse in-order (right, node, left).

# Wrong - standard in-order (left, node, right)
dfs(node.left)
node.val += curSum
dfs(node.right)

# Correct - reverse in-order (right, node, left)
dfs(node.right)
node.val += curSum
dfs(node.left)

Incorrect Order of Update and Accumulation

A subtle bug occurs when you add to curSum before updating the node's value, or update the node before adding the original value to curSum. The sequence must be: save original, update node with accumulated sum, then add original to running sum.

# Wrong - loses original value
curSum += node.val
node.val = curSum  # curSum already includes node.val

# Correct
tmp = node.val
node.val += curSum
curSum += tmp

Forgetting to Handle Null Nodes in Recursive Calls

Without a proper base case, the recursion will crash when reaching null children. Always check for null before processing a node.

# Wrong - crashes on null nodes
def dfs(node):
    dfs(node.right)  # NullPointerException if node is None

# Correct
def dfs(node):
    if not node:
        return
    dfs(node.right)