You are given the root of a Binary Search Tree (BST), convert it to a Greater Tree such that every key of the original BST is changed to the original key plus the sum of all keys greater than the original key in BST.
As a reminder, a binary search tree is a tree that satisfies these constraints:
The left subtree of a node contains only nodes with keys less than the node's key.
The right subtree of a node contains only nodes with keys greater than the node's key.
Both the left and right subtrees must also be binary search trees.
Please upgrade to NeetCode Pro to view company tags.
Prerequisites
Before attempting this problem, you should be comfortable with:
Binary Search Tree Properties - Understanding that in a BST, all nodes in the right subtree are greater than the current node
Tree Traversal (In-Order) - Visiting nodes in sorted order (left, node, right) and reverse in-order (right, node, left) for descending order
Depth First Search (DFS) - Recursively traversing tree structures to process each node
Stack (for Iterative DFS) - Converting recursive tree traversal to an iterative approach using an explicit stack
1. Depth First Search (Two Pass)
Intuition
For each node in a BST, the Greater Tree value should be the sum of all nodes with values greater than or equal to the current node. In a BST, all greater values are found in the right subtree and ancestors that are greater. A simple two-pass approach first calculates the total sum of all nodes, then traverses in-order (left to right). As we visit each node, we update its value to the remaining sum and subtract its original value from the total.
Algorithm
First pass: Recursively calculate the sum of all node values in the tree.
Second pass: Perform an in-order traversal (left, current, right).
At each node during the second pass:
Save the original value temporarily.
Update the node's value to the current total sum (which represents all values >= this node).
Subtract the original value from the total sum for subsequent nodes.
/**
* Definition for a binary tree node.
* public class TreeNode {
* public int val;
* public TreeNode left;
* public TreeNode right;
* public TreeNode(int val=0, TreeNode left=null, TreeNode right=null) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/publicclassSolution{publicTreeNodeConvertBST(TreeNode root){intGetSum(TreeNode node){if(node ==null)return0;return node.val +GetSum(node.left)+GetSum(node.right);}int totalSum =GetSum(root);voidDfs(TreeNode node){if(node ==null)return;Dfs(node.left);int tmp = node.val;
node.val = totalSum;
totalSum -= tmp;Dfs(node.right);}Dfs(root);return root;}}
/**
* Definition for a binary tree node.
* type TreeNode struct {
* Val int
* Left *TreeNode
* Right *TreeNode
* }
*/funcconvertBST(root *TreeNode)*TreeNode {var getSum func(node *TreeNode)int
getSum =func(node *TreeNode)int{if node ==nil{return0}return node.Val +getSum(node.Left)+getSum(node.Right)}
totalSum :=getSum(root)var dfs func(node *TreeNode)
dfs =func(node *TreeNode){if node ==nil{return}dfs(node.Left)
tmp := node.Val
node.Val = totalSum
totalSum -= tmp
dfs(node.Right)}dfs(root)return root
}
/**
* Definition for a binary tree node.
* class TreeNode(var `val`: Int = 0) {
* var left: TreeNode? = null
* var right: TreeNode? = null
* }
*/class Solution {funconvertBST(root: TreeNode?): TreeNode?{fungetSum(node: TreeNode?): Int {if(node ==null)return0return node.`val` +getSum(node.left)+getSum(node.right)}var totalSum =getSum(root)fundfs(node: TreeNode?){if(node ==null)returndfs(node.left)val tmp = node.`val`
node.`val` = totalSum
totalSum -= tmp
dfs(node.right)}dfs(root)return root
}}
/**
* Definition for a binary tree node.
* public class TreeNode {
* public var val: Int
* public var left: TreeNode?
* public var right: TreeNode?
* public init() { self.val = 0; self.left = nil; self.right = nil; }
* public init(_ val: Int) { self.val = val; self.left = nil; self.right = nil; }
* public init(_ val: Int, _ left: TreeNode?, _ right: TreeNode?) {
* self.val = val
* self.left = left
* self.right = right
* }
* }
*/classSolution{funcconvertBST(_ root:TreeNode?)->TreeNode?{funcgetSum(_ node:TreeNode?)->Int{guardlet node = node else{return0}return node.val +getSum(node.left)+getSum(node.right)}var totalSum =getSum(root)funcdfs(_ node:TreeNode?){guardlet node = node else{return}dfs(node.left)let tmp = node.val
node.val = totalSum
totalSum -= tmp
dfs(node.right)}dfs(root)return root
}}
Time & Space Complexity
Time complexity: O(n)
Space complexity: O(n) for recursion stack.
2. Depth First Search (One Pass)
Intuition
We can do this in a single pass by traversing the tree in reverse in-order (right, current, left). In a BST, this visits nodes from largest to smallest. We maintain a running sum of all nodes visited so far. When we visit a node, all previously visited nodes have greater values, so we add the current node's value to our running sum and update the node to this sum.
Algorithm
Initialize a running sum variable to 0.
Perform a reverse in-order traversal: first visit the right subtree, then the current node, then the left subtree.
At each node:
Save the node's original value.
Add the running sum to the node's value (this gives the sum of all greater nodes plus itself).
Add the original value to the running sum for future nodes.
Continue until all nodes are processed and return the root.
/**
* Definition for a binary tree node.
* class TreeNode {
* constructor(val = 0, left = null, right = null) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/classSolution{/**
* @param {TreeNode} root
* @return {TreeNode}
*/convertBST(root){let curSum =0;constdfs=(node)=>{if(!node)return;dfs(node.right);let tmp = node.val;
node.val += curSum;
curSum += tmp;dfs(node.left);};dfs(root);return root;}}
/**
* Definition for a binary tree node.
* public class TreeNode {
* public int val;
* public TreeNode left;
* public TreeNode right;
* public TreeNode(int val=0, TreeNode left=null, TreeNode right=null) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/publicclassSolution{publicTreeNodeConvertBST(TreeNode root){int curSum =0;voidDfs(TreeNode node){if(node ==null)return;Dfs(node.right);int tmp = node.val;
node.val += curSum;
curSum += tmp;Dfs(node.left);}Dfs(root);return root;}}
/**
* Definition for a binary tree node.
* type TreeNode struct {
* Val int
* Left *TreeNode
* Right *TreeNode
* }
*/funcconvertBST(root *TreeNode)*TreeNode {
curSum :=0var dfs func(node *TreeNode)
dfs =func(node *TreeNode){if node ==nil{return}dfs(node.Right)
tmp := node.Val
node.Val += curSum
curSum += tmp
dfs(node.Left)}dfs(root)return root
}
/**
* Definition for a binary tree node.
* class TreeNode(var `val`: Int = 0) {
* var left: TreeNode? = null
* var right: TreeNode? = null
* }
*/class Solution {funconvertBST(root: TreeNode?): TreeNode?{var curSum =0fundfs(node: TreeNode?){if(node ==null)returndfs(node.right)val tmp = node.`val`
node.`val` += curSum
curSum += tmp
dfs(node.left)}dfs(root)return root
}}
/**
* Definition for a binary tree node.
* public class TreeNode {
* public var val: Int
* public var left: TreeNode?
* public var right: TreeNode?
* public init() { self.val = 0; self.left = nil; self.right = nil; }
* public init(_ val: Int) { self.val = val; self.left = nil; self.right = nil; }
* public init(_ val: Int, _ left: TreeNode?, _ right: TreeNode?) {
* self.val = val
* self.left = left
* self.right = right
* }
* }
*/classSolution{funcconvertBST(_ root:TreeNode?)->TreeNode?{var curSum =0funcdfs(_ node:TreeNode?){guardlet node = node else{return}dfs(node.right)let tmp = node.val
node.val += curSum
curSum += tmp
dfs(node.left)}dfs(root)return root
}}
Time & Space Complexity
Time complexity: O(n)
Space complexity: O(n) for recursion stack.
3. Iterative DFS
Intuition
The recursive reverse in-order traversal can be converted to an iterative version using a stack. We simulate the call stack explicitly, pushing nodes as we traverse right, then processing them in order. This achieves the same result as the recursive one-pass solution but avoids recursion overhead and potential stack overflow for very deep trees.
Algorithm
Initialize a running sum to 0 and an empty stack.
Start with the root node and traverse as far right as possible, pushing each node onto the stack.
Pop a node from the stack:
Add its value to the running sum.
Update the node's value to the running sum.
Move to its left child and repeat the right-traversal process.
Continue until the stack is empty and there are no more nodes to process.
/**
* Definition for a binary tree node.
* class TreeNode {
* constructor(val = 0, left = null, right = null) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/classSolution{/**
* @param {TreeNode} root
* @return {TreeNode}
*/convertBST(root){let curSum =0;const stack =[];let node = root;while(stack.length || node){while(node){
stack.push(node);
node = node.right;}
node = stack.pop();
curSum += node.val;
node.val = curSum;
node = node.left;}return root;}}
/**
* Definition for a binary tree node.
* public class TreeNode {
* public int val;
* public TreeNode left;
* public TreeNode right;
* public TreeNode(int val=0, TreeNode left=null, TreeNode right=null) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/publicclassSolution{publicTreeNodeConvertBST(TreeNode root){int curSum =0;Stack<TreeNode> stack =newStack<TreeNode>();TreeNode node = root;while(stack.Count >0|| node !=null){while(node !=null){
stack.Push(node);
node = node.right;}
node = stack.Pop();
curSum += node.val;
node.val = curSum;
node = node.left;}return root;}}
/**
* Definition for a binary tree node.
* type TreeNode struct {
* Val int
* Left *TreeNode
* Right *TreeNode
* }
*/funcconvertBST(root *TreeNode)*TreeNode {
curSum :=0
stack :=[]*TreeNode{}
node := root
forlen(stack)>0|| node !=nil{for node !=nil{
stack =append(stack, node)
node = node.Right
}
node = stack[len(stack)-1]
stack = stack[:len(stack)-1]
curSum += node.Val
node.Val = curSum
node = node.Left
}return root
}
/**
* Definition for a binary tree node.
* class TreeNode(var `val`: Int = 0) {
* var left: TreeNode? = null
* var right: TreeNode? = null
* }
*/class Solution {funconvertBST(root: TreeNode?): TreeNode?{var curSum =0val stack = ArrayDeque<TreeNode>()var node = root
while(stack.isNotEmpty()|| node !=null){while(node !=null){
stack.addLast(node)
node = node.right
}
node = stack.removeLast()
curSum += node.`val`
node.`val` = curSum
node = node.left
}return root
}}
/**
* Definition for a binary tree node.
* public class TreeNode {
* public var val: Int
* public var left: TreeNode?
* public var right: TreeNode?
* public init() { self.val = 0; self.left = nil; self.right = nil; }
* public init(_ val: Int) { self.val = val; self.left = nil; self.right = nil; }
* public init(_ val: Int, _ left: TreeNode?, _ right: TreeNode?) {
* self.val = val
* self.left = left
* self.right = right
* }
* }
*/classSolution{funcconvertBST(_ root:TreeNode?)->TreeNode?{var curSum =0var stack =[TreeNode]()var node = root
while!stack.isEmpty || node !=nil{while node !=nil{
stack.append(node!)
node = node?.right}
node = stack.removeLast()
curSum += node!.val
node!.val = curSum
node = node?.left}return root
}}
Time & Space Complexity
Time complexity: O(n)
Space complexity: O(n)
4. Morris Traversal
Intuition
Morris traversal allows us to traverse a tree without using extra space for a stack or recursion. It works by temporarily modifying the tree structure to create links back to ancestor nodes, then restoring the original structure. For this problem, we use reverse Morris in-order traversal to visit nodes from largest to smallest while maintaining only O(1) extra space.
Algorithm
Initialize a running sum to 0 and start at the root.
While the current node is not null:
If the current node has a right child, find its in-order predecessor in the right subtree (the leftmost node).
If the predecessor's left child is null, set it to point to the current node (create a thread) and move to the right child.
If the predecessor's left child points to the current node, remove the thread, process the current node (add to running sum and update value), and move to the left child.
If there's no right child, process the current node and move to the left child.
# Definition for a binary tree node.# class TreeNode:# def __init__(self, val=0, left=None, right=None):# self.val = val# self.left = left# self.right = rightclassSolution:defconvertBST(self, root: Optional[TreeNode])-> Optional[TreeNode]:
curSum =0
cur = root
while cur:if cur.right:
prev = cur.right
while prev.left and prev.left != cur:
prev = prev.left
ifnot prev.left:
prev.left = cur
cur = cur.right
else:
prev.left =None
curSum += cur.val
cur.val = curSum
cur = cur.left
else:
curSum += cur.val
cur.val = curSum
cur = cur.left
return root
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode() {}
* TreeNode(int val) { this.val = val; }
* TreeNode(int val, TreeNode left, TreeNode right) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/publicclassSolution{publicTreeNodeconvertBST(TreeNode root){int curSum =0;TreeNode cur = root;while(cur !=null){if(cur.right !=null){TreeNode prev = cur.right;while(prev.left !=null&& prev.left != cur){
prev = prev.left;}if(prev.left ==null){
prev.left = cur;
cur = cur.right;}else{
prev.left =null;
curSum += cur.val;
cur.val = curSum;
cur = cur.left;}}else{
curSum += cur.val;
cur.val = curSum;
cur = cur.left;}}return root;}}
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/classSolution{public:
TreeNode*convertBST(TreeNode* root){int curSum =0;
TreeNode* cur = root;while(cur){if(cur->right){
TreeNode* prev = cur->right;while(prev->left && prev->left != cur){
prev = prev->left;}if(!prev->left){
prev->left = cur;
cur = cur->right;}else{
prev->left =nullptr;
curSum += cur->val;
cur->val = curSum;
cur = cur->left;}}else{
curSum += cur->val;
cur->val = curSum;
cur = cur->left;}}return root;}};
/**
* Definition for a binary tree node.
* class TreeNode {
* constructor(val = 0, left = null, right = null) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/classSolution{/**
* @param {TreeNode} root
* @return {TreeNode}
*/convertBST(root){let curSum =0;let cur = root;while(cur){if(cur.right){let prev = cur.right;while(prev.left && prev.left !== cur){
prev = prev.left;}if(!prev.left){
prev.left = cur;
cur = cur.right;}else{
prev.left =null;
curSum += cur.val;
cur.val = curSum;
cur = cur.left;}}else{
curSum += cur.val;
cur.val = curSum;
cur = cur.left;}}return root;}}
/**
* Definition for a binary tree node.
* public class TreeNode {
* public int val;
* public TreeNode left;
* public TreeNode right;
* public TreeNode(int val=0, TreeNode left=null, TreeNode right=null) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/publicclassSolution{publicTreeNodeConvertBST(TreeNode root){int curSum =0;TreeNode cur = root;while(cur !=null){if(cur.right !=null){TreeNode prev = cur.right;while(prev.left !=null&& prev.left != cur){
prev = prev.left;}if(prev.left ==null){
prev.left = cur;
cur = cur.right;}else{
prev.left =null;
curSum += cur.val;
cur.val = curSum;
cur = cur.left;}}else{
curSum += cur.val;
cur.val = curSum;
cur = cur.left;}}return root;}}
/**
* Definition for a binary tree node.
* type TreeNode struct {
* Val int
* Left *TreeNode
* Right *TreeNode
* }
*/funcconvertBST(root *TreeNode)*TreeNode {
curSum :=0
cur := root
for cur !=nil{if cur.Right !=nil{
prev := cur.Right
for prev.Left !=nil&& prev.Left != cur {
prev = prev.Left
}if prev.Left ==nil{
prev.Left = cur
cur = cur.Right
}else{
prev.Left =nil
curSum += cur.Val
cur.Val = curSum
cur = cur.Left
}}else{
curSum += cur.Val
cur.Val = curSum
cur = cur.Left
}}return root
}
/**
* Definition for a binary tree node.
* class TreeNode(var `val`: Int = 0) {
* var left: TreeNode? = null
* var right: TreeNode? = null
* }
*/class Solution {funconvertBST(root: TreeNode?): TreeNode?{var curSum =0var cur = root
while(cur !=null){if(cur.right !=null){var prev = cur.right
while(prev!!.left !=null&& prev.left != cur){
prev = prev.left
}if(prev.left ==null){
prev.left = cur
cur = cur.right
}else{
prev.left =null
curSum += cur.`val`
cur.`val` = curSum
cur = cur.left
}}else{
curSum += cur.`val`
cur.`val` = curSum
cur = cur.left
}}return root
}}
/**
* Definition for a binary tree node.
* public class TreeNode {
* public var val: Int
* public var left: TreeNode?
* public var right: TreeNode?
* public init() { self.val = 0; self.left = nil; self.right = nil; }
* public init(_ val: Int) { self.val = val; self.left = nil; self.right = nil; }
* public init(_ val: Int, _ left: TreeNode?, _ right: TreeNode?) {
* self.val = val
* self.left = left
* self.right = right
* }
* }
*/classSolution{funcconvertBST(_ root:TreeNode?)->TreeNode?{var curSum =0var cur = root
while cur !=nil{if cur!.right!=nil{var prev = cur!.rightwhile prev!.left!=nil&& prev!.left!== cur {
prev = prev!.left}if prev!.left==nil{
prev!.left= cur
cur = cur!.right}else{
prev!.left=nil
curSum += cur!.val
cur!.val = curSum
cur = cur!.left}}else{
curSum += cur!.val
cur!.val = curSum
cur = cur!.left}}return root
}}
Time & Space Complexity
Time complexity: O(n)
Space complexity: O(1)
Common Pitfalls
Using Standard In-Order Traversal Instead of Reverse In-Order
Standard in-order traversal visits nodes from smallest to largest. For the Greater Tree, we need to accumulate sums from larger values first, requiring reverse in-order (right, node, left).
# Wrong - standard in-order (left, node, right)
dfs(node.left)
node.val += curSum
dfs(node.right)# Correct - reverse in-order (right, node, left)
dfs(node.right)
node.val += curSum
dfs(node.left)
Incorrect Order of Update and Accumulation
A subtle bug occurs when you add to curSum before updating the node's value, or update the node before adding the original value to curSum. The sequence must be: save original, update node with accumulated sum, then add original to running sum.
# Wrong - loses original value
curSum += node.val
node.val = curSum # curSum already includes node.val# Correct
tmp = node.val
node.val += curSum
curSum += tmp
Forgetting to Handle Null Nodes in Recursive Calls
Without a proper base case, the recursion will crash when reaching null children. Always check for null before processing a node.
# Wrong - crashes on null nodesdefdfs(node):
dfs(node.right)# NullPointerException if node is None# Correctdefdfs(node):ifnot node:return
dfs(node.right)