From 300df97492188ddf8c5eb5740019d86f7ee72bb5 Mon Sep 17 00:00:00 2001 From: haianhng31 Date: Thu, 14 Nov 2024 02:44:27 +0000 Subject: [PATCH 1/8] create pdq_index2 and test cases --- .../signal_type/pdq/pdq_index2.py | 185 +++++++++++++++++ .../signal_type/tests/test_pdq_index2.py | 195 ++++++++++++++++++ 2 files changed, 380 insertions(+) create mode 100644 python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py create mode 100644 python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py diff --git a/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py new file mode 100644 index 000000000..3f8a708f1 --- /dev/null +++ b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +Implementation of SignalTypeIndex abstraction for PDQ by wrapping +hashing.pdq_faiss_matcher. +""" + +import typing as t +import faiss +import numpy as np +import pickle + + +from threatexchange.signal_type.index import ( + IndexMatchUntyped, + SignalSimilarityInfoWithIntDistance, + SignalTypeIndex, + T as IndexT, +) +from threatexchange.signal_type.pdq.pdq_faiss_matcher import ( + PDQMultiHashIndex, + PDQFlatHashIndex, + PDQHashIndex, +) + +DEFAULT_MATCH_DIST = 31 +DIMENSIONALITY = 256 + +PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT] + + +class PDQIndex2(SignalTypeIndex[IndexT]): + """ + Wrapper around the pdq faiss index lib using PDQMultiHashIndex + """ + + def __init__( + self, + threshold: int = DEFAULT_MATCH_DIST, + index: t.Optional[faiss.Index] = None, + entries: t.Iterable[t.Tuple[str, IndexT]] = (), + ) -> None: + super().__init__() + self.threshold = threshold + + if index is None: + index = faiss.IndexFlatL2(DIMENSIONALITY) + self.index = _PDQHashIndex(index) + + # Matches hash to Faiss index + self._deduper: t.Dict[str, faiss.IndexFlatL2] = {} + # Entry mapping: Each list[entries]'s index is its hash's index + self._idx_to_entries: t.List[t.List[IndexT]] = [] + + self.add_all(entries=entries) + + def __len__(self) -> int: + return len(self._idx_to_entries) + + def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]: + """ + Look up entries against the index, up to the max supported distance. + """ + results: t.List[PDQIndexMatch[IndexT]] = [] + matches_list: t.List[t.List[t.Any]] = self.index.search( + queries=[hash], threshold=self.threshold + ) + + for matches in matches_list: + for match_hash, distance in matches: + entries = self._idx_to_entries[match_hash] # Get the Faiss index + # Create match objects for each entry + results.extend( + PDQIndexMatch( + SignalSimilarityInfoWithIntDistance(distance=int(distance)), + entry, + ) + for entry in entries + ) + return results + + def add(self, signal_str: str, entry: IndexT) -> None: + self.add_all(((signal_str, entry),)) + + def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None: + for h, i in entries: + existing_faiss_id = self._deduper.get(h) + if existing_faiss_id is None: + self.index.add([h]) + self._idx_to_entries.append([i]) + next_id = len(self._deduper) # Because faiss index starts from 0 up + self._deduper[h] = next_id + else: + # Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication + self._idx_to_entries[existing_faiss_id].append(h) + + def serialize(self, fout: t.BinaryIO) -> None: + """ + Serialize the PDQ index to a binary stream. + """ + fout.write(pickle.dumps(self)) + + @classmethod + def deserialize(cls, fin: t.BinaryIO) -> "PDQIndex2[IndexT]": + """ + Deserialize a PDQ index from a binary stream. + """ + return pickle.loads(fin.read()) + + +class _PDQHashIndex: + """ + A wrapper around the faiss index for pickle serialization + """ + + def __init__(self, faiss_index: faiss.Index) -> None: + self.faiss_index = faiss_index + + def add(self, pdq_strings: t.Sequence[str]) -> None: + """ + Add PDQ hashes to the FAISS index. + Args: + pdq_strings (Sequence[str]): PDQ hash strings to add + """ + vectors = self._convert_pdq_strings_to_ndarray(pdq_strings) + self.faiss_index.add(vectors) + + def search( + self, queries: t.Sequence[str], threshold: int = DEFAULT_MATCH_DIST + ) -> t.List[t.List[t.Any]]: + """ + Search the FAISS index for matches to the given PDQ queries. + Args: + queries (Sequence[str]): The PDQ signal strings to search for. + threshold (int): The maximum distance threshold for matches. + Returns: + 2D list of tuples that store (matches, distances) for each query + """ + query_array: np.ndarray = self._convert_pdq_strings_to_ndarray(queries) + limits, distances, indices = self.faiss_index.range_search( + query_array, threshold + 1 + ) + + results: t.List[t.List[t.Any]] = [] + for i in range(len(queries)): + matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]] + dists = [dist for dist in distances[limits[i] : limits[i + 1]]] + results.append(list(zip(matches, dists))) + return results + + def __getstate__(self): + data = faiss.serialize_index(self.faiss_index) + return data + + def __setstate__(self, data): + self.faiss_index = faiss.deserialize_index(data) + + def _convert_pdq_strings_to_ndarray( + self, pdq_strings: t.Sequence[str] + ) -> np.ndarray: + """ + Convert multiple PDQ hash strings to a numpy array. + Args: + pdq_strings (Sequence[str]): A sequence of 64-character hexadecimal PDQ hash strings + Returns: + np.ndarray: A 2D array of shape (n_queries, 256) where each row is the full PDQ hash as a bit array + """ + hash_arrays = [] + for pdq_str in pdq_strings: + print("converting string:", pdq_str) + try: + # Convert hex string to integer + hash_int = int(pdq_str, 16) + # Convert to binary string, padding to ensure 256 bits + binary_str = format(hash_int, "0256b") + # Convert to numpy array + hash_array = np.array( + [int(bit) for bit in binary_str], dtype=np.float32 + ) + hash_arrays.append(hash_array) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid PDQ hash string: {pdq_str}") from e + + # Convert list of arrays to a single 2D array + return np.array(hash_arrays, dtype=np.float32) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py new file mode 100644 index 000000000..7570c7492 --- /dev/null +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py @@ -0,0 +1,195 @@ +import pytest +import io +import faiss +from threatexchange.signal_type.pdq.pdq_index2 import ( + PDQIndex2, + _PDQHashIndex, + DIMENSIONALITY, + DEFAULT_MATCH_DIST, +) + + +@pytest.fixture +def empty_index(): + """Fixture for an empty index.""" + return PDQIndex2[str]() + + +@pytest.fixture +def custom_index_with_threshold(): + """Fixture for an index with custom index and threshold.""" + custom_index = faiss.IndexFlatL2(DIMENSIONALITY + 1) + custom_threshold = DEFAULT_MATCH_DIST + 1 + return PDQIndex2[str](index=custom_index, threshold=custom_threshold) + + +@pytest.fixture +def sample_index(): + """Fixture for an index with a small sample set.""" + pdq_hashes = [ + "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", + "f" * 64, + "0" * 64, + "a" * 64, + ] + index = PDQIndex2[str](entries=[(h, pdq_hashes.index(h)) for h in pdq_hashes]) + return index, pdq_hashes + + +@pytest.fixture +def sample_index_with_one_entry(): + """Fixture for an index with a small sample set.""" + pdq_hashes = [ + "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", + "f" * 64, + "0" * 64, + "a" * 64, + ] + index = PDQIndex2[str]( + entries=[ + ("f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", 0) + ] + ) + return index, pdq_hashes + + +def test_init(empty_index) -> None: + assert empty_index.threshold == DEFAULT_MATCH_DIST + assert isinstance(empty_index.index, _PDQHashIndex) + assert isinstance(empty_index.index.faiss_index, faiss.IndexFlatL2) + assert empty_index.index.faiss_index.d == DIMENSIONALITY + assert empty_index._deduper == dict() + assert empty_index._idx_to_entries == [] + + +def test_serialize_deserialize(empty_index) -> None: + buffer = io.BytesIO() + empty_index.serialize(buffer) + buffer.seek(0) + deserialized_index: PDQIndex2[str] = PDQIndex2.deserialize(buffer) + + assert isinstance(deserialized_index, PDQIndex2) + assert deserialized_index.threshold == empty_index.threshold + assert isinstance(deserialized_index.index, _PDQHashIndex) + assert isinstance(deserialized_index.index.faiss_index, faiss.IndexFlatL2) + assert deserialized_index.index.faiss_index.d == DIMENSIONALITY + assert deserialized_index._deduper == empty_index._deduper + assert deserialized_index._idx_to_entries == empty_index._idx_to_entries + + +def test_serialize_deserialize_with_custom_index_threshold( + custom_index_with_threshold, +) -> None: + buffer = io.BytesIO() + custom_index_with_threshold.serialize(buffer) + buffer.seek(0) + deserialized_index: PDQIndex2[str] = PDQIndex2.deserialize(buffer) + + assert isinstance(deserialized_index, PDQIndex2) + assert deserialized_index.threshold == custom_index_with_threshold.threshold + assert isinstance(deserialized_index.index, _PDQHashIndex) + assert isinstance(deserialized_index.index.faiss_index, faiss.IndexFlatL2) + assert deserialized_index.index.faiss_index.d == DIMENSIONALITY + 1 + assert deserialized_index._deduper == custom_index_with_threshold._deduper + assert ( + deserialized_index._idx_to_entries + == custom_index_with_threshold._idx_to_entries + ) + + +def test_empty_index_query(empty_index): + """Test querying an empty index.""" + query_hash = "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22" + + # Query should return empty list + results = empty_index.query(query_hash) + assert len(results) == 0 + + +def test_sample_set_exact_match(sample_index): + """Test exact matches in sample set.""" + index, pdq_hashes = sample_index + + # Query with existing hash + query_hash = pdq_hashes[0] + results = index.query(query_hash) + + assert len(results) == 1 + assert ( + results[0].similarity_info.distance == 0 + ) # Exact match should have distance 0 + + +def test_sample_set_no_match(sample_index): + """Test no matches in sample set.""" + index, _ = sample_index + results = index.query("b" * 64) + assert len(results) == 0 + + +def test_sample_set_near_match(sample_index): + """Test near matches in sample set.""" + index, pdq_hashes = sample_index + + # Create a near-match by flipping a few bits + base_hash = pdq_hashes[0] + near_hash = hex(int(base_hash, 16) ^ 0xF)[2:].zfill(64) # Flip 4 bits + + results = index.query(near_hash) + assert len(results) > 0 # Should find near matches + assert results[0].similarity_info.distance > 0 + + +def test_sample_set_threshold(sample_index): + """Test distance threshold behavior.""" + _, pdq_hashes = sample_index + + narrow_index = PDQIndex2[str](threshold=10) # Strict matching + wide_index = PDQIndex2[str](threshold=50) # Loose matching + + for hash_str in pdq_hashes: + narrow_index.add(hash_str, hash_str) + wide_index.add(hash_str, hash_str) + + # Create a test hash with known distance + base_hash = pdq_hashes[0] + test_hash = hex(int(base_hash, 16) ^ ((1 << 20) - 1))[2:].zfill( + 64 + ) # ~20 bits different + + narrow_results = narrow_index.query(test_hash) + wide_results = wide_index.query(test_hash) + + assert len(wide_results) > len(narrow_results) # Wide threshold should match more + + +def test_duplicate_handling(sample_index): + """Test how the index handles duplicate entries.""" + index, pdq_hashes = sample_index + + # Add same hash multiple times + test_hash = pdq_hashes[0] + index.add_all(entries=[(test_hash, i) for i in range(3)]) + + results = index.query(test_hash) + + # Should find all entries associated with the hash + assert len(results) == 4 + for result in results: + assert result.similarity_info.distance == 0 + + +def test_one_entry_sample_index(sample_index_with_one_entry): + """Test how the index handles when it only has one entry.""" + index, pdq_hashes = sample_index_with_one_entry + + matching_test_hash = pdq_hashes[0] # This is the existing hash in index + unmatching_test_hash = pdq_hashes[1] + + results = index.query(matching_test_hash) + # Should find 1 entry associated with the hash + assert len(results) == 1 + assert results[0].similarity_info.distance == 0 + + results = index.query(unmatching_test_hash) + assert len(results) == 0 From 9fd77447c751969dcf2359cc150e58eb0ae58278 Mon Sep 17 00:00:00 2001 From: haianhng31 Date: Mon, 18 Nov 2024 18:56:51 +0000 Subject: [PATCH 2/8] pytx - continue fixing reimplementation pdq index --- .../signal_type/pdq/pdq_index2.py | 113 ++++--------- .../signal_type/pdq/pdq_utils.py | 19 +++ .../threatexchange/signal_type/pdq/signal.py | 4 +- .../signal_type/tests/test_pdq_index2.py | 159 ++++-------------- 4 files changed, 87 insertions(+), 208 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py index 3f8a708f1..8f0eb40c4 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py @@ -8,7 +8,6 @@ import typing as t import faiss import numpy as np -import pickle from threatexchange.signal_type.index import ( @@ -17,38 +16,36 @@ SignalTypeIndex, T as IndexT, ) -from threatexchange.signal_type.pdq.pdq_faiss_matcher import ( - PDQMultiHashIndex, - PDQFlatHashIndex, - PDQHashIndex, +from threatexchange.signal_type.pdq.signal import PDQ_CONFIDENT_MATCH_THRESHOLD +from threatexchange.signal_type.pdq.pdq_utils import ( + BITS_IN_PDQ, + convert_pdq_strings_to_ndarray, ) -DEFAULT_MATCH_DIST = 31 -DIMENSIONALITY = 256 - PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT] class PDQIndex2(SignalTypeIndex[IndexT]): """ - Wrapper around the pdq faiss index lib using PDQMultiHashIndex + Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search. """ def __init__( self, - threshold: int = DEFAULT_MATCH_DIST, index: t.Optional[faiss.Index] = None, entries: t.Iterable[t.Tuple[str, IndexT]] = (), + *, + threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD, ) -> None: super().__init__() self.threshold = threshold if index is None: - index = faiss.IndexFlatL2(DIMENSIONALITY) - self.index = _PDQHashIndex(index) + index = faiss.IndexFlatL2(BITS_IN_PDQ) + self.index = _PDQFaissIndex(index) # Matches hash to Faiss index - self._deduper: t.Dict[str, faiss.IndexFlatL2] = {} + self._deduper: t.Dict[str, int] = {} # Entry mapping: Each list[entries]'s index is its hash's index self._idx_to_entries: t.List[t.List[IndexT]] = [] @@ -62,21 +59,20 @@ def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]: Look up entries against the index, up to the max supported distance. """ results: t.List[PDQIndexMatch[IndexT]] = [] - matches_list: t.List[t.List[t.Any]] = self.index.search( + matches_list: t.List[t.Tuple[int, int]] = self.index.search( queries=[hash], threshold=self.threshold ) - for matches in matches_list: - for match_hash, distance in matches: - entries = self._idx_to_entries[match_hash] # Get the Faiss index - # Create match objects for each entry - results.extend( - PDQIndexMatch( - SignalSimilarityInfoWithIntDistance(distance=int(distance)), - entry, - ) - for entry in entries + for match, distance in matches_list: + entries = self._idx_to_entries[match] + # Create match objects for each entry + results.extend( + PDQIndexMatch( + SignalSimilarityInfoWithIntDistance(distance=int(distance)), + entry, ) + for entry in entries + ) return results def add(self, signal_str: str, entry: IndexT) -> None: @@ -92,23 +88,10 @@ def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None: self._deduper[h] = next_id else: # Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication - self._idx_to_entries[existing_faiss_id].append(h) - - def serialize(self, fout: t.BinaryIO) -> None: - """ - Serialize the PDQ index to a binary stream. - """ - fout.write(pickle.dumps(self)) - - @classmethod - def deserialize(cls, fin: t.BinaryIO) -> "PDQIndex2[IndexT]": - """ - Deserialize a PDQ index from a binary stream. - """ - return pickle.loads(fin.read()) + self._idx_to_entries[existing_faiss_id].append(i) -class _PDQHashIndex: +class _PDQFaissIndex: """ A wrapper around the faiss index for pickle serialization """ @@ -119,67 +102,31 @@ def __init__(self, faiss_index: faiss.Index) -> None: def add(self, pdq_strings: t.Sequence[str]) -> None: """ Add PDQ hashes to the FAISS index. - Args: - pdq_strings (Sequence[str]): PDQ hash strings to add """ - vectors = self._convert_pdq_strings_to_ndarray(pdq_strings) + vectors = convert_pdq_strings_to_ndarray(pdq_strings) self.faiss_index.add(vectors) def search( - self, queries: t.Sequence[str], threshold: int = DEFAULT_MATCH_DIST - ) -> t.List[t.List[t.Any]]: + self, queries: t.Sequence[str], threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD + ) -> t.List[t.Tuple[int, int]]: """ Search the FAISS index for matches to the given PDQ queries. - Args: - queries (Sequence[str]): The PDQ signal strings to search for. - threshold (int): The maximum distance threshold for matches. - Returns: - 2D list of tuples that store (matches, distances) for each query """ - query_array: np.ndarray = self._convert_pdq_strings_to_ndarray(queries) + query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries) limits, distances, indices = self.faiss_index.range_search( query_array, threshold + 1 ) - results: t.List[t.List[t.Any]] = [] + results: t.List[t.Tuple[int, int]] = [] for i in range(len(queries)): matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]] dists = [dist for dist in distances[limits[i] : limits[i + 1]]] - results.append(list(zip(matches, dists))) + for j in range(len(matches)): + results.append((matches[j], dists[j])) return results def __getstate__(self): - data = faiss.serialize_index(self.faiss_index) - return data + return faiss.serialize_index(self.faiss_index) def __setstate__(self, data): self.faiss_index = faiss.deserialize_index(data) - - def _convert_pdq_strings_to_ndarray( - self, pdq_strings: t.Sequence[str] - ) -> np.ndarray: - """ - Convert multiple PDQ hash strings to a numpy array. - Args: - pdq_strings (Sequence[str]): A sequence of 64-character hexadecimal PDQ hash strings - Returns: - np.ndarray: A 2D array of shape (n_queries, 256) where each row is the full PDQ hash as a bit array - """ - hash_arrays = [] - for pdq_str in pdq_strings: - print("converting string:", pdq_str) - try: - # Convert hex string to integer - hash_int = int(pdq_str, 16) - # Convert to binary string, padding to ensure 256 bits - binary_str = format(hash_int, "0256b") - # Convert to numpy array - hash_array = np.array( - [int(bit) for bit in binary_str], dtype=np.float32 - ) - hash_arrays.append(hash_array) - except (ValueError, TypeError) as e: - raise ValueError(f"Invalid PDQ hash string: {pdq_str}") from e - - # Convert list of arrays to a single 2D array - return np.array(hash_arrays, dtype=np.float32) diff --git a/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py b/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py index c7a3ae351..3dd8cc42d 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py @@ -1,6 +1,9 @@ #!/usr/bin/env python # Copyright (c) Meta Platforms, Inc. and affiliates. +import numpy as np +import typing as t + BITS_IN_PDQ = 256 PDQ_HEX_STR_LEN = int(BITS_IN_PDQ / 4) @@ -49,3 +52,19 @@ def pdq_match(pdq_hex_a: str, pdq_hex_b: str, threshold: int) -> bool: """ distance = simple_distance(pdq_hex_a, pdq_hex_b) return distance <= threshold + + +def convert_pdq_strings_to_ndarray(pdq_strings: t.Sequence[str]) -> np.ndarray: + """ + Convert multiple PDQ hash strings to a numpy array. + """ + if not all(len(pdq_str) == PDQ_HEX_STR_LEN for pdq_str in pdq_strings): + raise ValueError("All PDQ hash strings must be 64 hex characters long") + + binary_arrays = [] + for pdq_str in pdq_strings: + hash_bytes = bytes.fromhex(pdq_str) + binary_array = np.unpackbits(np.frombuffer(hash_bytes, dtype=np.uint8)) + binary_arrays.append(binary_array) + + return np.array(binary_arrays, dtype=np.uint8) diff --git a/python-threatexchange/threatexchange/signal_type/pdq/signal.py b/python-threatexchange/threatexchange/signal_type/pdq/signal.py index 3c0aa0dce..e6684d588 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/signal.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/signal.py @@ -18,6 +18,8 @@ ) from threatexchange.signal_type.pdq.pdq_index import PDQIndex +PDQ_CONFIDENT_MATCH_THRESHOLD = 31 + class PdqSignal( signal_base.SimpleSignalType, @@ -43,7 +45,7 @@ class PdqSignal( # This may need to be updated (TODO make more configurable) # Hashes of distance less than or equal to this threshold are considered a 'match' - PDQ_CONFIDENT_MATCH_THRESHOLD = 31 + PDQ_CONFIDENT_MATCH_THRESHOLD = PDQ_CONFIDENT_MATCH_THRESHOLD # Images with less than quality 50 are too unreliable to match on QUALITY_THRESHOLD = 50 diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py index 7570c7492..6c75f1a0e 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py @@ -1,118 +1,35 @@ import pytest import io import faiss -from threatexchange.signal_type.pdq.pdq_index2 import ( - PDQIndex2, - _PDQHashIndex, - DIMENSIONALITY, - DEFAULT_MATCH_DIST, -) - - -@pytest.fixture -def empty_index(): - """Fixture for an empty index.""" - return PDQIndex2[str]() - - -@pytest.fixture -def custom_index_with_threshold(): - """Fixture for an index with custom index and threshold.""" - custom_index = faiss.IndexFlatL2(DIMENSIONALITY + 1) - custom_threshold = DEFAULT_MATCH_DIST + 1 - return PDQIndex2[str](index=custom_index, threshold=custom_threshold) - - -@pytest.fixture -def sample_index(): - """Fixture for an index with a small sample set.""" - pdq_hashes = [ - "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", - "f" * 64, - "0" * 64, - "a" * 64, - ] - index = PDQIndex2[str](entries=[(h, pdq_hashes.index(h)) for h in pdq_hashes]) - return index, pdq_hashes - - -@pytest.fixture -def sample_index_with_one_entry(): - """Fixture for an index with a small sample set.""" - pdq_hashes = [ - "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", - "f" * 64, - "0" * 64, - "a" * 64, - ] - index = PDQIndex2[str]( - entries=[ - ("f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", 0) - ] - ) - return index, pdq_hashes - - -def test_init(empty_index) -> None: - assert empty_index.threshold == DEFAULT_MATCH_DIST - assert isinstance(empty_index.index, _PDQHashIndex) - assert isinstance(empty_index.index.faiss_index, faiss.IndexFlatL2) - assert empty_index.index.faiss_index.d == DIMENSIONALITY - assert empty_index._deduper == dict() - assert empty_index._idx_to_entries == [] - - -def test_serialize_deserialize(empty_index) -> None: - buffer = io.BytesIO() - empty_index.serialize(buffer) - buffer.seek(0) - deserialized_index: PDQIndex2[str] = PDQIndex2.deserialize(buffer) - - assert isinstance(deserialized_index, PDQIndex2) - assert deserialized_index.threshold == empty_index.threshold - assert isinstance(deserialized_index.index, _PDQHashIndex) - assert isinstance(deserialized_index.index.faiss_index, faiss.IndexFlatL2) - assert deserialized_index.index.faiss_index.d == DIMENSIONALITY - assert deserialized_index._deduper == empty_index._deduper - assert deserialized_index._idx_to_entries == empty_index._idx_to_entries - - -def test_serialize_deserialize_with_custom_index_threshold( - custom_index_with_threshold, -) -> None: - buffer = io.BytesIO() - custom_index_with_threshold.serialize(buffer) - buffer.seek(0) - deserialized_index: PDQIndex2[str] = PDQIndex2.deserialize(buffer) - - assert isinstance(deserialized_index, PDQIndex2) - assert deserialized_index.threshold == custom_index_with_threshold.threshold - assert isinstance(deserialized_index.index, _PDQHashIndex) - assert isinstance(deserialized_index.index.faiss_index, faiss.IndexFlatL2) - assert deserialized_index.index.faiss_index.d == DIMENSIONALITY + 1 - assert deserialized_index._deduper == custom_index_with_threshold._deduper - assert ( - deserialized_index._idx_to_entries - == custom_index_with_threshold._idx_to_entries - ) +from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2, _PDQFaissIndex +from threatexchange.signal_type.pdq.signal import PDQ_CONFIDENT_MATCH_THRESHOLD +from threatexchange.signal_type.pdq.pdq_utils import BITS_IN_PDQ + +SAMPLE_HASH = "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22" + +SAMPLE_HASHES = [ + SAMPLE_HASH, + "f" * 64, + "0" * 64, + "a" * 64, +] -def test_empty_index_query(empty_index): +def test_empty_index_query(): """Test querying an empty index.""" - query_hash = "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22" + index = PDQIndex2() # Query should return empty list - results = empty_index.query(query_hash) + results = index.query(SAMPLE_HASH) assert len(results) == 0 -def test_sample_set_exact_match(sample_index): +def test_sample_set_exact_match(): """Test exact matches in sample set.""" - index, pdq_hashes = sample_index + index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) # Query with existing hash - query_hash = pdq_hashes[0] - results = index.query(query_hash) + results = index.query(SAMPLE_HASH) assert len(results) == 1 assert ( @@ -120,40 +37,35 @@ def test_sample_set_exact_match(sample_index): ) # Exact match should have distance 0 -def test_sample_set_no_match(sample_index): +def test_sample_set_no_match(): """Test no matches in sample set.""" - index, _ = sample_index + index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) results = index.query("b" * 64) assert len(results) == 0 -def test_sample_set_near_match(sample_index): +def test_sample_set_near_match(): """Test near matches in sample set.""" - index, pdq_hashes = sample_index - + index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) # Create a near-match by flipping a few bits - base_hash = pdq_hashes[0] - near_hash = hex(int(base_hash, 16) ^ 0xF)[2:].zfill(64) # Flip 4 bits + near_hash = hex(int(SAMPLE_HASH, 16) ^ 0xF)[2:].zfill(64) results = index.query(near_hash) assert len(results) > 0 # Should find near matches assert results[0].similarity_info.distance > 0 -def test_sample_set_threshold(sample_index): - """Test distance threshold behavior.""" - _, pdq_hashes = sample_index - +def test_sample_set_threshold(): + """Verify that the sample set respects the specified distance threshold.""" narrow_index = PDQIndex2[str](threshold=10) # Strict matching wide_index = PDQIndex2[str](threshold=50) # Loose matching - for hash_str in pdq_hashes: + for hash_str in SAMPLE_HASHES: narrow_index.add(hash_str, hash_str) wide_index.add(hash_str, hash_str) # Create a test hash with known distance - base_hash = pdq_hashes[0] - test_hash = hex(int(base_hash, 16) ^ ((1 << 20) - 1))[2:].zfill( + test_hash = hex(int(SAMPLE_HASH, 16) ^ ((1 << 20) - 1))[2:].zfill( 64 ) # ~20 bits different @@ -163,15 +75,14 @@ def test_sample_set_threshold(sample_index): assert len(wide_results) > len(narrow_results) # Wide threshold should match more -def test_duplicate_handling(sample_index): +def test_duplicate_handling(): """Test how the index handles duplicate entries.""" - index, pdq_hashes = sample_index + index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) # Add same hash multiple times - test_hash = pdq_hashes[0] - index.add_all(entries=[(test_hash, i) for i in range(3)]) + index.add_all(entries=[(SAMPLE_HASH, i) for i in range(3)]) - results = index.query(test_hash) + results = index.query(SAMPLE_HASH) # Should find all entries associated with the hash assert len(results) == 4 @@ -179,12 +90,12 @@ def test_duplicate_handling(sample_index): assert result.similarity_info.distance == 0 -def test_one_entry_sample_index(sample_index_with_one_entry): +def test_one_entry_sample_index(): """Test how the index handles when it only has one entry.""" - index, pdq_hashes = sample_index_with_one_entry + index = PDQIndex2(entries=[(SAMPLE_HASH, 0)]) - matching_test_hash = pdq_hashes[0] # This is the existing hash in index - unmatching_test_hash = pdq_hashes[1] + matching_test_hash = SAMPLE_HASHES[0] # This is the existing hash in index + unmatching_test_hash = SAMPLE_HASHES[1] results = index.query(matching_test_hash) # Should find 1 entry associated with the hash From 3cf072bd19d9e87af783d8bc4af7b239aec4614f Mon Sep 17 00:00:00 2001 From: haianhng31 Date: Mon, 18 Nov 2024 19:54:53 +0000 Subject: [PATCH 3/8] pytx - add unittest for pdq index2 --- .../signal_type/tests/test_pdq_index2.py | 50 ++++++++++++++++--- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py index 6c75f1a0e..5e90adfcd 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py @@ -1,9 +1,8 @@ -import pytest -import io -import faiss -from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2, _PDQFaissIndex -from threatexchange.signal_type.pdq.signal import PDQ_CONFIDENT_MATCH_THRESHOLD -from threatexchange.signal_type.pdq.pdq_utils import BITS_IN_PDQ +from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 +from threatexchange.signal_type.pdq.signal import PdqSignal +import typing as t +import numpy as np +from threatexchange.signal_type.pdq.pdq_utils import convert_pdq_strings_to_ndarray SAMPLE_HASH = "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22" @@ -15,6 +14,45 @@ ] +def test_pdq_index(): + common_hashes = [PdqSignal.get_random_signal() for _ in range(100)] # Make sure they have at least 100 similar hashes + base_hashes = common_hashes + [PdqSignal.get_random_signal() for _ in range(1000)] + query_hashes = common_hashes + [PdqSignal.get_random_signal() for _ in range(10000)] + + def brute_force_match( + base: t.List[str], query: str, threshold: int = 32 + ) -> t.Set[int]: + matches = set() + query_arr = convert_pdq_strings_to_ndarray([query])[0] + + for i, base_hash in enumerate(base): + base_arr = convert_pdq_strings_to_ndarray([base_hash])[0] + distance = np.count_nonzero(query_arr != base_arr) + if distance <= threshold: + matches.add(i) + return matches + + brute_force_matches = { + query_hash: brute_force_match(base_hashes, query_hash) + for query_hash in query_hashes + } + + index = PDQIndex2() + for i, base_hash in enumerate(base_hashes): + index.add(base_hash, i) + + for query_hash in query_hashes: + expected_indices = brute_force_matches[query_hash] + index_results = index.query(query_hash) + + result_indices = {result.metadata for result in index_results} + + assert result_indices == expected_indices, ( + f"Mismatch for hash {query_hash}: " + f"Expected {expected_indices}, Got {result_indices}" + ) + + def test_empty_index_query(): """Test querying an empty index.""" index = PDQIndex2() From cd44cbded112bfb402f3df1a9a1bf0aae296bb33 Mon Sep 17 00:00:00 2001 From: haianhng31 Date: Mon, 18 Nov 2024 19:57:03 +0000 Subject: [PATCH 4/8] pytx fix lint --- .../threatexchange/signal_type/tests/test_pdq_index2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py index 5e90adfcd..2958d26c8 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py @@ -15,7 +15,9 @@ def test_pdq_index(): - common_hashes = [PdqSignal.get_random_signal() for _ in range(100)] # Make sure they have at least 100 similar hashes + common_hashes = [ + PdqSignal.get_random_signal() for _ in range(100) + ] # Make sure they have at least 100 similar hashes base_hashes = common_hashes + [PdqSignal.get_random_signal() for _ in range(1000)] query_hashes = common_hashes + [PdqSignal.get_random_signal() for _ in range(10000)] From a5bfb6d2102b6b1b6ad3e166e97c9a0c84a740d0 Mon Sep 17 00:00:00 2001 From: haianhng31 Date: Tue, 19 Nov 2024 19:57:11 +0000 Subject: [PATCH 5/8] pytx - add test case for pdq index2 --- .../signal_type/tests/test_pdq_index2.py | 144 +++++++++--------- 1 file changed, 68 insertions(+), 76 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py index 2958d26c8..eef380b7d 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py @@ -1,41 +1,53 @@ -from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 -from threatexchange.signal_type.pdq.signal import PdqSignal import typing as t import numpy as np +import random + +from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 +from threatexchange.signal_type.pdq.signal import PdqSignal from threatexchange.signal_type.pdq.pdq_utils import convert_pdq_strings_to_ndarray -SAMPLE_HASH = "f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22" +SAMPLE_HASHES = [PdqSignal.get_random_signal() for _ in range(100)] + + +def _brute_force_match( + base: t.List[str], query: str, threshold: int = 32 +) -> t.Set[int]: + matches = set() + query_arr = convert_pdq_strings_to_ndarray([query])[0] + + for i, base_hash in enumerate(base): + base_arr = convert_pdq_strings_to_ndarray([base_hash])[0] + distance = np.count_nonzero(query_arr != base_arr) + if distance <= threshold: + matches.add(i) + return matches + -SAMPLE_HASHES = [ - SAMPLE_HASH, - "f" * 64, - "0" * 64, - "a" * 64, -] +def _generate_random_hash_with_distance(hash: str, distance: int) -> str: + if len(hash) != 64 or not all(c in "0123456789abcdef" for c in hash.lower()): + raise ValueError("Hash must be a 64-character hexadecimal string") + if distance < 0 or distance > 256: + raise ValueError("Distance must be between 0 and 256") + + hash_bits = bin(int(hash, 16))[2:].zfill(256) # Convert hash to binary + bits = list(hash_bits) + positions = random.sample( + range(256), distance + ) # Randomly select unique positions to flip + for pos in positions: + bits[pos] = "0" if bits[pos] == "1" else "1" # Flip selected bit positions + modified_hash = hex(int("".join(bits), 2))[2:].zfill(64) # Convert back to hex + + return modified_hash def test_pdq_index(): - common_hashes = [ - PdqSignal.get_random_signal() for _ in range(100) - ] # Make sure they have at least 100 similar hashes - base_hashes = common_hashes + [PdqSignal.get_random_signal() for _ in range(1000)] - query_hashes = common_hashes + [PdqSignal.get_random_signal() for _ in range(10000)] - - def brute_force_match( - base: t.List[str], query: str, threshold: int = 32 - ) -> t.Set[int]: - matches = set() - query_arr = convert_pdq_strings_to_ndarray([query])[0] - - for i, base_hash in enumerate(base): - base_arr = convert_pdq_strings_to_ndarray([base_hash])[0] - distance = np.count_nonzero(query_arr != base_arr) - if distance <= threshold: - matches.add(i) - return matches + # Make sure base_hashes and query_hashes have at least 100 similar hashes + base_hashes = SAMPLE_HASHES + [PdqSignal.get_random_signal() for _ in range(1000)] + query_hashes = SAMPLE_HASHES + [PdqSignal.get_random_signal() for _ in range(10000)] brute_force_matches = { - query_hash: brute_force_match(base_hashes, query_hash) + query_hash: _brute_force_match(base_hashes, query_hash) for query_hash in query_hashes } @@ -55,28 +67,39 @@ def brute_force_match( ) +def test_pdq_index_with_exact_distance(): + thresholds: t.List[int] = [10, 31, 50] + indexes: t.List[PDQIndex2] = [] + for thres in thresholds: + index = PDQIndex2( + entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES], + threshold=thres, + ) + indexes.append(index) + + distances: t.List[int] = [0, 1, 20, 30, 31, 60] + query_hash = SAMPLE_HASHES[0] + + for i in range(len(indexes)): + index = indexes[i] + + for dist in distances: + query_hash_w_dist = _generate_random_hash_with_distance(query_hash, dist) + results = index.query(query_hash_w_dist) + result_indices = {result.similarity_info.distance for result in results} + if dist <= thresholds[i]: + assert dist in result_indices + + def test_empty_index_query(): """Test querying an empty index.""" index = PDQIndex2() # Query should return empty list - results = index.query(SAMPLE_HASH) + results = index.query(PdqSignal.get_random_signal()) assert len(results) == 0 -def test_sample_set_exact_match(): - """Test exact matches in sample set.""" - index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) - - # Query with existing hash - results = index.query(SAMPLE_HASH) - - assert len(results) == 1 - assert ( - results[0].similarity_info.distance == 0 - ) # Exact match should have distance 0 - - def test_sample_set_no_match(): """Test no matches in sample set.""" index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) @@ -84,45 +107,14 @@ def test_sample_set_no_match(): assert len(results) == 0 -def test_sample_set_near_match(): - """Test near matches in sample set.""" - index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) - # Create a near-match by flipping a few bits - near_hash = hex(int(SAMPLE_HASH, 16) ^ 0xF)[2:].zfill(64) - - results = index.query(near_hash) - assert len(results) > 0 # Should find near matches - assert results[0].similarity_info.distance > 0 - - -def test_sample_set_threshold(): - """Verify that the sample set respects the specified distance threshold.""" - narrow_index = PDQIndex2[str](threshold=10) # Strict matching - wide_index = PDQIndex2[str](threshold=50) # Loose matching - - for hash_str in SAMPLE_HASHES: - narrow_index.add(hash_str, hash_str) - wide_index.add(hash_str, hash_str) - - # Create a test hash with known distance - test_hash = hex(int(SAMPLE_HASH, 16) ^ ((1 << 20) - 1))[2:].zfill( - 64 - ) # ~20 bits different - - narrow_results = narrow_index.query(test_hash) - wide_results = wide_index.query(test_hash) - - assert len(wide_results) > len(narrow_results) # Wide threshold should match more - - def test_duplicate_handling(): """Test how the index handles duplicate entries.""" index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) # Add same hash multiple times - index.add_all(entries=[(SAMPLE_HASH, i) for i in range(3)]) + index.add_all(entries=[(SAMPLE_HASHES[0], i) for i in range(3)]) - results = index.query(SAMPLE_HASH) + results = index.query(SAMPLE_HASHES[0]) # Should find all entries associated with the hash assert len(results) == 4 @@ -132,7 +124,7 @@ def test_duplicate_handling(): def test_one_entry_sample_index(): """Test how the index handles when it only has one entry.""" - index = PDQIndex2(entries=[(SAMPLE_HASH, 0)]) + index = PDQIndex2(entries=[(SAMPLE_HASHES[0], 0)]) matching_test_hash = SAMPLE_HASHES[0] # This is the existing hash in index unmatching_test_hash = SAMPLE_HASHES[1] From 0a32d650d1fa63c7ae3845422f7953711e7802ae Mon Sep 17 00:00:00 2001 From: haianhng31 Date: Wed, 20 Nov 2024 20:58:11 +0000 Subject: [PATCH 6/8] pytx continue edit pdq index2 --- .../signal_type/pdq/pdq_index2.py | 19 +++++---- .../signal_type/pdq/pdq_utils.py | 1 + .../threatexchange/signal_type/pdq/signal.py | 7 ++-- .../signal_type/tests/test_pdq_index2.py | 42 +++++++++++-------- 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py index 8f0eb40c4..312be505f 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py @@ -1,8 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. """ -Implementation of SignalTypeIndex abstraction for PDQ by wrapping -hashing.pdq_faiss_matcher. +Implementation of SignalTypeIndex abstraction for PDQ """ import typing as t @@ -16,9 +15,9 @@ SignalTypeIndex, T as IndexT, ) -from threatexchange.signal_type.pdq.signal import PDQ_CONFIDENT_MATCH_THRESHOLD from threatexchange.signal_type.pdq.pdq_utils import ( BITS_IN_PDQ, + PDQ_CONFIDENT_MATCH_THRESHOLD, convert_pdq_strings_to_ndarray, ) @@ -28,6 +27,10 @@ class PDQIndex2(SignalTypeIndex[IndexT]): """ Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search. + + This is a redo of the existing PDQ index, + designed to be simpler and fix hard-to-squash bugs in the existing implementation. + Purpose of this class: to replace the original index in pytx 2.0 """ def __init__( @@ -42,7 +45,7 @@ def __init__( if index is None: index = faiss.IndexFlatL2(BITS_IN_PDQ) - self.index = _PDQFaissIndex(index) + self._index = _PDQFaissIndex(index) # Matches hash to Faiss index self._deduper: t.Dict[str, int] = {} @@ -56,10 +59,10 @@ def __len__(self) -> int: def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]: """ - Look up entries against the index, up to the max supported distance. + Look up entries against the index, up to the threshold. """ results: t.List[PDQIndexMatch[IndexT]] = [] - matches_list: t.List[t.Tuple[int, int]] = self.index.search( + matches_list: t.List[t.Tuple[int, int]] = self._index.search( queries=[hash], threshold=self.threshold ) @@ -82,7 +85,7 @@ def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None: for h, i in entries: existing_faiss_id = self._deduper.get(h) if existing_faiss_id is None: - self.index.add([h]) + self._index.add([h]) self._idx_to_entries.append([i]) next_id = len(self._deduper) # Because faiss index starts from 0 up self._deduper[h] = next_id @@ -107,7 +110,7 @@ def add(self, pdq_strings: t.Sequence[str]) -> None: self.faiss_index.add(vectors) def search( - self, queries: t.Sequence[str], threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD + self, queries: t.Sequence[str], threshold: int ) -> t.List[t.Tuple[int, int]]: """ Search the FAISS index for matches to the given PDQ queries. diff --git a/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py b/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py index 3dd8cc42d..f9ff126c9 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py @@ -6,6 +6,7 @@ BITS_IN_PDQ = 256 PDQ_HEX_STR_LEN = int(BITS_IN_PDQ / 4) +PDQ_CONFIDENT_MATCH_THRESHOLD = 31 def simple_distance_binary(bin_a: str, bin_b: str) -> int: diff --git a/python-threatexchange/threatexchange/signal_type/pdq/signal.py b/python-threatexchange/threatexchange/signal_type/pdq/signal.py index e6684d588..35617c4ac 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/signal.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/signal.py @@ -12,14 +12,15 @@ from threatexchange.content_type.content_base import ContentType from threatexchange.content_type.photo import PhotoContent from threatexchange.signal_type import signal_base -from threatexchange.signal_type.pdq.pdq_utils import simple_distance +from threatexchange.signal_type.pdq.pdq_utils import ( + simple_distance, + PDQ_CONFIDENT_MATCH_THRESHOLD, +) from threatexchange.exchanges.impl.fb_threatexchange_signal import ( HasFbThreatExchangeIndicatorType, ) from threatexchange.signal_type.pdq.pdq_index import PDQIndex -PDQ_CONFIDENT_MATCH_THRESHOLD = 31 - class PdqSignal( signal_base.SimpleSignalType, diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py index eef380b7d..863ee34cc 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py @@ -4,29 +4,31 @@ from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 from threatexchange.signal_type.pdq.signal import PdqSignal -from threatexchange.signal_type.pdq.pdq_utils import convert_pdq_strings_to_ndarray +from threatexchange.signal_type.pdq.pdq_utils import simple_distance -SAMPLE_HASHES = [PdqSignal.get_random_signal() for _ in range(100)] + +def _generate_sample_hashes(size: int, seed: int = 42): + random.seed(seed) + return [PdqSignal.get_random_signal() for _ in range(size)] + + +SAMPLE_HASHES = _generate_sample_hashes(100) def _brute_force_match( base: t.List[str], query: str, threshold: int = 32 -) -> t.Set[int]: +) -> t.Set[t.Tuple[int, int]]: matches = set() - query_arr = convert_pdq_strings_to_ndarray([query])[0] for i, base_hash in enumerate(base): - base_arr = convert_pdq_strings_to_ndarray([base_hash])[0] - distance = np.count_nonzero(query_arr != base_arr) + distance = simple_distance(base_hash, query) if distance <= threshold: - matches.add(i) + matches.add((i, distance)) return matches def _generate_random_hash_with_distance(hash: str, distance: int) -> str: - if len(hash) != 64 or not all(c in "0123456789abcdef" for c in hash.lower()): - raise ValueError("Hash must be a 64-character hexadecimal string") - if distance < 0 or distance > 256: + if not (0 <= distance <= 256): raise ValueError("Distance must be between 0 and 256") hash_bits = bin(int(hash, 16))[2:].zfill(256) # Convert hash to binary @@ -42,9 +44,9 @@ def _generate_random_hash_with_distance(hash: str, distance: int) -> str: def test_pdq_index(): - # Make sure base_hashes and query_hashes have at least 100 similar hashes - base_hashes = SAMPLE_HASHES + [PdqSignal.get_random_signal() for _ in range(1000)] - query_hashes = SAMPLE_HASHES + [PdqSignal.get_random_signal() for _ in range(10000)] + # Make sure base_hashes and query_hashes have at least 10 similar hashes + base_hashes = SAMPLE_HASHES + query_hashes = SAMPLE_HASHES[:10] + _generate_sample_hashes(10) brute_force_matches = { query_hash: _brute_force_match(base_hashes, query_hash) @@ -59,7 +61,10 @@ def test_pdq_index(): expected_indices = brute_force_matches[query_hash] index_results = index.query(query_hash) - result_indices = {result.metadata for result in index_results} + result_indices: t.Set[t.Tuple[t.Any, int]] = { + (result.metadata, result.similarity_info.distance) + for result in index_results + } assert result_indices == expected_indices, ( f"Mismatch for hash {query_hash}: " @@ -69,13 +74,14 @@ def test_pdq_index(): def test_pdq_index_with_exact_distance(): thresholds: t.List[int] = [10, 31, 50] - indexes: t.List[PDQIndex2] = [] - for thres in thresholds: - index = PDQIndex2( + + indexes = [ + PDQIndex2( entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES], threshold=thres, ) - indexes.append(index) + for thres in thresholds + ] distances: t.List[int] = [0, 1, 20, 30, 31, 60] query_hash = SAMPLE_HASHES[0] From 610af2ee3f86ac8e407b78e25e65034b123aef77 Mon Sep 17 00:00:00 2001 From: haianhng31 Date: Wed, 20 Nov 2024 21:10:46 +0000 Subject: [PATCH 7/8] pytx resolve from nit comments for pdq index2 --- .../threatexchange/signal_type/pdq/pdq_utils.py | 8 ++++---- .../threatexchange/signal_type/pdq/signal.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py b/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py index f9ff126c9..67b518d2a 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py @@ -6,6 +6,7 @@ BITS_IN_PDQ = 256 PDQ_HEX_STR_LEN = int(BITS_IN_PDQ / 4) +# Hashes of distance less than or equal to this threshold are considered a 'match' PDQ_CONFIDENT_MATCH_THRESHOLD = 31 @@ -55,15 +56,14 @@ def pdq_match(pdq_hex_a: str, pdq_hex_b: str, threshold: int) -> bool: return distance <= threshold -def convert_pdq_strings_to_ndarray(pdq_strings: t.Sequence[str]) -> np.ndarray: +def convert_pdq_strings_to_ndarray(pdq_strings: t.Iterable[str]) -> np.ndarray: """ Convert multiple PDQ hash strings to a numpy array. """ - if not all(len(pdq_str) == PDQ_HEX_STR_LEN for pdq_str in pdq_strings): - raise ValueError("All PDQ hash strings must be 64 hex characters long") - binary_arrays = [] for pdq_str in pdq_strings: + if len(pdq_str) != PDQ_HEX_STR_LEN: + raise ValueError("PDQ hash string must be 64 hex characters long") hash_bytes = bytes.fromhex(pdq_str) binary_array = np.unpackbits(np.frombuffer(hash_bytes, dtype=np.uint8)) binary_arrays.append(binary_array) diff --git a/python-threatexchange/threatexchange/signal_type/pdq/signal.py b/python-threatexchange/threatexchange/signal_type/pdq/signal.py index 35617c4ac..64ef69cfc 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/signal.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/signal.py @@ -45,8 +45,6 @@ class PdqSignal( INDICATOR_TYPE = "HASH_PDQ" # This may need to be updated (TODO make more configurable) - # Hashes of distance less than or equal to this threshold are considered a 'match' - PDQ_CONFIDENT_MATCH_THRESHOLD = PDQ_CONFIDENT_MATCH_THRESHOLD # Images with less than quality 50 are too unreliable to match on QUALITY_THRESHOLD = 50 From 2ea86a18985cf8074722531c7f77c4d2ebe9a35e Mon Sep 17 00:00:00 2001 From: haianhng31 Date: Thu, 21 Nov 2024 21:11:48 +0000 Subject: [PATCH 8/8] pytx - edit pdq index2 --- .../signal_type/pdq/pdq_index2.py | 2 +- .../signal_type/tests/test_pdq_index2.py | 69 ++++++++++++++----- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py index 312be505f..b1e80dbb6 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py @@ -71,7 +71,7 @@ def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]: # Create match objects for each entry results.extend( PDQIndexMatch( - SignalSimilarityInfoWithIntDistance(distance=int(distance)), + SignalSimilarityInfoWithIntDistance(distance=distance), entry, ) for entry in entries diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py index 863ee34cc..b06dd8b1a 100644 --- a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py @@ -1,18 +1,20 @@ import typing as t -import numpy as np import random +import io +import faiss from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 from threatexchange.signal_type.pdq.signal import PdqSignal from threatexchange.signal_type.pdq.pdq_utils import simple_distance -def _generate_sample_hashes(size: int, seed: int = 42): +def _get_hash_generator(seed: int = 42): random.seed(seed) - return [PdqSignal.get_random_signal() for _ in range(size)] + def get_n_hashes(n: int): + return [PdqSignal.get_random_signal() for _ in range(n)] -SAMPLE_HASHES = _generate_sample_hashes(100) + return get_n_hashes def _brute_force_match( @@ -44,18 +46,17 @@ def _generate_random_hash_with_distance(hash: str, distance: int) -> str: def test_pdq_index(): + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) # Make sure base_hashes and query_hashes have at least 10 similar hashes - base_hashes = SAMPLE_HASHES - query_hashes = SAMPLE_HASHES[:10] + _generate_sample_hashes(10) + query_hashes = base_hashes[:10] + get_random_hashes(1000) brute_force_matches = { query_hash: _brute_force_match(base_hashes, query_hash) for query_hash in query_hashes } - index = PDQIndex2() - for i, base_hash in enumerate(base_hashes): - index.add(base_hash, i) + index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) for query_hash in query_hashes: expected_indices = brute_force_matches[query_hash] @@ -73,18 +74,21 @@ def test_pdq_index(): def test_pdq_index_with_exact_distance(): + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + thresholds: t.List[int] = [10, 31, 50] indexes = [ PDQIndex2( - entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES], + entries=[(h, base_hashes.index(h)) for h in base_hashes], threshold=thres, ) for thres in thresholds ] distances: t.List[int] = [0, 1, 20, 30, 31, 60] - query_hash = SAMPLE_HASHES[0] + query_hash = base_hashes[0] for i in range(len(indexes)): index = indexes[i] @@ -97,6 +101,23 @@ def test_pdq_index_with_exact_distance(): assert dist in result_indices +def test_serialize_deserialize_index(): + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) + + buffer = io.BytesIO() + index.serialize(buffer) + buffer.seek(0) + deserialized_index = PDQIndex2.deserialize(buffer) + + assert isinstance(deserialized_index, PDQIndex2) + assert isinstance(deserialized_index._index.faiss_index, faiss.IndexFlatL2) + assert deserialized_index.threshold == index.threshold + assert deserialized_index._deduper == index._deduper + assert deserialized_index._idx_to_entries == index._idx_to_entries + + def test_empty_index_query(): """Test querying an empty index.""" index = PDQIndex2() @@ -108,19 +129,23 @@ def test_empty_index_query(): def test_sample_set_no_match(): """Test no matches in sample set.""" - index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) results = index.query("b" * 64) assert len(results) == 0 def test_duplicate_handling(): """Test how the index handles duplicate entries.""" - index = PDQIndex2(entries=[(h, SAMPLE_HASHES.index(h)) for h in SAMPLE_HASHES]) + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) # Add same hash multiple times - index.add_all(entries=[(SAMPLE_HASHES[0], i) for i in range(3)]) + index.add_all(entries=[(base_hashes[0], i) for i in range(3)]) - results = index.query(SAMPLE_HASHES[0]) + results = index.query(base_hashes[0]) # Should find all entries associated with the hash assert len(results) == 4 @@ -129,11 +154,17 @@ def test_duplicate_handling(): def test_one_entry_sample_index(): - """Test how the index handles when it only has one entry.""" - index = PDQIndex2(entries=[(SAMPLE_HASHES[0], 0)]) + """ + Test how the index handles when it only has one entry. + + See issue github.com/facebook/ThreatExchange/issues/1318 + """ + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + index = PDQIndex2(entries=[(base_hashes[0], 0)]) - matching_test_hash = SAMPLE_HASHES[0] # This is the existing hash in index - unmatching_test_hash = SAMPLE_HASHES[1] + matching_test_hash = base_hashes[0] # This is the existing hash in index + unmatching_test_hash = base_hashes[1] results = index.query(matching_test_hash) # Should find 1 entry associated with the hash