LeetCode/NeetCode

[Trees: Segment Tree] 307. Range Sum Query - Mutable

hyunkookim 2025. 4. 8. 04:20

307. Range Sum Query - Mutable

 

이 문제는 세그먼트 트리로 아주 잘 풀리는 전형적인 문제예요.

구간 합을 빠르게 구하면서 값도 실시간으로 바꾸는 기능이 필요하기 때문이에요.

 

아래는 NumArray 클래스를 세그먼트 트리를 이용해 구현한 코드예요.

 

 

class SegmentTree:
    def __init__(self, total, L, R):
        # 현재 노드가 담당하는 구간 [L, R]의 합을 저장
        self.sum = total
        # 왼쪽 자식 노드
        self.left = None
        # 오른쪽 자식 노드
        self.right = None
        # 구간의 왼쪽 경계 (시작 인덱스)
        self.L = L
        # 구간의 오른쪽 경계 (끝 인덱스)
        self.R = R

    @staticmethod
    def build(nums, L, R):
        # 리프 노드인 경우 (하나의 값만 가지는 구간)
        if L == R:
            return SegmentTree(nums[L], L, R)
        
        # 중간 지점을 기준으로 왼쪽, 오른쪽 서브트리로 분할
        M = (L + R) // 2
        root = SegmentTree(0, L, R)  # 임시 sum=0으로 노드 생성
        # 왼쪽 자식 트리 재귀 생성
        root.left = SegmentTree.build(nums, L, M)
        # 오른쪽 자식 트리 재귀 생성
        root.right = SegmentTree.build(nums, M + 1, R)
        # 현재 노드의 합은 자식 노드들의 합
        root.sum = root.left.sum + root.right.sum
        return root

    def update(self, index, val):
        # 리프 노드 도달 시 값 갱신
        if self.L == self.R:
            self.sum = val
            return

        # 중간 지점 계산
        M = (self.L + self.R) // 2
        # 인덱스가 오른쪽에 있다면 오른쪽 자식으로 내려감
        if index > M:
            self.right.update(index, val)
        # 왼쪽에 있다면 왼쪽 자식으로 내려감
        else:
            self.left.update(index, val)
        # 갱신된 자식의 합을 기반으로 현재 노드의 합 재계산
        self.sum = self.left.sum + self.right.sum

    def rangeQuery(self, L, R):
        # 현재 노드의 구간이 쿼리와 정확히 일치할 경우 바로 반환
        if L == self.L and R == self.R:
            return self.sum

        # 중간 인덱스 계산 (노드 기준)
        M = (self.L + self.R) // 2

        # 쿼리 구간이 오른쪽 자식에만 걸쳐있는 경우
        if L > M:
            return self.right.rangeQuery(L, R)
        # 쿼리 구간이 왼쪽 자식에만 걸쳐있는 경우
        elif R <= M:
            return self.left.rangeQuery(L, R)
        # 쿼리 구간이 양쪽 자식에 걸쳐있는 경우, 나눠서 각각 호출 후 합침
        else:
            return self.left.rangeQuery(L, M) + self.right.rangeQuery(M + 1, R)


class NumArray:

    def __init__(self, nums: List[int]):
        # nums가 비어있을 수 있으므로 예외 처리하면 더 안전함 (여기선 생략)
        # 세그먼트 트리를 nums 전체 범위에 대해 생성
        self.root = SegmentTree.build(nums, 0, len(nums) - 1)

    def update(self, index: int, val: int) -> None:
        # 인덱스 위치 값을 val로 갱신
        self.root.update(index, val)

    def sumRange(self, left: int, right: int) -> int:
        # 구간 [left, right]의 합을 계산해서 반환
        return self.root.rangeQuery(left, right)