Skip to main content

Recover Binary Search Tree

MEDIUMProblemSolveExternal Links

Description

You are given the root of a binary search tree (BST) in which the values of exactly two nodes have been swapped by mistake. Your task is to recover the tree so that it becomes a valid BST again, without changing its structure.

In other words, only the values of the two misplaced nodes should be swapped back — no node should be moved, added, or removed. The tree's shape must remain exactly as given.

Examples

Example 1

Input: root = [1, 3, null, null, 2]

Output: [3, 1, null, null, 2]

Explanation: The tree looks like this:

    1
   /
  3
   \
    2

Node 3 is in the left subtree of node 1, but 3 > 1, which violates the BST property. The nodes with values 1 and 3 were swapped by mistake. After swapping them back, the tree becomes:

    3
   /
  1
   \
    2

Now every node satisfies the BST property: 1 < 3, and 2 is in the right subtree of 1 with 1 < 2 < 3.

Example 2

Input: root = [3, 1, 4, null, null, 2]

Output: [2, 1, 4, null, null, 3]

Explanation: The tree looks like this:

    3
   / \
  1   4
     /
    2

The inorder traversal gives [1, 3, 2, 4]. In a valid BST, the inorder should be strictly increasing. Here, 3 > 2 is a violation. The nodes with values 3 and 2 were swapped by mistake. After swapping them back:

    2
   / \
  1   4
     /
    3

The inorder traversal is now [1, 2, 3, 4], which is strictly increasing — a valid BST.

Constraints

  • 2 ≤ number of nodes ≤ 1000
  • -2^31 ≤ Node.val ≤ 2^31 - 1
  • Exactly two nodes in the BST have been swapped

Editorial

Brute Force

Intuition

The most fundamental property of a BST is that its inorder traversal produces values in strictly increasing (sorted) order. If exactly two nodes were swapped, the inorder traversal will no longer be sorted — but only two values will be out of place.

Imagine you have a row of numbered cards that should be in ascending order, but someone secretly swapped two of them. The simplest way to fix this is to collect all the numbers, sort them properly, and then place them back into the card slots one by one from left to right.

Similarly, we can perform an inorder traversal of the corrupted BST, collect all values into an array, sort that array to obtain the correct ordering, and then do a second inorder traversal to write the sorted values back into the tree nodes.

Step-by-Step Explanation

Let's trace through Example 2: root = [3, 1, 4, null, null, 2]

The tree structure is:

    3
   / \
  1   4
     /
    2

Step 1: Perform inorder traversal to collect all node values.

  • Visit left subtree of 3 → visit node 1 (leaf) → collect 1
  • Visit node 3 → collect 3
  • Visit right subtree of 3 → visit left subtree of 4 → visit node 2 (leaf) → collect 2
  • Visit node 4 → collect 4
  • Collected array: [1, 3, 2, 4]

Step 2: Sort the collected array.

  • Before sort: [1, 3, 2, 4]
  • After sort: [1, 2, 3, 4]

Step 3: Perform a second inorder traversal to reassign sorted values back to nodes.

  • Visit node 1 → assign values[0] = 1 (unchanged)
  • Visit node 3 → assign values[1] = 2 (changed from 3 to 2)
  • Visit node 2 → assign values[2] = 3 (changed from 2 to 3)
  • Visit node 4 → assign values[3] = 4 (unchanged)

Step 4: The tree is now corrected:

    2
   / \
  1   4
     /
    3

Inorder: [1, 2, 3, 4] — a valid BST.

Brute Force — Collect Inorder, Sort, and Reassign — Watch how we collect the inorder traversal into an array, sort it, and then write the sorted values back into the tree to fix the BST.

Algorithm

  1. Perform an inorder traversal of the BST and store all node values in an array
  2. Sort the array to get the correct ordering
  3. Perform a second inorder traversal of the BST
  4. At each node during the second traversal, replace its value with the next value from the sorted array
  5. After the second traversal completes, the BST is restored

Code

class Solution {
public:
    void recoverTree(TreeNode* root) {
        vector<int> inorder;
        collectInorder(root, inorder);
        sort(inorder.begin(), inorder.end());
        int idx = 0;
        assignInorder(root, inorder, idx);
    }

private:
    void collectInorder(TreeNode* node, vector<int>& vals) {
        if (!node) return;
        collectInorder(node->left, vals);
        vals.push_back(node->val);
        collectInorder(node->right, vals);
    }

