LeetCode/NeetCode

[최소신장트리 MST: Prim, Kruskal] 1489. Find Critical and Pseudo-Critical Edges in Minimum Spanning Tree ★★★★★

hyunkookim 2025. 4. 11. 08:01

1489. Find Critical and Pseudo-Critical Edges in Minimum Spanning Tree

 

Given a weighted undirected connected graph with n vertices numbered from 0 to n - 1,

👉 가중치가 있는 무방향 연결 그래프가 주어지며, 이 그래프는 0부터 n - 1까지 번호가 매겨진 정점들을 포함합니다.

 

and an array edges where edges[i] = [ai, bi, weighti] represents a bidirectional and weighted edge between nodes ai and bi.

👉 edges[i] = [ai, bi, weighti] 형식의 배열이 주어지고, 이는 정점 ai와 bi 사이의 양방향 가중치 간선을 의미합니다.

 

A minimum spanning tree (MST) is a subset of the graph's edges that connects all vertices without cycles and with the minimum possible total edge weight.

👉 최소 신장 트리(MST)는 사이클 없이 모든 정점을 연결하면서 총 간선 가중치의 합이 최소가 되는 간선들의 부분 집합입니다.

 

Find all the critical and pseudo-critical edges in the given graph's minimum spanning tree (MST).

👉 이 그래프의 MST에서 중요 간선(critical edge)준중요 간선(pseudo-critical edge) 을 모두 찾아야 합니다.

 

An MST edge whose deletion from the graph would cause the MST weight to increase is called a critical edge.

👉 MST에서 해당 간선을 제거하면 전체 MST의 비용이 증가하는 간선중요 간선(critical edge) 이라고 부릅니다.

 

On the other hand, a pseudo-critical edge is that which can appear in some MSTs but not all.

👉 반면, 어떤 MST에는 포함되지만, 다른 MST에는 포함되지 않을 수도 있는 간선준중요 간선(pseudo-critical edge) 라고 합니다.

 

Note that you can return the indices of the edges in any order.

👉 결과로 반환할 간선들의 인덱스는 어떤 순서든 상관 없습니다.

 

💡 다시 문제 정리

주어진 간선 중 어떤 간선이:

  1. 🔴 Critical Edge: 빼면 MST 비용이 증가함 (필수!)
  2. 🟡 Pseudo-Critical Edge: 넣어도 MST 비용이 동일하게 유지됨 (필수는 아니지만 가능)

'Prim' 알고리즘 기반 풀이 전략

우선 기본적으로 해야 할 것:

1️⃣ 기본 MST 비용 계산

  • 아무 간선도 강제로 포함하거나 제외하지 않고, Prim 알고리즘으로 MST를 만들고 그 최소 비용을 기록해 둡니다.
    • 예: original_cost = prim(points, edges)

2️⃣ 각 간선을 제거해보기 (Critical 여부 판단)

  • 특정 간선을 제외한 상태로 Prim을 돌림
    • 만약 이때 MST를 만들 수 없거나 비용이 증가하면 → Critical

3️⃣ 각 간선을 강제로 포함해보기 (Pseudo-Critical 여부 판단)

  • 특정 간선을 무조건 포함한 상태에서 Prim을 시작
    • 이 간선을 먼저 연결한 후, 나머지 MST를 구성
    • 이때 전체 비용이 original_cost와 같으면 → Pseudo-Critical
import heapq

