Skip to main content

Kth Smallest Element in a BST

MEDIUMProblemSolveExternal Links

Description

Given the root of a binary search tree (BST) and a positive integer k, return the kth smallest value among all the node values in the tree. The value of k is 1-indexed, meaning k=1 asks for the smallest element, k=2 asks for the second smallest, and so on.

A binary search tree is a binary tree where for every node, all values in its left subtree are strictly less than the node's value, and all values in its right subtree are strictly greater. This ordering property is what makes it possible to efficiently find the kth smallest element.

You are guaranteed that k is valid — it is always between 1 and the total number of nodes in the tree, inclusive.

Examples

Example 1

Input: root = [3, 1, 4, null, 2], k = 1

Output: 1

Explanation: The BST looks like:

    3
   / \
  1   4
   \
    2

The inorder traversal of this BST produces the sorted sequence [1, 2, 3, 4]. The 1st smallest element is 1.

Example 2

Input: root = [5, 3, 6, 2, 4, null, null, 1], k = 3

Output: 3

Explanation: The BST looks like:

        5
       / \
      3   6
     / \
    2   4
   /
  1

The inorder traversal produces [1, 2, 3, 4, 5, 6]. The 3rd smallest element is 3.

Example 3

Input: root = [2, 1, 3], k = 2

Output: 2

Explanation: The BST has three nodes with values 1, 2, 3. The sorted order is [1, 2, 3], and the 2nd smallest element is 2.

Constraints

  • The number of nodes in the tree is n
  • 1 ≤ k ≤ n ≤ 10^4
  • 0 ≤ Node.val ≤ 10^4

Editorial

Brute Force

Intuition

The simplest way to find the kth smallest element is to collect all the values in the tree, sort them, and pick the kth one. Since we are dealing with a BST, the inorder traversal (left → root → right) already visits nodes in ascending order — but for the brute force, let us not rely on this property yet.

Imagine you have a bag of numbered balls scattered across different shelves (tree nodes). The brute force approach is: gather every ball into a single pile, sort the pile from smallest to largest, and then count to the kth ball. It is straightforward but wasteful because you collect and sort everything even though you only need one specific element.

We perform a complete traversal of the tree (any order — preorder, postorder, or inorder), store all values in an array, sort the array, and return the element at index k-1 (since arrays are 0-indexed).

Step-by-Step Explanation

Let's trace with root = [5, 3, 6, 2, 4, null, null, 1], k = 3:

        5
       / \
      3   6
     / \
    2   4
   /
  1

Step 1: Start a traversal to collect all node values. We will use a simple preorder traversal (root → left → right).

Step 2: Visit node 5. Add 5 to the list. List: [5]

Step 3: Visit node 3. Add 3 to the list. List: [5, 3]

Step 4: Visit node 2. Add 2 to the list. List: [5, 3, 2]

Step 5: Visit node 1. Add 1 to the list. List: [5, 3, 2, 1]

Step 6: Node 1 has no children. Backtrack. Visit node 4. Add 4 to the list. List: [5, 3, 2, 1, 4]

Step 7: Visit node 6. Add 6 to the list. List: [5, 3, 2, 1, 4, 6]

Step 8: All nodes collected. Sort the list: [1, 2, 3, 4, 5, 6]

Step 9: k = 3, so return the element at index 2 (0-indexed): the value 3.

Result: 3

Brute Force — Collect All Values, Sort, Pick kth — Watch how we traverse the entire tree to collect all values, then sort them to find the kth smallest. This approach ignores the BST ordering property.

Algorithm

  1. Traverse the entire tree (any traversal order) and collect all node values into an array
  2. Sort the array in ascending order
  3. Return the element at index k-1 (converting from 1-indexed to 0-indexed)

Code

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
public:
    void collect(TreeNode* root, vector<int>& vals) {
        if (root == nullptr) return;
        vals.push_back(root->val);
        collect(root->left, vals);
        collect(root->right, vals);
    }
    
    int kthSmallest(TreeNode* root, int k) {
        vector<int> vals;
        collect(root, vals);
        sort(vals.begin(), vals.end());
        return vals[k - 1];
    }
};
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        vals = []
        
        def collect(node):
            if node is None:
                return
            vals.append(node.val)
            collect(node.left)
            collect(node.right)
        
        collect(root)
        vals.sort()
        return vals[k - 1]
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        List<Integer> vals = new ArrayList<>();
        collect(root, vals);
        Collections.sort(vals);
        return vals.get(k - 1);
    }
    
    private void collect(TreeNode root, List<Integer> vals) {
        if (root == null) return;
        vals.add(root.val);
        collect(root.left, vals);
        collect(root.right, vals);
    }
}

Complexity Analysis

Time Complexity: O(n log n)

