Skip to main content

Partition a Linked List around a given value

MEDIUMProblemSolveExternal Links

Description

Given a singly linked list and a value x, partition the linked list such that all nodes with values less than x come first, then all nodes with values equal to x, and finally all nodes with values greater than x.

The relative order of nodes within each of the three groups (less than, equal to, greater than) must be preserved. In other words, if node A appeared before node B in the original list and both are less than x, then node A must still appear before node B in the result.

The partition must work in-place — you rearrange the existing nodes by changing their next pointers, not by creating new nodes.

Examples

Example 1

Input: Linked list: 1 → 4 → 3 → 2 → 5 → 2 → 3, x = 3

Output: 1 → 2 → 2 → 3 → 3 → 4 → 5

Explanation: We partition around the value 3:

  • Nodes less than 3: 1, 2, 2 (original order preserved)
  • Nodes equal to 3: 3, 3 (original order preserved)
  • Nodes greater than 3: 4, 5 (original order preserved)

Concatenating these three groups gives: 1 → 2 → 2 → 3 → 3 → 4 → 5.

Example 2

Input: Linked list: 1 → 4 → 2 → 10, x = 3

Output: 1 → 2 → 4 → 10

Explanation: We partition around the value 3:

  • Nodes less than 3: 1, 2
  • Nodes equal to 3: (none)
  • Nodes greater than 3: 4, 10

Since no node equals 3, we just concatenate the less-than and greater-than groups: 1 → 2 → 4 → 10.

Example 3

Input: Linked list: 5 → 1, x = 5

Output: 1 → 5

Explanation: Node 1 is less than 5, and node 5 equals 5. So 1 comes first, then 5: 1 → 5.

Constraints

  • 1 ≤ size of linked list ≤ 10^5
  • 1 ≤ data of node ≤ 10^5
  • 1 ≤ x ≤ 10^5

Editorial

Brute Force

Intuition

The most straightforward approach is to extract all the node values into an array, partition the array into three groups (less than x, equal to x, greater than x) while preserving relative order, and then write the values back into the linked list nodes.

Think of it like sorting mail into three bins — one for letters addressed to people with names before 'M', one for exactly 'M', and one for names after 'M'. You go through the pile once, drop each letter into the right bin, then stack the bins together: first bin on top, then middle, then bottom.

This approach is simple to implement but uses extra space proportional to the number of nodes.

Step-by-Step Explanation

Let's trace through with list = [1, 4, 3, 2, 5, 2, 3], x = 3:

Step 1: Extract all values into an array.

  • values = [1, 4, 3, 2, 5, 2, 3]

Step 2: Create three separate lists for less, equal, and greater.

  • less = [], equal = [], greater = []

Step 3: Scan through the array, placing each value in the correct group.

  • 1 < 3 → less = [1]
  • 4 > 3 → greater = [4]
  • 3 == 3 → equal = [3]
  • 2 < 3 → less = [1, 2]
  • 5 > 3 → greater = [4, 5]
  • 2 < 3 → less = [1, 2, 2]
  • 3 == 3 → equal = [3, 3]

Step 4: Concatenate: less + equal + greater = [1, 2, 2, 3, 3, 4, 5]

Step 5: Write values back into the linked list nodes.

  • Result: 1 → 2 → 2 → 3 → 3 → 4 → 5

Brute Force — Extract, Classify, Write Back — Watch how we scan through the values and classify each into less-than, equal-to, or greater-than groups, then concatenate the groups.

Algorithm

  1. Traverse the linked list and extract all values into an array
  2. Create three empty lists: less, equal, greater
  3. For each value in the array:
    • If value < x, append to less
    • If value == x, append to equal
    • If value > x, append to greater
  4. Concatenate the three lists: result = less + equal + greater
  5. Traverse the linked list again and overwrite each node's value with the corresponding value from result
  6. Return the head

Code

/*
struct Node {
    int data;
    Node* next;
    Node(int x) : data(x), next(nullptr) {}
};
*/
class Solution {
public:
    Node* partition(Node* head, int x) {
        vector<int> less, equal, greater;
        
        Node* curr = head;
        while (curr != nullptr) {
            if (curr->data < x) {
                less.push_back(curr->data);
            } else if (curr->data == x) {
                equal.push_back(curr->data);
            } else {
                greater.push_back(curr->data);
            }
            curr = curr->next;
        }
        
        // Concatenate
        vector<int> result;
        for (int v : less) result.push_back(v);
        for (int v : equal) result.push_back(v);
        for (int v : greater) result.push_back(v);
        
        // Write back
        curr = head;
        int i = 0;
        while (curr != nullptr) {
            curr->data = result[i++];
            curr = curr->next;
        }
        
        return head;
    }
};
class Solution:
    def partition(self, head, x):
        less, equal, greater = [], [], []
        
        curr = head
        while curr:
            if curr.data < x:
                less.append(curr.data)
            elif curr.data == x:
                equal.append(curr.data)
            else:
                greater.append(curr.data)
            curr = curr.next
        
        # Concatenate
        result = less + equal + greater
        
        # Write back
        curr = head
        for val in result:
            curr.data = val
            curr = curr.next
        
        return head
