1489. Find Critical and Pseudo Critical Edges in Minimum Spanning Tree - Explanation

Problem Link

Description

You are given a weighted undirected connected graph with n vertices numbered from 0 to n - 1, and an array edges where edges[i] = [a[i], b[i], weight[i]] represents a bidirectional and weighted edge between nodes a[i] and b[i]. A minimum spanning tree (MST) is a subset of the graph's edges that connects all vertices without cycles and with the minimum possible total edge weight.

Find all the critical and pseudo-critical edges in the given graph's minimum spanning tree (MST). An MST edge whose deletion from the graph would cause the MST weight to increase is called a critical edge. On the other hand, a pseudo-critical edge is that which can appear in some MSTs but not all.

Note that you can return the indices of the edges in any order.

Example 1:

Input: n = 4, edges = [[0,3,2],[0,2,5],[1,2,4]]

Output: [[0,2,1],[]]

Example 2:

Input: n = 5, edges = [[0,3,2],[0,4,2],[1,3,2],[3,4,2],[2,3,1],[1,2,3],[0,1,1]]

Output: [[4,6],[0,1,2,3]]

Constraints:

  • 2 <= n <= 100
  • 1 <= edges.length <= min(200, n * (n - 1) / 2)
  • edges[i].length == 3
  • 0 <= a[i] < b[i] < n
  • 1 <= weight[i] <= 1000
  • All pairs (a[i], b[i]) are distinct.

Company Tags

Please upgrade to NeetCode Pro to view company tags.



1. Kruskal's Algorithm - I

class UnionFind:
    def __init__(self, n):
        self.par = [i for i in range(n)]
        self.rank = [1] * n

    def find(self, v1):
        while v1 != self.par[v1]:
            self.par[v1] = self.par[self.par[v1]]
            v1 = self.par[v1]
        return v1

    def union(self, v1, v2):
        p1, p2 = self.find(v1), self.find(v2)
        if p1 == p2:
            return False
        if self.rank[p1] > self.rank[p2]:
            self.par[p2] = p1
            self.rank[p1] += self.rank[p2]
        else:
            self.par[p1] = p2
            self.rank[p2] += self.rank[p1]
        return True


class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        for i, e in enumerate(edges):
            e.append(i)  # [v1, v2, weight, original_index]

        edges.sort(key=lambda e: e[2])

        mst_weight = 0
        uf = UnionFind(n)
        for v1, v2, w, i in edges:
            if uf.union(v1, v2):
                mst_weight += w

        critical, pseudo = [], []
        for n1, n2, e_weight, i in edges:
            # Try without curr edge
            weight = 0
            uf = UnionFind(n)
            for v1, v2, w, j in edges:
                if i != j and uf.union(v1, v2):
                    weight += w
            if max(uf.rank) != n or weight > mst_weight:
                critical.append(i)
                continue

            # Try with curr edge
            uf = UnionFind(n)
            uf.union(n1, n2)
            weight = e_weight
            for v1, v2, w, j in edges:
                if uf.union(v1, v2):
                    weight += w
            if weight == mst_weight:
                pseudo.append(i)
        return [critical, pseudo]

Time & Space Complexity

  • Time complexity: O(E2)O(E ^ 2)
  • Space complexity: O(V+E)O(V + E)

Where VV is the number of vertices and EE is the number of edges.


2. Kruskal's Algorithm - II

class UnionFind:
    def __init__(self, n):
        self.n = n
        self.Parent = list(range(n + 1))
        self.Size = [1] * (n + 1)

    def find(self, node):
        if self.Parent[node] != node:
            self.Parent[node] = self.find(self.Parent[node])
        return self.Parent[node]

    def union(self, u, v):
        pu = self.find(u)
        pv = self.find(v)
        if pu == pv:
            return False
        self.n -= 1
        if self.Size[pu] < self.Size[pv]:
            pu, pv = pv, pu
        self.Size[pu] += self.Size[pv]
        self.Parent[pv] = pu
        return True

    def isConnected(self):
        return self.n == 1

