Skip to main content

Construct Quad Tree

MEDIUMProblemSolveExternal Links

Description

Given an n × n matrix grid containing only 0s and 1s, represent the grid using a Quad-Tree and return the root of that tree.

A Quad-Tree is a tree data structure where each internal node has exactly four children: topLeft, topRight, bottomLeft, and bottomRight. Every node has two attributes:

  • val: True if the node represents a region of all 1s, False if the node represents a region of all 0s. When isLeaf is False, val can be set to any value.
  • isLeaf: True if the node is a leaf (the entire region it covers has the same value), False if the node has four children (the region contains a mix of 0s and 1s).

The construction follows these rules:

  1. If the current grid region has all the same values (all 0s or all 1s), create a leaf node with isLeaf = True and val equal to that uniform value.
  2. If the current grid region has mixed values, create an internal node with isLeaf = False, divide the region into four equal quadrants, and recursively build a child node for each quadrant.

The grid size n is always a power of 2, which guarantees that the grid can be evenly subdivided at each level until individual cells are reached.

A 4x4 grid being recursively subdivided into quadrants to form a Quad-Tree structure
A 4x4 grid being recursively subdivided into quadrants to form a Quad-Tree structure

Examples

Example 1

Input: grid = [[0, 1], [1, 0]]

Output: [[0, 1], [1, 0], [1, 1], [1, 1], [1, 0]]

Explanation: The 2×2 grid has mixed values (contains both 0s and 1s), so the root is not a leaf. We divide it into four 1×1 quadrants:

  • topLeft = grid[0][0] = 0 → leaf node with val=False
  • topRight = grid[0][1] = 1 → leaf node with val=True
  • bottomLeft = grid[1][0] = 1 → leaf node with val=True
  • bottomRight = grid[1][1] = 0 → leaf node with val=False

The root node has isLeaf=False, val=True (arbitrary), and four leaf children. In the serialized output, each node is represented as [isLeaf, val].

Example 2

Input: grid = [[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0],[1,1,1,1,1,1,1,1],[1,1,1,1,1,1,1,1],[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0],[1,1,1,1,0,0,0,0]]

Output: [[0,1],[1,1],[0,1],[1,1],[1,0],null,null,null,null,[1,0],[1,0],[1,1],[1,1]]

Explanation: The 8×8 grid has mixed values, so the root is not a leaf. Dividing into four 4×4 quadrants:

  • topLeft (rows 0-3, cols 0-3): All 1s → leaf with val=True
  • topRight (rows 0-3, cols 4-7): Mixed values → not a leaf, subdivide further into four 2×2 quadrants:
    • topLeft of topRight: all 0s → leaf val=False
    • topRight of topRight: all 0s → leaf val=False
    • bottomLeft of topRight: all 1s → leaf val=True
    • bottomRight of topRight: all 1s → leaf val=True
  • bottomLeft (rows 4-7, cols 0-3): All 1s → leaf with val=True
  • bottomRight (rows 4-7, cols 4-7): All 0s → leaf with val=False

Constraints

  • n == grid.length == grid[i].length
  • n == 2^x where 0 ≤ x ≤ 6
  • grid[i][j] is either 0 or 1

Editorial

Brute Force

Intuition

The most direct approach follows the Quad-Tree construction rules literally. For any given rectangular region of the grid:

  1. Scan every cell in that region to check if all values are the same.
  2. If they are all the same, create a leaf node.
  3. If they differ, split the region into four equal quadrants and recursively process each one.

Imagine you have a large tile mosaic and want to describe it compactly. You first look at the entire mosaic — if every tile is the same color, you say "it's all blue" (one leaf). If not, you divide it into four quarters and describe each quarter separately. You keep subdividing until you reach sections that are uniform.

The expensive part is step 1: to determine if a region is uniform, we must check every single cell in it. For a region of size k × k, that scan costs O(k²). Since we might do this at multiple recursion levels, the total scanning work adds up.

Step-by-Step Explanation

Let's trace through with grid = [[0, 1], [1, 0]] (a 2×2 grid):

