-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pytx] Implement a new cleaner PDQ index solution #1698
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
300df97
create pdq_index2 and test cases
haianhng31 9fd7744
pytx - continue fixing reimplementation pdq index
haianhng31 3cf072b
pytx - add unittest for pdq index2
haianhng31 cd44cbd
pytx fix lint
haianhng31 a5bfb6d
pytx - add test case for pdq index2
haianhng31 0a32d65
pytx continue edit pdq index2
haianhng31 610af2e
pytx resolve from nit comments for pdq index2
haianhng31 2ea86a1
pytx - edit pdq index2
haianhng31 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
135 changes: 135 additions & 0 deletions
135
python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
""" | ||
Implementation of SignalTypeIndex abstraction for PDQ | ||
""" | ||
|
||
import typing as t | ||
import faiss | ||
import numpy as np | ||
|
||
|
||
from threatexchange.signal_type.index import ( | ||
IndexMatchUntyped, | ||
SignalSimilarityInfoWithIntDistance, | ||
SignalTypeIndex, | ||
T as IndexT, | ||
) | ||
from threatexchange.signal_type.pdq.pdq_utils import ( | ||
BITS_IN_PDQ, | ||
PDQ_CONFIDENT_MATCH_THRESHOLD, | ||
convert_pdq_strings_to_ndarray, | ||
) | ||
|
||
PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT] | ||
|
||
|
||
class PDQIndex2(SignalTypeIndex[IndexT]): | ||
""" | ||
Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search. | ||
|
||
This is a redo of the existing PDQ index, | ||
designed to be simpler and fix hard-to-squash bugs in the existing implementation. | ||
Purpose of this class: to replace the original index in pytx 2.0 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
index: t.Optional[faiss.Index] = None, | ||
entries: t.Iterable[t.Tuple[str, IndexT]] = (), | ||
*, | ||
threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD, | ||
) -> None: | ||
super().__init__() | ||
self.threshold = threshold | ||
|
||
if index is None: | ||
index = faiss.IndexFlatL2(BITS_IN_PDQ) | ||
self._index = _PDQFaissIndex(index) | ||
|
||
# Matches hash to Faiss index | ||
self._deduper: t.Dict[str, int] = {} | ||
# Entry mapping: Each list[entries]'s index is its hash's index | ||
self._idx_to_entries: t.List[t.List[IndexT]] = [] | ||
|
||
self.add_all(entries=entries) | ||
|
||
def __len__(self) -> int: | ||
return len(self._idx_to_entries) | ||
|
||
def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]: | ||
""" | ||
Look up entries against the index, up to the threshold. | ||
""" | ||
results: t.List[PDQIndexMatch[IndexT]] = [] | ||
matches_list: t.List[t.Tuple[int, int]] = self._index.search( | ||
queries=[hash], threshold=self.threshold | ||
) | ||
|
||
for match, distance in matches_list: | ||
entries = self._idx_to_entries[match] | ||
# Create match objects for each entry | ||
results.extend( | ||
PDQIndexMatch( | ||
SignalSimilarityInfoWithIntDistance(distance=distance), | ||
entry, | ||
) | ||
for entry in entries | ||
) | ||
return results | ||
|
||
def add(self, signal_str: str, entry: IndexT) -> None: | ||
self.add_all(((signal_str, entry),)) | ||
|
||
def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None: | ||
for h, i in entries: | ||
existing_faiss_id = self._deduper.get(h) | ||
if existing_faiss_id is None: | ||
self._index.add([h]) | ||
self._idx_to_entries.append([i]) | ||
next_id = len(self._deduper) # Because faiss index starts from 0 up | ||
self._deduper[h] = next_id | ||
else: | ||
# Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication | ||
self._idx_to_entries[existing_faiss_id].append(i) | ||
|
||
|
||
class _PDQFaissIndex: | ||
""" | ||
A wrapper around the faiss index for pickle serialization | ||
""" | ||
|
||
def __init__(self, faiss_index: faiss.Index) -> None: | ||
self.faiss_index = faiss_index | ||
|
||
def add(self, pdq_strings: t.Sequence[str]) -> None: | ||
""" | ||
Add PDQ hashes to the FAISS index. | ||
""" | ||
vectors = convert_pdq_strings_to_ndarray(pdq_strings) | ||
self.faiss_index.add(vectors) | ||
|
||
def search( | ||
self, queries: t.Sequence[str], threshold: int | ||
) -> t.List[t.Tuple[int, int]]: | ||
""" | ||
Search the FAISS index for matches to the given PDQ queries. | ||
""" | ||
query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries) | ||
limits, distances, indices = self.faiss_index.range_search( | ||
query_array, threshold + 1 | ||
) | ||
|
||
results: t.List[t.Tuple[int, int]] = [] | ||
for i in range(len(queries)): | ||
matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]] | ||
dists = [dist for dist in distances[limits[i] : limits[i + 1]]] | ||
for j in range(len(matches)): | ||
results.append((matches[j], dists[j])) | ||
return results | ||
|
||
def __getstate__(self): | ||
return faiss.serialize_index(self.faiss_index) | ||
|
||
def __setstate__(self, data): | ||
self.faiss_index = faiss.deserialize_index(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
175 changes: 175 additions & 0 deletions
175
python-threatexchange/threatexchange/signal_type/tests/test_pdq_index2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import typing as t | ||
import random | ||
import io | ||
import faiss | ||
|
||
from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 | ||
from threatexchange.signal_type.pdq.signal import PdqSignal | ||
from threatexchange.signal_type.pdq.pdq_utils import simple_distance | ||
|
||
|
||
def _get_hash_generator(seed: int = 42): | ||
random.seed(seed) | ||
|
||
def get_n_hashes(n: int): | ||
return [PdqSignal.get_random_signal() for _ in range(n)] | ||
|
||
return get_n_hashes | ||
|
||
|
||
def _brute_force_match( | ||
base: t.List[str], query: str, threshold: int = 32 | ||
) -> t.Set[t.Tuple[int, int]]: | ||
matches = set() | ||
|
||
for i, base_hash in enumerate(base): | ||
distance = simple_distance(base_hash, query) | ||
if distance <= threshold: | ||
matches.add((i, distance)) | ||
return matches | ||
|
||
|
||
def _generate_random_hash_with_distance(hash: str, distance: int) -> str: | ||
if not (0 <= distance <= 256): | ||
raise ValueError("Distance must be between 0 and 256") | ||
|
||
hash_bits = bin(int(hash, 16))[2:].zfill(256) # Convert hash to binary | ||
bits = list(hash_bits) | ||
positions = random.sample( | ||
range(256), distance | ||
) # Randomly select unique positions to flip | ||
for pos in positions: | ||
bits[pos] = "0" if bits[pos] == "1" else "1" # Flip selected bit positions | ||
modified_hash = hex(int("".join(bits), 2))[2:].zfill(64) # Convert back to hex | ||
Dcallies marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return modified_hash | ||
|
||
|
||
def test_pdq_index(): | ||
get_random_hashes = _get_hash_generator() | ||
base_hashes = get_random_hashes(100) | ||
# Make sure base_hashes and query_hashes have at least 10 similar hashes | ||
query_hashes = base_hashes[:10] + get_random_hashes(1000) | ||
|
||
brute_force_matches = { | ||
query_hash: _brute_force_match(base_hashes, query_hash) | ||
for query_hash in query_hashes | ||
} | ||
|
||
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) | ||
|
||
for query_hash in query_hashes: | ||
expected_indices = brute_force_matches[query_hash] | ||
index_results = index.query(query_hash) | ||
|
||
result_indices: t.Set[t.Tuple[t.Any, int]] = { | ||
(result.metadata, result.similarity_info.distance) | ||
for result in index_results | ||
} | ||
|
||
assert result_indices == expected_indices, ( | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"Mismatch for hash {query_hash}: " | ||
f"Expected {expected_indices}, Got {result_indices}" | ||
) | ||
|
||
|
||
def test_pdq_index_with_exact_distance(): | ||
get_random_hashes = _get_hash_generator() | ||
base_hashes = get_random_hashes(100) | ||
|
||
thresholds: t.List[int] = [10, 31, 50] | ||
|
||
indexes = [ | ||
PDQIndex2( | ||
entries=[(h, base_hashes.index(h)) for h in base_hashes], | ||
threshold=thres, | ||
) | ||
for thres in thresholds | ||
] | ||
|
||
distances: t.List[int] = [0, 1, 20, 30, 31, 60] | ||
query_hash = base_hashes[0] | ||
|
||
for i in range(len(indexes)): | ||
index = indexes[i] | ||
|
||
for dist in distances: | ||
query_hash_w_dist = _generate_random_hash_with_distance(query_hash, dist) | ||
results = index.query(query_hash_w_dist) | ||
result_indices = {result.similarity_info.distance for result in results} | ||
if dist <= thresholds[i]: | ||
assert dist in result_indices | ||
|
||
|
||
def test_serialize_deserialize_index(): | ||
get_random_hashes = _get_hash_generator() | ||
base_hashes = get_random_hashes(100) | ||
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) | ||
|
||
buffer = io.BytesIO() | ||
index.serialize(buffer) | ||
buffer.seek(0) | ||
deserialized_index = PDQIndex2.deserialize(buffer) | ||
|
||
assert isinstance(deserialized_index, PDQIndex2) | ||
assert isinstance(deserialized_index._index.faiss_index, faiss.IndexFlatL2) | ||
assert deserialized_index.threshold == index.threshold | ||
assert deserialized_index._deduper == index._deduper | ||
assert deserialized_index._idx_to_entries == index._idx_to_entries | ||
|
||
|
||
def test_empty_index_query(): | ||
"""Test querying an empty index.""" | ||
index = PDQIndex2() | ||
|
||
# Query should return empty list | ||
results = index.query(PdqSignal.get_random_signal()) | ||
assert len(results) == 0 | ||
|
||
|
||
def test_sample_set_no_match(): | ||
"""Test no matches in sample set.""" | ||
get_random_hashes = _get_hash_generator() | ||
base_hashes = get_random_hashes(100) | ||
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) | ||
results = index.query("b" * 64) | ||
assert len(results) == 0 | ||
|
||
|
||
def test_duplicate_handling(): | ||
"""Test how the index handles duplicate entries.""" | ||
get_random_hashes = _get_hash_generator() | ||
base_hashes = get_random_hashes(100) | ||
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes]) | ||
|
||
# Add same hash multiple times | ||
index.add_all(entries=[(base_hashes[0], i) for i in range(3)]) | ||
|
||
results = index.query(base_hashes[0]) | ||
|
||
# Should find all entries associated with the hash | ||
assert len(results) == 4 | ||
for result in results: | ||
assert result.similarity_info.distance == 0 | ||
|
||
|
||
def test_one_entry_sample_index(): | ||
""" | ||
Test how the index handles when it only has one entry. | ||
|
||
See issue github.com/facebook/ThreatExchange/issues/1318 | ||
""" | ||
get_random_hashes = _get_hash_generator() | ||
base_hashes = get_random_hashes(100) | ||
index = PDQIndex2(entries=[(base_hashes[0], 0)]) | ||
|
||
matching_test_hash = base_hashes[0] # This is the existing hash in index | ||
unmatching_test_hash = base_hashes[1] | ||
|
||
results = index.query(matching_test_hash) | ||
# Should find 1 entry associated with the hash | ||
assert len(results) == 1 | ||
assert results[0].similarity_info.distance == 0 | ||
|
||
results = index.query(unmatching_test_hash) | ||
assert len(results) == 0 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can probably leave this off and rely on
build