class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        # 각 간선에 index 붙이기 (뒤에서 결과에 쓸 수 있도록)
        for i in range(len(edges)):
            edges[i].append(i)  # edges[i] = [u, v, weight, index]

        # 가중치 기준으로 정렬 (Prim에서도 최소 힙이긴 하지만 일관성 위해)
        edges.sort(key=lambda x: x[2])  # x[2] = weight

        # -------------------------------
        # MST를 구성하는 prim 알고리즘 정의
        # exclude_idx: 제외할 간선 인덱스
        # include_edge: 먼저 포함할 간선 (강제로 MST에 넣기)
        # -------------------------------
        def prim(n, edges, exclude_idx=None, include_edge=None):
            adj = {i: [] for i in range(n)}  # 인접 리스트 초기화

            # 인접 리스트 구성 (exclude_idx 제외)
            for u, v, w, idx in edges:
                if idx == exclude_idx:
                    continue  # 특정 간선은 제외
                adj[u].append((w, v, idx))
                adj[v].append((w, u, idx))

            total_cost = 0
            visited = set()
            min_heap = []

            # include_edge가 있으면 먼저 포함하고 시작
            if include_edge:
                u, v, w, idx = include_edge
                total_cost += w
                visited.add(u)
                visited.add(v)
                for next_w, next_v, next_idx in adj[u]:
                    if next_v not in visited:
                        heapq.heappush(min_heap, (next_w, v, next_v, next_idx))
                for next_w, next_u, next_idx in adj[v]:
                    if next_u not in visited:
                        heapq.heappush(min_heap, (next_w, u, next_u, next_idx))
            else:
                # 아무것도 포함하지 않았다면 0번 노드부터 시작
                visited.add(0)
                for w, v, idx in adj[0]:
                    heapq.heappush(min_heap, (w, 0, v, idx))

            # Prim 알고리즘 시작
            while len(visited) < n and min_heap:
                w, u, v, idx = heapq.heappop(min_heap)
                if v in visited:
                    continue
                visited.add(v)
                total_cost += w
                for next_w, next_v, next_idx in adj[v]:
                    if next_v not in visited:
                        heapq.heappush(min_heap, (next_w, v, next_v, next_idx))

            # 모든 노드를 방문하지 못했다면 유효하지 않은 MST (그래프가 끊어진 경우)
            if len(visited) < n:
                return float('inf')

            return total_cost

        # ---------------------------------------
        # 1. 기본 MST 비용 계산 (원본 전체 그래프 기준)
        # ---------------------------------------
        original_cost = prim(n, edges)

        critical = []
        pseudo_critical = []

        # ---------------------------------------
        # 2. 각 간선에 대해 critical/pseudo 판단
        # ---------------------------------------
        for edge in edges:
            idx = edge[3]

            # (1) 이 간선을 제외한 MST 구성 → 비용 증가 or 연결 불가 → critical
            cost_without = prim(n, edges, exclude_idx=idx)
            if cost_without > original_cost:
                critical.append(idx)
                continue  # 이미 critical이면 pseudo 검사할 필요 없음

            # (2) 이 간선을 강제로 포함한 MST 구성 → 비용 동일 → pseudo
            cost_with = prim(n, edges, include_edge=edge)
            if cost_with == original_cost:
                pseudo_critical.append(idx)

        return [critical, pseudo_critical]

 

 'Kruskal' 알고리즘 기반 풀이 전략

🔍 전체 흐름 요약

  1. 간선에 인덱스를 붙이고, 가중치 기준으로 정렬
  2. build_mst() 함수로:
    • 특정 간선을 제외하거나
    • 특정 간선을 강제로 포함한 후 MST 구성
  3. 기본 MST 비용과 비교하여:
    • 비용 증가 → Critical
    • 비용 동일 → Pseudo-Critical
class UnionFind:
    def __init__(self, n):
        # 각 노드는 처음에 자기 자신이 부모 (자기 자신이 루트)
        self.parent = [i for i in range(n)]
        self.size = [1] * n  # 각 집합의 크기
        self.components = n   # 현재 연결된 컴포넌트(집합)의 개수

    def find(self, x):
        # x의 루트를 찾아감 (경로 압축)
        if x != self.parent[x]:
            self.parent[x] = self.find(self.parent[x])  # 루트를 찾아가면서 바로 부모 갱신
        return self.parent[x]

    def union(self, x, y):
        # x와 y를 같은 집합으로 합침
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x != root_y:
            # size가 큰 쪽으로 작은 쪽을 붙임 (union by size)
            if self.size[root_x] < self.size[root_y]:
                self.parent[root_x] = root_y
                self.size[root_y] += self.size[root_x]
            else:
                self.parent[root_y] = root_x
                self.size[root_x] += self.size[root_y]
            self.components -= 1  # 컴포넌트 하나 줄어듦
            return True  # 성공적으로 합쳐짐
        return False  # 이미 같은 집합이었음 → 합치지 않음

class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        # 1. 간선에 인덱스를 붙임: edges[i] = [u, v, weight] → [weight, u, v, index]
        indexed_edges = []
        for i, (u, v, w) in enumerate(edges):
            indexed_edges.append([w, u, v, i])

        # 2. 가중치 기준으로 정렬 (Kruskal 알고리즘을 위한 준비)
        indexed_edges.sort()

        # 3. MST를 구성하고 비용을 반환하는 함수 정의
        def build_mst(n, edges, exclude_idx=None, include_edge=None):
            uf = UnionFind(n)
            total_cost = 0

            # 강제로 포함할 간선이 있다면 먼저 union 처리
            if include_edge:
                w, u, v, idx = include_edge
                if uf.union(u, v):
                    total_cost += w

            # 정렬된 간선들을 하나씩 확인
            for w, u, v, idx in edges:
                if idx == exclude_idx:
                    continue  # 제외 대상이면 넘어감
                if uf.union(u, v):
                    total_cost += w  # 연결되면 비용 추가

            # 모든 노드가 하나로 연결된 경우만 MST로 인정
            return total_cost if uf.components == 1 else float('inf')

        # 4. 전체 그래프의 MST 비용을 기준으로 먼저 구함
        base_cost = build_mst(n, indexed_edges)

        critical = []          # 반드시 MST에 포함되어야 하는 간선들
        pseudo_critical = []   # 어떤 MST에는 포함될 수 있는 간선들

        # 5. 각 간선을 하나씩 검사
        for edge in indexed_edges:
            w, u, v, idx = edge

            # (1) 현재 간선을 제외하고 MST를 구성해보기
            cost_without = build_mst(n, indexed_edges, exclude_idx=idx)

            # MST 비용이 증가하거나 연결이 불가능하면 → critical
            if cost_without > base_cost:
                critical.append(idx)
                continue  # critical이면 pseudo 검사는 생략

            # (2) 현재 간선을 강제로 포함하고 MST 구성
            cost_with = build_mst(n, indexed_edges, include_edge=edge)

            # 비용이 동일하면 → pseudo-critical
            if cost_with == base_cost:
                pseudo_critical.append(idx)

        # 6. 결과 반환 (인덱스 순서는 문제에서 상관없다고 명시됨)
        return [critical, pseudo_critical]

 

