Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytx] Implement a new cleaner PDQ index solution #1698

Merged
merged 8 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions python-threatexchange/threatexchange/signal_type/pdq/pdq_index2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Implementation of SignalTypeIndex abstraction for PDQ
"""

import typing as t
import faiss
import numpy as np


from threatexchange.signal_type.index import (
IndexMatchUntyped,
SignalSimilarityInfoWithIntDistance,
SignalTypeIndex,
T as IndexT,
)
from threatexchange.signal_type.pdq.pdq_utils import (
BITS_IN_PDQ,
PDQ_CONFIDENT_MATCH_THRESHOLD,
convert_pdq_strings_to_ndarray,
)

PDQIndexMatch = IndexMatchUntyped[SignalSimilarityInfoWithIntDistance, IndexT]


class PDQIndex2(SignalTypeIndex[IndexT]):
"""
Indexing and querying PDQ signals using Faiss for approximate nearest neighbor search.

This is a redo of the existing PDQ index,
designed to be simpler and fix hard-to-squash bugs in the existing implementation.
Purpose of this class: to replace the original index in pytx 2.0
"""

def __init__(
self,
index: t.Optional[faiss.Index] = None,
entries: t.Iterable[t.Tuple[str, IndexT]] = (),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can probably leave this off and rely on build

*,
threshold: int = PDQ_CONFIDENT_MATCH_THRESHOLD,
) -> None:
super().__init__()
self.threshold = threshold

if index is None:
index = faiss.IndexFlatL2(BITS_IN_PDQ)
self._index = _PDQFaissIndex(index)

# Matches hash to Faiss index
self._deduper: t.Dict[str, int] = {}
# Entry mapping: Each list[entries]'s index is its hash's index
self._idx_to_entries: t.List[t.List[IndexT]] = []

self.add_all(entries=entries)

def __len__(self) -> int:
return len(self._idx_to_entries)

def query(self, hash: str) -> t.Sequence[PDQIndexMatch[IndexT]]:
"""
Look up entries against the index, up to the threshold.
"""
results: t.List[PDQIndexMatch[IndexT]] = []
matches_list: t.List[t.Tuple[int, int]] = self._index.search(
queries=[hash], threshold=self.threshold
)

for match, distance in matches_list:
entries = self._idx_to_entries[match]
# Create match objects for each entry
results.extend(
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(distance=distance),
entry,
)
for entry in entries
)
return results

def add(self, signal_str: str, entry: IndexT) -> None:
self.add_all(((signal_str, entry),))

def add_all(self, entries: t.Iterable[t.Tuple[str, IndexT]]) -> None:
for h, i in entries:
existing_faiss_id = self._deduper.get(h)
if existing_faiss_id is None:
self._index.add([h])
self._idx_to_entries.append([i])
next_id = len(self._deduper) # Because faiss index starts from 0 up
self._deduper[h] = next_id
else:
# Since this already exists, we don't add it to Faiss because Faiss cannot handle duplication
self._idx_to_entries[existing_faiss_id].append(i)


class _PDQFaissIndex:
"""
A wrapper around the faiss index for pickle serialization
"""

def __init__(self, faiss_index: faiss.Index) -> None:
self.faiss_index = faiss_index

def add(self, pdq_strings: t.Sequence[str]) -> None:
"""
Add PDQ hashes to the FAISS index.
"""
vectors = convert_pdq_strings_to_ndarray(pdq_strings)
self.faiss_index.add(vectors)

def search(
self, queries: t.Sequence[str], threshold: int
) -> t.List[t.Tuple[int, int]]:
"""
Search the FAISS index for matches to the given PDQ queries.
"""
query_array: np.ndarray = convert_pdq_strings_to_ndarray(queries)
limits, distances, indices = self.faiss_index.range_search(
query_array, threshold + 1
)

results: t.List[t.Tuple[int, int]] = []
for i in range(len(queries)):
matches = [idx.item() for idx in indices[limits[i] : limits[i + 1]]]
dists = [dist for dist in distances[limits[i] : limits[i + 1]]]
for j in range(len(matches)):
results.append((matches[j], dists[j]))
return results

def __getstate__(self):
return faiss.serialize_index(self.faiss_index)

def __setstate__(self, data):
self.faiss_index = faiss.deserialize_index(data)
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.

import numpy as np
import typing as t

BITS_IN_PDQ = 256
PDQ_HEX_STR_LEN = int(BITS_IN_PDQ / 4)
# Hashes of distance less than or equal to this threshold are considered a 'match'
PDQ_CONFIDENT_MATCH_THRESHOLD = 31


def simple_distance_binary(bin_a: str, bin_b: str) -> int:
Expand Down Expand Up @@ -49,3 +54,18 @@ def pdq_match(pdq_hex_a: str, pdq_hex_b: str, threshold: int) -> bool:
"""
distance = simple_distance(pdq_hex_a, pdq_hex_b)
return distance <= threshold


def convert_pdq_strings_to_ndarray(pdq_strings: t.Iterable[str]) -> np.ndarray:
"""
Convert multiple PDQ hash strings to a numpy array.
"""
binary_arrays = []
for pdq_str in pdq_strings:
if len(pdq_str) != PDQ_HEX_STR_LEN:
raise ValueError("PDQ hash string must be 64 hex characters long")
hash_bytes = bytes.fromhex(pdq_str)
binary_array = np.unpackbits(np.frombuffer(hash_bytes, dtype=np.uint8))
binary_arrays.append(binary_array)

