310. Minimum Height Trees - Explanation

Problem Link

Description

A tree is an undirected graph in which any two vertices are connected by exactly one path. In other words, any connected graph without simple cycles is a tree.

You are given a tree of n nodes labelled from 0 to n - 1, and an array of n - 1 edges where edges[i] = [a[i], b[i]] indicates that there is an undirected edge between the two nodes a[i] and b[i] in the tree, you can choose any node of the tree as the root. When you select a node x as the root, the result tree has height h. Among all possible rooted trees, those with minimum height (i.e. min(h)) are called minimum height trees (MHTs).

Return a list of all MHTs' root labels. You can return the answer in any order.

The height of a rooted tree is the number of edges on the longest downward path between the root and a leaf.

Example 1:

Input: n = 5, edges = [[0,1],[3,1],[2,3],[4,1]]

Output: [3,1]

Explanation: As shown, the trees with root labels [3,1] are MHT's with height of 2.

Example 2:

Input: n = 4, edges = [[1,0],[2,0],[3,0]]

Output: [0]

Explanation: As shown, the tree with root label [0] is MHT with height of 1.

Constraints:

  • 1 <= n <= 20,000
  • edges.length == n - 1
  • 0 <= a[i], b[i] < n
  • a[i] != b[i]
  • All the pairs (a[i], b[i]) are distinct.
  • The given input is guaranteed to be a tree and there will be no repeated edges.


Topics

Company Tags

Please upgrade to NeetCode Pro to view company tags.



Prerequisites

Before attempting this problem, you should be comfortable with:

  • Tree/Graph Representation - Understanding how to represent trees using adjacency lists
  • Depth First Search (DFS) - Ability to traverse trees recursively while computing subtree properties like height
  • Breadth First Search (BFS) / Topological Sort - Understanding level-by-level processing and leaf removal techniques
  • Tree Properties - Familiarity with concepts like tree diameter, centroids, and the relationship between tree structure and height

1. Brute Force (DFS)

Intuition

The most straightforward approach is to try each node as a potential root and measure the resulting tree height. The height of a tree rooted at any node is the maximum distance to any other node, which we can find using dfs.

By computing the height for every possible root, we can identify which nodes produce the minimum height. While simple to understand, this repeats a lot of work since we recompute distances from scratch for each candidate root.

Algorithm

  1. Build an adjacency list from the edges.
  2. Define a dfs function that computes the height of the tree when rooted at a given node.
  3. For each node from 0 to n-1, run dfs to get its tree height.
  4. Track the minimum height seen and collect all nodes that achieve this minimum.
  5. Return the list of nodes with minimum height.
class Solution:
    def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
        adj = [[] for _ in range(n)]
        for u, v in edges:
            adj[u].append(v)
            adj[v].append(u)

        def dfs(node, parent):
            hgt = 0
            for nei in adj[node]:
                if nei == parent:
                    continue
                hgt = max(hgt, 1 + dfs(nei, node))
            return hgt

        minHgt = n
        res = []
        for i in range(n):
            curHgt = dfs(i, -1)
            if curHgt == minHgt:
                res.append(i)
            elif curHgt < minHgt:
                res = [i]
                minHgt = curHgt

        return res

Time & Space Complexity

  • Time complexity: O(V(V+E))O(V * (V + E))
  • Space complexity: O(V)O(V)

Where VV is the number of vertices and EE is the number of edges.


2. Dynamic Programming On Trees (Rerooting)

Intuition

Rather than recomputing everything for each root, we can reuse information. The tree height from any node depends on the longest path in two directions: down into its subtree and up through its parent to the rest of the tree.

We run two dfs passes. The first computes the two longest downward paths for each node (we need two in case the longest path goes through the child we came from). The second pass propagates information from parent to children, combining the parent's best path with sibling subtree heights.

Algorithm

  1. Build an adjacency list and initialize a dp array storing the top two heights for each node.
  2. First dfs (post-order): for each node, compute the two longest paths into its children's subtrees.
  3. Second dfs (pre-order): propagate the "upward" height from parent to children, updating each node's best heights to include paths through the parent.
  4. After both passes, each node's maximum height (dp[i][0]) represents its tree height as root.
  5. Find the minimum value across all nodes and return all nodes achieving that minimum.
class Solution:
    def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
        adj = [[] for _ in range(n)]
        for u, v in edges:
            adj[u].append(v)
            adj[v].append(u)

        dp = [[0] * 2 for _ in range(n)] # top two heights for each node

        def dfs(node, parent):
            for nei in adj[node]:
                if nei == parent:
                    continue
                dfs(nei, node)
                curHgt = 1 + dp[nei][0]
                if curHgt > dp[node][0]:
                    dp[node][1] = dp[node][0]
                    dp[node][0] = curHgt
                elif curHgt > dp[node][1]:
                    dp[node][1] = curHgt

        def dfs1(node, parent, topHgt):
            if topHgt > dp[node][0]:
                dp[node][1] = dp[node][0]
                dp[node][0] = topHgt
            elif topHgt > dp[node][1]:
                dp[node][1] = topHgt

            for nei in adj[node]:
                if nei == parent:
                    continue
                toChild = 1 + (dp[node][1] if dp[node][0] == 1 + dp[nei][0] else dp[node][0])
                dfs1(nei, node, toChild)

        dfs(0, -1)
        dfs1(0, -1, 0)

        minHgt, res = n, []
        for i in range(n):
            minHgt = min(minHgt, dp[i][0])
        for i in range(n):
            if minHgt == dp[i][0]:
                res.append(i)
        return res

