diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6259a6b3..495104a3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,7 @@ jobs: runs-on: "ubuntu-latest" strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/datasketch/lsh.py b/datasketch/lsh.py index 2f1c9b40..4a682f84 100644 --- a/datasketch/lsh.py +++ b/datasketch/lsh.py @@ -229,21 +229,21 @@ def insert( def merge( self, other: MinHashLSH, - check_disjointness: bool = False + check_overlap: bool = False ): """Merge the other MinHashLSH with this one, making this one the union of both the MinHashLSH. Args: other (MinHashLSH): The other MinHashLSH. - check_duplication (bool): To avoid duplicate keys in the storage + check_overlap (bool): Check if there are any overlapping keys before merging and raise if there are any. (`default=True`) Raises: ValueError: If the two MinHashLSH have different initialization - parameters. + parameters, or if `check_overlap` is `True` and there are overlapping keys. """ - self._merge(other, check_disjointness=check_disjointness, buffer=False) + self._merge(other, check_overlap=check_overlap, buffer=False) def insertion_session(self, buffer_size: int = 50000) -> MinHashLSHInsertionSession: """ @@ -304,25 +304,24 @@ def _insert( def __equivalent(self, other:MinHashLSH) -> bool: """ Returns: - bool: If the two MinHashLSH have equal num_perm, number of bands, size of each band and hashfunc (if provided) then two are equivalent. + bool: If the two MinHashLSH have equal num_perm, number of bands, size of each band then two are equivalent. """ return ( type(self) is type(other) and self.h == other.h and self.b == other.b and - self.r == other.r and - type(self.hashfunc) == type(other.hashfunc) + self.r == other.r ) def _merge( self, other: MinHashLSH, - check_disjointness: bool = False, + check_overlap: bool = False, buffer: bool = False ) -> MinHashLSH: if self.__equivalent(other): - if check_disjointness and set(self.keys).intersection(set(other.keys)): - raise ValueError("The keys are not disjoint, duplicate key exists.") + if check_overlap and set(self.keys).intersection(set(other.keys)): + raise ValueError("The keys are overlapping, duplicate key exists.") for key in other.keys: Hs = other.keys.get(key) self.keys.insert(key, *Hs, buffer=buffer) diff --git a/examples/lsh_examples.py b/examples/lsh_examples.py index 6d50563c..007e1399 100644 --- a/examples/lsh_examples.py +++ b/examples/lsh_examples.py @@ -47,8 +47,8 @@ def eg1(): lsh1.merge(lsh2) print("Does m1 exist in the lsh1...", "m1" in lsh1.keys) - # if check_disjointness flag is set to True then it will check the disjointness of the keys in the two MinHashLSH - lsh1.merge(lsh2,check_disjointness=True) + # if check_overlap flag is set to True then it will check the overlapping of the keys in the two MinHashLSH + lsh1.merge(lsh2,check_overlap=True) def eg2(): mg = WeightedMinHashGenerator(10, 5) diff --git a/test/test_lsh.py b/test/test_lsh.py index c2b080f8..a2893753 100644 --- a/test/test_lsh.py +++ b/test/test_lsh.py @@ -273,14 +273,24 @@ def test_merge(self): self.assertTrue("c" in lsh1.hashtables[i][H]) self.assertTrue(lsh1.merge, lsh2) - self.assertRaises(ValueError, lsh1.merge, lsh2, check_disjointness=True) + self.assertRaises(ValueError, lsh1.merge, lsh2, check_overlap=True) - m5 = MinHash(32) + m5 = MinHash(16) m5.update("e".encode("utf-8")) - lsh3 = MinHashLSH(threshold=0.5, num_perm=32) + lsh3 = MinHashLSH(threshold=0.5, num_perm=16) lsh3.insert("a",m5) - self.assertRaises(ValueError, lsh1.merge, lsh3, check_disjointness=True) + self.assertRaises(ValueError, lsh1.merge, lsh3, check_overlap=True) + + lsh1.merge(lsh3) + + m6 = MinHash(16) + m6.update("e".encode("utf-8")) + lsh4 = MinHashLSH(threshold=0.5, num_perm=16) + lsh4.insert("a",m6) + + lsh1.merge(lsh4, check_overlap=False) + def test_merge_redis(self): with patch('redis.Redis', fake_redis) as mock_redis: @@ -321,16 +331,26 @@ def test_merge_redis(self): self.assertTrue(pickle.dumps("c") in lsh1.hashtables[i][H]) self.assertTrue(lsh1.merge, lsh2) - self.assertRaises(ValueError, lsh1.merge, lsh2, check_disjointness=True) + self.assertRaises(ValueError, lsh1.merge, lsh2, check_overlap=True) - m5 = MinHash(32) + m5 = MinHash(16) m5.update("e".encode("utf-8")) - lsh3 = MinHashLSH(threshold=0.5, num_perm=32, storage_config={ + lsh3 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={ 'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379} }) lsh3.insert("a",m5) - self.assertRaises(ValueError, lsh1.merge, lsh3, check_disjointness=True) + self.assertRaises(ValueError, lsh1.merge, lsh3, check_overlap=True) + + m6 = MinHash(16) + m6.update("e".encode("utf-8")) + lsh4 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={ + 'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379} + }) + lsh4.insert("a",m6) + + lsh1.merge(lsh4, check_overlap=False) + class TestWeightedMinHashLSH(unittest.TestCase):