Step 1: Consider the full 2×2 region (rows 0-1, cols 0-1). Check if all values are the same.

  • Scan: grid[0][0]=0, grid[0][1]=1. Found different values (0 ≠ 1).
  • The region is NOT uniform. Create an internal node (isLeaf=False).

Step 2: Divide into four 1×1 quadrants.

  • topLeft: rows 0-0, cols 0-0
  • topRight: rows 0-0, cols 1-1
  • bottomLeft: rows 1-1, cols 0-0
  • bottomRight: rows 1-1, cols 1-1

Step 3: Process topLeft quadrant (1×1, single cell). grid[0][0] = 0.

  • A single cell is always uniform. Create leaf node: isLeaf=True, val=False (0).

Step 4: Process topRight quadrant (1×1). grid[0][1] = 1.

  • Create leaf node: isLeaf=True, val=True (1).

Step 5: Process bottomLeft quadrant (1×1). grid[1][0] = 1.

  • Create leaf node: isLeaf=True, val=True (1).

Step 6: Process bottomRight quadrant (1×1). grid[1][1] = 0.

  • Create leaf node: isLeaf=True, val=False (0).

Step 7: Attach all four leaf children to the root internal node.

  • Root: isLeaf=False, val=True (arbitrary)
  • Children: topLeft(leaf,0), topRight(leaf,1), bottomLeft(leaf,1), bottomRight(leaf,0)

Step 8: Return the root. Serialized: [[0,1],[1,0],[1,1],[1,1],[1,0]].

Brute Force — Scan and Subdivide the Grid — Watch how we check the entire region for uniformity, then recursively subdivide mixed regions into four quadrants until each quadrant is uniform.

Algorithm

  1. Define a recursive function build(row, col, size) that builds a Quad-Tree for the sub-grid starting at (row, col) with side length size.
  2. Check uniformity: Scan all size × size cells in the region. If all values equal the first cell's value:
    • Return a leaf node with isLeaf=True and val = that uniform value.
  3. Subdivide: If the region is mixed:
    • Create an internal node with isLeaf=False.
    • Set half = size / 2.
    • Recursively build four children:
      • topLeft = build(row, col, half)
      • topRight = build(row, col + half, half)
      • bottomLeft = build(row + half, col, half)
      • bottomRight = build(row + half, col + half, half)
  4. Return the root node from build(0, 0, n).

Code

/*
// Definition for a QuadTree node.
class Node {
public:
    bool val;
    bool isLeaf;
    Node* topLeft;
    Node* topRight;
    Node* bottomLeft;
    Node* bottomRight;
    
    Node() : val(false), isLeaf(false), topLeft(nullptr), topRight(nullptr),
             bottomLeft(nullptr), bottomRight(nullptr) {}
    
    Node(bool _val, bool _isLeaf) : val(_val), isLeaf(_isLeaf),
        topLeft(nullptr), topRight(nullptr),
        bottomLeft(nullptr), bottomRight(nullptr) {}
    
    Node(bool _val, bool _isLeaf, Node* _topLeft, Node* _topRight,
         Node* _bottomLeft, Node* _bottomRight)
        : val(_val), isLeaf(_isLeaf), topLeft(_topLeft), topRight(_topRight),
          bottomLeft(_bottomLeft), bottomRight(_bottomRight) {}
};
*/

class Solution {
public:
    bool isUniform(vector<vector<int>>& grid, int row, int col, int size) {
        int val = grid[row][col];
        for (int i = row; i < row + size; i++) {
            for (int j = col; j < col + size; j++) {
                if (grid[i][j] != val) return false;
            }
        }
        return true;
    }
    
    Node* build(vector<vector<int>>& grid, int row, int col, int size) {
        if (isUniform(grid, row, col, size)) {
            return new Node(grid[row][col] == 1, true);
        }
        
        int half = size / 2;
        Node* node = new Node(
            true, false,
            build(grid, row, col, half),
            build(grid, row, col + half, half),
            build(grid, row + half, col, half),
            build(grid, row + half, col + half, half)
        );
        return node;
    }
    
