LeetCode/NeetCode

[최소신장트리 MST: Prim, Kruskal] 1584. Min Cost to Connect All Points

hyunkookim 2025. 4. 11. 07:43

1584. Min Cost to Connect All Points

 

1) 프림 알고리즘으로..

import heapq

class Solution:
    def minCostConnectPoints(self, points: List[List[int]]) -> int:
        # points: [xi, yi] 형태의 2차원 좌표들이 주어짐
        # 모든 점들을 연결하는 최소 비용을 구해야 하며,
        # 비용은 맨해튼 거리 |xi - xj| + |yi - yj|

        N = len(points)  # 총 점의 개수

        adj = {}  # 인접 리스트 생성: 각 노드마다 [거리, 이웃노드] 목록 저장
        for x in range(N):
            adj[x] = []

        # 모든 노드쌍에 대해 맨해튼 거리 계산하여 인접 리스트에 저장
        for i in range(N):
            xi, yi = points[i]
            for j in range(N):
                if i == j:
                    continue  # 자기 자신은 제외
                xj, yj = points[j]
                dist = abs(xi - xj) + abs(yi - yj)  # 맨해튼 거리
                adj[i].append([dist, j])  # 양방향 그래프 저장
                adj[j].append([dist, i])

        visit = set()      # MST에 포함된 노드들
        mst = []           # MST를 구성하는 간선들 저장 (문제에선 안 써도 됨)
        minCost = 0        # 최종 최소 비용
        minHeap = []       # 우선순위 큐: [거리, 출발노드, 도착노드]

        # 0번 노드에서 시작
        visit.add(0)
        for dist, neighbor in adj[0]:
            heapq.heappush(minHeap, [dist, 0, neighbor])  # 0번 노드와 연결된 간선들을 힙에 추가

        # Prim 알고리즘: MST의 노드 수가 N이 될 때까지 반복
        while len(visit) < N and minHeap:
            dist, snode, dnode = heapq.heappop(minHeap)  # 가장 짧은 간선을 선택

            if dnode in visit:
                continue  # 이미 방문한 노드라면 스킵

            visit.add(dnode)          # 새 노드 MST에 포함
            minCost += dist           # 비용 누적
            mst.append([snode, dnode])  # 간선 기록 (필수는 아님)

            # 새로 방문한 노드의 인접 노드들을 힙에 추가
            for cost, nei in adj[dnode]:
                if nei not in visit:
                    heapq.heappush(minHeap, [cost, dnode, nei])

        return minCost  # 모든 노드를 연결하는 최소 비용 반환

 

2) 크루스칼 알고리즘으로.. : 유니온 파인드(공통 조상) 활용

class UnionFind:
    def __init__(self, n):
        self.parent = [i for i in range(n)]
        self.size = [1] * n

    def find(self, x):
        # find root of x
        # x 가 조상 즉 root 이면 x == self.parent[x] 만족
        if x != self.parent[x]:
            self.parent[x] = self.find(self.parent[x])  # 경로 압축 (Path Compression)
        # 조상 반환
        return self.parent[x]

    def union(self, x, y):
        # 두개 노드 합침.
        # 여기서 합쳐지면 true 반환
        # 이미 합쳐져있어서, 여기서 합치지않으면 false 반환
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x != root_y:
            # 두개 조상이 다르면, 합치고, return True
            # size 가 크면, 더 조상.. 그쪽으로 합침
            if self.size[root_x] < self.size[root_y]:
                self.parent[root_x] = root_y
                self.size[root_y] += self.size[root_x]
            else:
                self.parent[root_y] = root_x
                self.size[root_x] += self.size[root_y]
            return True  # 성공적으로 합쳤음
        return False  # 이미 같은 집합이었음 → 합치지 않음