https://youtu.be/83JnUxrLKJU

 

class UnionFind:
    def __init__(self, n):
        self.n = n  # 현재 연결된 컴포넌트 수
        self.Parent = list(range(n + 1))  # 노드의 부모 초기화 (1-index 고려, n+1까지)
        self.Size = [1] * (n + 1)         # 각 집합의 크기

    def find(self, node):
        # 경로 압축: 조상을 찾아서 parent를 직접 연결
        if self.Parent[node] != node:
            self.Parent[node] = self.find(self.Parent[node])
        return self.Parent[node]

    def union(self, u, v):
        # u와 v의 조상 찾기
        pu = self.find(u)
        pv = self.find(v)

        if pu == pv:
            return False  # 이미 같은 집합 → 사이클이므로 union 안 함

        # 항상 큰 집합 쪽으로 병합 (Size 기준)
        if self.Size[pu] < self.Size[pv]:
            pu, pv = pv, pu
        self.Size[pu] += self.Size[pv]
        self.Parent[pv] = pu
        self.n -= 1  # 컴포넌트 수 하나 줄이기
        return True  # union 성공

    def isConnected(self):
        # 컴포넌트 수가 1이면 모든 노드가 연결됨 (즉, MST 완성)
        return self.n == 1

class Solution:
    def findCriticalAndPseudoCriticalEdges(self, n: int, edges: List[List[int]]) -> List[List[int]]:
        # 각 간선에 원래의 인덱스를 붙여줌: [u, v, weight, original_index]
        for i, e in enumerate(edges):
            e.append(i)

        # 가중치 기준 정렬 (Kruskal을 위한 준비)
        edges.sort(key=lambda e: e[2])

        # ✅ MST를 구성하는 함수
        # index: 제외하거나 포함할 간선의 인덱스
        # include: 포함(True) / 제외(False)
        def findMST(index, include):
            uf = UnionFind(n)
            wgt = 0  # MST 총 가중치

            # 특정 간선을 강제로 먼저 포함하는 경우
            if include:
                u, v, weight = edges[index][0], edges[index][1], edges[index][2]
                wgt += weight
                uf.union(u, v)

            # 모든 간선 탐색
            for i, e in enumerate(edges):
                if i == index:
                    continue  # 제외하려는 간선은 스킵
                if uf.union(e[0], e[1]):
                    wgt += e[2]

            return wgt if uf.isConnected() else float("inf")

        # 먼저 기본 MST의 최소 가중치를 구함 (어떤 간선도 제외/포함 없이)
        mst_wgt = findMST(-1, False)

        critical = []        # 반드시 MST에 포함되어야 하는 간선 리스트
        pseudo = []          # MST에 포함될 수도 있는 간선 리스트

        # 간선 하나씩 검사
        for i, e in enumerate(edges):
            # (1) 현재 간선을 "제외"하고 MST 구성 → 비용 증가 or MST 불가능하면 critical
            if findMST(i, False) > mst_wgt:
                critical.append(e[3])  # 원래 인덱스 추가
            # (2) 현재 간선을 "강제로 포함"하고 MST 구성 → 비용 동일하면 pseudo
            elif findMST(i, True) == mst_wgt:
                pseudo.append(e[3])    # 원래 인덱스 추가

        # [Critical 간선 리스트, Pseudo-Critical 간선 리스트] 형태로 반환
        return [critical, pseudo]

 

✅ 한눈에 흐름 요약

  1. MST 기준 비용 먼저 구함 (findMST(-1, False))
  2. 모든 간선에 대해:
    • 제거했을 때 비용이 증가 or 연결 실패 → Critical
    • 강제로 넣었을 때 비용 유지 → Pseudo-Critical
  3. 결과는 [critical, pseudo] 형태로 반환

🧠 참고 포인트

  • edges[i].append(i)로 인덱스를 같이 들고 다니는 방식 아주 좋음
  • findMST() 함수로 로직을 캡슐화해서 코드가 반복되지 않음 → 유지보수 & 가독성 👍
  • uf.isConnected() 로 MST 완성 여부를 정확히 확인함