    Node* construct(vector<vector<int>>& grid) {
        return build(grid, 0, 0, grid.size());
    }
};
"""
# Definition for a QuadTree node.
class Node:
    def __init__(self, val, isLeaf, topLeft, topRight, bottomLeft, bottomRight):
        self.val = val
        self.isLeaf = isLeaf
        self.topLeft = topLeft
        self.topRight = topRight
        self.bottomLeft = bottomLeft
        self.bottomRight = bottomRight
"""

class Solution:
    def construct(self, grid: List[List[int]]) -> 'Node':
        
        def is_uniform(row, col, size):
            val = grid[row][col]
            for i in range(row, row + size):
                for j in range(col, col + size):
                    if grid[i][j] != val:
                        return False
            return True
        
        def build(row, col, size):
            if is_uniform(row, col, size):
                return Node(grid[row][col] == 1, True)
            
            half = size // 2
            return Node(
                True, False,
                build(row, col, half),
                build(row, col + half, half),
                build(row + half, col, half),
                build(row + half, col + half, half)
            )
        
        return build(0, 0, len(grid))
/*
// Definition for a QuadTree node.
class Node {
    public boolean val;
    public boolean isLeaf;
    public Node topLeft;
    public Node topRight;
    public Node bottomLeft;
    public Node bottomRight;

    public Node() {
        this.val = false;
        this.isLeaf = false;
    }

    public Node(boolean val, boolean isLeaf) {
        this.val = val;
        this.isLeaf = isLeaf;
    }

    public Node(boolean val, boolean isLeaf, Node topLeft, Node topRight,
                Node bottomLeft, Node bottomRight) {
        this.val = val;
        this.isLeaf = isLeaf;
        this.topLeft = topLeft;
        this.topRight = topRight;
        this.bottomLeft = bottomLeft;
        this.bottomRight = bottomRight;
    }
};
*/

class Solution {
    private boolean isUniform(int[][] grid, int row, int col, int size) {
        int val = grid[row][col];
        for (int i = row; i < row + size; i++) {
            for (int j = col; j < col + size; j++) {
                if (grid[i][j] != val) return false;
            }
        }
        return true;
    }
    
    private Node build(int[][] grid, int row, int col, int size) {
        if (isUniform(grid, row, col, size)) {
            return new Node(grid[row][col] == 1, true);
        }
        
        int half = size / 2;
        return new Node(
            true, false,
            build(grid, row, col, half),
            build(grid, row, col + half, half),
            build(grid, row + half, col, half),
            build(grid, row + half, col + half, half)
        );
    }
    
    public Node construct(int[][] grid) {
        return build(grid, 0, 0, grid.length);
    }
}

Complexity Analysis

Time Complexity: O(n² log n)

At the top level, we scan all n² cells to check uniformity. If the region is mixed, we split into four n/2 × n/2 quadrants, and each of those scans (n/2)² = n²/4 cells. At each recursion level, the total scanning work across all sub-problems is O(n²), and there are O(log n) levels of recursion (since the grid size halves each time and n = 2^x). This gives O(n² log n) total work.

In the worst case (checkerboard pattern where every cell differs from its neighbors), every region must be subdivided down to single cells, maximizing the scanning work.

Space Complexity: O(n²)

In the worst case (checkerboard), we create one node per cell, giving O(n²) tree nodes. The recursion stack depth is O(log n). The dominant cost is the tree nodes themselves: O(n²).

Why This Approach Is Not Efficient

The brute force approach redundantly scans the same cells multiple times. When a region is not uniform and we subdivide it, we already scanned all its cells once — but then each of the four child regions scans its portion again. The same cell might be scanned at multiple recursion levels.

For example, with a 64×64 grid, the top level scans all 4,096 cells. If the grid is mixed, the four 32×32 quadrants together scan 4,096 cells again. This redundant scanning at each level multiplies the work by O(log n).

The key insight: instead of scanning every cell in a region to determine uniformity, we can precompute a prefix sum of the grid. A 2D prefix sum lets us calculate the sum of any rectangular sub-region in O(1) time. If the sum equals 0, all cells are 0 (uniform). If the sum equals size², all cells are 1 (uniform). Otherwise, the region is mixed. This eliminates the O(n²) scanning cost at each recursion level, reducing total time from O(n² log n) to O(n²).