class Solution:
    def minCostConnectPoints(self, points: List[List[int]]) -> int:
        # 1. 간선들을 최소 힙(minHeap)에 넣기 (weight 기준 정렬)
        minHeap = []
        N = len(points)
        for n1 in range(N):
            x1, y1 = points[n1]
            for n2 in range(N):
                if n1 == n2:
                    continue  # 자기 자신하고는 연결하지 않음
                x2, y2 = points[n2]
                dist = abs(x1 - x2) + abs(y1 - y2)  # 맨해튼 거리 계산
                heapq.heappush(minHeap, [dist, n1, n2])  # 간선을 힙에 넣음
                heapq.heappush(minHeap, [dist, n2, n1])  # 양방향이지만 중복됨

        # 👉 여기서 중복 간선 (n1, n2)와 (n2, n1)을 둘 다 넣었기 때문에,
        #     정렬은 정확하더라도 실제로 같은 간선이 두 번 고려됩니다.
        #     Kruskal은 어차피 union에서 중복 체크를 하므로 작동은 하긴 해요.
        #     다만 일반적으로는 (i < j) 조건으로 한 번만 넣는 게 더 효율적입니다.

        # 2. 유니온 파인드 초기화
        minCost = 0
        unionFind = UnionFind(N)
        components = N  # 현재 연결되지 않은 컴포넌트 개수 (초기에는 노드 각각이 따로)

        while components > 1 and minHeap:
            dist, n1, n2 = heapq.heappop(minHeap)

            if unionFind.union(n1, n2):  # true 라 함은, 연결안되어 있어서, 이번에 연결됐다는 의미
                # 연결됐으므로,
                minCost += dist  # 비용 추가
                components -= 1  # 연결된 컴포넌트 수 하나 줄어듦

        # 모든 노드가 하나의 컴포넌트로 연결되었으면 MST 완성
        return minCost if components == 1 else -1

 

권장 코드: components 변수를 UnionFind 클래스 내로..

import heapq

class UnionFind:
    def __init__(self, n):
        self.parent = [i for i in range(n)]  # 각 노드는 자기 자신이 부모
        self.size = [1] * n                  # 각 집합의 크기
        
        """
        # 현재 연결된 컴포넌트 수
        """
        self.components = n                   

    def find(self, x):
        # find root of x
        if x != self.parent[x]:
            self.parent[x] = self.find(self.parent[x])  # 경로 압축
        return self.parent[x]

    def union(self, x, y):
        # 두 노드가 같은 집합이 아니면 합침, true 반환
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x != root_y:
            # size가 큰 쪽으로 합치기 (union by size)
            if self.size[root_x] < self.size[root_y]:
                self.parent[root_x] = root_y
                self.size[root_y] += self.size[root_x]
            else:
                self.parent[root_y] = root_x
                self.size[root_x] += self.size[root_y]

			"""
            # 연결된 컴포넌트 수 하나 줄임
            """
            self.components -= 1  
            return True
        return False  # 이미 같은 집합이면 연결하지 않음

class Solution:
    def minCostConnectPoints(self, points: List[List[int]]) -> int:
        N = len(points)
        minHeap = []

        # 모든 쌍의 간선들 계산 후 힙에 넣음
        for n1 in range(N):
            x1, y1 = points[n1]
            for n2 in range(n1 + 1, N):  # 중복 제거: n1 < n2 만 고려
                x2, y2 = points[n2]
                dist = abs(x1 - x2) + abs(y1 - y2)
                heapq.heappush(minHeap, [dist, n1, n2])

        # 유니온 파인드 초기화
        uf = UnionFind(N)
        minCost = 0

        # 크루스칼 알고리즘
        while uf.components > 1 and minHeap:
            dist, n1, n2 = heapq.heappop(minHeap)

            if uf.union(n1, n2):  # 연결 성공하면 비용 추가
                minCost += dist

        # 모든 노드가 연결되어 있으면 MST 완성
        return minCost if uf.components == 1 else -1

 

https://youtu.be/f7JOBJIC-NA