return np.array(binary_arrays, dtype=np.uint8)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.photo import PhotoContent
from threatexchange.signal_type import signal_base
from threatexchange.signal_type.pdq.pdq_utils import simple_distance
from threatexchange.signal_type.pdq.pdq_utils import (
simple_distance,
PDQ_CONFIDENT_MATCH_THRESHOLD,
)
from threatexchange.exchanges.impl.fb_threatexchange_signal import (
HasFbThreatExchangeIndicatorType,
)
Expand Down Expand Up @@ -42,8 +45,6 @@ class PdqSignal(
INDICATOR_TYPE = "HASH_PDQ"

# This may need to be updated (TODO make more configurable)
# Hashes of distance less than or equal to this threshold are considered a 'match'
PDQ_CONFIDENT_MATCH_THRESHOLD = 31
# Images with less than quality 50 are too unreliable to match on
QUALITY_THRESHOLD = 50

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import typing as t
import random
import io
import faiss

from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2
from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.pdq.pdq_utils import simple_distance


def _get_hash_generator(seed: int = 42):
random.seed(seed)

def get_n_hashes(n: int):
return [PdqSignal.get_random_signal() for _ in range(n)]

return get_n_hashes


def _brute_force_match(
base: t.List[str], query: str, threshold: int = 32
) -> t.Set[t.Tuple[int, int]]:
matches = set()

for i, base_hash in enumerate(base):
distance = simple_distance(base_hash, query)
if distance <= threshold:
matches.add((i, distance))
return matches


def _generate_random_hash_with_distance(hash: str, distance: int) -> str:
if not (0 <= distance <= 256):
raise ValueError("Distance must be between 0 and 256")

hash_bits = bin(int(hash, 16))[2:].zfill(256) # Convert hash to binary
bits = list(hash_bits)
positions = random.sample(
range(256), distance
) # Randomly select unique positions to flip
for pos in positions:
bits[pos] = "0" if bits[pos] == "1" else "1" # Flip selected bit positions
modified_hash = hex(int("".join(bits), 2))[2:].zfill(64) # Convert back to hex
Dcallies marked this conversation as resolved.
Show resolved Hide resolved

return modified_hash


def test_pdq_index():
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
# Make sure base_hashes and query_hashes have at least 10 similar hashes
query_hashes = base_hashes[:10] + get_random_hashes(1000)

brute_force_matches = {
query_hash: _brute_force_match(base_hashes, query_hash)
for query_hash in query_hashes
}

index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes])

for query_hash in query_hashes:
expected_indices = brute_force_matches[query_hash]
index_results = index.query(query_hash)

result_indices: t.Set[t.Tuple[t.Any, int]] = {
(result.metadata, result.similarity_info.distance)
for result in index_results
}

assert result_indices == expected_indices, (
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved
f"Mismatch for hash {query_hash}: "
f"Expected {expected_indices}, Got {result_indices}"
)


def test_pdq_index_with_exact_distance():
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)

thresholds: t.List[int] = [10, 31, 50]

indexes = [
PDQIndex2(
entries=[(h, base_hashes.index(h)) for h in base_hashes],
threshold=thres,
)
for thres in thresholds
]

distances: t.List[int] = [0, 1, 20, 30, 31, 60]
query_hash = base_hashes[0]

for i in range(len(indexes)):
index = indexes[i]

for dist in distances:
query_hash_w_dist = _generate_random_hash_with_distance(query_hash, dist)
results = index.query(query_hash_w_dist)
result_indices = {result.similarity_info.distance for result in results}
if dist <= thresholds[i]:
assert dist in result_indices


def test_serialize_deserialize_index():
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes])

buffer = io.BytesIO()
index.serialize(buffer)
buffer.seek(0)
deserialized_index = PDQIndex2.deserialize(buffer)

assert isinstance(deserialized_index, PDQIndex2)
assert isinstance(deserialized_index._index.faiss_index, faiss.IndexFlatL2)
assert deserialized_index.threshold == index.threshold
assert deserialized_index._deduper == index._deduper
assert deserialized_index._idx_to_entries == index._idx_to_entries


def test_empty_index_query():
"""Test querying an empty index."""
index = PDQIndex2()

# Query should return empty list
results = index.query(PdqSignal.get_random_signal())
assert len(results) == 0


def test_sample_set_no_match():
"""Test no matches in sample set."""
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes])
results = index.query("b" * 64)
assert len(results) == 0


def test_duplicate_handling():
"""Test how the index handles duplicate entries."""
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
index = PDQIndex2(entries=[(h, base_hashes.index(h)) for h in base_hashes])

# Add same hash multiple times
index.add_all(entries=[(base_hashes[0], i) for i in range(3)])

results = index.query(base_hashes[0])

# Should find all entries associated with the hash
assert len(results) == 4
for result in results:
assert result.similarity_info.distance == 0


def test_one_entry_sample_index():
"""
Test how the index handles when it only has one entry.

See issue github.com/facebook/ThreatExchange/issues/1318
"""
get_random_hashes = _get_hash_generator()
base_hashes = get_random_hashes(100)
index = PDQIndex2(entries=[(base_hashes[0], 0)])

matching_test_hash = base_hashes[0] # This is the existing hash in index
unmatching_test_hash = base_hashes[1]

results = index.query(matching_test_hash)
# Should find 1 entry associated with the hash
assert len(results) == 1
assert results[0].similarity_info.distance == 0

results = index.query(unmatching_test_hash)
assert len(results) == 0