Prerequisites

Before attempting this problem, you should be comfortable with:

  • Binary Search Tree (BST) Properties - Understanding that inorder traversal of a valid BST produces a sorted sequence
  • Inorder Traversal - Both recursive and iterative implementations for visiting nodes in left-root-right order
  • Tree Validation - Checking if a binary tree satisfies BST constraints using min/max bounds
  • Morris Traversal (Optional) - Space-optimized tree traversal technique using threading for O(1) space solution

1. Brute Force

Intuition

The simplest approach is to try every possible pair of nodes and check if swapping their values produces a valid BST. Since exactly two nodes were swapped, one such pair must restore the tree to its correct state.

For each pair of nodes, we swap their values, validate whether the resulting tree is a valid BST, and either keep the swap (if valid) or revert it (if invalid). While inefficient, this approach guarantees finding the solution by exhaustively checking all possibilities.

Algorithm

  1. Define a helper function isBST that validates whether a tree is a valid BST using BFS with min/max bounds for each node.
  2. Use a nested DFS approach: the outer DFS (dfs) iterates through each node as a potential first swap candidate.
  3. For each first candidate, the inner DFS (dfs1) iterates through every other node as the second swap candidate.
  4. For each pair, swap their values, check if the tree is now a valid BST, and if so, return true. Otherwise, swap back and continue.
  5. Once a valid swap is found, the tree is corrected in place.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        """
        Do not return anything, modify root in-place instead.
        """
        def isBST(node):
            if not node:
                return True

            q = deque([(node, float("-inf"), float("inf"))])
            while q:
                cur, left, right = q.popleft()
                if not (left < cur.val < right):
                    return False
                if cur.left:
                    q.append((cur.left, left, cur.val))
                if cur.right:
                    q.append((cur.right, cur.val, right))

            return True

        def dfs1(node1, node2):
            if not node2 or node1 == node2:
                return False
            
            node1.val, node2.val = node2.val, node1.val
            if isBST(root):
                return True
            
            node1.val, node2.val = node2.val, node1.val
            return dfs1(node1, node2.left) or dfs1(node1, node2.right)


        def dfs(node):
            if not node:
                return False
            
            if dfs1(node, root):
                return True
            
            return dfs(node.left) or dfs(node.right)
        
        dfs(root)
        return root

Time & Space Complexity

  • Time complexity: O(n2)O(n ^ 2)
  • Space complexity: O(n)O(n)

2. Inorder Traversal

Intuition

A key property of BSTs is that an inorder traversal produces values in sorted order. When two nodes are swapped incorrectly, this sorted sequence will have one or two "inversions" where a value is greater than the next value.

If the swapped nodes are adjacent in the inorder sequence, there will be exactly one inversion. If they are not adjacent, there will be two inversions. In the first inversion, the larger (out-of-place) node is the first swapped node. In the second inversion (or the same one if only one exists), the smaller node is the second swapped node.

By collecting all nodes during inorder traversal and then scanning for these inversions, we can identify the two swapped nodes and swap their values back.

Algorithm

  1. Perform an inorder traversal of the tree and store all nodes in a list.
  2. Scan the list to find inversions (where arr[i].val > arr[i+1].val).
  3. On the first inversion, mark node1 = arr[i] and node2 = arr[i+1].
  4. If a second inversion is found, update node2 = arr[i+1] (the first swap target is already correct).
  5. Swap the values of node1 and node2 to restore the BST.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        """
        Do not return anything, modify root in-place instead.
        """
        arr = []
        def inorder(node):
            if not node:
                return
            
            inorder(node.left)
            arr.append(node)
            inorder(node.right)

        inorder(root)
        node1, node2 = None, None

        for i in range(len(arr) - 1):
            if arr[i].val > arr[i + 1].val:
                node2 = arr[i + 1]
                if node1 is None:
                    node1 = arr[i]
                else:
                    break
        node1.val, node2.val = node2.val, node1.val

