나의 풀이 :

from collections import defaultdict


def find(a):
if not parent[a]:
return a
pa = parent[a]
parent[a] = find(pa)
weight[a] += weight[pa]
return parent[a]


def union(a, b, w):
pa = find(a)
pb = find(b)
if pa == pb:
return
diff = weight[b] - weight[a]
if rank[pa] > rank[pb]:
pa, pb = pb, pa
w = -w
diff = -diff
weight[pa] = w + diff
parent[pa] = pb
if rank[pa] == rank[pb]:
rank[pb] += 1


T = int(input())
for tc in range(1, T + 1):
N, M = map(int, input().split())
parent = defaultdict(int)
weight = defaultdict(int)
rank = defaultdict(int)
ans = []
for _ in range(M):
work = input()
if work[0] == '!':
a, b, w = map(int, work.split()[1:])
union(a, b, w)
else:
a, b = map(int, work.split()[1:])
if find(a) == find(b):
ans.append(weight[a] - weight[b])
else:
ans.append('UNKNOWN')
print('#{} {}'.format(tc, ' '.join(map(str, ans))))


한마디 :

이번 기회에 Union-Find를 확실히 익혀야겠다.