はじめに
Union-Find木について、どんなものなのか、AtCoderなどではどんな問題で使うのか、Pythonで実装したサンプルコードについて紹介していきます。
実際のAtCoderで利用したプログラムも掲載しています。
Union-Find木とは
Union-Find木とは、グループが互いに重ならない(素集合)ようにグループ分けを管理するデータ構造です。
以下のことが効率的にできます。
- 併合:グループ同士をまとめる
- 判定:要素同士が同じグループに属しているか
Union-Find木は、名前の通り木を使って実現します。
各要素がノード、グループが木となります。
初期化
まずはじめは、n個の要素を用意します。
このときはそれぞれのノードは木になっていません。
併合
2つのグループをまとめて1つにします。
片方の根にもう片方の木の根に辺を張り、1つの木にします。
判定
同じグループかの判定は木をたどり、同じ根になるかどうかで判定します。
使いどころ
AtCoderなどでは、グループ分けをするような問題で使えることが多いですが、一度覚えるとなんでもできるのではないかと思ってしまうので、注意が必要です。
逆にUninon-Find木で考えると簡単になるのに、思いつかなかったという問題もあったので、問題の定式化が重要になるかと思います。
Union-Find木のサンプルコード
Union-Find木のサンプルコードになります。
クラスとして実装しているので、利用する場合はそのままコピペで使えます。
from collections import defaultdict
class UnionFind():
"""
Union Find木クラス
Attributes
--------------------
n : int
要素数
root : list
木の要素数
0未満であればそのノードが根であり、添字の値が要素数
rank : list
木の深さ
"""
def __init__(self, n):
"""
Parameters
---------------------
n : int
要素数
"""
self.n = n
self.root = [-1]*(n+1)
self.rank = [0]*(n+1)
def find(self, x):
"""
ノードxの根を見つける
Parameters
---------------------
x : int
見つけるノード
Returns
---------------------
root : int
根のノード
"""
if(self.root[x] < 0):
return x
else:
self.root[x] = self.find(self.root[x])
return self.root[x]
def unite(self, x, y):
"""
木の併合
Parameters
---------------------
x : int
併合したノード
y : int
併合したノード
"""
x = self.find(x)
y = self.find(y)
if(x == y):
return
elif(self.rank[x] > self.rank[y]):
self.root[x] += self.root[y]
self.root[y] = x
else:
self.root[y] += self.root[x]
self.root[x] = y
if(self.rank[x] == self.rank[y]):
self.rank[y] += 1
def same(self, x, y):
"""
同じグループに属するか判定
Parameters
---------------------
x : int
判定したノード
y : int
判定したノード
Returns
---------------------
ans : bool
同じグループに属しているか
"""
return self.find(x) == self.find(y)
def size(self, x):
"""
木のサイズを計算
Parameters
---------------------
x : int
計算したい木のノード
Returns
---------------------
size : int
木のサイズ
"""
return -self.root[self.find(x)]
def roots(self):
"""
根のノードを取得
Returns
---------------------
roots : list
根のノード
"""
return [i for i, x in enumerate(self.root) if x < 0]
def group_size(self):
"""
グループ数を取得
Returns
---------------------
size : int
グループ数
"""
return len(self.roots())
def group_members(self):
"""
全てのグループごとのノードを取得
Returns
---------------------
group_members : defaultdict
根をキーとしたノードのリスト
"""
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
return group_members
使い方
以下のようにUnion-Find木で併合や判定が可能です。
n = 5
uf = UnionFind(n)
uf.unite(1, 2)
uf.unite(4, 3)
uf.unite(4, 5)
uf.find(1)
uf.find(4)
uf.same(1, 2)
uf.same(1, 3)
AtCoderのサンプル問題
ARC106 B問題

N頂点M辺の無向グラフが与えられ、それぞれの頂点を$a{i}$から$b{i}$に変更可能か判定する問題です。
各連結成分ごとの総和で判定するため、Union-Find木でグループ分けし、グループごとの総和を求めています。
import numpy as np
n, m = map(int, input().split())
a = np.array(list(map(int, input().split())))
b = np.array(list(map(int, input().split())))
c = []
d = []
for _ in range(m):
tmp_c, tmp_d = map(int, input().split())
c.append(tmp_c-1)
d.append(tmp_d-1)
uf = UnionFind(n)
for i in range(m):
uf.unite(c[i], d[i])
for g in uf.group_members().values():
if sum(a[g]) != sum(b[g]):
print('No')
exit()
print('Yes')
ABC177 D問題

N人のM個の関係性(友達かどうか)が与えられ、友達の友達は友達のとき、N人を同じグループに友達がいない最小のグループ数を求めます。
友達同士のグループを作り、友達同士が最大のグループを求める必要があるのでUnion-Find木でグループ分けし、それぞれのグループの人数で判定しています。
n, m = map(int, input().split())
A = []
B = []
for _ in range(m):
a, b = map(int, input().split())
A.append(a-1)
B.append(b-1)
uf = UnionFind(n)
for i in range(m):
uf.unite(A[i], B[i])
ans = 0
for g in uf.group_members().values():
g_size = len(g)
if g_size > ans:
ans = g_size
print(ans)
ABC206 D問題

N項からなる数列を回文に変える問題です。
回文になる組み合わせ(同じになるべき数字同士)のグループに分けるのにUnion-Find木を利用します。
Union-Find木で解くと考えるまでが大変ですが、解き方がわかれば簡単に実装できます。
n = int(input())
A = list(map(int, input().split()))
uf = UnionFind(2 * 10**5 + 1)
ans = 0
for i in range(n//2):
uf.unite(A[i], A[-(i+1)])
for g in uf.group_members().values():
ans += len(g)-1
print(ans)