Skip to main content

Detect Squares

MEDIUMProblemSolveExternal Links

Description

Design a data structure that can store points on a 2D plane and, given a query point, count how many axis-aligned squares can be formed using points already stored in the structure together with the query point.

An axis-aligned square is one whose sides are all parallel (or perpendicular) to the x-axis and y-axis — no rotated or diagonal squares. The square must have positive area (all four corners must be distinct points).

Implement the DetectSquares class:

  • DetectSquares() — Initializes an empty data structure.
  • void add(int[] point) — Adds a new point point = [x, y] to the data structure. Duplicate points are allowed — if the same coordinate is added multiple times, each occurrence is treated as a separate point.
  • int count(int[] point) — Given a query point point = [x, y], return the number of ways to choose three points from the stored data such that these three points and the query point form an axis-aligned square with positive area.

Examples

Example 1

Input:

["DetectSquares", "add", "add", "add", "count", "count", "add", "count"]
[[], [[3, 10]], [[11, 2]], [[3, 2]], [[11, 10]], [[14, 8]], [[11, 2]], [[11, 10]]]

Output: [null, null, null, null, 1, 0, null, 2]

Explanation:

  • DetectSquares() — Initialize.
  • add([3, 10]) — Store point (3, 10).
  • add([11, 2]) — Store point (11, 2).
  • add([3, 2]) — Store point (3, 2).
  • count([11, 10]) — Query with (11, 10). The stored points (3, 10), (11, 2), (3, 2) together with the query form a valid axis-aligned square with side length 8. Result: 1.
  • count([14, 8]) — Query with (14, 8). No three stored points complete a valid square. Result: 0.
  • add([11, 2]) — Store another copy of (11, 2). Now (11, 2) appears twice.
  • count([11, 10]) — Query again with (11, 10). The same square exists, but now (11, 2) has 2 copies. We can pick either copy, giving 2 valid combinations.

Example 2

Input:

["DetectSquares", "add", "add", "add", "add", "count"]
[[], [[0, 0]], [[0, 2]], [[2, 0]], [[2, 2]], [[0, 0]]]

Output: [null, null, null, null, null, 1]

Explanation:

  • Add points (0,0), (0,2), (2,0), (2,2).
  • count([0, 0]) — Query with (0, 0). The stored points (0, 2), (2, 0), (2, 2) with the query point form a square:
    (0,2) ---- (2,2)
      |          |
      |          |
    (0,0) ---- (2,0)
    
    Result: 1.

Constraints

  • point.length == 2
  • 0 ≤ x, y ≤ 1000
  • At most 5000 calls will be made to add and count combined

Editorial

Brute Force - Check All Triplets

Intuition

The most naive approach: store every added point in a list, and when count is called, check every possible triplet of stored points to see if, together with the query point, they form an axis-aligned square.

To verify that four points form an axis-aligned square, we need:

  1. Exactly two distinct x-coordinates and exactly two distinct y-coordinates among the four points.
  2. The absolute difference in x-coordinates equals the absolute difference in y-coordinates (so the side length is the same horizontally and vertically).
  3. Each of the four combinations (x₁, y₁), (x₁, y₂), (x₂, y₁), (x₂, y₂) is present.

Think of it like having a bag of push-pins on a board. For each query, you pull out every combination of three pins and check if they, along with the query pin, form the four corners of an upright square. This is exhaustive and correct, but painfully slow.

Step-by-Step Explanation

Stored points: (3, 10), (11, 2), (3, 2). Query: count((11, 10)).

Step 1: We have 3 stored points. Generate all C(3, 3) = 1 triplet: {(3, 10), (11, 2), (3, 2)}.

Step 2: Check if the query (11, 10) plus these three form a square.

  • Collect all four points: (11, 10), (3, 10), (11, 2), (3, 2).
  • Distinct x-values: {3, 11}. Distinct y-values: {2, 10}. Both have exactly 2 ✓.
  • |11 - 3| = 8, |10 - 2| = 8. Equal ✓.
  • Are all four corners (3,2), (3,10), (11,2), (11,10) present? YES ✓.

