The simplest approach is to try every possible pair of nodes and check if swapping their values produces a valid BST. Since exactly two nodes were swapped, one such pair must restore the tree to its correct state.
For each pair of nodes, we swap their values, validate whether the resulting tree is a valid BST, and either keep the swap (if valid) or revert it (if invalid). While inefficient, this approach guarantees finding the solution by exhaustively checking all possibilities.
isBST that validates whether a tree is a valid BST using BFS with min/max bounds for each node.dfs) iterates through each node as a potential first swap candidate.dfs1) iterates through every other node as the second swap candidate.true. Otherwise, swap back and continue.# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
"""
Do not return anything, modify root in-place instead.
"""
def isBST(node):
if not node:
return True
q = deque([(node, float("-inf"), float("inf"))])
while q:
cur, left, right = q.popleft()
if not (left < cur.val < right):
return False
if cur.left:
q.append((cur.left, left, cur.val))
if cur.right:
q.append((cur.right, cur.val, right))
return True
def dfs1(node1, node2):
if not node2 or node1 == node2:
return False
node1.val, node2.val = node2.val, node1.val
if isBST(root):
return True
node1.val, node2.val = node2.val, node1.val
return dfs1(node1, node2.left) or dfs1(node1, node2.right)
def dfs(node):
if not node:
return False
if dfs1(node, root):
return True
return dfs(node.left) or dfs(node.right)
dfs(root)
return rootA key property of BSTs is that an inorder traversal produces values in sorted order. When two nodes are swapped incorrectly, this sorted sequence will have one or two "inversions" where a value is greater than the next value.
If the swapped nodes are adjacent in the inorder sequence, there will be exactly one inversion. If they are not adjacent, there will be two inversions. In the first inversion, the larger (out-of-place) node is the first swapped node. In the second inversion (or the same one if only one exists), the smaller node is the second swapped node.
By collecting all nodes during inorder traversal and then scanning for these inversions, we can identify the two swapped nodes and swap their values back.
arr[i].val > arr[i+1].val).node1 = arr[i] and node2 = arr[i+1].node2 = arr[i+1] (the first swap target is already correct).node1 and node2 to restore the BST.# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
"""
Do not return anything, modify root in-place instead.
"""
arr = []
def inorder(node):
if not node:
return
inorder(node.left)
arr.append(node)
inorder(node.right)
inorder(root)
node1, node2 = None, None
for i in range(len(arr) - 1):
if arr[i].val > arr[i + 1].val:
node2 = arr[i + 1]
if node1 is None:
node1 = arr[i]
else:
break
node1.val, node2.val = node2.val, node1.valThis approach uses the same logic as the recursive inorder traversal but implements it iteratively using an explicit stack. The advantage is that we can detect inversions on the fly during traversal rather than collecting all nodes first.
By keeping track of the previously visited node, we can immediately detect when the current node's value is less than the previous node's value, signaling an inversion. This allows us to identify the swapped nodes in a single pass through the tree.
node1, node2, prev, and curr.prev. If prev.val > curr.val, an inversion is found.node1 = prev and node2 = curr.node2 = curr and break early.node1 and node2.# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
stack = []
node1 = node2 = prev = None
curr = root
while stack or curr:
while curr:
stack.append(curr)
curr = curr.left
curr = stack.pop()
if prev and prev.val > curr.val:
node2 = curr
if not node1:
node1 = prev
else:
break
prev = curr
curr = curr.right
node1.val, node2.val = node2.val, node1.valMorris traversal allows us to perform inorder traversal without using a stack or recursion, achieving O(1) space complexity. The technique works by temporarily modifying the tree structure: for each node with a left child, we find its inorder predecessor and create a temporary link back to the current node.
This temporary threading allows us to return to ancestor nodes after processing the left subtree without needing a stack. We can detect inversions during this traversal just like in the iterative approach, but without the extra space for a stack.
node1, node2, prev, and curr (starting at root).curr is not null:curr has no left child, process it (check for inversion with prev), update prev, and move to the right child.curr has a left child, find its inorder predecessor (rightmost node in left subtree).curr and move left.curr, remove the thread, process curr, update prev, and move right.node1 and node2.# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
node1 = node2 = prev = None
curr = root
while curr:
if not curr.left:
if prev and prev.val > curr.val:
node2 = curr
if not node1:
node1 = prev
prev = curr
curr = curr.right
else:
pred = curr.left
while pred.right and pred.right != curr:
pred = pred.right
if not pred.right:
pred.right = curr
curr = curr.left
else:
pred.right = None
if prev and prev.val > curr.val:
node2 = curr
if not node1:
node1 = prev
prev = curr
curr = curr.right
node1.val, node2.val = node2.val, node1.val