Skip to content

Commit

Permalink
Merging (Identically Specified) MinHashLSH objects
Browse files Browse the repository at this point in the history
Fixes #205
  • Loading branch information
rupeshkumaar committed Jan 21, 2024
1 parent 48501a0 commit d96b6e7
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
10 changes: 6 additions & 4 deletions datasketch/lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,16 @@ def _insert(
for H, hashtable in zip(Hs, self.hashtables):
hashtable.insert(H, key, buffer=buffer)

def __eq__(self, other:MinHashLSH) -> bool:
def __equivalent(self, other:MinHashLSH) -> bool:
"""
Returns:
bool: If the two MinHashLSH has equal num_perm then two are equivalent.
bool: If the two MinHashLSH has equal num_perm, band size and size of each bands then two are equivalent.
"""
return (
type(self) is type(other) and
self.h == other.h
self.h == other.h and
self.b == other.b and
self.r == other.r
)

def _merge(
Expand All @@ -317,7 +319,7 @@ def _merge(
check_disjointness: bool = False,
buffer: bool = False
) -> MinHashLSH:
if self == other:
if self.__equivalent(other):
if check_disjointness and set(self.keys).intersection(set(other.keys)):
raise ValueError("The keys are not disjoint, duplicate key exists.")
for key in other.keys:
Expand Down
52 changes: 52 additions & 0 deletions test/test_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def test_merge(self):
self.assertTrue("d" in items)
self.assertTrue("a" in lsh1)
self.assertTrue("b" in lsh1)
self.assertTrue("c" in lsh1)
self.assertTrue("d" in lsh1)
for i, H in enumerate(lsh1.keys["c"]):
self.assertTrue("c" in lsh1.hashtables[i][H])

Expand All @@ -280,6 +282,56 @@ def test_merge(self):

self.assertRaises(ValueError, lsh1.merge, lsh3, check_disjointness=True)

def test_merge_redis(self):
with patch('redis.Redis', fake_redis) as mock_redis:
lsh1 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={
'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379}
})
lsh2 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={
'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379}
})

m1 = MinHash(16)
m1.update("a".encode("utf8"))
m2 = MinHash(16)
m2.update("b".encode("utf8"))
lsh1.insert("a", m1)
lsh1.insert("b", m2)

m3 = MinHash(16)
m3.update("c".encode("utf8"))
m4 = MinHash(16)
m4.update("d".encode("utf8"))
lsh2.insert("c", m3)
lsh2.insert("d", m4)

lsh1.merge(lsh2)
for t in lsh1.hashtables:
self.assertTrue(len(t) >= 1)
items = []
for H in t:
items.extend(t[H])
self.assertTrue(pickle.dumps("c") in items)
self.assertTrue(pickle.dumps("d") in items)
self.assertTrue("a" in lsh1)
self.assertTrue("b" in lsh1)
self.assertTrue("c" in lsh1)
self.assertTrue("d" in lsh1)
for i, H in enumerate(lsh1.keys[pickle.dumps("c")]):
self.assertTrue(pickle.dumps("c") in lsh1.hashtables[i][H])

self.assertTrue(lsh1.merge, lsh2)
self.assertRaises(ValueError, lsh1.merge, lsh2, check_disjointness=True)

m5 = MinHash(32)
m5.update("e".encode("utf-8"))
lsh3 = MinHashLSH(threshold=0.5, num_perm=32, storage_config={
'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379}
})
lsh3.insert("a",m5)

self.assertRaises(ValueError, lsh1.merge, lsh3, check_disjointness=True)

class TestWeightedMinHashLSH(unittest.TestCase):

def test_init(self):
Expand Down

0 comments on commit d96b6e7

Please sign in to comment.