Skip to content

Commit

Permalink
Merging (Identically Specified) MinHashLSH objects
Browse files Browse the repository at this point in the history
  • Loading branch information
rupeshkumaar committed Mar 11, 2024
1 parent 6628db8 commit ce29b01
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: "ubuntu-latest"
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
19 changes: 9 additions & 10 deletions datasketch/lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,21 +229,21 @@ def insert(
def merge(
self,
other: MinHashLSH,
check_disjointness: bool = False
check_overlap: bool = False
):
"""Merge the other MinHashLSH with this one, making this one the union
of both the MinHashLSH.
Args:
other (MinHashLSH): The other MinHashLSH.
check_duplication (bool): To avoid duplicate keys in the storage
check_overlap (bool): Check if there are any overlapping keys before merging and raise if there are any.
(`default=True`)
Raises:
ValueError: If the two MinHashLSH have different initialization
parameters.
parameters, or if `check_overlap` is `True` and there are overlapping keys.
"""
self._merge(other, check_disjointness=check_disjointness, buffer=False)
self._merge(other, check_overlap=check_overlap, buffer=False)

def insertion_session(self, buffer_size: int = 50000) -> MinHashLSHInsertionSession:
"""
Expand Down Expand Up @@ -304,25 +304,24 @@ def _insert(
def __equivalent(self, other:MinHashLSH) -> bool:
"""
Returns:
bool: If the two MinHashLSH have equal num_perm, number of bands, size of each band and hashfunc (if provided) then two are equivalent.
bool: If the two MinHashLSH have equal num_perm, number of bands, size of each band then two are equivalent.
"""
return (
type(self) is type(other) and
self.h == other.h and
self.b == other.b and
self.r == other.r and
type(self.hashfunc) == type(other.hashfunc)
self.r == other.r
)

def _merge(
self,
other: MinHashLSH,
check_disjointness: bool = False,
check_overlap: bool = False,
buffer: bool = False
) -> MinHashLSH:
if self.__equivalent(other):
if check_disjointness and set(self.keys).intersection(set(other.keys)):
raise ValueError("The keys are not disjoint, duplicate key exists.")
if check_overlap and set(self.keys).intersection(set(other.keys)):
raise ValueError("The keys are overlapping, duplicate key exists.")
for key in other.keys:
Hs = other.keys.get(key)
self.keys.insert(key, *Hs, buffer=buffer)
Expand Down
4 changes: 2 additions & 2 deletions examples/lsh_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def eg1():

lsh1.merge(lsh2)
print("Does m1 exist in the lsh1...", "m1" in lsh1.keys)
# if check_disjointness flag is set to True then it will check the disjointness of the keys in the two MinHashLSH
lsh1.merge(lsh2,check_disjointness=True)
# if check_overlap flag is set to True then it will check the overlapping of the keys in the two MinHashLSH
lsh1.merge(lsh2,check_overlap=True)

def eg2():
mg = WeightedMinHashGenerator(10, 5)
Expand Down
36 changes: 28 additions & 8 deletions test/test_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,24 @@ def test_merge(self):
self.assertTrue("c" in lsh1.hashtables[i][H])

self.assertTrue(lsh1.merge, lsh2)
self.assertRaises(ValueError, lsh1.merge, lsh2, check_disjointness=True)
self.assertRaises(ValueError, lsh1.merge, lsh2, check_overlap=True)

m5 = MinHash(32)
m5 = MinHash(16)
m5.update("e".encode("utf-8"))
lsh3 = MinHashLSH(threshold=0.5, num_perm=32)
lsh3 = MinHashLSH(threshold=0.5, num_perm=16)
lsh3.insert("a",m5)

self.assertRaises(ValueError, lsh1.merge, lsh3, check_disjointness=True)
self.assertRaises(ValueError, lsh1.merge, lsh3, check_overlap=True)

lsh1.merge(lsh3)

m6 = MinHash(16)
m6.update("e".encode("utf-8"))
lsh4 = MinHashLSH(threshold=0.5, num_perm=16)
lsh4.insert("a",m6)

lsh1.merge(lsh4, check_overlap=False)


def test_merge_redis(self):
with patch('redis.Redis', fake_redis) as mock_redis:
Expand Down Expand Up @@ -321,16 +331,26 @@ def test_merge_redis(self):
self.assertTrue(pickle.dumps("c") in lsh1.hashtables[i][H])

self.assertTrue(lsh1.merge, lsh2)
self.assertRaises(ValueError, lsh1.merge, lsh2, check_disjointness=True)
self.assertRaises(ValueError, lsh1.merge, lsh2, check_overlap=True)

m5 = MinHash(32)
m5 = MinHash(16)
m5.update("e".encode("utf-8"))
lsh3 = MinHashLSH(threshold=0.5, num_perm=32, storage_config={
lsh3 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={
'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379}
})
lsh3.insert("a",m5)

self.assertRaises(ValueError, lsh1.merge, lsh3, check_disjointness=True)
self.assertRaises(ValueError, lsh1.merge, lsh3, check_overlap=True)

m6 = MinHash(16)
m6.update("e".encode("utf-8"))
lsh4 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={
'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379}
})
lsh4.insert("a",m6)

lsh1.merge(lsh4, check_overlap=False)


class TestWeightedMinHashLSH(unittest.TestCase):

Expand Down

0 comments on commit ce29b01

Please sign in to comment.