From d96b6e705d2a6de5b8ac928317472e29a84d6c22 Mon Sep 17 00:00:00 2001 From: "rupesh.kumar" <57129475+rupeshkumaar@users.noreply.github.com> Date: Sun, 21 Jan 2024 12:04:24 +0530 Subject: [PATCH] Merging (Identically Specified) MinHashLSH objects Fixes #205 --- datasketch/lsh.py | 10 +++++---- test/test_lsh.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/datasketch/lsh.py b/datasketch/lsh.py index a25f5891..069e23be 100644 --- a/datasketch/lsh.py +++ b/datasketch/lsh.py @@ -301,14 +301,16 @@ def _insert( for H, hashtable in zip(Hs, self.hashtables): hashtable.insert(H, key, buffer=buffer) - def __eq__(self, other:MinHashLSH) -> bool: + def __equivalent(self, other:MinHashLSH) -> bool: """ Returns: - bool: If the two MinHashLSH has equal num_perm then two are equivalent. + bool: If the two MinHashLSH has equal num_perm, band size and size of each bands then two are equivalent. """ return ( type(self) is type(other) and - self.h == other.h + self.h == other.h and + self.b == other.b and + self.r == other.r ) def _merge( @@ -317,7 +319,7 @@ def _merge( check_disjointness: bool = False, buffer: bool = False ) -> MinHashLSH: - if self == other: + if self.__equivalent(other): if check_disjointness and set(self.keys).intersection(set(other.keys)): raise ValueError("The keys are not disjoint, duplicate key exists.") for key in other.keys: diff --git a/test/test_lsh.py b/test/test_lsh.py index a15be323..c2b080f8 100644 --- a/test/test_lsh.py +++ b/test/test_lsh.py @@ -267,6 +267,8 @@ def test_merge(self): self.assertTrue("d" in items) self.assertTrue("a" in lsh1) self.assertTrue("b" in lsh1) + self.assertTrue("c" in lsh1) + self.assertTrue("d" in lsh1) for i, H in enumerate(lsh1.keys["c"]): self.assertTrue("c" in lsh1.hashtables[i][H]) @@ -280,6 +282,56 @@ def test_merge(self): self.assertRaises(ValueError, lsh1.merge, lsh3, check_disjointness=True) + def test_merge_redis(self): + with patch('redis.Redis', fake_redis) as mock_redis: + lsh1 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={ + 'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379} + }) + lsh2 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={ + 'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379} + }) + + m1 = MinHash(16) + m1.update("a".encode("utf8")) + m2 = MinHash(16) + m2.update("b".encode("utf8")) + lsh1.insert("a", m1) + lsh1.insert("b", m2) + + m3 = MinHash(16) + m3.update("c".encode("utf8")) + m4 = MinHash(16) + m4.update("d".encode("utf8")) + lsh2.insert("c", m3) + lsh2.insert("d", m4) + + lsh1.merge(lsh2) + for t in lsh1.hashtables: + self.assertTrue(len(t) >= 1) + items = [] + for H in t: + items.extend(t[H]) + self.assertTrue(pickle.dumps("c") in items) + self.assertTrue(pickle.dumps("d") in items) + self.assertTrue("a" in lsh1) + self.assertTrue("b" in lsh1) + self.assertTrue("c" in lsh1) + self.assertTrue("d" in lsh1) + for i, H in enumerate(lsh1.keys[pickle.dumps("c")]): + self.assertTrue(pickle.dumps("c") in lsh1.hashtables[i][H]) + + self.assertTrue(lsh1.merge, lsh2) + self.assertRaises(ValueError, lsh1.merge, lsh2, check_disjointness=True) + + m5 = MinHash(32) + m5.update("e".encode("utf-8")) + lsh3 = MinHashLSH(threshold=0.5, num_perm=32, storage_config={ + 'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379} + }) + lsh3.insert("a",m5) + + self.assertRaises(ValueError, lsh1.merge, lsh3, check_disjointness=True) + class TestWeightedMinHashLSH(unittest.TestCase): def test_init(self):