Skip to main content

All Nodes Distance K in Binary Tree

MEDIUMProblemSolveExternal Links

Description

Given the root of a binary tree, the value of a target node, and an integer k, return an array of the values of all nodes that have a distance k from the target node.

The distance between two nodes in a binary tree is the number of edges on the path connecting them. A node at distance k from the target could be:

  • A descendant of the target (reachable by going downward through children)
  • An ancestor of the target (reachable by going upward through the parent)
  • In a completely different branch (reachable by going up to an ancestor, then down a different subtree)

You can return the answer in any order.

A binary tree with target node highlighted, showing nodes at distance k in all three directions: downward in subtree, upward to ancestors, and across to other branches
A binary tree with target node highlighted, showing nodes at distance k in all three directions: downward in subtree, upward to ancestors, and across to other branches

Examples

Example 1

Input: root = [3, 5, 1, 6, 2, 0, 8, null, null, 7, 4], target = 5, k = 2

Tree structure:

          3
         / \
        5    1
       / \  / \
      6   2 0   8
         / \
        7   4

Output: [7, 4, 1]

Explanation: The nodes at distance 2 from target node 5 are:

  • Node 7: target (5) → right child (2) → left child (7). Two edges downward.
  • Node 4: target (5) → right child (2) → right child (4). Two edges downward.
  • Node 1: target (5) → parent (3) → right child (1). One edge up, then one edge down into the sibling branch.

All three nodes are exactly 2 edges away from node 5.

Example 2

Input: root = [1], target = 1, k = 3

Tree structure:

  1

Output: []

Explanation: The tree has only one node. There are no nodes at distance 3 from it. When k exceeds the number of reachable nodes, the result is empty.

Constraints

  • The number of nodes in the tree is in the range [1, 500]
  • 0 ≤ Node.val ≤ 500
  • All the values Node.val are unique
  • target is the value of one of the nodes in the tree
  • 0 ≤ k ≤ 1000

Editorial

Brute Force

Intuition

The most straightforward way to find all nodes at distance k from the target is: for every node in the tree, compute its distance to the target, and check if that distance equals k.

How do we find the distance between two nodes in a binary tree? We use their Lowest Common Ancestor (LCA). The LCA of two nodes a and b is the deepest node that is an ancestor of both. Once we find the LCA, the distance is:

distance(a,b)=depth(a,LCA)+depth(b,LCA)\text{distance}(a, b) = \text{depth}(a, \text{LCA}) + \text{depth}(b, \text{LCA})

where depth(node, LCA) is the number of edges from the LCA down to that node.

The algorithm is:

  1. Traverse every node v in the tree.
  2. For each v, find the LCA of v and target.
  3. Compute the distance from v to target using the LCA.
  4. If the distance equals k, add v to the result.

Finding the LCA of two nodes takes O(N) in the worst case (traversing the tree), and we do this for each of the N nodes. So the total time is O(N²).

This approach is conceptually simple — it reduces the problem to repeated distance queries — but it performs a lot of redundant work because each LCA computation re-traverses much of the tree.

Algorithm

  1. Define a helper function findLCA(root, p, q) that returns the LCA of nodes p and q.
  2. Define a helper function findDepth(root, target, depth) that returns the depth (distance from root) of the target node in the subtree rooted at root. Returns -1 if target is not found.
  3. For each node v in the tree:
    a. Find lca = findLCA(root, target, v).
    b. Find d1 = findDepth(lca, target, 0) (distance from LCA to target).
    c. Find d2 = findDepth(lca, v, 0) (distance from LCA to v).
    d. If d1 + d2 == k, add v.val to the result.
  4. Return the result list.

Code

class Solution {
public:
    // Find LCA of nodes p and q
    TreeNode* findLCA(TreeNode* root, TreeNode* p, TreeNode* q) {
        if (!root || root == p || root == q) return root;
        TreeNode* left = findLCA(root->left, p, q);
        TreeNode* right = findLCA(root->right, p, q);
        if (left && right) return root;
        return left ? left : right;
    }
    
    // Find distance from root to target node
    int findDist(TreeNode* root, TreeNode* target, int depth) {
        if (!root) return -1;
        if (root == target) return depth;
        int left = findDist(root->left, target, depth + 1);
        if (left != -1) return left;
        return findDist(root->right, target, depth + 1);
    }
    
    // Collect all nodes into a list
    void collectNodes(TreeNode* root, vector<TreeNode*>& nodes) {
        if (!root) return;
        nodes.push_back(root);
        collectNodes(root->left, nodes);
        collectNodes(root->right, nodes);
    }
    
