【AtCoder】Pythonで使いこなすUnion-Find木

はじめに

Union-Find木について、どんなものなのか、AtCoderなどではどんな問題で使うのか、Pythonで実装したサンプルコードについて紹介していきます。

実際のAtCoderで利用したプログラムも掲載しています。

Union-Find木とは

Union-Find木とは、グループが互いに重ならない(素集合)ようにグループ分けを管理するデータ構造です。

以下のことが効率的にできます。

  • 併合:グループ同士をまとめる
  • 判定:要素同士が同じグループに属しているか

Union-Find木は、名前の通り木を使って実現します。
各要素がノード、グループが木となります。

初期化

まずはじめは、n個の要素を用意します。

このときはそれぞれのノードは木になっていません。

併合

2つのグループをまとめて1つにします。

片方の根にもう片方の木の根に辺を張り、1つの木にします。


判定

同じグループかの判定は木をたどり、同じ根になるかどうかで判定します。

使いどころ

AtCoderなどでは、グループ分けをするような問題で使えることが多いですが、一度覚えるとなんでもできるのではないかと思ってしまうので、注意が必要です。

逆にUninon-Find木で考えると簡単になるのに、思いつかなかったという問題もあったので、問題の定式化が重要になるかと思います。

Union-Find木のサンプルコード

Union-Find木のサンプルコードになります。

クラスとして実装しているので、利用する場合はそのままコピペで使えます。

from collections import defaultdict

class UnionFind():
    """
    Union Find木クラス

    Attributes
    --------------------
    n : int
        要素数
    root : list
        木の要素数
        0未満であればそのノードが根であり、添字の値が要素数
    rank : list
        木の深さ
    """

    def __init__(self, n):
        """
        Parameters
        ---------------------
        n : int
            要素数
        """
        self.n = n
        self.root = [-1]*(n+1)
        self.rank = [0]*(n+1)

    def find(self, x):
        """
        ノードxの根を見つける

        Parameters
        ---------------------
        x : int
            見つけるノード

        Returns
        ---------------------
        root : int
            根のノード
        """
        if(self.root[x] < 0):
            return x
        else:
            self.root[x] = self.find(self.root[x])
            return self.root[x]

    def unite(self, x, y):
        """
        木の併合

        Parameters
        ---------------------
        x : int
            併合したノード
        y : int
            併合したノード
        """
        x = self.find(x)
        y = self.find(y)

        if(x == y):
            return
        elif(self.rank[x] > self.rank[y]):
            self.root[x] += self.root[y]
            self.root[y] = x
        else:
            self.root[y] += self.root[x]
            self.root[x] = y
            if(self.rank[x] == self.rank[y]):
                self.rank[y] += 1

    def same(self, x, y):
        """
        同じグループに属するか判定

        Parameters
        ---------------------
        x : int
            判定したノード
        y : int
            判定したノード

        Returns
        ---------------------
        ans : bool
            同じグループに属しているか
        """
        return self.find(x) == self.find(y)

    def size(self, x):
        """
        木のサイズを計算

        Parameters
        ---------------------
        x : int
            計算したい木のノード

        Returns
        ---------------------
        size : int
            木のサイズ
        """
        return -self.root[self.find(x)]

    def roots(self):
        """
        根のノードを取得

        Returns
        ---------------------
        roots : list
            根のノード
        """
        return [i for i, x in enumerate(self.root) if x < 0]

    def group_size(self):
        """
        グループ数を取得

        Returns
        ---------------------
        size : int
            グループ数
        """
        return len(self.roots())

    def group_members(self):
        """
        全てのグループごとのノードを取得

        Returns
        ---------------------
        group_members : defaultdict
            根をキーとしたノードのリスト
        """
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

使い方

以下のようにUnion-Find木で併合や判定が可能です。

n = 5
uf = UnionFind(n)
uf.unite(1, 2)
uf.unite(4, 3)
uf.unite(4, 5)

uf.find(1)
uf.find(4)

uf.same(1, 2)
uf.same(1, 3)

AtCoderのサンプル問題

ARC106 B問題

B - Values
AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.

N頂点M辺の無向グラフが与えられ、それぞれの頂点を$a{i}$から$b{i}$に変更可能か判定する問題です。

各連結成分ごとの総和で判定するため、Union-Find木でグループ分けし、グループごとの総和を求めています。

import numpy as np

n, m = map(int, input().split())
a = np.array(list(map(int, input().split())))
b = np.array(list(map(int, input().split())))
c = []
d = []
for _ in range(m):
    tmp_c, tmp_d = map(int, input().split())
    c.append(tmp_c-1)
    d.append(tmp_d-1)

uf = UnionFind(n)
for i in range(m):
    uf.unite(c[i], d[i])

for g in uf.group_members().values():
    if sum(a[g]) != sum(b[g]):
        print('No')
        exit()

print('Yes')

ABC177 D問題

D - Friends
AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.

N人のM個の関係性(友達かどうか)が与えられ、友達の友達は友達のとき、N人を同じグループに友達がいない最小のグループ数を求めます。

友達同士のグループを作り、友達同士が最大のグループを求める必要があるのでUnion-Find木でグループ分けし、それぞれのグループの人数で判定しています。

n, m = map(int, input().split())
A = []
B = []
for _ in range(m):
    a, b = map(int, input().split())
    A.append(a-1)
    B.append(b-1)

uf = UnionFind(n)

for i in range(m):
    uf.unite(A[i], B[i])

ans = 0
for g in uf.group_members().values():
    g_size = len(g)
    if g_size > ans:
        ans = g_size

print(ans)

ABC206 D問題

D - KAIBUNsyo
AtCoder is a programming contest site for anyone from beginners to experts. We hold weekly programming contests online.

N項からなる数列を回文に変える問題です。

回文になる組み合わせ(同じになるべき数字同士)のグループに分けるのにUnion-Find木を利用します。
Union-Find木で解くと考えるまでが大変ですが、解き方がわかれば簡単に実装できます。

n = int(input())
A = list(map(int, input().split()))
uf = UnionFind(2 * 10**5 + 1)

ans = 0
for i in range(n//2):
    uf.unite(A[i], A[-(i+1)])

for g in uf.group_members().values():
    ans += len(g)-1

print(ans)

参考

タイトルとURLをコピーしました