Traversing the tree takes O(n) time since we visit each of the n nodes once. Sorting the collected values takes O(n log n). The overall time complexity is dominated by the sorting step, giving us O(n log n).

Space Complexity: O(n)

We store all n node values in an array, which requires O(n) space. Additionally, the recursion stack uses O(h) space where h is the tree height. The total space is O(n).

Why This Approach Is Not Efficient

The brute force approach completely ignores the binary search tree property. In a BST, the inorder traversal (left → root → right) already produces values in sorted ascending order. By collecting all values and then sorting, we are doing redundant work — the tree already encodes the sorted order in its structure.

With n up to 10^4, the O(n log n) sort is fast enough, but we can do better. Since inorder traversal gives us sorted values for free, we just need to perform an inorder traversal and count to the kth element. This eliminates the sorting step entirely.

Furthermore, we do not need to traverse the entire tree. Once we have found the kth element during inorder traversal, we can stop immediately. This means for small values of k, we might visit far fewer than n nodes.

Better Approach - Inorder Traversal (Full)

Intuition

The fundamental property of a BST is that an inorder traversal visits nodes in ascending order of their values. This means if we perform an inorder traversal and store the results, we get a sorted array without needing to sort.

Think of a BST like a library where books are organized by number. The leftmost shelf has the smallest numbers, and as you walk right, the numbers increase. An inorder traversal is like walking through the library in order — you naturally encounter books from smallest to largest. To find the kth smallest book, just walk and count until you reach k.

In this approach, we perform a complete inorder traversal, store all values in an array, and return the (k-1)th element. While this avoids sorting, it still traverses the entire tree even if k is small.

Step-by-Step Explanation

Let's trace with root = [5, 3, 6, 2, 4, null, null, 1], k = 3:

        5
       / \
      3   6
     / \
    2   4
   /
  1

Step 1: Start inorder traversal. Go left from 5 → 3 → 2 → 1.

Step 2: Node 1 has no left child. Visit 1. Inorder list: [1]. No right child. Backtrack.

Step 3: Back at node 2. Visit 2. Inorder list: [1, 2]. No right child. Backtrack.

Step 4: Back at node 3. Visit 3. Inorder list: [1, 2, 3]. Now go right.

Step 5: At node 4. No left child. Visit 4. Inorder list: [1, 2, 3, 4]. No right child. Backtrack.

Step 6: Back at root 5. Visit 5. Inorder list: [1, 2, 3, 4, 5]. Now go right.

Step 7: At node 6. No left child. Visit 6. Inorder list: [1, 2, 3, 4, 5, 6]. No right child.

Step 8: Traversal complete. The inorder list is [1, 2, 3, 4, 5, 6]. Return element at index k-1 = 2, which is 3.

Result: 3

Inorder Traversal — BST Gives Sorted Order for Free — Watch how inorder traversal (left → node → right) visits BST nodes in ascending order. We collect all values into a sorted list without needing to sort.

Algorithm

  1. Perform an inorder traversal of the BST
  2. During traversal, collect every node's value into an array
  3. The array will be in sorted ascending order (BST property)
  4. Return the element at index k-1

Code

class Solution {
public:
    void inorder(TreeNode* root, vector<int>& vals) {
        if (root == nullptr) return;
        inorder(root->left, vals);
        vals.push_back(root->val);
        inorder(root->right, vals);
    }
    
    int kthSmallest(TreeNode* root, int k) {
        vector<int> vals;
        inorder(root, vals);
        return vals[k - 1];
    }
};
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        vals = []
        
        def inorder(node):
            if node is None:
                return
            inorder(node.left)
            vals.append(node.val)
            inorder(node.right)
        
        inorder(root)
        return vals[k - 1]
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        List<Integer> vals = new ArrayList<>();
        inorder(root, vals);
        return vals.get(k - 1);
    }
    
    private void inorder(TreeNode root, List<Integer> vals) {
        if (root == null) return;
        inorder(root.left, vals);
        vals.add(root.val);
        inorder(root.right, vals);
    }
}

Complexity Analysis

Time Complexity: O(n)

We visit every node exactly once during the inorder traversal. Each visit involves O(1) work (appending to a list). No sorting step is needed since inorder traversal of a BST naturally produces sorted output. Total time: O(n).

Space Complexity: O(n)

We store all n values in an array, requiring O(n) space. The recursion stack also uses O(h) space where h is the height of the tree. The total is O(n).

Why This Approach Is Not Efficient

While this approach eliminates the sorting step by leveraging the BST property, it still has two inefficiencies:

  1. It always traverses the entire tree. Even if k = 1 (we only need the smallest element), we visit all n nodes. For a tree with 10,000 nodes and k = 1, we do 9,999 unnecessary visits.

  2. It stores all values in an array. We use O(n) extra space for the array, even though we only need one value from it.

