A tree is an undirected graph in which any two vertices are connected by exactly one path. In other words, any connected graph without simple cycles is a tree.
You are given a tree of n nodes labelled from 0 to n - 1, and an array of n - 1 edges where edges[i] = [a[i], b[i]] indicates that there is an undirected edge between the two nodes a[i] and b[i] in the tree, you can choose any node of the tree as the root. When you select a node x as the root, the result tree has height h. Among all possible rooted trees, those with minimum height (i.e. min(h)) are called minimum height trees (MHTs).
Return a list of all MHTs' root labels. You can return the answer in any order.
The height of a rooted tree is the number of edges on the longest downward path between the root and a leaf.
Example 1:
Input: n = 5, edges = [[0,1],[3,1],[2,3],[4,1]]
Output: [3,1]Explanation: As shown, the trees with root labels [3,1] are MHT's with height of 2.
Example 2:
Input: n = 4, edges = [[1,0],[2,0],[3,0]]
Output: [0]Explanation: As shown, the tree with root label [0] is MHT with height of 1.
Constraints:
1 <= n <= 20,000edges.length == n - 10 <= a[i], b[i] < na[i] != b[i](a[i], b[i]) are distinct.Before attempting this problem, you should be comfortable with:
The most straightforward approach is to try each node as a potential root and measure the resulting tree height. The height of a tree rooted at any node is the maximum distance to any other node, which we can find using dfs.
By computing the height for every possible root, we can identify which nodes produce the minimum height. While simple to understand, this repeats a lot of work since we recompute distances from scratch for each candidate root.
dfs function that computes the height of the tree when rooted at a given node.0 to n-1, run dfs to get its tree height.class Solution:
def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
adj = [[] for _ in range(n)]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
def dfs(node, parent):
hgt = 0
for nei in adj[node]:
if nei == parent:
continue
hgt = max(hgt, 1 + dfs(nei, node))
return hgt
minHgt = n
res = []
for i in range(n):
curHgt = dfs(i, -1)
if curHgt == minHgt:
res.append(i)
elif curHgt < minHgt:
res = [i]
minHgt = curHgt
return resWhere is the number of vertices and is the number of edges.
Rather than recomputing everything for each root, we can reuse information. The tree height from any node depends on the longest path in two directions: down into its subtree and up through its parent to the rest of the tree.
We run two dfs passes. The first computes the two longest downward paths for each node (we need two in case the longest path goes through the child we came from). The second pass propagates information from parent to children, combining the parent's best path with sibling subtree heights.
dp array storing the top two heights for each node.dfs (post-order): for each node, compute the two longest paths into its children's subtrees.dfs (pre-order): propagate the "upward" height from parent to children, updating each node's best heights to include paths through the parent.dp[i][0]) represents its tree height as root.class Solution:
def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
adj = [[] for _ in range(n)]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
dp = [[0] * 2 for _ in range(n)] # top two heights for each node
def dfs(node, parent):
for nei in adj[node]:
if nei == parent:
continue
dfs(nei, node)
curHgt = 1 + dp[nei][0]
if curHgt > dp[node][0]:
dp[node][1] = dp[node][0]
dp[node][0] = curHgt
elif curHgt > dp[node][1]:
dp[node][1] = curHgt
def dfs1(node, parent, topHgt):
if topHgt > dp[node][0]:
dp[node][1] = dp[node][0]
dp[node][0] = topHgt
elif topHgt > dp[node][1]:
dp[node][1] = topHgt
for nei in adj[node]:
if nei == parent:
continue
toChild = 1 + (dp[node][1] if dp[node][0] == 1 + dp[nei][0] else dp[node][0])
dfs1(nei, node, toChild)
dfs(0, -1)
dfs1(0, -1, 0)
minHgt, res = n, []
for i in range(n):
minHgt = min(minHgt, dp[i][0])
for i in range(n):
if minHgt == dp[i][0]:
res.append(i)
return resWhere is the number of vertices and is the number of edges.
The minimum height trees are rooted at the centroid(s) of the tree. These centroids lie at the middle of the longest path (diameter) in the tree. If the diameter has odd length, there are two centroids; if even, there's exactly one.
We find the diameter using two bfs/dfs passes: first find the farthest node from any starting point, then find the farthest node from that. The path between these two endpoints is the diameter, and its middle node(s) are the answer.
dfs from node 0 to find the farthest node (call it node_a).dfs from node_a to find the farthest node (node_b) and the diameter length.node_a to node_b, collecting all nodes along the way.class Solution:
def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
if n == 1:
return [0]
adj = [[] for _ in range(n)]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
def dfs(node, parent):
farthest_node = node
max_distance = 0
for nei in adj[node]:
if nei != parent:
nei_node, nei_distance = dfs(nei, node)
if nei_distance + 1 > max_distance:
max_distance = nei_distance + 1
farthest_node = nei_node
return farthest_node, max_distance
node_a, _ = dfs(0, -1)
node_b, diameter = dfs(node_a, -1)
centroids = []
def find_centroids(node, parent):
if node == node_b:
centroids.append(node)
return True
for nei in adj[node]:
if nei != parent:
if find_centroids(nei, node):
centroids.append(node)
return True
return False
find_centroids(node_a, -1)
L = len(centroids)
if diameter % 2 == 0:
return [centroids[L // 2]]
else:
return [centroids[L // 2 - 1], centroids[L // 2]]Where is the number of vertices and is the number of edges.
Imagine peeling the tree like an onion, removing leaves layer by layer. The nodes that remain at the very end (when only 1 or 2 nodes are left) must be the centroids, since they're the innermost points of the tree.
Each round of removal brings us one step closer to the center. Since a tree can have at most 2 centroids (on a diameter of odd length), we stop when 2 or fewer nodes remain.
1), excluding the special case where n = 1.2 nodes remain:class Solution:
def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
if n == 1:
return [0]
adj = defaultdict(list)
for n1, n2 in edges:
adj[n1].append(n2)
adj[n2].append(n1)
edge_cnt = {}
leaves = deque()
for src, neighbors in adj.items():
edge_cnt[src] = len(neighbors)
if len(neighbors) == 1:
leaves.append(src)
while leaves:
if n <= 2:
return list(leaves)
for _ in range(len(leaves)):
node = leaves.popleft()
n -= 1
for nei in adj[node]:
edge_cnt[nei] -= 1
if edge_cnt[nei] == 1:
leaves.append(nei)Where is the number of vertices and is the number of edges.
When n = 1, there are no edges and the only node 0 is trivially the root of a minimum height tree. Many solutions fail to handle this case, leading to empty results or index errors when trying to process an empty edge list.
In tree traversal, using a full visited set is unnecessary overhead. Since trees have no cycles, simply passing the parent node to avoid revisiting the previous node is sufficient and more efficient. Using a visited set can also cause issues if not cleared properly between multiple DFS calls.
When peeling leaves layer by layer, a common mistake is modifying the degree array while iterating over the current batch of leaves. This can cause nodes to be processed prematurely or skipped entirely. Always process all current-level leaves before updating degrees for the next level.