1489. Find Critical and Pseudo Critical Edges in Minimum Spanning Tree - Explanation

Problem Link

Description

You are given a weighted undirected connected graph with n vertices numbered from 0 to n - 1, and an array edges where edges[i] = [a[i], b[i], weight[i]] represents a bidirectional and weighted edge between nodes a[i] and b[i]. A minimum spanning tree (MST) is a subset of the graph's edges that connects all vertices without cycles and with the minimum possible total edge weight.

Find all the critical and pseudo-critical edges in the given graph's minimum spanning tree (MST). An MST edge whose deletion from the graph would cause the MST weight to increase is called a critical edge. On the other hand, a pseudo-critical edge is that which can appear in some MSTs but not all.

Note that you can return the indices of the edges in any order.

Example 1:

Input: n = 4, edges = [[0,3,2],[0,2,5],[1,2,4]]

Output: [[0,2,1],[]]

Example 2:

Input: n = 5, edges = [[0,3,2],[0,4,2],[1,3,2],[3,4,2],[2,3,1],[1,2,3],[0,1,1]]

Output: [[4,6],[0,1,2,3]]

Constraints:

  • 2 <= n <= 100
  • 1 <= edges.length <= min(200, n * (n - 1) / 2)
  • edges[i].length == 3
  • 0 <= a[i] < b[i] < n
  • 1 <= weight[i] <= 1000
  • All pairs (a[i], b[i]) are distinct.


Company Tags

Please upgrade to NeetCode Pro to view company tags.



1. Kruskal's Algorithm - I

Intuition

An edge is critical if removing it increases the MST weight or disconnects the graph. An edge is pseudo-critical if it can appear in some MST but is not mandatory. We test each edge by building the MST without it (to check criticality) and by forcing it into the MST first (to check if it can be part of a valid MST without increasing weight).

Algorithm

  1. Append original indices to each edge, then sort edges by weight.
  2. Build the standard MST using Kruskal's algorithm to get mstWeight.
  3. For each edge:
    • Build an MST excluding this edge. If the graph becomes disconnected or the weight increases, the edge is critical.
    • Otherwise, build an MST that includes this edge first. If the total weight equals mstWeight, the edge is pseudo-critical.
  4. Return the lists of critical and pseudo-critical edge indices.
class UnionFind:
    def __init__(self, n):
        self.par = [i for i in range(n)]
        self.rank = [1] * n

    def find(self, v1):
        while v1 != self.par[v1]:
            self.par[v1] = self.par[self.par[v1]]
            v1 = self.par[v1]
        return v1

    def union(self, v1, v2):
        p1, p2 = self.find(v1), self.find(v2)
        if p1 == p2:
            return False
        if self.rank[p1] > self.rank[p2]:
            self.par[p2] = p1
            self.rank[p1] += self.rank[p2]
        else:
            self.par[p1] = p2
            self.rank[p2] += self.rank[p1]
        return True


class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        for i, e in enumerate(edges):
            e.append(i)  # [v1, v2, weight, original_index]

        edges.sort(key=lambda e: e[2])

        mst_weight = 0
        uf = UnionFind(n)
        for v1, v2, w, i in edges:
            if uf.union(v1, v2):
                mst_weight += w

        critical, pseudo = [], []
        for n1, n2, e_weight, i in edges:
            # Try without curr edge
            weight = 0
            uf = UnionFind(n)
            for v1, v2, w, j in edges:
                if i != j and uf.union(v1, v2):
                    weight += w
            if max(uf.rank) != n or weight > mst_weight:
                critical.append(i)
                continue

            # Try with curr edge
            uf = UnionFind(n)
            uf.union(n1, n2)
            weight = e_weight
            for v1, v2, w, j in edges:
                if uf.union(v1, v2):
                    weight += w
            if weight == mst_weight:
                pseudo.append(i)
        return [critical, pseudo]

Time & Space Complexity

  • Time complexity: O(E2)O(E ^ 2)
  • Space complexity: O(V+E)O(V + E)

Where VV is the number of vertices and EE is the number of edges.


2. Kruskal's Algorithm - II

Intuition

This is a cleaner implementation of the same approach. We use a helper function findMST that optionally skips one edge or forces one edge to be included first. By comparing results against the baseline MST weight, we classify each edge as critical or pseudo-critical.

Algorithm

  1. Append indices to edges and sort by weight.
  2. Compute mstWeight by calling findMST(-1, false) (no exclusions or inclusions).
  3. For each edge at index i:
    • If findMST(i, false) returns a weight greater than mstWeight, the edge is critical.
    • Else if findMST(i, true) equals mstWeight, the edge is pseudo-critical.
  4. Return the two lists.
class UnionFind:
    def __init__(self, n):
        self.n = n
        self.Parent = list(range(n + 1))
        self.Size = [1] * (n + 1)

    def find(self, node):
        if self.Parent[node] != node:
            self.Parent[node] = self.find(self.Parent[node])
        return self.Parent[node]

    def union(self, u, v):
        pu = self.find(u)
        pv = self.find(v)
        if pu == pv:
            return False
        self.n -= 1
        if self.Size[pu] < self.Size[pv]:
            pu, pv = pv, pu
        self.Size[pu] += self.Size[pv]
        self.Parent[pv] = pu
        return True

    def isConnected(self):
        return self.n == 1