class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        for i, e in enumerate(edges):
            e.append(i)
        edges.sort(key = lambda e : e[2])

        def findMST(index, include):
            uf = UnionFind(n)
            wgt = 0
            if include:
                wgt += edges[index][2]
                uf.union(edges[index][0], edges[index][1])

            for i, e in enumerate(edges):
                if i == index:
                    continue
                if uf.union(e[0], e[1]):
                    wgt += e[2]
            return wgt if uf.isConnected() else float("inf")

        mst_wgt = findMST(-1, False)
        critical, pseudo = [], []
        for i, e in enumerate(edges):
            if mst_wgt < findMST(i, False):
                critical.append(e[3])
            elif mst_wgt == findMST(i, True):
                pseudo.append(e[3])

        return [critical, pseudo]

Time & Space Complexity

  • Time complexity: O(E2)O(E ^ 2)
  • Space complexity: O(V+E)O(V + E)

Where VV is the number of vertices and EE is the number of edges.


3. Dijkstra's Algorithm

class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        for i, edge in enumerate(edges):
            edge.append(i)

        adj = defaultdict(list)
        for u, v, w, idx in edges:
            adj[u].append((v, w, idx))
            adj[v].append((u, w, idx))

        def minimax(src, dst, exclude_idx):
            dist = [float('inf')] * n
            dist[src] = 0
            pq = [(0, src)]

            while pq:
                max_w, u = heappop(pq)
                if u == dst:
                    return max_w

                for v, weight, edge_idx in adj[u]:
                    if edge_idx == exclude_idx:
                        continue
                    new_w = max(max_w, weight)
                    if new_w < dist[v]:
                        dist[v] = new_w
                        heappush(pq, (new_w, v))

            return float('inf')

        critical, pseudo = [], []
        for i, (u, v, w, idx) in enumerate(edges):
            if w < minimax(u, v, idx):
                critical.append(idx)
            elif w == minimax(u, v, -1):
                pseudo.append(idx)

        return [critical, pseudo]

Time & Space Complexity

  • Time complexity: O(E2logV)O(E ^ 2 \log V)
  • Space complexity: O(V+E)O(V + E)

Where VV is the number of vertices and EE is the number of edges.


4. Kruskal's Algorithm + DFS

class UnionFind:
    def __init__(self, n):
        self.Parent = list(range(n + 1))
        self.Size = [1] * (n + 1)

    def find(self, node):
        if self.Parent[node] != node:
            self.Parent[node] = self.find(self.Parent[node])
        return self.Parent[node]

    def union(self, u, v):
        pu = self.find(u)
        pv = self.find(v)
        if pu == pv:
            return False
        if self.Size[pu] < self.Size[pv]:
            pu, pv = pv, pu
        self.Size[pu] += self.Size[pv]
        self.Parent[pv] = pu
        return True

class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        mst = [[] for _ in range(n)]
        mstEdge = []

        edge_list = [(w, u, v, i) for i, (u, v, w) in enumerate(edges)]
        edge_list.sort()

        uf = UnionFind(n)
        for w, u, v, i in edge_list:
            if uf.union(u, v):
                mst[u].append((v, i))
                mst[v].append((u, i))
                mstEdge.append(i)

        def dfs(node):
            for next, ind in mst[node]:
                if path and ind == path[-1]:
                    continue
                path.append(ind)
                if next == dst or dfs(next):
                    return True
                path.pop()
            return False

        pseudo, mstEdge = set(), set(mstEdge)
        for ind in range(len(edges)):
            if ind in mstEdge:
                continue
            path, dst = [], edges[ind][1]
            dfs(edges[ind][0])
            for i in path:
                if edges[i][2] == edges[ind][2]:
                    pseudo.add(i)
                    pseudo.add(ind)

        return [list(mstEdge - pseudo), list(pseudo)]

Time & Space Complexity

  • Time complexity: O(E2)O(E ^ 2)
  • Space complexity: O(V+E)O(V + E)

Where VV is the number of vertices and EE is the number of edges.