    void assignInorder(TreeNode* node, vector<int>& vals, int& idx) {
        if (!node) return;
        assignInorder(node->left, vals, idx);
        node->val = vals[idx++];
        assignInorder(node->right, vals, idx);
    }
};
class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        inorder = []

        def collect(node):
            if not node:
                return
            collect(node.left)
            inorder.append(node.val)
            collect(node.right)

        collect(root)
        inorder.sort()

        self.idx = 0

        def assign(node):
            if not node:
                return
            assign(node.left)
            node.val = inorder[self.idx]
            self.idx += 1
            assign(node.right)

        assign(root)
class Solution {
    private int idx = 0;

    public void recoverTree(TreeNode root) {
        List<Integer> inorder = new ArrayList<>();
        collectInorder(root, inorder);
        Collections.sort(inorder);
        idx = 0;
        assignInorder(root, inorder);
    }

    private void collectInorder(TreeNode node, List<Integer> vals) {
        if (node == null) return;
        collectInorder(node.left, vals);
        vals.add(node.val);
        collectInorder(node.right, vals);
    }

    private void assignInorder(TreeNode node, List<Integer> vals) {
        if (node == null) return;
        assignInorder(node.left, vals);
        node.val = vals.get(idx++);
        assignInorder(node.right, vals);
    }
}

Complexity Analysis

Time Complexity: O(n log n)

The inorder traversal takes O(n) to collect all values, and sorting the array costs O(n log n). The second traversal to reassign values is another O(n). The dominant term is the sorting step, giving O(n log n) overall.

Space Complexity: O(n)

We store all n node values in an auxiliary array. Additionally, the recursion stack for inorder traversal uses O(h) space where h is the height of the tree. In the worst case (skewed tree), h = n, but the array already uses O(n), so overall space is O(n).

Why This Approach Is Not Efficient

The brute force approach sorts the entire inorder array, costing O(n log n) time, even though only two values are out of place. This is wasteful — we do not need to sort the whole array to identify which two nodes were swapped.

Moreover, we use O(n) extra space to store the full inorder sequence. Since the BST property guarantees that inorder traversal is sorted, any violation of the sorted order directly reveals the swapped nodes.

The key insight is: if we traverse the tree in inorder and track the previously visited node, we can detect the exact two nodes that were swapped by finding where the sorted order breaks. This eliminates the need for both sorting and storing the full array, reducing time to O(n) and space to O(h) where h is the tree height.

Better Approach - Inorder Violation Detection

Intuition

Since inorder traversal of a valid BST yields sorted values, swapping two nodes creates either one or two "violations" — places where a value is greater than the value that follows it.

Picture a line of people arranged by height from shortest to tallest. If two people secretly swap positions, you can find them by walking down the line and checking each consecutive pair. If a taller person stands before a shorter one, you have found a violation.

There are two scenarios:

Case 1 — Adjacent swap (one violation): If the two swapped nodes are next to each other in the inorder sequence, there is exactly one place where the order breaks. For example, in [1, 3, 2, 4], the pair (3, 2) is the only violation. The swapped nodes are 3 and 2.

Case 2 — Non-adjacent swap (two violations): If the swapped nodes are far apart in the inorder sequence, the order breaks in two places. For example, in [1, 7, 3, 4, 5, 2, 8], the first violation is (7, 3) and the second is (5, 2). The swapped nodes are 7 (from the first violation) and 2 (from the second violation).

We use three pointers — first, middle, and last — along with a prev pointer to track the previously visited node. On the first violation, we set first = prev and middle = current. On the second violation (if any), we set last = current. Finally, if last exists, we swap first and last; otherwise, we swap first and middle.

Two diagrams showing adjacent swap (one violation) and non-adjacent swap (two violations) in BST inorder traversal
Two diagrams showing adjacent swap (one violation) and non-adjacent swap (two violations) in BST inorder traversal

Step-by-Step Explanation

Let's trace through Example 2: root = [3, 1, 4, null, null, 2]

Tree structure:

    3
   / \
  1   4
     /
    2

We perform an inorder traversal while tracking prev, first, middle, and last.

Step 1: Initialize all pointers to null: prev = null, first = null, middle = null, last = null.