class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        for i, e in enumerate(edges):
            e.append(i)
        edges.sort(key = lambda e : e[2])

        def findMST(index, include):
            uf = UnionFind(n)
            wgt = 0
            if include:
                wgt += edges[index][2]
                uf.union(edges[index][0], edges[index][1])

            for i, e in enumerate(edges):
                if i == index:
                    continue
                if uf.union(e[0], e[1]):
                    wgt += e[2]
            return wgt if uf.isConnected() else float("inf")

        mst_wgt = findMST(-1, False)
        critical, pseudo = [], []
        for i, e in enumerate(edges):
            if mst_wgt < findMST(i, False):
                critical.append(e[3])
            elif mst_wgt == findMST(i, True):
                pseudo.append(e[3])

        return [critical, pseudo]

Time & Space Complexity

  • Time complexity: O(E2)O(E ^ 2)
  • Space complexity: O(V+E)O(V + E)

Where VV is the number of vertices and EE is the number of edges.


3. Dijkstra's Algorithm

Intuition

For each edge connecting nodes u and v with weight w, we ask: is there an alternate path from u to v using only edges with weight at most w? We use a modified Dijkstra that finds the minimax path (minimizing the maximum edge weight along the path). If the edge's weight is strictly less than the minimax without it, the edge is critical. If equal to the minimax, it is pseudo-critical.

Algorithm

  1. Build an adjacency list with edge indices.
  2. For each edge (u, v, w, idx):
    • Compute minimax(u, v, idx): the minimum possible maximum edge weight on any path from u to v, excluding edge idx.
    • If w < minimax(u, v, idx), the edge is critical.
    • Else compute minimax(u, v, -1) (no exclusion). If w == minimax(u, v, -1), the edge is pseudo-critical.
  3. Return both lists.
class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        for i, edge in enumerate(edges):
            edge.append(i)

        adj = defaultdict(list)
        for u, v, w, idx in edges:
            adj[u].append((v, w, idx))
            adj[v].append((u, w, idx))

        def minimax(src, dst, exclude_idx):
            dist = [float('inf')] * n
            dist[src] = 0
            pq = [(0, src)]

            while pq:
                max_w, u = heappop(pq)
                if u == dst:
                    return max_w

                for v, weight, edge_idx in adj[u]:
                    if edge_idx == exclude_idx:
                        continue
                    new_w = max(max_w, weight)
                    if new_w < dist[v]:
                        dist[v] = new_w
                        heappush(pq, (new_w, v))

            return float('inf')

        critical, pseudo = [], []
        for i, (u, v, w, idx) in enumerate(edges):
            if w < minimax(u, v, idx):
                critical.append(idx)
            elif w == minimax(u, v, -1):
                pseudo.append(idx)

        return [critical, pseudo]

Time & Space Complexity

  • Time complexity: O(E2logV)O(E ^ 2 \log V)
  • Space complexity: O(V+E)O(V + E)

Where VV is the number of vertices and EE is the number of edges.


4. Kruskal's Algorithm + DFS

Intuition

After building one MST, edges not in the MST create cycles when added. We use DFS to find the path in the MST between the endpoints of each non-MST edge. If any edge on this path has the same weight as the non-MST edge, both edges are pseudo-critical (they can be swapped). MST edges not identified as pseudo-critical are critical.

Algorithm

  1. Build an MST using Kruskal's algorithm, recording which edges are included and building an adjacency list for the MST.
  2. For each edge not in the MST:
    • Use DFS to find the unique path in the MST between its endpoints.
    • For each edge on this path with equal weight, mark both edges as pseudo-critical.
  3. Critical edges are MST edges that are not pseudo-critical.
  4. Return both lists.
class UnionFind:
    def __init__(self, n):
        self.Parent = list(range(n + 1))
        self.Size = [1] * (n + 1)

    def find(self, node):
        if self.Parent[node] != node:
            self.Parent[node] = self.find(self.Parent[node])
        return self.Parent[node]

    def union(self, u, v):
        pu = self.find(u)
        pv = self.find(v)
        if pu == pv:
            return False
        if self.Size[pu] < self.Size[pv]:
            pu, pv = pv, pu
        self.Size[pu] += self.Size[pv]
        self.Parent[pv] = pu
        return True

class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        mst = [[] for _ in range(n)]
        mstEdge = []

        edge_list = [(w, u, v, i) for i, (u, v, w) in enumerate(edges)]
        edge_list.sort()

        uf = UnionFind(n)
        for w, u, v, i in edge_list:
            if uf.union(u, v):
                mst[u].append((v, i))
                mst[v].append((u, i))
                mstEdge.append(i)

        def dfs(node):
            for next, ind in mst[node]:
                if path and ind == path[-1]:
                    continue
                path.append(ind)
                if next == dst or dfs(next):
                    return True
                path.pop()
            return False

        pseudo, mstEdge = set(), set(mstEdge)
        for ind in range(len(edges)):
            if ind in mstEdge:
                continue
            path, dst = [], edges[ind][1]
            dfs(edges[ind][0])
            for i in path:
                if edges[i][2] == edges[ind][2]:
                    pseudo.add(i)
                    pseudo.add(ind)

        return [list(mstEdge - pseudo), list(pseudo)]

Time & Space Complexity

  • Time complexity: O(E2)O(E ^ 2)
  • Space complexity: O(V+E)O(V + E)

Where VV is the number of vertices and EE is the number of edges.