    vector<int> distanceK(TreeNode* root, TreeNode* target, int k) {
        vector<int> result;
        vector<TreeNode*> allNodes;
        collectNodes(root, allNodes);
        
        for (TreeNode* node : allNodes) {
            TreeNode* lca = findLCA(root, target, node);
            int d1 = findDist(lca, target, 0);
            int d2 = findDist(lca, node, 0);
            if (d1 + d2 == k) {
                result.push_back(node->val);
            }
        }
        return result;
    }
};
class Solution:
    def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> list[int]:
        
        def find_lca(node, p, q):
            """Find lowest common ancestor of p and q."""
            if not node or node == p or node == q:
                return node
            left = find_lca(node.left, p, q)
            right = find_lca(node.right, p, q)
            if left and right:
                return node
            return left if left else right
        
        def find_dist(node, target, depth):
            """Find distance from node to target."""
            if not node:
                return -1
            if node == target:
                return depth
            left = find_dist(node.left, target, depth + 1)
            if left != -1:
                return left
            return find_dist(node.right, target, depth + 1)
        
        def collect_nodes(node, nodes):
            """Collect all nodes in the tree."""
            if not node:
                return
            nodes.append(node)
            collect_nodes(node.left, nodes)
            collect_nodes(node.right, nodes)
        
        all_nodes = []
        collect_nodes(root, all_nodes)
        
        result = []
        for node in all_nodes:
            lca = find_lca(root, target, node)
            d1 = find_dist(lca, target, 0)
            d2 = find_dist(lca, node, 0)
            if d1 + d2 == k:
                result.append(node.val)
        
        return result
class Solution {
    private TreeNode findLCA(TreeNode root, TreeNode p, TreeNode q) {
        if (root == null || root == p || root == q) return root;
        TreeNode left = findLCA(root.left, p, q);
        TreeNode right = findLCA(root.right, p, q);
        if (left != null && right != null) return root;
        return left != null ? left : right;
    }
    
    private int findDist(TreeNode root, TreeNode target, int depth) {
        if (root == null) return -1;
        if (root == target) return depth;
        int left = findDist(root.left, target, depth + 1);
        if (left != -1) return left;
        return findDist(root.right, target, depth + 1);
    }
    
    private void collectNodes(TreeNode root, List<TreeNode> nodes) {
        if (root == null) return;
        nodes.add(root);
        collectNodes(root.left, nodes);
        collectNodes(root.right, nodes);
    }
    
    public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
        List<TreeNode> allNodes = new ArrayList<>();
        collectNodes(root, allNodes);
        
        List<Integer> result = new ArrayList<>();
        for (TreeNode node : allNodes) {
            TreeNode lca = findLCA(root, target, node);
            int d1 = findDist(lca, target, 0);
            int d2 = findDist(lca, node, 0);
            if (d1 + d2 == k) {
                result.add(node.val);
            }
        }
        return result;
    }
}

Complexity Analysis

Time Complexity: O(N²)

For each of the N nodes, we call findLCA which takes O(N) in the worst case, and findDist which also takes O(N). So each node requires O(N) work, and we process N nodes. Total: O(N) × O(N) = O(N²).

For the given constraint of N ≤ 500, this means up to 250,000 operations — which is acceptable but wasteful.

Space Complexity: O(N)

The recursion stack for LCA and distance computations can go up to O(H) where H is the height. The allNodes list stores N nodes. Total: O(N).

Why This Approach Is Not Efficient

The brute force approach performs redundant tree traversals. For every single node, we re-traverse the tree to find the LCA and compute distances. This is wasteful because:

  1. Repeated LCA computations: The LCA of target and a node deep in the left subtree is the same for many nodes in that subtree. We recompute it from scratch each time.

  2. No structural insight: We treat each node independently, ignoring the fact that tree distance has a clear structure — nodes at distance k form a "ring" around the target.

  3. The core problem: In a binary tree, we can only traverse downward (parent → children). But nodes at distance k from the target can be in any direction — including upward and across branches. The brute force handles this by going through the root every time, but there is a much cleaner approach.

Key insight: If we could traverse the tree in all directions — not just downward — we could simply start at the target and explore outward for exactly k steps, like BFS on a graph. The only thing stopping us is the lack of parent pointers.

What if we built those parent pointers ourselves? One DFS to map every node to its parent, then the tree becomes an undirected graph. From there, a simple BFS from the target — expanding to left child, right child, and parent — finds all nodes at distance k in O(N) time.

Optimal Approach - Parent Map + BFS from Target