class Solution {
    Node partition(Node head, int x) {
        List<Integer> less = new ArrayList<>();
        List<Integer> equal = new ArrayList<>();
        List<Integer> greater = new ArrayList<>();
        
        Node curr = head;
        while (curr != null) {
            if (curr.data < x) {
                less.add(curr.data);
            } else if (curr.data == x) {
                equal.add(curr.data);
            } else {
                greater.add(curr.data);
            }
            curr = curr.next;
        }
        
        // Concatenate
        List<Integer> result = new ArrayList<>();
        result.addAll(less);
        result.addAll(equal);
        result.addAll(greater);
        
        // Write back
        curr = head;
        int i = 0;
        while (curr != null) {
            curr.data = result.get(i++);
            curr = curr.next;
        }
        
        return head;
    }
}

Complexity Analysis

Time Complexity: O(n)

We traverse the linked list once to extract values: O(n). Classifying each value into one of three groups is O(1) per element. Concatenating the three groups is O(n). Writing back is O(n). Total: O(n).

Space Complexity: O(n)

We use three auxiliary arrays/lists that together hold all n values, plus the concatenated result array. Total extra space is O(n).

Why This Approach Is Not Efficient

While the brute force has optimal O(n) time complexity, its O(n) space usage is unnecessary. The problem asks us to partition in-place by rearranging the node pointers, not by copying values to auxiliary arrays.

Overwriting node values is generally frowned upon in linked list problems because:

  1. It doesn't demonstrate understanding of pointer manipulation, which is the core skill being tested.
  2. In real-world scenarios, nodes may carry complex data beyond a single integer — copying data would be expensive.
  3. The problem explicitly states 'the partition must work in place'.

We can achieve the same result with O(1) extra space by maintaining three separate linked list chains (using just head/tail pointers for each) and then connecting them at the end. This eliminates the need for any auxiliary array.

Optimal Approach - Three-Pointer Partition

Intuition

Instead of copying values into arrays, we can build three separate linked list chains directly using the existing nodes. We maintain a head and tail pointer for each of three chains:

  • Less chain: collects all nodes with value < x
  • Equal chain: collects all nodes with value == x
  • Greater chain: collects all nodes with value > x

As we traverse the original list, we detach each node and append it to the appropriate chain. Since we process nodes in order and append (not prepend), the relative order within each chain is automatically preserved.

After processing all nodes, we connect the three chains together: the tail of 'less' points to the head of 'equal', and the tail of 'equal' points to the head of 'greater'. The overall head is the head of the first non-empty chain.

Imagine you're sorting a deck of cards into three piles on a table. You go through the deck one card at a time, placing each card on the bottom of its pile. When done, you stack the three piles together — small pile, then medium pile, then large pile. You never needed a second deck of cards — you just moved cards from one pile to another.

Step-by-Step Explanation

Let's trace through with list = [1, 4, 3, 2, 5, 2, 3], x = 3:

Step 1: Initialize three chains with dummy heads.

  • lessHead → dummy, lessTail → dummy
  • equalHead → dummy, equalTail → dummy
  • greaterHead → dummy, greaterTail → dummy

Step 2: Process node(1). Value 1 < 3 → append to less chain.

  • less: dummy → 1

Step 3: Process node(4). Value 4 > 3 → append to greater chain.

  • greater: dummy → 4

Step 4: Process node(3). Value 3 == 3 → append to equal chain.

  • equal: dummy → 3

Step 5: Process node(2). Value 2 < 3 → append to less chain.

  • less: dummy → 1 → 2

Step 6: Process node(5). Value 5 > 3 → append to greater chain.

  • greater: dummy → 4 → 5

Step 7: Process node(2). Value 2 < 3 → append to less chain.

  • less: dummy → 1 → 2 → 2

Step 8: Process node(3). Value 3 == 3 → append to equal chain.

  • equal: dummy → 3 → 3