Step 3: This is a valid square. Count += 1.

Step 4: No more triplets. Return 1.

With n stored points, we check O(n³) triplets per query — clearly impractical for large inputs.

Algorithm

Data Structure: A list storing all added points.

add(point):

  1. Append the point to the list.

count(point):

  1. Let the query point be (x₁, y₁).
  2. For every combination of 3 points from the stored list:
    a. Collect the 4 points (query + 3 chosen).
    b. Check if they form an axis-aligned square.
    c. If yes, increment count.
  3. Return count.

Code

class DetectSquares {
    vector<pair<int,int>> points;
public:
    DetectSquares() {}

    void add(vector<int> point) {
        points.push_back({point[0], point[1]});
    }

    int count(vector<int> point) {
        int x1 = point[0], y1 = point[1];
        int n = points.size();
        int ans = 0;
        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                for (int k = j + 1; k < n; k++) {
                    vector<pair<int,int>> four = {{x1, y1}, points[i], points[j], points[k]};
                    sort(four.begin(), four.end());
                    if (four[0].first == four[1].first &&
                        four[2].first == four[3].first &&
                        four[0].second == four[2].second &&
                        four[1].second == four[3].second &&
                        four[2].first - four[0].first == four[1].second - four[0].second &&
                        four[2].first != four[0].first) {
                        ans++;
                    }
                }
            }
        }
        return ans;
    }
};
class DetectSquares:
    def __init__(self):
        self.points = []

    def add(self, point: list[int]) -> None:
        self.points.append(tuple(point))

    def count(self, point: list[int]) -> int:
        x1, y1 = point
        ans = 0
        n = len(self.points)
        for i in range(n):
            for j in range(i + 1, n):
                for k in range(j + 1, n):
                    four = sorted([(x1, y1), self.points[i], self.points[j], self.points[k]])
                    if (four[0][0] == four[1][0] and
                        four[2][0] == four[3][0] and
                        four[0][1] == four[2][1] and
                        four[1][1] == four[3][1] and
                        four[2][0] - four[0][0] == four[1][1] - four[0][1] and
                        four[2][0] != four[0][0]):
                        ans += 1
        return ans
class DetectSquares {
    private List<int[]> points;

    public DetectSquares() {
        points = new ArrayList<>();
    }

    public void add(int[] point) {
        points.add(point);
    }

    public int count(int[] point) {
        int x1 = point[0], y1 = point[1];
        int n = points.size();
        int ans = 0;
        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                for (int k = j + 1; k < n; k++) {
                    int[][] four = {{x1, y1}, points.get(i), points.get(j), points.get(k)};
                    Arrays.sort(four, (a, b) -> a[0] != b[0] ? a[0] - b[0] : a[1] - b[1]);
                    if (four[0][0] == four[1][0] &&
                        four[2][0] == four[3][0] &&
                        four[0][1] == four[2][1] &&
                        four[1][1] == four[3][1] &&
                        four[2][0] - four[0][0] == four[1][1] - four[0][1] &&
                        four[2][0] != four[0][0]) {
                        ans++;
                    }
                }
            }
        }
        return ans;
    }
}

Complexity Analysis

Time Complexity:

  • add: O(1) — simple list append.
  • count: O(n³) where n is the number of stored points. We examine every combination of 3 points from n. For 5000 points, that's roughly 2 × 10^10 operations — far too slow.

Space Complexity: O(n) for storing all added points.

Why This Approach Is Not Efficient

The brute force blindly checks all triplets without using the geometric structure of the problem. With up to 5000 add + count calls, the point list can grow large, and O(n³) per count is catastrophically slow.