Intuition

The fundamental challenge of this problem is that a binary tree only supports downward traversal (from parent to children). But nodes at distance k from the target can be in any direction — down into the target's subtree, up to ancestors, or across into sibling branches.

The elegant solution: convert the tree into an undirected graph by adding parent pointers, then perform a standard BFS from the target.

Phase 1: Build the Parent Map

Traverse the entire tree using DFS. For each node, store a mapping: parent[node] = node's parent. This one-time traversal gives every node the ability to "look up" at its parent.

Phase 2: BFS from Target

Now treat the tree as an undirected graph. From any node, we can move in three directions:

  • Left child (downward-left)
  • Right child (downward-right)
  • Parent (upward)

Start a BFS from the target node. At each step, expand to all three neighbors. Use a visited set to avoid revisiting nodes (which would cause infinite loops, since the graph is now undirected). After exactly k levels of BFS, all nodes in the queue are at distance k from the target.

This approach is clean, modular, and easy to reason about:

  • Phase 1 runs in O(N) time
  • Phase 2 runs in O(N) time
  • Total: O(N) time, O(N) space

Step-by-Step Explanation

Let us trace through Example 1: root = [3, 5, 1, 6, 2, 0, 8, null, null, 7, 4], target = 5, k = 2

          3
         / \
        5    1
       / \  / \
      6   2 0   8
         / \
        7   4

Phase 1: Build Parent Map (DFS)

Step 1: Start DFS at root (node 3). It has no parent, so parent[3] = null.

Step 2: Visit node 5 (left child of 3). parent[5] = 3.

Step 3: Visit node 6 (left child of 5). parent[6] = 5. It's a leaf.

Step 4: Visit node 2 (right child of 5). parent[2] = 5.

Step 5: Visit node 7 (left child of 2). parent[7] = 2. Leaf.

Step 6: Visit node 4 (right child of 2). parent[4] = 2. Leaf.

Step 7: Back to root, visit node 1 (right child of 3). parent[1] = 3.

Step 8: Visit node 0 (left child of 1). parent[0] = 1. Leaf.

Step 9: Visit node 8 (right child of 1). parent[8] = 1. Leaf.

Parent map complete: {3: null, 5: 3, 6: 5, 2: 5, 7: 2, 4: 2, 1: 3, 0: 1, 8: 1}

Phase 2: BFS from Target (node 5), k = 2

Step 10: Initialize BFS. Queue = [5]. Visited = {5}. Distance = 0.

Step 11: Distance 0 → 1. Process node 5. Its three neighbors are: left child 6, right child 2, parent 3. None are visited. Enqueue all three. Queue = [6, 2, 3]. Visited = {5, 6, 2, 3}. Distance = 1.

Step 12: Distance 1 → 2. Process the entire level [6, 2, 3].

  • Node 6: neighbors are left=null, right=null, parent=5. Parent 5 is visited. No new nodes from 6.
  • Node 2: neighbors are left=7, right=4, parent=5. Parent 5 is visited. Enqueue 7 and 4. Visited adds {7, 4}.
  • Node 3: neighbors are left=5, right=1, parent=null. Left 5 is visited. Enqueue 1. Visited adds {1}.

Queue = [7, 4, 1]. Distance = 2.

Step 13: Distance equals k = 2. All nodes currently in the queue are at distance 2 from target. Collect their values: [7, 4, 1].

Result: [7, 4, 1].

Parent Map + BFS from Target — Expanding in All Three Directions — Watch Phase 1 build parent pointers for every node, then Phase 2 run BFS from the target node, expanding outward through left child, right child, AND parent, until we reach distance k.

Algorithm

Phase 1: Build Parent Map

  1. Create an empty hash map parent.
  2. Run DFS from root:
    • For each node, set parent[node] = node's parent.
    • Root's parent is null.

Phase 2: BFS from Target
3. Create a queue, add the target node. Create a visited set, add the target.
4. Set distance = 0.
5. While the queue is not empty:
a. If distance == k, return all values in the queue. These are the answer.
b. Process all nodes at the current level:
- For each node, examine its three neighbors: left child, right child, parent (from the map).
- If a neighbor exists and is not visited, mark it as visited and enqueue it.
c. Increment distance.
6. If BFS completes without reaching distance k, return an empty list.

Code

#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <queue>
using namespace std;

