Prerequisites
Before attempting this problem, you should be comfortable with:
- Binary Search Tree (BST) Properties - Understanding that inorder traversal of a valid BST produces a sorted sequence
- Inorder Traversal - Both recursive and iterative implementations for visiting nodes in left-root-right order
- Tree Validation - Checking if a binary tree satisfies BST constraints using min/max bounds
- Morris Traversal (Optional) - Space-optimized tree traversal technique using threading for O(1) space solution
1. Brute Force
Intuition
The simplest approach is to try every possible pair of nodes and check if swapping their values produces a valid BST. Since exactly two nodes were swapped, one such pair must restore the tree to its correct state.
For each pair of nodes, we swap their values, validate whether the resulting tree is a valid BST, and either keep the swap (if valid) or revert it (if invalid). While inefficient, this approach guarantees finding the solution by exhaustively checking all possibilities.
Algorithm
- Define a helper function
isBST that validates whether a tree is a valid BST using BFS with min/max bounds for each node.
- Use a nested DFS approach: the outer DFS (
dfs) iterates through each node as a potential first swap candidate.
- For each first candidate, the inner DFS (
dfs1) iterates through every other node as the second swap candidate.
- For each pair, swap their values, check if the tree is now a valid BST, and if so, return
true. Otherwise, swap back and continue.
- Once a valid swap is found, the tree is corrected in place.
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
"""
Do not return anything, modify root in-place instead.
"""
def isBST(node):
if not node:
return True
q = deque([(node, float("-inf"), float("inf"))])
while q:
cur, left, right = q.popleft()
if not (left < cur.val < right):
return False
if cur.left:
q.append((cur.left, left, cur.val))
if cur.right:
q.append((cur.right, cur.val, right))
return True
def dfs1(node1, node2):
if not node2 or node1 == node2:
return False
node1.val, node2.val = node2.val, node1.val
if isBST(root):
return True
node1.val, node2.val = node2.val, node1.val
return dfs1(node1, node2.left) or dfs1(node1, node2.right)
def dfs(node):
if not node:
return False
if dfs1(node, root):
return True
return dfs(node.left) or dfs(node.right)
dfs(root)
return root
public class Solution {
public void recoverTree(TreeNode root) {
dfs(root, root);
}
private boolean dfs(TreeNode node1, TreeNode root) {
if (node1 == null) return false;
if (dfs1(node1, root, root)) return true;
return dfs(node1.left, root) || dfs(node1.right, root);
}
private boolean dfs1(TreeNode node1, TreeNode node2, TreeNode root) {
if (node2 == null || node1 == node2) return false;
swap(node1, node2);
if (isBST(root)) return true;
swap(node1, node2);
return dfs1(node1, node2.left, root) || dfs1(node1, node2.right, root);
}
private boolean isBST(TreeNode node) {
Queue<Object[]> q = new LinkedList<>();
q.offer(new Object[]{node, Long.MIN_VALUE, Long.MAX_VALUE});
while (!q.isEmpty()) {
Object[] curr = q.poll();
TreeNode n = (TreeNode) curr[0];
long left = (long) curr[1];
long right = (long) curr[2];
if (n == null) continue;
if (n.val <= left || n.val >= right) return false;
q.offer(new Object[]{n.left, left, (long) n.val});
q.offer(new Object[]{n.right, (long) n.val, right});
}
return true;
}
private void swap(TreeNode a, TreeNode b) {
int tmp = a.val;
a.val = b.val;
b.val = tmp;
}
}
class Solution {
public:
void recoverTree(TreeNode* root) {
dfs(root, root, root);
}
bool dfs(TreeNode* node1, TreeNode* node2, TreeNode* root) {
if (!node1) return false;
if (dfs1(node1, node2, root)) return true;
return dfs(node1->left, node2, root) || dfs(node1->right, node2, root);
}
bool dfs1(TreeNode* node1, TreeNode* node2, TreeNode* root) {
if (!node2 || node1 == node2) return false;
swap(node1->val, node2->val);
if (isBST(root)) return true;
swap(node1->val, node2->val);
return dfs1(node1, node2->left, root) || dfs1(node1, node2->right, root);
}
bool isBST(TreeNode* node) {
TreeNode* prev = nullptr;
return inorder(node, prev);
}
bool inorder(TreeNode* node, TreeNode*& prev) {
if (!node) return true;
if (!inorder(node->left, prev)) return false;
if (prev && prev->val >= node->val) return false;
prev = node;
return inorder(node->right, prev);
}
};
class Solution {
recoverTree(root) {
const isBST = (node) => {
const q = [[node, -Infinity, Infinity]];
while (q.length) {
const [cur, left, right] = q.shift();
if (!cur) continue;
if (!(left < cur.val && cur.val < right)) return false;
q.push([cur.left, left, cur.val]);
q.push([cur.right, cur.val, right]);
}
return true;
};
const dfs1 = (node1, node2) => {
if (!node2 || node1 === node2) return false;
[node1.val, node2.val] = [node2.val, node1.val];
if (isBST(root)) return true;
[node1.val, node2.val] = [node2.val, node1.val];
return dfs1(node1, node2.left) || dfs1(node1, node2.right);
};
const dfs = (node) => {
if (!node) return false;
if (dfs1(node, root)) return true;
return dfs(node.left) || dfs(node.right);
};
dfs(root);
}
}
public class Solution {
public void RecoverTree(TreeNode root) {
Dfs(root, root);
}
private bool Dfs(TreeNode node1, TreeNode root) {
if (node1 == null) return false;
if (Dfs1(node1, root, root)) return true;
return Dfs(node1.left, root) || Dfs(node1.right, root);
}
private bool Dfs1(TreeNode node1, TreeNode node2, TreeNode root) {
if (node2 == null || node1 == node2) return false;
Swap(node1, node2);
if (IsBST(root)) return true;
Swap(node1, node2);
return Dfs1(node1, node2.left, root) || Dfs1(node1, node2.right, root);
}
private bool IsBST(TreeNode node) {
var q = new Queue<(TreeNode node, long min, long max)>();
q.Enqueue((node, long.MinValue, long.MaxValue));
while (q.Count > 0) {
var (cur, min, max) = q.Dequeue();
if (cur == null) continue;
long val = cur.val;
if (val <= min || val >= max) return false;
q.Enqueue((cur.left, min, val));
q.Enqueue((cur.right, val, max));
}
return true;
}
private void Swap(TreeNode a, TreeNode b) {
int temp = a.val;
a.val = b.val;
b.val = temp;
}
}
func recoverTree(root *TreeNode) {
var dfs func(node1 *TreeNode) bool
var dfs1 func(node1, node2 *TreeNode) bool
isBST := func(node *TreeNode) bool {
type item struct {
n *TreeNode
min, max int64
}
q := []item{{node, math.MinInt64, math.MaxInt64}}
for len(q) > 0 {
cur := q[0]
q = q[1:]
if cur.n == nil {
continue
}
val := int64(cur.n.Val)
if val <= cur.min || val >= cur.max {
return false
}
q = append(q, item{cur.n.Left, cur.min, val})
q = append(q, item{cur.n.Right, val, cur.max})
}
return true
}
dfs1 = func(node1, node2 *TreeNode) bool {
if node2 == nil || node1 == node2 {
return false
}
node1.Val, node2.Val = node2.Val, node1.Val
if isBST(root) {
return true
}
node1.Val, node2.Val = node2.Val, node1.Val
return dfs1(node1, node2.Left) || dfs1(node1, node2.Right)
}
dfs = func(node1 *TreeNode) bool {
if node1 == nil {
return false
}
if dfs1(node1, root) {
return true
}
return dfs(node1.Left) || dfs(node1.Right)
}
dfs(root)
}
class Solution {
fun recoverTree(root: TreeNode?): Unit {
fun isBST(node: TreeNode?): Boolean {
val q = ArrayDeque<Triple<TreeNode?, Long, Long>>()
q.add(Triple(node, Long.MIN_VALUE, Long.MAX_VALUE))
while (q.isNotEmpty()) {
val (cur, min, max) = q.removeFirst()
if (cur == null) continue
val v = cur.`val`.toLong()
if (v <= min || v >= max) return false
q.add(Triple(cur.left, min, v))
q.add(Triple(cur.right, v, max))
}
return true
}
fun dfs1(node1: TreeNode, node2: TreeNode?): Boolean {
if (node2 == null || node1 === node2) return false
val tmp = node1.`val`
node1.`val` = node2.`val`
node2.`val` = tmp
if (isBST(root)) return true
node2.`val` = node1.`val`
node1.`val` = tmp
return dfs1(node1, node2.left) || dfs1(node1, node2.right)
}
fun dfs(node1: TreeNode?): Boolean {
if (node1 == null) return false
if (dfs1(node1, root)) return true
return dfs(node1.left) || dfs(node1.right)
}
dfs(root)
}
}
class Solution {
func recoverTree(_ root: TreeNode?) {
func isBST(_ node: TreeNode?) -> Bool {
var q: [(TreeNode?, Int64, Int64)] = [(node, Int64.min, Int64.max)]
while !q.isEmpty {
let (cur, minVal, maxVal) = q.removeFirst()
guard let cur = cur else { continue }
let v = Int64(cur.val)
if v <= minVal || v >= maxVal { return false }
q.append((cur.left, minVal, v))
q.append((cur.right, v, maxVal))
}
return true
}
func dfs1(_ node1: TreeNode, _ node2: TreeNode?) -> Bool {
guard let node2 = node2 else { return false }
if node1 === node2 { return false }
let tmp = node1.val
node1.val = node2.val
node2.val = tmp
if isBST(root) { return true }
node2.val = node1.val
node1.val = tmp
return dfs1(node1, node2.left) || dfs1(node1, node2.right)
}
func dfs(_ node1: TreeNode?) -> Bool {
guard let node1 = node1 else { return false }
if dfs1(node1, root) { return true }
return dfs(node1.left) || dfs(node1.right)
}
_ = dfs(root)
}
}
impl Solution {
pub fn recover_tree(root: &mut Option<Rc<RefCell<TreeNode>>>) {
fn collect(node: &Option<Rc<RefCell<TreeNode>>>, nodes: &mut Vec<Rc<RefCell<TreeNode>>>) {
if let Some(n) = node {
let n_ref = n.borrow();
collect(&n_ref.left, nodes);
drop(n_ref);
nodes.push(Rc::clone(n));
let n_ref = n.borrow();
collect(&n_ref.right, nodes);
}
}
fn is_bst(node: &Option<Rc<RefCell<TreeNode>>>, lo: i64, hi: i64) -> bool {
match node {
None => true,
Some(n) => {
let n_ref = n.borrow();
let v = n_ref.val as i64;
if v <= lo || v >= hi {
return false;
}
is_bst(&n_ref.left, lo, v) && is_bst(&n_ref.right, v, hi)
}
}
}
let mut nodes = Vec::new();
collect(root, &mut nodes);
for i in 0..nodes.len() {
for j in i + 1..nodes.len() {
{
let mut a = nodes[i].borrow_mut();
let mut b = nodes[j].borrow_mut();
std::mem::swap(&mut a.val, &mut b.val);
}
if is_bst(root, i64::MIN, i64::MAX) {
return;
}
{
let mut a = nodes[i].borrow_mut();
let mut b = nodes[j].borrow_mut();
std::mem::swap(&mut a.val, &mut b.val);
}
}
}
}
}
Time & Space Complexity
- Time complexity: O(n2)
- Space complexity: O(n)
2. Inorder Traversal
Intuition
A key property of BSTs is that an inorder traversal produces values in sorted order. When two nodes are swapped incorrectly, this sorted sequence will have one or two "inversions" where a value is greater than the next value.
If the swapped nodes are adjacent in the inorder sequence, there will be exactly one inversion. If they are not adjacent, there will be two inversions. In the first inversion, the larger (out-of-place) node is the first swapped node. In the second inversion (or the same one if only one exists), the smaller node is the second swapped node.
By collecting all nodes during inorder traversal and then scanning for these inversions, we can identify the two swapped nodes and swap their values back.
Algorithm
- Perform an inorder traversal of the tree and store all nodes in a list.
- Scan the list to find inversions (where
arr[i].val > arr[i+1].val).
- On the first inversion, mark
node1 = arr[i] and node2 = arr[i+1].
- If a second inversion is found, update
node2 = arr[i+1] (the first swap target is already correct).
- Swap the values of
node1 and node2 to restore the BST.
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
"""
Do not return anything, modify root in-place instead.
"""
arr = []
def inorder(node):
if not node:
return
inorder(node.left)
arr.append(node)
inorder(node.right)
inorder(root)
node1, node2 = None, None
for i in range(len(arr) - 1):
if arr[i].val > arr[i + 1].val:
node2 = arr[i + 1]
if node1 is None:
node1 = arr[i]
else:
break
node1.val, node2.val = node2.val, node1.val
public class Solution {
public void recoverTree(TreeNode root) {
List<TreeNode> arr = new ArrayList<>();
inorder(root, arr);
TreeNode node1 = null, node2 = null;
for (int i = 0; i < arr.size() - 1; i++) {
if (arr.get(i).val > arr.get(i + 1).val) {
node2 = arr.get(i + 1);
if (node1 == null) node1 = arr.get(i);
else break;
}
}
int temp = node1.val;
node1.val = node2.val;
node2.val = temp;
}
private void inorder(TreeNode node, List<TreeNode> arr) {
if (node == null) return;
inorder(node.left, arr);
arr.add(node);
inorder(node.right, arr);
}
}
class Solution {
public:
void recoverTree(TreeNode* root) {
vector<TreeNode*> arr;
inorder(root, arr);
TreeNode* node1 = nullptr;
TreeNode* node2 = nullptr;
for (int i = 0; i < arr.size() - 1; i++) {
if (arr[i]->val > arr[i + 1]->val) {
node2 = arr[i + 1];
if (!node1) node1 = arr[i];
else break;
}
}
swap(node1->val, node2->val);
}
void inorder(TreeNode* node, vector<TreeNode*>& arr) {
if (!node) return;
inorder(node->left, arr);
arr.push_back(node);
inorder(node->right, arr);
}
};
class Solution {
recoverTree(root) {
const arr = [];
const inorder = (node) => {
if (!node) return;
inorder(node.left);
arr.push(node);
inorder(node.right);
};
inorder(root);
let node1 = null,
node2 = null;
for (let i = 0; i < arr.length - 1; i++) {
if (arr[i].val > arr[i + 1].val) {
node2 = arr[i + 1];
if (node1 === null) node1 = arr[i];
else break;
}
}
[node1.val, node2.val] = [node2.val, node1.val];
}
}
public class Solution {
public void RecoverTree(TreeNode root) {
var arr = new List<TreeNode>();
Inorder(root, arr);
TreeNode node1 = null, node2 = null;
for (int i = 0; i < arr.Count - 1; i++) {
if (arr[i].val > arr[i + 1].val) {
node2 = arr[i + 1];
if (node1 == null) node1 = arr[i];
else break;
}
}
int temp = node1.val;
node1.val = node2.val;
node2.val = temp;
}
private void Inorder(TreeNode node, List<TreeNode> arr) {
if (node == null) return;
Inorder(node.left, arr);
arr.Add(node);
Inorder(node.right, arr);
}
}
func recoverTree(root *TreeNode) {
var arr []*TreeNode
var inorder func(node *TreeNode)
inorder = func(node *TreeNode) {
if node == nil {
return
}
inorder(node.Left)
arr = append(arr, node)
inorder(node.Right)
}
inorder(root)
var node1, node2 *TreeNode
for i := 0; i < len(arr)-1; i++ {
if arr[i].Val > arr[i+1].Val {
node2 = arr[i+1]
if node1 == nil {
node1 = arr[i]
} else {
break
}
}
}
node1.Val, node2.Val = node2.Val, node1.Val
}
class Solution {
fun recoverTree(root: TreeNode?): Unit {
val arr = mutableListOf<TreeNode>()
fun inorder(node: TreeNode?) {
if (node == null) return
inorder(node.left)
arr.add(node)
inorder(node.right)
}
inorder(root)
var node1: TreeNode? = null
var node2: TreeNode? = null
for (i in 0 until arr.size - 1) {
if (arr[i].`val` > arr[i + 1].`val`) {
node2 = arr[i + 1]
if (node1 == null) node1 = arr[i]
else break
}
}
val temp = node1!!.`val`
node1.`val` = node2!!.`val`
node2.`val` = temp
}
}
class Solution {
func recoverTree(_ root: TreeNode?) {
var arr = [TreeNode]()
func inorder(_ node: TreeNode?) {
guard let node = node else { return }
inorder(node.left)
arr.append(node)
inorder(node.right)
}
inorder(root)
var node1: TreeNode? = nil
var node2: TreeNode? = nil
for i in 0..<(arr.count - 1) {
if arr[i].val > arr[i + 1].val {
node2 = arr[i + 1]
if node1 == nil { node1 = arr[i] }
else { break }
}
}
let temp = node1!.val
node1!.val = node2!.val
node2!.val = temp
}
}
impl Solution {
pub fn recover_tree(root: &mut Option<Rc<RefCell<TreeNode>>>) {
fn inorder(node: &Option<Rc<RefCell<TreeNode>>>, arr: &mut Vec<Rc<RefCell<TreeNode>>>) {
if let Some(n) = node {
let n_ref = n.borrow();
inorder(&n_ref.left, arr);
drop(n_ref);
arr.push(Rc::clone(n));
let n_ref = n.borrow();
inorder(&n_ref.right, arr);
}
}
let mut arr = Vec::new();
inorder(root, &mut arr);
let mut node1: Option<usize> = None;
let mut node2: Option<usize> = None;
for i in 0..arr.len() - 1 {
if arr[i].borrow().val > arr[i + 1].borrow().val {
node2 = Some(i + 1);
if node1.is_none() {
node1 = Some(i);
} else {
break;
}
}
}
let i = node1.unwrap();
let j = node2.unwrap();
let mut a = arr[i].borrow_mut();
let mut b = arr[j].borrow_mut();
std::mem::swap(&mut a.val, &mut b.val);
}
}
Time & Space Complexity
- Time complexity: O(n)
- Space complexity: O(n)
3. Iterative Inorder Traversal
Intuition
This approach uses the same logic as the recursive inorder traversal but implements it iteratively using an explicit stack. The advantage is that we can detect inversions on the fly during traversal rather than collecting all nodes first.
By keeping track of the previously visited node, we can immediately detect when the current node's value is less than the previous node's value, signaling an inversion. This allows us to identify the swapped nodes in a single pass through the tree.
Algorithm
- Initialize an empty stack and pointers for
node1, node2, prev, and curr.
- Traverse the tree iteratively using the stack for inorder traversal.
- For each visited node, compare it with
prev. If prev.val > curr.val, an inversion is found.
- On the first inversion, set
node1 = prev and node2 = curr.
- On the second inversion, update
node2 = curr and break early.
- After traversal, swap the values of
node1 and node2.
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
stack = []
node1 = node2 = prev = None
curr = root
while stack or curr:
while curr:
stack.append(curr)
curr = curr.left
curr = stack.pop()
if prev and prev.val > curr.val:
node2 = curr
if not node1:
node1 = prev
else:
break
prev = curr
curr = curr.right
node1.val, node2.val = node2.val, node1.val
public class Solution {
public void recoverTree(TreeNode root) {
Stack<TreeNode> stack = new Stack<>();
TreeNode node1 = null, node2 = null, prev = null, curr = root;
while (!stack.isEmpty() || curr != null) {
while (curr != null) {
stack.push(curr);
curr = curr.left;
}
curr = stack.pop();
if (prev != null && prev.val > curr.val) {
node2 = curr;
if (node1 == null) node1 = prev;
else break;
}
prev = curr;
curr = curr.right;
}
int temp = node1.val;
node1.val = node2.val;
node2.val = temp;
}
}
class Solution {
public:
void recoverTree(TreeNode* root) {
stack<TreeNode*> stack;
TreeNode *node1 = nullptr, *node2 = nullptr, *prev = nullptr, *curr = root;
while (!stack.empty() || curr) {
while (curr) {
stack.push(curr);
curr = curr->left;
}
curr = stack.top(); stack.pop();
if (prev && prev->val > curr->val) {
node2 = curr;
if (!node1) node1 = prev;
else break;
}
prev = curr;
curr = curr->right;
}
swap(node1->val, node2->val);
}
};
class Solution {
recoverTree(root) {
let stack = [];
let node1 = null,
node2 = null,
prev = null,
curr = root;
while (stack.length > 0 || curr) {
while (curr) {
stack.push(curr);
curr = curr.left;
}
curr = stack.pop();
if (prev && prev.val > curr.val) {
node2 = curr;
if (!node1) node1 = prev;
else break;
}
prev = curr;
curr = curr.right;
}
[node1.val, node2.val] = [node2.val, node1.val];
}
}
public class Solution {
public void RecoverTree(TreeNode root) {
var stack = new Stack<TreeNode>();
TreeNode node1 = null, node2 = null, prev = null, curr = root;
while (stack.Count > 0 || curr != null) {
while (curr != null) {
stack.Push(curr);
curr = curr.left;
}
curr = stack.Pop();
if (prev != null && prev.val > curr.val) {
node2 = curr;
if (node1 == null) node1 = prev;
else break;
}
prev = curr;
curr = curr.right;
}
int temp = node1.val;
node1.val = node2.val;
node2.val = temp;
}
}
func recoverTree(root *TreeNode) {
var stack []*TreeNode
var node1, node2, prev *TreeNode
curr := root
for len(stack) > 0 || curr != nil {
for curr != nil {
stack = append(stack, curr)
curr = curr.Left
}
curr = stack[len(stack)-1]
stack = stack[:len(stack)-1]
if prev != nil && prev.Val > curr.Val {
node2 = curr
if node1 == nil {
node1 = prev
} else {
break
}
}
prev = curr
curr = curr.Right
}
node1.Val, node2.Val = node2.Val, node1.Val
}
class Solution {
fun recoverTree(root: TreeNode?): Unit {
val stack = ArrayDeque<TreeNode>()
var node1: TreeNode? = null
var node2: TreeNode? = null
var prev: TreeNode? = null
var curr = root
while (stack.isNotEmpty() || curr != null) {
while (curr != null) {
stack.addLast(curr)
curr = curr.left
}
curr = stack.removeLast()
if (prev != null && prev.`val` > curr!!.`val`) {
node2 = curr
if (node1 == null) node1 = prev
else break
}
prev = curr
curr = curr?.right
}
val temp = node1!!.`val`
node1.`val` = node2!!.`val`
node2.`val` = temp
}
}
class Solution {
func recoverTree(_ root: TreeNode?) {
var stack = [TreeNode]()
var node1: TreeNode? = nil
var node2: TreeNode? = nil
var prev: TreeNode? = nil
var curr = root
while !stack.isEmpty || curr != nil {
while curr != nil {
stack.append(curr!)
curr = curr!.left
}
curr = stack.removeLast()
if prev != nil && prev!.val > curr!.val {
node2 = curr
if node1 == nil { node1 = prev }
else { break }
}
prev = curr
curr = curr?.right
}
let temp = node1!.val
node1!.val = node2!.val
node2!.val = temp
}
}
impl Solution {
pub fn recover_tree(root: &mut Option<Rc<RefCell<TreeNode>>>) {
let mut stack: Vec<Rc<RefCell<TreeNode>>> = Vec::new();
let mut node1: Option<Rc<RefCell<TreeNode>>> = None;
let mut node2: Option<Rc<RefCell<TreeNode>>> = None;
let mut prev: Option<Rc<RefCell<TreeNode>>> = None;
let mut curr = root.clone();
while !stack.is_empty() || curr.is_some() {
while let Some(c) = curr {
curr = c.borrow().left.clone();
stack.push(c);
}
let c = stack.pop().unwrap();
if let Some(ref p) = prev {
if p.borrow().val > c.borrow().val {
node2 = Some(Rc::clone(&c));
if node1.is_none() {
node1 = Some(Rc::clone(p));
} else {
break;
}
}
}
prev = Some(Rc::clone(&c));
curr = c.borrow().right.clone();
}
let n1 = node1.unwrap();
let n2 = node2.unwrap();
std::mem::swap(&mut n1.borrow_mut().val, &mut n2.borrow_mut().val);
}
}
Time & Space Complexity
- Time complexity: O(n)
- Space complexity: O(n)
4. Morris Traversal
Intuition
Morris traversal allows us to perform inorder traversal without using a stack or recursion, achieving O(1) space complexity. The technique works by temporarily modifying the tree structure: for each node with a left child, we find its inorder predecessor and create a temporary link back to the current node.
This temporary threading allows us to return to ancestor nodes after processing the left subtree without needing a stack. We can detect inversions during this traversal just like in the iterative approach, but without the extra space for a stack.
Algorithm
- Initialize pointers for
node1, node2, prev, and curr (starting at root).
- While
curr is not null:
- If
curr has no left child, process it (check for inversion with prev), update prev, and move to the right child.
- If
curr has a left child, find its inorder predecessor (rightmost node in left subtree).
- If the predecessor's right pointer is null, create a thread to
curr and move left.
- If the predecessor's right pointer points to
curr, remove the thread, process curr, update prev, and move right.
- After traversal, swap the values of
node1 and node2.
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
node1 = node2 = prev = None
curr = root
while curr:
if not curr.left:
if prev and prev.val > curr.val:
node2 = curr
if not node1:
node1 = prev
prev = curr
curr = curr.right
else:
pred = curr.left
while pred.right and pred.right != curr:
pred = pred.right
if not pred.right:
pred.right = curr
curr = curr.left
else:
pred.right = None
if prev and prev.val > curr.val:
node2 = curr
if not node1:
node1 = prev
prev = curr
curr = curr.right
node1.val, node2.val = node2.val, node1.val
public class Solution {
public void recoverTree(TreeNode root) {
TreeNode node1 = null, node2 = null, prev = null, curr = root;
while (curr != null) {
if (curr.left == null) {
if (prev != null && prev.val > curr.val) {
node2 = curr;
if (node1 == null) node1 = prev;
}
prev = curr;
curr = curr.right;
} else {
TreeNode pred = curr.left;
while (pred.right != null && pred.right != curr) {
pred = pred.right;
}
if (pred.right == null) {
pred.right = curr;
curr = curr.left;
} else {
pred.right = null;
if (prev != null && prev.val > curr.val) {
node2 = curr;
if (node1 == null) node1 = prev;
}
prev = curr;
curr = curr.right;
}
}
}
int temp = node1.val;
node1.val = node2.val;
node2.val = temp;
}
}
class Solution {
public:
void recoverTree(TreeNode* root) {
TreeNode* node1 = nullptr;
TreeNode* node2 = nullptr;
TreeNode* prev = nullptr;
TreeNode* curr = root;
while (curr) {
if (!curr->left) {
if (prev && prev->val > curr->val) {
node2 = curr;
if (!node1) node1 = prev;
}
prev = curr;
curr = curr->right;
} else {
TreeNode* pred = curr->left;
while (pred->right && pred->right != curr) {
pred = pred->right;
}
if (!pred->right) {
pred->right = curr;
curr = curr->left;
} else {
pred->right = nullptr;
if (prev && prev->val > curr->val) {
node2 = curr;
if (!node1) node1 = prev;
}
prev = curr;
curr = curr->right;
}
}
}
swap(node1->val, node2->val);
}
};
class Solution {
recoverTree(root) {
let node1 = null,
node2 = null,
prev = null,
curr = root;
while (curr !== null) {
if (curr.left === null) {
if (prev !== null && prev.val > curr.val) {
node2 = curr;
if (node1 === null) node1 = prev;
}
prev = curr;
curr = curr.right;
} else {
let pred = curr.left;
while (pred.right !== null && pred.right !== curr) {
pred = pred.right;
}
if (pred.right === null) {
pred.right = curr;
curr = curr.left;
} else {
pred.right = null;
if (prev !== null && prev.val > curr.val) {
node2 = curr;
if (node1 === null) node1 = prev;
}
prev = curr;
curr = curr.right;
}
}
}
let temp = node1.val;
node1.val = node2.val;
node2.val = temp;
}
}
public class Solution {
public void RecoverTree(TreeNode root) {
TreeNode node1 = null, node2 = null, prev = null, curr = root;
while (curr != null) {
if (curr.left == null) {
if (prev != null && prev.val > curr.val) {
node2 = curr;
if (node1 == null) node1 = prev;
}
prev = curr;
curr = curr.right;
} else {
TreeNode pred = curr.left;
while (pred.right != null && pred.right != curr) {
pred = pred.right;
}
if (pred.right == null) {
pred.right = curr;
curr = curr.left;
} else {
pred.right = null;
if (prev != null && prev.val > curr.val) {
node2 = curr;
if (node1 == null) node1 = prev;
}
prev = curr;
curr = curr.right;
}
}
}
int temp = node1.val;
node1.val = node2.val;
node2.val = temp;
}
}
func recoverTree(root *TreeNode) {
var node1, node2, prev *TreeNode
curr := root
for curr != nil {
if curr.Left == nil {
if prev != nil && prev.Val > curr.Val {
node2 = curr
if node1 == nil {
node1 = prev
}
}
prev = curr
curr = curr.Right
} else {
pred := curr.Left
for pred.Right != nil && pred.Right != curr {
pred = pred.Right
}
if pred.Right == nil {
pred.Right = curr
curr = curr.Left
} else {
pred.Right = nil
if prev != nil && prev.Val > curr.Val {
node2 = curr
if node1 == nil {
node1 = prev
}
}
prev = curr
curr = curr.Right
}
}
}
node1.Val, node2.Val = node2.Val, node1.Val
}
class Solution {
fun recoverTree(root: TreeNode?): Unit {
var node1: TreeNode? = null
var node2: TreeNode? = null
var prev: TreeNode? = null
var curr = root
while (curr != null) {
if (curr.left == null) {
if (prev != null && prev.`val` > curr.`val`) {
node2 = curr
if (node1 == null) node1 = prev
}
prev = curr
curr = curr.right
} else {
var pred = curr.left
while (pred?.right != null && pred.right != curr) {
pred = pred.right
}
if (pred?.right == null) {
pred?.right = curr
curr = curr.left
} else {
pred.right = null
if (prev != null && prev.`val` > curr.`val`) {
node2 = curr
if (node1 == null) node1 = prev
}
prev = curr
curr = curr.right
}
}
}
val temp = node1!!.`val`
node1.`val` = node2!!.`val`
node2.`val` = temp
}
}
class Solution {
func recoverTree(_ root: TreeNode?) {
var node1: TreeNode? = nil
var node2: TreeNode? = nil
var prev: TreeNode? = nil
var curr = root
while curr != nil {
if curr!.left == nil {
if prev != nil && prev!.val > curr!.val {
node2 = curr
if node1 == nil { node1 = prev }
}
prev = curr
curr = curr!.right
} else {
var pred = curr!.left
while pred?.right != nil && pred?.right !== curr {
pred = pred?.right
}
if pred?.right == nil {
pred?.right = curr
curr = curr!.left
} else {
pred?.right = nil
if prev != nil && prev!.val > curr!.val {
node2 = curr
if node1 == nil { node1 = prev }
}
prev = curr
curr = curr!.right
}
}
}
let temp = node1!.val
node1!.val = node2!.val
node2!.val = temp
}
}
impl Solution {
pub fn recover_tree(root: &mut Option<Rc<RefCell<TreeNode>>>) {
let mut node1: Option<Rc<RefCell<TreeNode>>> = None;
let mut node2: Option<Rc<RefCell<TreeNode>>> = None;
let mut prev: Option<Rc<RefCell<TreeNode>>> = None;
let mut curr = root.clone();
while let Some(c) = curr.clone() {
if c.borrow().left.is_none() {
if let Some(ref p) = prev {
if p.borrow().val > c.borrow().val {
node2 = Some(Rc::clone(&c));
if node1.is_none() {
node1 = Some(Rc::clone(p));
}
}
}
prev = Some(Rc::clone(&c));
curr = c.borrow().right.clone();
} else {
let mut pred = c.borrow().left.clone();
loop {
let next = {
let p = pred.as_ref().unwrap().borrow();
if p.right.is_some() && !Rc::ptr_eq(p.right.as_ref().unwrap(), &c) {
p.right.clone()
} else {
break;
}
};
pred = next;
}
let pred_node = pred.unwrap();
if pred_node.borrow().right.is_none() {
pred_node.borrow_mut().right = Some(Rc::clone(&c));
curr = c.borrow().left.clone();
} else {
pred_node.borrow_mut().right = None;
if let Some(ref p) = prev {
if p.borrow().val > c.borrow().val {
node2 = Some(Rc::clone(&c));
if node1.is_none() {
node1 = Some(Rc::clone(p));
}
}
}
prev = Some(Rc::clone(&c));
curr = c.borrow().right.clone();
}
}
}
let n1 = node1.unwrap();
let n2 = node2.unwrap();
std::mem::swap(&mut n1.borrow_mut().val, &mut n2.borrow_mut().val);
}
}
Time & Space Complexity
- Time complexity: O(n)
- Space complexity: O(1)
Common Pitfalls
Assuming Only One Inversion Exists
When two nodes are swapped in a BST, there can be either one or two inversions in the inorder sequence depending on whether the swapped nodes are adjacent. Assuming only one inversion and not checking for a second one causes incorrect identification of the swapped nodes.
Swapping Node References Instead of Values
The problem asks to recover the tree by swapping values, not by restructuring node pointers. Attempting to swap the actual node positions in the tree is unnecessarily complex and error-prone. Simply swap the val fields of the two identified nodes.
Incorrectly Identifying First and Second Nodes
In the first inversion, node1 is the larger (out-of-place) element, and node2 is the smaller one. If a second inversion is found, node2 should be updated to the smaller element of that inversion, while node1 remains unchanged. Mixing up which node to update at each inversion leads to swapping the wrong pair.