Step 2: Visit node 1 (leftmost). prev is null, so no comparison. Set prev = node(1).

Step 3: Visit node 3 (root). Compare: prev.val(1) < current.val(3)? Yes, no violation. Set prev = node(3).

Step 4: Visit node 2. Compare: prev.val(3) < current.val(2)? No! 3 > 2, this is a violation.

  • This is the first violation (first is null), so set first = prev = node(3), middle = current = node(2).
  • Set prev = node(2).

Step 5: Visit node 4. Compare: prev.val(2) < current.val(4)? Yes, no violation. Set prev = node(4).

Step 6: Traversal complete. Check: is last null? Yes, so this is an adjacent swap case. Swap values of first(3) and middle(2).

  • node with value 3 becomes 2, node with value 2 becomes 3.

Result: Tree is now [2, 1, 4, null, null, 3] — a valid BST with inorder [1, 2, 3, 4].

Inorder Violation Detection — Finding Swapped Nodes — Watch how inorder traversal detects BST violations by comparing each node with the previously visited node, identifying the two swapped nodes.

Algorithm

  1. Initialize four pointers: prev, first, middle, last, all set to null
  2. Perform an inorder traversal of the BST
  3. At each node, if prev exists and prev.val > current.val:
    • If first is null (first violation): set first = prev, middle = current
    • Else (second violation): set last = current
  4. Update prev = current after each node
  5. After traversal completes:
    • If last is not null: swap values of first and last (non-adjacent case)
    • Else: swap values of first and middle (adjacent case)

Code

class Solution {
public:
    TreeNode* prev = nullptr;
    TreeNode* first = nullptr;
    TreeNode* middle = nullptr;
    TreeNode* last = nullptr;

    void recoverTree(TreeNode* root) {
        inorder(root);
        if (last != nullptr) {
            swap(first->val, last->val);
        } else {
            swap(first->val, middle->val);
        }
    }

private:
    void inorder(TreeNode* node) {
        if (node == nullptr) return;

        inorder(node->left);

        if (prev != nullptr && prev->val > node->val) {
            if (first == nullptr) {
                first = prev;
                middle = node;
            } else {
                last = node;
            }
        }
        prev = node;

        inorder(node->right);
    }
};
class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        self.prev = None
        self.first = None
        self.middle = None
        self.last = None

        def inorder(node):
            if not node:
                return

            inorder(node.left)

            if self.prev and self.prev.val > node.val:
                if self.first is None:
                    self.first = self.prev
                    self.middle = node
                else:
                    self.last = node

            self.prev = node
            inorder(node.right)

        inorder(root)

        if self.last:
            self.first.val, self.last.val = self.last.val, self.first.val
        else:
            self.first.val, self.middle.val = self.middle.val, self.first.val
class Solution {
    private TreeNode prev = null;
    private TreeNode first = null;
    private TreeNode middle = null;
    private TreeNode last = null;

    public void recoverTree(TreeNode root) {
        inorder(root);
        if (last != null) {
            int temp = first.val;
            first.val = last.val;
            last.val = temp;
        } else {
            int temp = first.val;
            first.val = middle.val;
            middle.val = temp;
        }
    }

    private void inorder(TreeNode node) {
        if (node == null) return;

        inorder(node.left);

        if (prev != null && prev.val > node.val) {
            if (first == null) {
                first = prev;
                middle = node;
            } else {
                last = node;
            }
        }
        prev = node;

        inorder(node.right);
    }
}

Complexity Analysis

Time Complexity: O(n)

We perform a single inorder traversal visiting each of the n nodes exactly once. At each node, we do O(1) work (one comparison and a few pointer updates). Total time is O(n).

Space Complexity: O(h)

The space is used by the recursion call stack, which goes as deep as the height of the tree. For a balanced BST, h = O(log n). For a skewed tree, h = O(n) in the worst case. We use only O(1) extra space for the four tracking pointers (prev, first, middle, last), but the recursion stack dominates.

Why This Approach Is Not Efficient

The inorder violation detection approach runs in O(n) time, which is optimal since we must examine every node at least once. However, it uses O(h) space due to the recursion call stack, where h is the height of the tree.

For a balanced BST with 1000 nodes, h ≈ 10, which is negligible. But in the worst case of a completely skewed tree, h = n = 1000, meaning the recursion stack uses O(n) space.

