forked from salaee/pegbis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdisjoint_set.py
40 lines (34 loc) · 1.12 KB
/
disjoint_set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
# disjoint-set forests using union-by-rank and path compression (sort of).
class universe:
def __init__(self, n_elements):
self.num = n_elements
self.elts = np.empty(shape=(n_elements, 3), dtype=int)
for i in range(n_elements):
self.elts[i, 0] = 0 # rank
self.elts[i, 1] = 1 # size
self.elts[i, 2] = i # p
def size(self, x):
return self.elts[x, 1]
def num_sets(self):
return self.num
def find(self, x):
y = int(x)
while y != self.elts[y, 2]:
y = self.elts[y, 2]
self.elts[x, 2] = y
return y
def join(self, x, y):
# x = int(x)
# y = int(y)
if self.elts[x, 0] > self.elts[y, 0]:
self.elts[y, 2] = x
self.elts[x, 1] += self.elts[y, 1]
self.elts[y, 1] = self.elts[x, 1]
else:
self.elts[x, 2] = y
self.elts[y, 1] += self.elts[x, 1]
self.elts[x, 1] = self.elts[y, 1]
if self.elts[x, 0] == self.elts[y, 0]:
self.elts[y, 0] += 1
self.num -= 1