class Solution {
public:
    vector<int> distanceK(TreeNode* root, TreeNode* target, int k) {
        // Phase 1: Build parent map using DFS
        unordered_map<TreeNode*, TreeNode*> parent;
        buildParentMap(root, nullptr, parent);
        
        // Phase 2: BFS from target
        queue<TreeNode*> q;
        unordered_set<TreeNode*> visited;
        q.push(target);
        visited.insert(target);
        
        int distance = 0;
        while (!q.empty()) {
            if (distance == k) {
                vector<int> result;
                while (!q.empty()) {
                    result.push_back(q.front()->val);
                    q.pop();
                }
                return result;
            }
            
            int size = q.size();
            for (int i = 0; i < size; i++) {
                TreeNode* node = q.front();
                q.pop();
                
                // Explore three directions
                for (TreeNode* neighbor : {node->left, node->right, parent[node]}) {
                    if (neighbor && visited.find(neighbor) == visited.end()) {
                        visited.insert(neighbor);
                        q.push(neighbor);
                    }
                }
            }
            distance++;
        }
        return {};
    }
    
private:
    void buildParentMap(TreeNode* node, TreeNode* par, 
                        unordered_map<TreeNode*, TreeNode*>& parent) {
        if (!node) return;
        parent[node] = par;
        buildParentMap(node->left, node, parent);
        buildParentMap(node->right, node, parent);
    }
};
from collections import deque

class Solution:
    def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> list[int]:
        # Phase 1: Build parent map using DFS
        parent = {}
        
        def build_parent_map(node, par):
            if not node:
                return
            parent[node] = par
            build_parent_map(node.left, node)
            build_parent_map(node.right, node)
        
        build_parent_map(root, None)
        
        # Phase 2: BFS from target
        queue = deque([target])
        visited = {target}
        distance = 0
        
        while queue:
            if distance == k:
                return [node.val for node in queue]
            
            for _ in range(len(queue)):
                node = queue.popleft()
                
                # Explore three directions: left, right, parent
                for neighbor in (node.left, node.right, parent.get(node)):
                    if neighbor and neighbor not in visited:
                        visited.add(neighbor)
                        queue.append(neighbor)
            
            distance += 1
        
        return []
import java.util.*;

class Solution {
    public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
        // Phase 1: Build parent map
        Map<TreeNode, TreeNode> parent = new HashMap<>();
        buildParentMap(root, null, parent);
        
        // Phase 2: BFS from target
        Queue<TreeNode> queue = new LinkedList<>();
        Set<TreeNode> visited = new HashSet<>();
        queue.offer(target);
        visited.add(target);
        
        int distance = 0;
        while (!queue.isEmpty()) {
            if (distance == k) {
                List<Integer> result = new ArrayList<>();
                for (TreeNode node : queue) {
                    result.add(node.val);
                }
                return result;
            }
            
            int size = queue.size();
            for (int i = 0; i < size; i++) {
                TreeNode node = queue.poll();
                
                // Explore three directions
                TreeNode[] neighbors = {node.left, node.right, parent.get(node)};
                for (TreeNode neighbor : neighbors) {
                    if (neighbor != null && !visited.contains(neighbor)) {
                        visited.add(neighbor);
                        queue.offer(neighbor);
                    }
                }
            }
            distance++;
        }
        return new ArrayList<>();
    }
    
    private void buildParentMap(TreeNode node, TreeNode par, Map<TreeNode, TreeNode> parent) {
        if (node == null) return;
        parent.put(node, par);
        buildParentMap(node.left, node, parent);
        buildParentMap(node.right, node, parent);
    }
}

Complexity Analysis

Time Complexity: O(N)

Phase 1 (build parent map): We visit each of the N nodes exactly once via DFS. Each visit does O(1) work (one hash map insertion). Total: O(N).

Phase 2 (BFS from target): In the worst case, we visit all N nodes before reaching distance k (if k is large). Each node is dequeued and processed at most once due to the visited set. For each node, we check 3 neighbors in O(1) time. Total: O(N).

Overall: O(N) + O(N) = O(N).

Space Complexity: O(N)

  • Parent map: stores one entry per node → O(N).
  • Visited set: stores up to N entries → O(N).
  • BFS queue: stores at most one level of the "radius-k sphere" around target → O(N) in worst case.
  • DFS recursion stack for Phase 1: O(H) where H is tree height.

Total: O(N).

Why this is optimal:

  • Every node must be examined at least once to build parent pointers → Ω(N) time is a lower bound.
  • We need O(N) space for the parent map (there is no way to traverse upward without it).
  • The BFS guarantees that we stop as soon as we reach distance k, avoiding unnecessary exploration.
  • The visited set prevents infinite cycles that would otherwise occur in the undirected graph.