The follow-up challenge asks: can we achieve O(1) extra space? The recursion stack is the bottleneck. If we could traverse the tree in inorder without recursion and without an explicit stack, we would eliminate this space overhead.

This is exactly what Morris Inorder Traversal achieves — it temporarily modifies the tree structure by creating threaded links to traverse without any stack, using only O(1) extra space. After traversal, all temporary modifications are undone, restoring the original tree.

Optimal Approach - Morris Inorder Traversal

Intuition

Morris Traversal is a clever technique that lets us perform inorder traversal using O(1) extra space — no recursion stack, no explicit stack. The idea is to temporarily create "threads" (links) from the inorder predecessor back to the current node, so we can find our way back up the tree without a stack.

Imagine you are exploring a maze. Normally you would use a ball of string (the stack) to find your way back. Morris Traversal instead temporarily draws arrows on the walls (threads) so you can retrace your path, and erases them once you are done.

Here is how it works for each node:

  • If the current node has no left child, process it and move right.
  • If the current node has a left child, find the inorder predecessor (the rightmost node in the left subtree).
    • If the predecessor's right child is null, create a thread (set predecessor's right to current node) and move left.
    • If the predecessor's right child already points to the current node (thread exists), we have returned via the thread. Remove the thread (restore null), process the current node, and move right.

While performing this traversal, we apply the same violation detection logic: compare each processed node with the previously processed node to find the two swapped nodes.

Step-by-Step Explanation

Let's trace Morris Traversal on Example 2: root = [3, 1, 4, null, null, 2]

Tree structure:

    3
   / \
  1   4
     /
    2

Step 1: Start at curr = node(3). It has a left child (node 1).

  • Find inorder predecessor of 3 in left subtree: go to node 1, then keep going right. Node 1 has no right child, so predecessor = node(1).
  • Predecessor's right is null → create thread: node(1).right = node(3). Move curr to left: curr = node(1).

Step 2: curr = node(1). It has no left child.

  • Process node(1): prev is null, so no comparison. Set prev = node(1).
  • Move right: curr = node(1).right = node(3) (via the thread we created).

Step 3: curr = node(3). It has a left child (node 1).

  • Find inorder predecessor: go to node 1, go right → reaches node(3) itself (the thread!). Predecessor's right points to curr.
  • Thread detected → remove thread: node(1).right = null. Process node(3): compare prev.val(1) < curr.val(3)? Yes, no violation. Set prev = node(3).
  • Move right: curr = node(4).

Step 4: curr = node(4). It has a left child (node 2).

  • Find inorder predecessor of 4 in left subtree: node(2) has no right child, so predecessor = node(2).
  • Predecessor's right is null → create thread: node(2).right = node(4). Move curr to left: curr = node(2).

Step 5: curr = node(2). It has no left child.

  • Process node(2): compare prev.val(3) < curr.val(2)? No! 3 > 2 → VIOLATION!
  • First violation: set first = node(3), middle = node(2). Set prev = node(2).
  • Move right: curr = node(2).right = node(4) (via thread).

Step 6: curr = node(4). It has a left child (node 2).

  • Find predecessor: go to node 2, go right → reaches node(4) itself (thread exists).
  • Thread detected → remove thread: node(2).right = null. Process node(4): compare prev.val(2) < curr.val(4)? Yes, no violation. Set prev = node(4).
  • Move right: curr = null. Traversal complete.

Step 7: last is null → adjacent swap case. Swap values of first(3) and middle(2). BST restored.

Morris Inorder Traversal — O(1) Space BST Recovery — Watch how Morris Traversal creates temporary threads to navigate the tree without a stack, detecting the two swapped nodes using O(1) extra space.

Algorithm

  1. Initialize: curr = root, prev = null, first = null, middle = null, last = null
  2. While curr is not null:
    • If curr has no left child:
      • Process curr: if prev exists and prev.val > curr.val, handle violation
      • Set prev = curr, move curr = curr.right
    • If curr has a left child:
      • Find the inorder predecessor (rightmost node in left subtree)
      • If predecessor's right is null: Create thread (predecessor.right = curr), move curr = curr.left
      • If predecessor's right is curr: Remove thread (predecessor.right = null), process curr (check violation), set prev = curr, move curr = curr.right
  3. After traversal: if last exists, swap first.val and last.val; else swap first.val and middle.val

Code