Optimal Approach - Prefix Sum Optimization

Intuition

The bottleneck in the brute force is checking whether a region is uniform. We scan k² cells for a k×k region. If we could answer "what is the sum of all values in this rectangle?" in O(1) time, we could determine uniformity instantly:

  • Sum = 0 → all cells are 0 → uniform (val=False)
  • Sum = k² → all cells are 1 → uniform (val=True)
  • Otherwise → mixed → must subdivide

A 2D prefix sum gives us exactly this ability. We precompute a matrix prefix where prefix[i][j] stores the sum of all values in the sub-grid from (0,0) to (i-1, j-1). Using inclusion-exclusion, the sum of any rectangular region can be computed in O(1).

Think of it like a running tally at a warehouse. Instead of counting every box in a section each time someone asks, you maintain a cumulative total for each row and column. When asked "how many boxes are in rows 3-5, columns 2-4?", you compute the answer with just 4 lookups and some subtraction — no need to walk through each shelf.

With O(1) uniformity checks, the recursion only does O(1) work per node, and the total number of nodes is at most O(n²). So the total time drops to O(n²) — the same as just reading the input.

Step-by-Step Explanation

Let's trace through with grid = [[0, 1], [1, 0]] (a 2×2 grid):

Step 1: Build the 2D prefix sum matrix.

  • prefix[0][0..2] = [0, 0, 0] (padding row)
  • prefix[1][0..2] = [0, 0, 1] → prefix[1][1] = grid[0][0] = 0, prefix[1][2] = 0 + 1 = 1
  • prefix[2][0..2] = [0, 1, 2] → prefix[2][1] = 0 + 1 = 1, prefix[2][2] = 0 + 1 + 1 + 0 = 2

Step 2: Start building from the full 2×2 region (row=0, col=0, size=2).

  • Compute region sum using prefix sum: sum = prefix[2][2] - prefix[0][2] - prefix[2][0] + prefix[0][0] = 2 - 0 - 0 + 0 = 2.
  • Is sum == 0? No. Is sum == 2² = 4? No (sum=2 ≠ 4). The region is mixed.
  • Create an internal node and subdivide.

Step 3: Process topLeft (row=0, col=0, size=1).

  • Sum = prefix[1][1] - prefix[0][1] - prefix[1][0] + prefix[0][0] = 0 - 0 - 0 + 0 = 0.
  • Sum == 0 → all zeros → leaf with val=False.

Step 4: Process topRight (row=0, col=1, size=1).

  • Sum = prefix[1][2] - prefix[0][2] - prefix[1][1] + prefix[0][1] = 1 - 0 - 0 + 0 = 1.
  • Sum == 1² = 1 → all ones → leaf with val=True.

Step 5: Process bottomLeft (row=1, col=0, size=1).

  • Sum = prefix[2][1] - prefix[1][1] - prefix[2][0] + prefix[1][0] = 1 - 0 - 0 + 0 = 1.
  • Sum == 1 → all ones → leaf with val=True.

Step 6: Process bottomRight (row=1, col=1, size=1).

  • Sum = prefix[2][2] - prefix[1][2] - prefix[2][1] + prefix[1][1] = 2 - 1 - 1 + 0 = 0.
  • Sum == 0 → all zeros → leaf with val=False.

Step 7: Attach all four leaf children to the root. Return the root.

  • Result: [[0,1],[1,0],[1,1],[1,1],[1,0]].

Prefix Sum Optimized — O(1) Uniformity Checks — Watch how the prefix sum matrix enables instant uniformity checks for any rectangular region, eliminating the need to scan individual cells.

Algorithm

  1. Precompute 2D prefix sum: Build a (n+1) × (n+1) prefix matrix where prefix[i][j] = sum of all values in grid[0..i-1][0..j-1].
  2. Define region sum query: For a region starting at (row, col) with side length size:
    • sum = prefix[row+size][col+size] - prefix[row][col+size] - prefix[row+size][col] + prefix[row][col]
  3. Define recursive builder build(row, col, size):
    • Compute the region sum in O(1).
    • If sum == 0: return leaf node with val=False (all zeros).
    • If sum == size²: return leaf node with val=True (all ones).
    • Otherwise: create internal node and recursively build four children with half the size.
  4. Return build(0, 0, n).