The key insight we're missing: an axis-aligned square is fully determined by any two diagonally opposite corners. If we know one corner (the query point) and pick a second point as the diagonal opposite, the other two corners are mathematically fixed. We don't need to search for triplets — we just need to check if those two fixed corners exist.

This reduces the problem from "search all triplets" (O(n³)) to "iterate over candidate diagonals and verify two lookups" (O(n)).

Optimal Approach - Diagonal Enumeration with Hash Map

Intuition

Here's the geometric insight that unlocks the efficient solution:

Given a query point (x₁, y₁), imagine it as one corner of a potential square. If we pick any other stored point as the diagonally opposite corner (x₃, y₃), the other two corners are completely determined:

  • Corner A: (x₁, y₃)
  • Corner B: (x₃, y₁)

But not just any point qualifies as a valid diagonal partner. For an axis-aligned square with positive area:

  1. The diagonal must have equal horizontal and vertical span: |x₃ - x₁| = |y₃ - y₁|
  2. Both spans must be non-zero: x₃ ≠ x₁ (otherwise all four points would collapse)

So the algorithm becomes:

  1. For each stored point (x₃, y₃), check if it could be the diagonal opposite of (x₁, y₁).
  2. If yes, the other two corners must be at (x₁, y₃) and (x₃, y₁). Look up how many copies of each exist.
  3. Multiply the counts: if (x₃, y₃) appears a times, (x₁, y₃) appears b times, and (x₃, y₁) appears c times, then there are a × b × c ways to form this particular square.

To make lookups O(1), we store point counts in a hash map: count_map[(x, y)] = number of times point (x, y) has been added.

Instead of iterating over every individual stored point (which could have duplicates), we iterate over unique stored points and use their counts. This avoids redundant work.

Step-by-Step Explanation

Stored points (with counts): (3, 10)×1, (11, 2)×2, (3, 2)×1. Query: count((11, 10)).

Step 1: Query point: (x₁, y₁) = (11, 10). Initialize total = 0.

Step 2: Consider diagonal candidate (3, 10).

  • Check |11 - 3| = 8, |10 - 10| = 0. Not equal (8 ≠ 0). NOT a valid diagonal. Skip.

Step 3: Consider diagonal candidate (11, 2).

  • Check |11 - 11| = 0, |10 - 2| = 8. Not equal (0 ≠ 8). NOT a valid diagonal. Skip.

Step 4: Consider diagonal candidate (3, 2).

  • Check |11 - 3| = 8, |10 - 2| = 8. Equal ✓ Both non-zero ✓. Valid diagonal!
  • Other two corners: (x₁, y₃) = (11, 2) and (x₃, y₁) = (3, 10).
  • Look up counts: count(3, 2) = 1, count(11, 2) = 2, count(3, 10) = 1.
  • Ways for this square: 1 × 2 × 1 = 2.
  • total += 2.

Step 5: No more unique points. Return total = 2.

The answer is 2: one square using the first copy of (11, 2), another using the second copy.

Diagonal Enumeration — Finding Squares via Opposite Corners — Watch how the algorithm iterates through stored points looking for valid diagonal partners, then performs O(1) lookups for the remaining two corners.

Algorithm

Data Structure: A hash map cnt where cnt[(x, y)] = number of times point (x, y) has been added. Optionally, also maintain a list of all added points (for iteration).

add(point):

  1. Increment cnt[(point[0], point[1])] by 1.
  2. Optionally store the point in a list (for the alternative iteration approach).

count(point):

  1. Let the query point be (x₁, y₁). Initialize total = 0.
  2. For each stored point (x₃, y₃) in the collection:
    • If x₃ == x₁ and y₃ == y₁, skip (same position → degenerate square with zero area).
    • If |x₃ − x₁| ≠ |y₃ − y₁|, skip (not a valid diagonal).
    • Otherwise, compute the two remaining corners: (x₁, y₃) and (x₃, y₁).
    • total += cnt[(x₃, y₃)] × cnt[(x₁, y₃)] × cnt[(x₃, y₁)].
  3. Return total.