Time & Space Complexity

  • Time complexity: O(n)O(n)
  • Space complexity: O(n)O(n)

3. Iterative Inorder Traversal

Intuition

This approach uses the same logic as the recursive inorder traversal but implements it iteratively using an explicit stack. The advantage is that we can detect inversions on the fly during traversal rather than collecting all nodes first.

By keeping track of the previously visited node, we can immediately detect when the current node's value is less than the previous node's value, signaling an inversion. This allows us to identify the swapped nodes in a single pass through the tree.

Algorithm

  1. Initialize an empty stack and pointers for node1, node2, prev, and curr.
  2. Traverse the tree iteratively using the stack for inorder traversal.
  3. For each visited node, compare it with prev. If prev.val > curr.val, an inversion is found.
  4. On the first inversion, set node1 = prev and node2 = curr.
  5. On the second inversion, update node2 = curr and break early.
  6. After traversal, swap the values of node1 and node2.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        stack = []
        node1 = node2 = prev = None
        curr = root

        while stack or curr:
            while curr:
                stack.append(curr)
                curr = curr.left

            curr = stack.pop()
            if prev and prev.val > curr.val:
                node2 = curr
                if not node1:
                    node1 = prev
                else:
                    break
            prev = curr
            curr = curr.right

        node1.val, node2.val = node2.val, node1.val

Time & Space Complexity

  • Time complexity: O(n)O(n)
  • Space complexity: O(n)O(n)

4. Morris Traversal

Intuition

Morris traversal allows us to perform inorder traversal without using a stack or recursion, achieving O(1) space complexity. The technique works by temporarily modifying the tree structure: for each node with a left child, we find its inorder predecessor and create a temporary link back to the current node.

This temporary threading allows us to return to ancestor nodes after processing the left subtree without needing a stack. We can detect inversions during this traversal just like in the iterative approach, but without the extra space for a stack.

Algorithm

  1. Initialize pointers for node1, node2, prev, and curr (starting at root).
  2. While curr is not null:
    • If curr has no left child, process it (check for inversion with prev), update prev, and move to the right child.
    • If curr has a left child, find its inorder predecessor (rightmost node in left subtree).
      • If the predecessor's right pointer is null, create a thread to curr and move left.
      • If the predecessor's right pointer points to curr, remove the thread, process curr, update prev, and move right.
  3. After traversal, swap the values of node1 and node2.
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        node1 = node2 = prev = None
        curr = root

        while curr:
            if not curr.left:
                if prev and prev.val > curr.val:
                    node2 = curr
                    if not node1:
                        node1 = prev
                prev = curr
                curr = curr.right
            else:
                pred = curr.left
                while pred.right and pred.right != curr:
                    pred = pred.right

                if not pred.right:
                    pred.right = curr
                    curr = curr.left
                else:
                    pred.right = None
                    if prev and prev.val > curr.val:
                        node2 = curr
                        if not node1:
                            node1 = prev
                    prev = curr
                    curr = curr.right

        node1.val, node2.val = node2.val, node1.val

Time & Space Complexity

  • Time complexity: O(n)O(n)
  • Space complexity: O(1)O(1)

Common Pitfalls

Assuming Only One Inversion Exists

When two nodes are swapped in a BST, there can be either one or two inversions in the inorder sequence depending on whether the swapped nodes are adjacent. Assuming only one inversion and not checking for a second one causes incorrect identification of the swapped nodes.

Swapping Node References Instead of Values

The problem asks to recover the tree by swapping values, not by restructuring node pointers. Attempting to swap the actual node positions in the tree is unnecessarily complex and error-prone. Simply swap the val fields of the two identified nodes.

Incorrectly Identifying First and Second Nodes

In the first inversion, node1 is the larger (out-of-place) element, and node2 is the smaller one. If a second inversion is found, node2 should be updated to the smaller element of that inversion, while node1 remains unchanged. Mixing up which node to update at each inversion leads to swapping the wrong pair.