Start with components. Each successful union reduces the count by 1:
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.count = n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
rootX, rootY = self.find(x), self.find(y)
if rootX == rootY:
return False # already connected
if self.rank[rootX] < self.rank[rootY]:
self.parent[rootX] = rootY
elif self.rank[rootX] > self.rank[rootY]:
self.parent[rootY] = rootX
else:
self.parent[rootY] = rootX
self.rank[rootX] += 1
self.count -= 1
return True
def countComponents(n, edges):
uf = UnionFind(n)
for a, b in edges:
uf.union(a, b)
return uf.count