Note: If we iterate over the list of all added points (with duplicates), the diagonal point's count is already reflected by visiting each duplicate separately, so the formula becomes:
total += cnt[(x₁, y₃)] × cnt[(x₃, y₁)]
(one factor for each visit of (x₃, y₃)).

Alternatively, iterate over unique points using the hash map keys and multiply all three counts.

Code

class DetectSquares {
    unordered_map<long long, int> cnt;
    vector<pair<int,int>> pts;

    long long encode(int x, int y) {
        return (long long)x * 1001 + y;
    }

public:
    DetectSquares() {}

    void add(vector<int> point) {
        int x = point[0], y = point[1];
        cnt[encode(x, y)]++;
        pts.push_back({x, y});
    }

    int count(vector<int> point) {
        int x1 = point[0], y1 = point[1];
        int ans = 0;

        for (auto& [x3, y3] : pts) {
            // Check if (x3, y3) can be diagonal opposite of (x1, y1)
            if (abs(x3 - x1) != abs(y3 - y1) || x3 == x1) continue;

            // Other two corners: (x1, y3) and (x3, y1)
            long long keyA = encode(x1, y3);
            long long keyB = encode(x3, y1);

            if (cnt.count(keyA) && cnt.count(keyB)) {
                ans += cnt[keyA] * cnt[keyB];
            }
        }
        return ans;
    }
};
from collections import defaultdict

class DetectSquares:
    def __init__(self):
        self.cnt = defaultdict(int)
        self.pts = []

    def add(self, point: list[int]) -> None:
        self.cnt[(point[0], point[1])] += 1
        self.pts.append(point)

    def count(self, point: list[int]) -> int:
        x1, y1 = point
        ans = 0

        for x3, y3 in self.pts:
            # Check valid diagonal: equal spans, non-zero
            if abs(x3 - x1) != abs(y3 - y1) or x3 == x1:
                continue

            # Other two corners are (x1, y3) and (x3, y1)
            ans += self.cnt[(x1, y3)] * self.cnt[(x3, y1)]

        return ans
class DetectSquares {
    private Map<Long, Integer> cnt;
    private List<int[]> pts;

    public DetectSquares() {
        cnt = new HashMap<>();
        pts = new ArrayList<>();
    }

    private long encode(int x, int y) {
        return (long) x * 1001 + y;
    }

    public void add(int[] point) {
        long key = encode(point[0], point[1]);
        cnt.merge(key, 1, Integer::sum);
        pts.add(point);
    }

    public int count(int[] point) {
        int x1 = point[0], y1 = point[1];
        int ans = 0;

        for (int[] p : pts) {
            int x3 = p[0], y3 = p[1];
            // Valid diagonal: equal absolute spans, non-zero
            if (Math.abs(x3 - x1) != Math.abs(y3 - y1) || x3 == x1) continue;

            long keyA = encode(x1, y3);
            long keyB = encode(x3, y1);

            ans += cnt.getOrDefault(keyA, 0) * cnt.getOrDefault(keyB, 0);
        }
        return ans;
    }
}

Complexity Analysis

Time Complexity:

  • add: O(1) — increment a hash map counter and append to a list.
  • count: O(n) where n is the total number of add calls. We iterate through all stored points and for each, perform O(1) hash map lookups. With at most 5000 total calls, each count operation processes at most 5000 points — fast enough.

Space Complexity: O(n) for storing all added points in both the list and the hash map.

Comparison:

OperationBrute ForceDiagonal Enumeration
addO(1)O(1)
countO(n³)O(n)
SpaceO(n)O(n)

The improvement is dramatic: from O(n³) to O(n). The key insight was recognizing that an axis-aligned square is fully determined by two diagonal corners, reducing the search from combinations of 3 points to single-point iteration with constant-time verification.