The key insight for optimization: if we perform the inorder traversal but count nodes as we visit them, we can stop as soon as we reach the kth node. This way, we visit at most k nodes (plus the path to reach the leftmost node), and we do not need to store the entire sorted list — just a counter.

Optimal Approach - Iterative Inorder with Early Termination

Intuition

We can perform an inorder traversal iteratively using an explicit stack, and maintain a counter that increments each time we "visit" a node (the moment we process it in the left → node → right order). The instant the counter reaches k, we return that node's value and stop — no need to continue traversing.

Think of it like counting runners in a race. You stand at the finish line and count: "first place, second place, third place..." Once you count the kth finisher, you announce the winner and leave. You do not need to wait for every runner to finish.

The iterative approach uses a stack to simulate the recursion. We push nodes as we go left, then pop and process them (incrementing the counter), then move to the right child. This mirrors the recursive inorder traversal but gives us the ability to stop mid-traversal.

Step-by-Step Explanation

Let's trace with root = [5, 3, 6, 2, 4, null, null, 1], k = 3:

        5
       / \
      3   6
     / \
    2   4
   /
  1

Step 1: Initialize empty stack and set curr = root (5). Count = 0.

Step 2: curr = 5 is not null. Push 5 onto stack. Move left: curr = 3. Stack: [5]

Step 3: curr = 3 is not null. Push 3 onto stack. Move left: curr = 2. Stack: [5, 3]

Step 4: curr = 2 is not null. Push 2 onto stack. Move left: curr = 1. Stack: [5, 3, 2]

Step 5: curr = 1 is not null. Push 1 onto stack. Move left: curr = null. Stack: [5, 3, 2, 1]

Step 6: curr is null. Pop from stack: node 1. Count = 1. Is count == k (3)? No. Move right: curr = null (node 1 has no right child). Stack: [5, 3, 2]

Step 7: curr is null. Pop from stack: node 2. Count = 2. Is count == k (3)? No. Move right: curr = null (node 2 has no right child). Stack: [5, 3]

Step 8: curr is null. Pop from stack: node 3. Count = 3. Is count == k (3)? YES! Return 3 immediately.

Result: 3. We only visited 3 nodes out of 6 — we stopped as soon as we found the answer.

Iterative Inorder with Early Stop — Find kth Without Full Traversal — Watch how we use an explicit stack to perform inorder traversal, counting each visited node. The moment count equals k, we stop and return — no wasted work.

Algorithm

  1. Initialize an empty stack and set curr = root, count = 0
  2. While curr is not null or the stack is not empty:
    a. While curr is not null: push curr onto the stack and move left (curr = curr.left)
    b. Pop the top node from the stack
    c. Increment count by 1
    d. If count equals k, return the popped node's value
    e. Move right: set curr = popped node's right child
  3. Return -1 (should never reach here if k is valid)

Code

class Solution {
public:
    int kthSmallest(TreeNode* root, int k) {
        stack<TreeNode*> stk;
        TreeNode* curr = root;
        int count = 0;
        
        while (curr != nullptr || !stk.empty()) {
            while (curr != nullptr) {
                stk.push(curr);
                curr = curr->left;
            }
            
            curr = stk.top();
            stk.pop();
            count++;
            
            if (count == k) {
                return curr->val;
            }
            
            curr = curr->right;
        }
        
        return -1;
    }
};
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        stack = []
        curr = root
        count = 0
        
        while curr is not None or stack:
            while curr is not None:
                stack.append(curr)
                curr = curr.left
            
            curr = stack.pop()
            count += 1
            
            if count == k:
                return curr.val
            
            curr = curr.right
        
        return -1
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        Stack<TreeNode> stack = new Stack<>();
        TreeNode curr = root;
        int count = 0;
        
        while (curr != null || !stack.isEmpty()) {
            while (curr != null) {
                stack.push(curr);
                curr = curr.left;
            }
            
            curr = stack.pop();
            count++;
            
            if (count == k) {
                return curr.val;
            }
            
            curr = curr.right;
        }
        
        return -1;
    }
}

Complexity Analysis

Time Complexity: O(H + k) where H is the height of the tree

We first descend to the leftmost node, which takes O(H) steps. Then we visit k nodes in inorder before stopping. The total is O(H + k). In the best case (balanced tree, small k), this is O(log n + k). In the worst case (k = n), this degenerates to O(n), which is no worse than the previous approach.

Space Complexity: O(H)

The stack holds at most H nodes at any time (the path from root to the current node). For a balanced tree, H = O(log n). For a skewed tree, H = O(n). Unlike the previous approach, we do NOT store all n values — we only keep the traversal stack, saving significant memory.