Step 9: Connect the three chains.

  • lessTail(2).next → equalHead.next (which is 3)
  • equalTail(3).next → greaterHead.next (which is 4)
  • greaterTail(5).next → null
  • Result: 1 → 2 → 2 → 3 → 3 → 4 → 5

Three-Pointer Partition — Building Three Chains In-Place — Watch how we traverse the original list once, appending each node to one of three chains (less, equal, greater), then connect the chains together.

Algorithm

  1. Create three dummy nodes: lessHead, equalHead, greaterHead. Initialize tail pointers for each to their respective dummy.
  2. Traverse the original list node by node:
    • If node.data < x: append to less chain (lessTail.next = node, advance lessTail)
    • If node.data == x: append to equal chain
    • If node.data > x: append to greater chain
  3. Terminate each chain: set lessTail.next = null, equalTail.next = null, greaterTail.next = null
  4. Connect the chains:
    • If equal chain is non-empty: lessTail.next = equalHead.next
    • If greater chain is non-empty: equalTail.next = greaterHead.next
    • Handle cases where less or equal chains might be empty
  5. Determine the overall head: head of first non-empty chain (less → equal → greater)
  6. Return the overall head

Code

class Solution {
public:
    Node* partition(Node* head, int x) {
        // Dummy heads for three chains
        Node lessHead(0), equalHead(0), greaterHead(0);
        Node* lessTail = &lessHead;
        Node* equalTail = &equalHead;
        Node* greaterTail = &greaterHead;
        
        Node* curr = head;
        while (curr != nullptr) {
            Node* nextNode = curr->next;
            curr->next = nullptr;
            
            if (curr->data < x) {
                lessTail->next = curr;
                lessTail = curr;
            } else if (curr->data == x) {
                equalTail->next = curr;
                equalTail = curr;
            } else {
                greaterTail->next = curr;
                greaterTail = curr;
            }
            
            curr = nextNode;
        }
        
        // Connect chains: less → equal → greater
        lessTail->next = equalHead.next;
        equalTail->next = greaterHead.next;
        
        // Return head of first non-empty chain
        if (lessHead.next != nullptr) return lessHead.next;
        if (equalHead.next != nullptr) return equalHead.next;
        return greaterHead.next;
    }
};
class Solution:
    def partition(self, head, x):
        # Dummy heads for three chains
        less_head = Node(0)
        equal_head = Node(0)
        greater_head = Node(0)
        
        less_tail = less_head
        equal_tail = equal_head
        greater_tail = greater_head
        
        curr = head
        while curr:
            next_node = curr.next
            curr.next = None
            
            if curr.data < x:
                less_tail.next = curr
                less_tail = curr
            elif curr.data == x:
                equal_tail.next = curr
                equal_tail = curr
            else:
                greater_tail.next = curr
                greater_tail = curr
            
            curr = next_node
        
        # Connect chains: less → equal → greater
        less_tail.next = equal_head.next
        equal_tail.next = greater_head.next
        
        # Return head of first non-empty chain
        if less_head.next:
            return less_head.next
        if equal_head.next:
            return equal_head.next
        return greater_head.next
class Solution {
    Node partition(Node head, int x) {
        // Dummy heads for three chains
        Node lessHead = new Node(0);
        Node equalHead = new Node(0);
        Node greaterHead = new Node(0);
        
        Node lessTail = lessHead;
        Node equalTail = equalHead;
        Node greaterTail = greaterHead;
        
        Node curr = head;
        while (curr != null) {
            Node nextNode = curr.next;
            curr.next = null;
            
            if (curr.data < x) {
                lessTail.next = curr;
                lessTail = curr;
            } else if (curr.data == x) {
                equalTail.next = curr;
                equalTail = curr;
            } else {
                greaterTail.next = curr;
                greaterTail = curr;
            }
            
            curr = nextNode;
        }
        
        // Connect chains: less → equal → greater
        lessTail.next = equalHead.next;
        equalTail.next = greaterHead.next;
        
        // Return head of first non-empty chain
        if (lessHead.next != null) return lessHead.next;
        if (equalHead.next != null) return equalHead.next;
        return greaterHead.next;
    }
}

Complexity Analysis

Time Complexity: O(n)

We traverse the linked list exactly once, processing each of the n nodes. For each node, we perform a constant number of operations: one comparison with x, one pointer reassignment to append to the appropriate chain, and one pointer advance. Connecting the three chains at the end takes O(1). Total: O(n).

Space Complexity: O(1)

We only use a fixed number of pointer variables (three dummy heads and three tail pointers — 6 pointers total regardless of input size). We don't create any new nodes or use any auxiliary data structures that grow with n. The nodes themselves are just relinked, not copied. This is genuinely in-place.