diff --git a/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py new file mode 100644 index 000000000..b1e80dbb6 --- /dev/null +++ b/python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +""" +Implementation of SignalTypeIndex abstraction for PDQ +""" + +import typing as t +import faiss +import numpy as np + + +from threatexchange.signal_type.index import ( + IndexMatchUntyped, + SignalSimilarityInfoWithIntDistance, + SignalTypeIndex, + T as IndexT, +) +from threatexchange.signal_type.pdq.pdq_utils import ( + BITS_IN_PDQ, + PDQ_CONFIDENT_MATCH_THRESHOLD, + convert_pdq_strings_to_ndarray, +) + +PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT] + + +class PDQIndex2(SignalTypeIndex[IndexT]): + """ + Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search. + + This is a redo of the existing PDQ index, + designed to be simpler and fix hard-to-squash bugs in the existing implementation. + Purpose of this class: to replace the original index in pytx 2.0 + """ + + def __init__( + self, + index: t.Optional[faiss.Index] = None, + entries: t.Iterable[t.Tuple[str, IndexT]] = (), + *, + threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD, + ) -> None: + super().__init__() + self.threshold = threshold + + if index is None: + index = faiss.IndexFlatL2(BITS_IN_PDQ) + self._index = _PDQFaissIndex(index) + + # Matches hash to Faiss index + self._deduper: t.Dict[str, int] = {} + # Entry mapping: Each list[entries]'s index is its hash's index + self._idx_to_entries: t.List[t.List[IndexT]] = [] + + self.add_all(entries=entries) + + def __len__(self) -> int: + return len(self._idx_to_entries) + + def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]: + """ + Look up entries against the index, up to the threshold. + """ + results: t.List[PDQIndexMatch[IndexT]] = [] + matches_list: t.List[t.Tuple[int, int]] = self._index.search( + queries=[hash], threshold=self.threshold + ) + + for match, distance in matches_list: + entries = self._idx_to_entries[match] + # Create match objects for each entry + results.extend( + PDQIndexMatch( + SignalSimilarityInfoWithIntDistance(distance=distance), + entry, + ) + for entry in entries + ) + return results + + def add(self, signal_str: str, entry: IndexT) -> None: + self.add_all(((signal_str, entry),)) + + def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None: + for h, i in entries: + existing_faiss_id = self._deduper.get(h) + if existing_faiss_id is None: + self._index.add([h]) + self._idx_to_entries.append([i]) + next_id = len(self._deduper) # Because faiss index starts from 0 up + self._deduper[h] = next_id + else: + # Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication + self._idx_to_entries[existing_faiss_id].append(i) + + +class _PDQFaissIndex: + """ + A wrapper around the faiss index for pickle serialization + """ + + def __init__(self, faiss_index: faiss.Index) -> None: + self.faiss_index = faiss_index + + def add(self, pdq_strings: t.Sequence[str]) -> None: + """ + Add PDQ hashes to the FAISS index. + """ + vectors = convert_pdq_strings_to_ndarray(pdq_strings) + self.faiss_index.add(vectors) + + def search( + self, queries: t.Sequence[str], threshold: int + ) -> t.List[t.Tuple[int, int]]: + """ + Search the FAISS index for matches to the given PDQ queries. + """ + query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries) + limits, distances, indices = self.faiss_index.range_search( + query_array, threshold + 1 + ) + + results: t.List[t.Tuple[int, int]] = [] + for i in range(len(queries)): + matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]] + dists = [dist for dist in distances[limits[i] : limits[i + 1]]] + for j in range(len(matches)): + results.append((matches[j], dists[j])) + return results + + def __getstate__(self): + return faiss.serialize_index(self.faiss_index) + + def __setstate__(self, data): + self.faiss_index = faiss.deserialize_index(data) diff --git a/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py b/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py index c7a3ae351..67b518d2a 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/pdq_utils.py @@ -1,8 +1,13 @@ #!/usr/bin/env python # Copyright (c) Meta Platforms, Inc. and affiliates. +import numpy as np +import typing as t + BITS_IN_PDQ = 256 PDQ_HEX_STR_LEN = int(BITS_IN_PDQ / 4) +# Hashes of distance less than or equal to this threshold are considered a 'match' +PDQ_CONFIDENT_MATCH_THRESHOLD = 31 def simple_distance_binary(bin_a: str, bin_b: str) -> int: @@ -49,3 +54,18 @@ def pdq_match(pdq_hex_a: str, pdq_hex_b: str, threshold: int) -> bool: """ distance = simple_distance(pdq_hex_a, pdq_hex_b) return distance <= threshold + + +def convert_pdq_strings_to_ndarray(pdq_strings: t.Iterable[str]) -> np.ndarray: + """ + Convert multiple PDQ hash strings to a numpy array. + """ + binary_arrays = [] + for pdq_str in pdq_strings: + if len(pdq_str) != PDQ_HEX_STR_LEN: + raise ValueError("PDQ hash string must be 64 hex characters long") + hash_bytes = bytes.fromhex(pdq_str) + binary_array = np.unpackbits(np.frombuffer(hash_bytes, dtype=np.uint8)) + binary_arrays.append(binary_array) + + return np.array(binary_arrays, dtype=np.uint8) diff --git a/python-threatexchange/threatexchange/signal_type/pdq/signal.py b/python-threatexchange/threatexchange/signal_type/pdq/signal.py index 3c0aa0dce..64ef69cfc 100644 --- a/python-threatexchange/threatexchange/signal_type/pdq/signal.py +++ b/python-threatexchange/threatexchange/signal_type/pdq/signal.py @@ -12,7 +12,10 @@ from threatexchange.content_type.content_base import ContentType from threatexchange.content_type.photo import PhotoContent from threatexchange.signal_type import signal_base -from threatexchange.signal_type.pdq.pdq_utils import simple_distance +from threatexchange.signal_type.pdq.pdq_utils import ( + simple_distance, + PDQ_CONFIDENT_MATCH_THRESHOLD, +) from threatexchange.exchanges.impl.fb_threatexchange_signal import ( HasFbThreatExchangeIndicatorType, ) @@ -42,8 +45,6 @@ class PdqSignal( INDICATOR_TYPE = "HASH_PDQ" # This may need to be updated (TODO make more configurable) - # Hashes of distance less than or equal to this threshold are considered a 'match' - PDQ_CONFIDENT_MATCH_THRESHOLD = 31 # Images with less than quality 50 are too unreliable to match on QUALITY_THRESHOLD = 50 diff --git a/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py new file mode 100644 index 000000000..b06dd8b1a --- /dev/null +++ b/python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py @@ -0,0 +1,175 @@ +import typing as t +import random +import io +import faiss + +from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 +from threatexchange.signal_type.pdq.signal import PdqSignal +from threatexchange.signal_type.pdq.pdq_utils import simple_distance + + +def _get_hash_generator(seed: int = 42): + random.seed(seed) + + def get_n_hashes(n: int): + return [PdqSignal.get_random_signal() for _ in range(n)] + + return get_n_hashes + + +def _brute_force_match( + base: t.List[str], query: str, threshold: int = 32 +) -> t.Set[t.Tuple[int, int]]: + matches = set() + + for i, base_hash in enumerate(base): + distance = simple_distance(base_hash, query) + if distance <= threshold: + matches.add((i, distance)) + return matches + + +def _generate_random_hash_with_distance(hash: str, distance: int) -> str: + if not (0 <= distance <= 256): + raise ValueError("Distance must be between 0 and 256") + + hash_bits = bin(int(hash, 16))[2:].zfill(256) # Convert hash to binary + bits = list(hash_bits) + positions = random.sample( + range(256), distance + ) # Randomly select unique positions to flip + for pos in positions: + bits[pos] = "0" if bits[pos] == "1" else "1" # Flip selected bit positions + modified_hash = hex(int("".join(bits), 2))[2:].zfill(64) # Convert back to hex + + return modified_hash + + +def test_pdq_index(): + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + # Make sure base_hashes and query_hashes have at least 10 similar hashes + query_hashes = base_hashes[:10] + get_random_hashes(1000) + + brute_force_matches = { + query_hash: _brute_force_match(base_hashes, query_hash) + for query_hash in query_hashes + } + + index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) + + for query_hash in query_hashes: + expected_indices = brute_force_matches[query_hash] + index_results = index.query(query_hash) + + result_indices: t.Set[t.Tuple[t.Any, int]] = { + (result.metadata, result.similarity_info.distance) + for result in index_results + } + + assert result_indices == expected_indices, ( + f"Mismatch for hash {query_hash}: " + f"Expected {expected_indices}, Got {result_indices}" + ) + + +def test_pdq_index_with_exact_distance(): + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + + thresholds: t.List[int] = [10, 31, 50] + + indexes = [ + PDQIndex2( + entries=[(h, base_hashes.index(h)) for h in base_hashes], + threshold=thres, + ) + for thres in thresholds + ] + + distances: t.List[int] = [0, 1, 20, 30, 31, 60] + query_hash = base_hashes[0] + + for i in range(len(indexes)): + index = indexes[i] + + for dist in distances: + query_hash_w_dist = _generate_random_hash_with_distance(query_hash, dist) + results = index.query(query_hash_w_dist) + result_indices = {result.similarity_info.distance for result in results} + if dist <= thresholds[i]: + assert dist in result_indices + + +def test_serialize_deserialize_index(): + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) + + buffer = io.BytesIO() + index.serialize(buffer) + buffer.seek(0) + deserialized_index = PDQIndex2.deserialize(buffer) + + assert isinstance(deserialized_index, PDQIndex2) + assert isinstance(deserialized_index._index.faiss_index, faiss.IndexFlatL2) + assert deserialized_index.threshold == index.threshold + assert deserialized_index._deduper == index._deduper + assert deserialized_index._idx_to_entries == index._idx_to_entries + + +def test_empty_index_query(): + """Test querying an empty index.""" + index = PDQIndex2() + + # Query should return empty list + results = index.query(PdqSignal.get_random_signal()) + assert len(results) == 0 + + +def test_sample_set_no_match(): + """Test no matches in sample set.""" + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) + results = index.query("b" * 64) + assert len(results) == 0 + + +def test_duplicate_handling(): + """Test how the index handles duplicate entries.""" + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) + + # Add same hash multiple times + index.add_all(entries=[(base_hashes[0], i) for i in range(3)]) + + results = index.query(base_hashes[0]) + + # Should find all entries associated with the hash + assert len(results) == 4 + for result in results: + assert result.similarity_info.distance == 0 + + +def test_one_entry_sample_index(): + """ + Test how the index handles when it only has one entry. + + See issue github.com/facebook/ThreatExchange/issues/1318 + """ + get_random_hashes = _get_hash_generator() + base_hashes = get_random_hashes(100) + index = PDQIndex2(entries=[(base_hashes[0], 0)]) + + matching_test_hash = base_hashes[0] # This is the existing hash in index + unmatching_test_hash = base_hashes[1] + + results = index.query(matching_test_hash) + # Should find 1 entry associated with the hash + assert len(results) == 1 + assert results[0].similarity_info.distance == 0 + + results = index.query(unmatching_test_hash) + assert len(results) == 0