Time & Space Complexity

  • Time complexity: O(V+E)O(V + E)
  • Space complexity: O(V)O(V)

Where VV is the number of vertices and EE is the number of edges.


3. Find Centroids of the Tree (DFS)

Intuition

The minimum height trees are rooted at the centroid(s) of the tree. These centroids lie at the middle of the longest path (diameter) in the tree. If the diameter has odd length, there are two centroids; if even, there's exactly one.

We find the diameter using two bfs/dfs passes: first find the farthest node from any starting point, then find the farthest node from that. The path between these two endpoints is the diameter, and its middle node(s) are the answer.

Algorithm

  1. Build an adjacency list from the edges.
  2. Run dfs from node 0 to find the farthest node (call it node_a).
  3. Run dfs from node_a to find the farthest node (node_b) and the diameter length.
  4. Trace the path from node_a to node_b, collecting all nodes along the way.
  5. If the diameter is even, return the single middle node; if odd, return the two middle nodes.
class Solution:
    def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
        if n == 1:
            return [0]

        adj = [[] for _ in range(n)]
        for u, v in edges:
            adj[u].append(v)
            adj[v].append(u)

        def dfs(node, parent):
            farthest_node = node
            max_distance = 0
            for nei in adj[node]:
                if nei != parent:
                    nei_node, nei_distance = dfs(nei, node)
                    if nei_distance + 1 > max_distance:
                        max_distance = nei_distance + 1
                        farthest_node = nei_node
            return farthest_node, max_distance

        node_a, _ = dfs(0, -1)
        node_b, diameter = dfs(node_a, -1)

        centroids = []

        def find_centroids(node, parent):
            if node == node_b:
                centroids.append(node)
                return True
            for nei in adj[node]:
                if nei != parent:
                    if find_centroids(nei, node):
                        centroids.append(node)
                        return True
            return False

        find_centroids(node_a, -1)
        L = len(centroids)
        if diameter % 2 == 0:
            return [centroids[L // 2]]
        else:
            return [centroids[L // 2 - 1], centroids[L // 2]]

Time & Space Complexity

  • Time complexity: O(V+E)O(V + E)
  • Space complexity: O(V)O(V)

Where VV is the number of vertices and EE is the number of edges.


4. Topological Sorting (BFS)

Intuition

Imagine peeling the tree like an onion, removing leaves layer by layer. The nodes that remain at the very end (when only 1 or 2 nodes are left) must be the centroids, since they're the innermost points of the tree.

Each round of removal brings us one step closer to the center. Since a tree can have at most 2 centroids (on a diameter of odd length), we stop when 2 or fewer nodes remain.

Algorithm

  1. Build an adjacency list and track each node's edge count (degree).
  2. Initialize a queue with all leaf nodes (degree = 1), excluding the special case where n = 1.
  3. While more than 2 nodes remain:
    • Remove all current leaves from the queue.
    • For each removed leaf, decrement its neighbor's degree.
    • If a neighbor becomes a leaf, add it to the queue.
  4. The remaining nodes in the queue are the minimum height tree roots.
class Solution:
    def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
        if n == 1:
            return [0]

        adj = defaultdict(list)
        for n1, n2 in edges:
            adj[n1].append(n2)
            adj[n2].append(n1)

        edge_cnt = {}
        leaves = deque()

        for src, neighbors in adj.items():
            edge_cnt[src] = len(neighbors)
            if len(neighbors) == 1:
                leaves.append(src)

        while leaves:
            if n <= 2:
                return list(leaves)
            for _ in range(len(leaves)):
                node = leaves.popleft()
                n -= 1
                for nei in adj[node]:
                    edge_cnt[nei] -= 1
                    if edge_cnt[nei] == 1:
                        leaves.append(nei)

Time & Space Complexity

  • Time complexity: O(V+E)O(V + E)
  • Space complexity: O(V)O(V)

Where VV is the number of vertices and EE is the number of edges.


Common Pitfalls

Forgetting the Single Node Edge Case

When n = 1, there are no edges and the only node 0 is trivially the root of a minimum height tree. Many solutions fail to handle this case, leading to empty results or index errors when trying to process an empty edge list.

Using a Visited Set Instead of Parent Tracking in DFS

In tree traversal, using a full visited set is unnecessary overhead. Since trees have no cycles, simply passing the parent node to avoid revisiting the previous node is sufficient and more efficient. Using a visited set can also cause issues if not cleared properly between multiple DFS calls.

Incorrectly Implementing Leaf Removal in Topological Sort

When peeling leaves layer by layer, a common mistake is modifying the degree array while iterating over the current batch of leaves. This can cause nodes to be processed prematurely or skipped entirely. Always process all current-level leaves before updating degrees for the next level.