class Solution {
public:
    void recoverTree(TreeNode* root) {
        TreeNode* curr = root;
        TreeNode* prev = nullptr;
        TreeNode* first = nullptr;
        TreeNode* middle = nullptr;
        TreeNode* last = nullptr;

        while (curr != nullptr) {
            if (curr->left == nullptr) {
                // No left child — process and go right
                if (prev != nullptr && prev->val > curr->val) {
                    if (first == nullptr) {
                        first = prev;
                        middle = curr;
                    } else {
                        last = curr;
                    }
                }
                prev = curr;
                curr = curr->right;
            } else {
                // Find inorder predecessor
                TreeNode* predecessor = curr->left;
                while (predecessor->right != nullptr && predecessor->right != curr) {
                    predecessor = predecessor->right;
                }

                if (predecessor->right == nullptr) {
                    // Create thread and move left
                    predecessor->right = curr;
                    curr = curr->left;
                } else {
                    // Thread exists — remove it and process
                    predecessor->right = nullptr;
                    if (prev != nullptr && prev->val > curr->val) {
                        if (first == nullptr) {
                            first = prev;
                            middle = curr;
                        } else {
                            last = curr;
                        }
                    }
                    prev = curr;
                    curr = curr->right;
                }
            }
        }

        if (last != nullptr) {
            swap(first->val, last->val);
        } else {
            swap(first->val, middle->val);
        }
    }
};
class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        curr = root
        prev = None
        first = None
        middle = None
        last = None

        while curr:
            if curr.left is None:
                # No left child — process and go right
                if prev and prev.val > curr.val:
                    if first is None:
                        first = prev
                        middle = curr
                    else:
                        last = curr
                prev = curr
                curr = curr.right
            else:
                # Find inorder predecessor
                predecessor = curr.left
                while predecessor.right and predecessor.right != curr:
                    predecessor = predecessor.right

                if predecessor.right is None:
                    # Create thread and move left
                    predecessor.right = curr
                    curr = curr.left
                else:
                    # Thread exists — remove it and process
                    predecessor.right = None
                    if prev and prev.val > curr.val:
                        if first is None:
                            first = prev
                            middle = curr
                        else:
                            last = curr
                    prev = curr
                    curr = curr.right

        if last:
            first.val, last.val = last.val, first.val
        else:
            first.val, middle.val = middle.val, first.val
class Solution {
    public void recoverTree(TreeNode root) {
        TreeNode curr = root;
        TreeNode prev = null;
        TreeNode first = null;
        TreeNode middle = null;
        TreeNode last = null;

        while (curr != null) {
            if (curr.left == null) {
                // No left child — process and go right
                if (prev != null && prev.val > curr.val) {
                    if (first == null) {
                        first = prev;
                        middle = curr;
                    } else {
                        last = curr;
                    }
                }
                prev = curr;
                curr = curr.right;
            } else {
                // Find inorder predecessor
                TreeNode predecessor = curr.left;
                while (predecessor.right != null && predecessor.right != curr) {
                    predecessor = predecessor.right;
                }

                if (predecessor.right == null) {
                    // Create thread and move left
                    predecessor.right = curr;
                    curr = curr.left;
                } else {
                    // Thread exists — remove it and process
                    predecessor.right = null;
                    if (prev != null && prev.val > curr.val) {
                        if (first == null) {
                            first = prev;
                            middle = curr;
                        } else {
                            last = curr;
                        }
                    }
                    prev = curr;
                    curr = curr.right;
                }
            }
        }

        if (last != null) {
            int temp = first.val;
            first.val = last.val;
            last.val = temp;
        } else {
            int temp = first.val;
            first.val = middle.val;
            middle.val = temp;
        }
    }
}

Complexity Analysis

Time Complexity: O(n)

Although it may seem like Morris Traversal does more work because of predecessor finding, each edge in the tree is traversed at most 3 times (once normally, once to create a thread, once to remove it). Since a binary tree with n nodes has at most n - 1 edges, the total work is bounded by O(3n) = O(n).

Space Complexity: O(1)

This is the key advantage of Morris Traversal. We use no recursion stack and no explicit stack. The only extra variables are a fixed number of pointers (curr, prev, first, middle, last, predecessor) — all O(1) space regardless of input size. The temporary threads modify the tree in-place and are fully restored before the traversal ends.