Code

class Solution {
public:
    vector<vector<int>> prefix;
    
    int regionSum(int row, int col, int size) {
        return prefix[row + size][col + size]
             - prefix[row][col + size]
             - prefix[row + size][col]
             + prefix[row][col];
    }
    
    Node* build(int row, int col, int size) {
        int sum = regionSum(row, col, size);
        int total = size * size;
        
        if (sum == 0) {
            return new Node(false, true);
        }
        if (sum == total) {
            return new Node(true, true);
        }
        
        int half = size / 2;
        return new Node(
            true, false,
            build(row, col, half),
            build(row, col + half, half),
            build(row + half, col, half),
            build(row + half, col + half, half)
        );
    }
    
    Node* construct(vector<vector<int>>& grid) {
        int n = grid.size();
        prefix.assign(n + 1, vector<int>(n + 1, 0));
        
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                prefix[i][j] = grid[i-1][j-1]
                             + prefix[i-1][j]
                             + prefix[i][j-1]
                             - prefix[i-1][j-1];
            }
        }
        
        return build(0, 0, n);
    }
};
class Solution:
    def construct(self, grid: List[List[int]]) -> 'Node':
        n = len(grid)
        
        # Build 2D prefix sum
        prefix = [[0] * (n + 1) for _ in range(n + 1)]
        for i in range(1, n + 1):
            for j in range(1, n + 1):
                prefix[i][j] = (grid[i-1][j-1]
                              + prefix[i-1][j]
                              + prefix[i][j-1]
                              - prefix[i-1][j-1])
        
        def region_sum(row, col, size):
            return (prefix[row + size][col + size]
                  - prefix[row][col + size]
                  - prefix[row + size][col]
                  + prefix[row][col])
        
        def build(row, col, size):
            total = region_sum(row, col, size)
            
            if total == 0:
                return Node(False, True)
            if total == size * size:
                return Node(True, True)
            
            half = size // 2
            return Node(
                True, False,
                build(row, col, half),
                build(row, col + half, half),
                build(row + half, col, half),
                build(row + half, col + half, half)
            )
        
        return build(0, 0, n)
class Solution {
    private int[][] prefix;
    
    private int regionSum(int row, int col, int size) {
        return prefix[row + size][col + size]
             - prefix[row][col + size]
             - prefix[row + size][col]
             + prefix[row][col];
    }
    
    private Node build(int row, int col, int size) {
        int sum = regionSum(row, col, size);
        int total = size * size;
        
        if (sum == 0) {
            return new Node(false, true);
        }
        if (sum == total) {
            return new Node(true, true);
        }
        
        int half = size / 2;
        return new Node(
            true, false,
            build(row, col, half),
            build(row, col + half, half),
            build(row + half, col, half),
            build(row + half, col + half, half)
        );
    }
    
    public Node construct(int[][] grid) {
        int n = grid.length;
        prefix = new int[n + 1][n + 1];
        
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                prefix[i][j] = grid[i-1][j-1]
                             + prefix[i-1][j]
                             + prefix[i][j-1]
                             - prefix[i-1][j-1];
            }
        }
        
        return build(0, 0, n);
    }
}

Complexity Analysis

Time Complexity: O(n²)

Building the prefix sum matrix takes O(n²). Each node in the Quad-Tree requires O(1) work (one prefix sum query plus creating a node). The maximum number of nodes in the tree is O(n²) (in the worst case, every cell becomes its own leaf). So the recursive construction is O(n²). Total: O(n²) + O(n²) = O(n²).

This is a significant improvement over the brute force's O(n² log n), eliminating the redundant scanning at each recursion level.

Space Complexity: O(n²)

The prefix sum matrix uses O(n²) space. The Quad-Tree itself can have up to O(n²) nodes in the worst case. The recursion stack depth is O(log n). Total: O(n²).