diff --git a/python-threatexchange/threatexchange/cli/hash_cmd.py b/python-threatexchange/threatexchange/cli/hash_cmd.py index cda59a8cc..947b45a6a 100644 --- a/python-threatexchange/threatexchange/cli/hash_cmd.py +++ b/python-threatexchange/threatexchange/cli/hash_cmd.py @@ -19,7 +19,6 @@ from threatexchange.signal_type.signal_base import FileHasher, SignalType from threatexchange.cli import command_base from threatexchange.cli.helpers import FlexFilesInputAction -from pathlib import Path class HashCommand(command_base.Command): @@ -55,8 +54,6 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No s.get_name() for s in signal_types if issubclass(s, FileHasher) ) - subparsers = ap.add_subparsers(dest="content_type", required=True) - ap.add_argument( "content_type", **common.argparse_choices_pre_type_kwargs( @@ -83,10 +80,8 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No help="only generate these signal types", ) - photo_parser = subparsers.add_parser("photo", help="Hash and preprocess photos") - - photo_parser.add_argument( - "--preprocess", + ap.add_argument( + "--photo-preprocess", choices=["unletterbox", "rotations"], help=( "Apply one of the preprocessing steps to the image before hashing. " @@ -95,17 +90,17 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No ), ) - photo_parser.add_argument( + ap.add_argument( "--black-threshold", type=int, - default=5, + default=10, help=( "Set the black threshold for unletterboxing (default: 5)." "Only applies when 'unletterbox' is selected in --preprocess." ), ) - photo_parser.add_argument( + ap.add_argument( "--save-output", action="store_true", help="If true, saves the processed image as a new file.", @@ -116,20 +111,20 @@ def __init__( content_type: t.Type[ContentType], signal_type: t.Optional[t.Type[SignalType]], files: t.List[pathlib.Path], - preprocess: t.Optional[str] = None, - black_threshold: int = 40, + photo_preprocess: t.Optional[str] = None, + black_threshold: int = 0, save_output: bool = False, ) -> None: - if not issubclass(self.preprocess, PhotoContent): - raise CommandError( - "--preprocess flag is only available for Photo content type", 2 - ) self.content_type = content_type self.signal_type = signal_type - self.preprocess = preprocess + self.photo_preprocess = photo_preprocess self.black_threshold = black_threshold self.save_output = save_output self.files = files + if self.photo_preprocess and not issubclass(self.content_type, PhotoContent): + raise CommandError( + "--photo-preprocess flag is only available for Photo content type", 2 + ) def execute(self, settings: CLISettings) -> None: hashers = [ @@ -146,19 +141,20 @@ def execute(self, settings: CLISettings) -> None: hashers = [self.signal_type] # type: ignore # can't detect intersection types - if self.preprocess: + if self.photo_preprocess: for file in self.files: - updated_bytes = [] + updated_bytes: t.List[bytes] = [] rotation_type = [] - if self.preprocess == "unletterbox": + if self.photo_preprocess == "unletterbox": updated_bytes.append( - PhotoContent.unletterbox(file, self.black_threshold) + PhotoContent.unletterbox(str(file), self.black_threshold) ) - elif self.preprocess == "rotations": + elif self.photo_preprocess == "rotations": with open(file, "rb") as f: image_bytes = f.read() - rotation_type, updated_bytes = ( - PhotoContent.all_simple_rotations(image_bytes).items() + rotations = PhotoContent.all_simple_rotations(image_bytes) + rotation_type, updated_bytes = list(rotations.keys()), list( + rotations.values() ) for idx, bytes_data in enumerate(updated_bytes): with tempfile.NamedTemporaryFile() as temp_file: @@ -168,12 +164,12 @@ def execute(self, settings: CLISettings) -> None: hash_str = hasher.hash_from_file(temp_file_path) if hash_str: print( - f"{rotation_type[idx].name if rotation_type[idx] else ''} {hasher.get_name()} {hash_str}" + f"{rotation_type[idx].name if rotation_type else ''} {hasher.get_name()} {hash_str}" ) if self.save_output: suffix = ( f"_{rotation_type[idx].name}" - if rotation_type[idx] + if rotation_type else "_unletterboxed" ) output_path = file.with_stem(f"{file.stem}{suffix}") @@ -186,36 +182,3 @@ def execute(self, settings: CLISettings) -> None: hash_str = hasher.hash_from_file(file) if hash_str: print(hasher.get_name(), hash_str) - - # if not self.rotations: - # for file in self.files: - # for hasher in hashers: - # if self.preprocess == "unletterbox": - # hash_str = PdqSignal.hash_from_bytes( - # PhotoContent.unletterbox( - # file, self.save_output, self.black_threshold - # ) - # ) - # else: - # hash_str = hasher.hash_from_file(file) - # if hash_str: - # print(hasher.get_name(), hash_str) - # return - - # if not issubclass(self.content_type, PhotoContent): - # raise CommandError( - # "--rotations flag is only available for Photo content type", 2 - # ) - - # for file in self.files: - # with open(file, "rb") as f: - # image_bytes = f.read() - # rotated_images = PhotoContent.all_simple_rotations(image_bytes) - # for rotation_type, rotated_bytes in rotated_images.items(): - # with tempfile.NamedTemporaryFile() as temp_file: # Create a temporary file to hold the byte data - # temp_file.write(rotated_bytes) - # temp_file_path = pathlib.Path(temp_file.name) - # for hasher in hashers: - # hash_str = hasher.hash_from_file(temp_file_path) - # if hash_str: - # print(rotation_type.name, hasher.get_name(), hash_str) diff --git a/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py b/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py index 0eaf96a1e..789e03b0f 100644 --- a/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py +++ b/python-threatexchange/threatexchange/cli/tests/hash_cmd_test.py @@ -83,8 +83,8 @@ def test_rotations_with_non_photo_content( """Test that rotation flag raises error with non-photo content""" for content_type in ["url", "text", "video"]: hash_cli.assert_cli_usage_error( - ("--rotations", content_type, str(tmp_file)), - msg_regex="--rotations flag is only available for Photo content type", + ("--photo-preprocess=rotations", content_type, str(tmp_file)), + msg_regex="--photo-preprocess flag is only available for Photo content type", ) @@ -93,7 +93,7 @@ def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper): test_file = pathlib.Path("threatexchange/tests/hashing/resources/LA.png") hash_cli.assert_cli_output( - ("--rotations", "photo", str(test_file)), + ("--photo-preprocess=rotations", "photo", str(test_file)), [ "ORIGINAL pdq accb6d39648035f8125c8ce6ba65007de7b54c67a2d93ef7b8f33b0611306715", "ROTATE90 pdq 1f70cbbc77edc5f9524faa1b18f3b76cd0a04a833e20f645d229d0acc8499c56", @@ -105,3 +105,49 @@ def test_rotations_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper): "FLIPMINUS1 pdq 5bb15db9e8a1f03c174a380a55aeaa2985bde9c60abce301bde48df918b5c15b", ], ) + + +def test_unletterbox_with_non_photo_content( + hash_cli: ThreatExchangeCLIE2eHelper, tmp_file: pathlib.Path +): + """Test that unletterbox flag raises error with non-photo content""" + for content_type in ["url", "text", "video"]: + hash_cli.assert_cli_usage_error( + ("--photo-preprocess=unletterbox", content_type, str(tmp_file)), + msg_regex="--photo-preprocess flag is only available for Photo content type", + ) + + +def test_unletterbox_with_photo_content(hash_cli: ThreatExchangeCLIE2eHelper): + """Test that photo unletterboxing is properly processed""" + test_file = pathlib.Path( + "threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg" + ) + clean_file = pathlib.Path("threatexchange/tests/hashing/resources/sample-b.jpg") + + hash_cli.assert_cli_output( + ("photo", str(clean_file)), + [ + "pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", + ], + ) + + """Test that photo unletterboxing is chnaged based off of allowed threshold""" + hash_cli.assert_cli_output( + ("--photo-preprocess=unletterbox", "photo", str(test_file)), + [ + "pdq 58f870cce0f4e84d8e378a32028f63f4b36e26f597621e1d33e6b39c4a9c9b22", + ], + ) + + hash_cli.assert_cli_output( + ( + "--photo-preprocess=unletterbox", + "--black-threshold=25", + "photo", + str(test_file), + ), + [ + "pdq f8f8f0cee0f4a84f06370a22038f63f0b36e2ed596621e1d33e6b39c4e9c9b22", + ], + ) diff --git a/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py b/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py index 315251230..7feac455f 100644 --- a/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py +++ b/python-threatexchange/threatexchange/content_type/preprocess/unletterboxing.py @@ -1,16 +1,22 @@ from PIL import Image +def is_pixel_black(pixel, threshold): + # Check if each color channel in the pixel is below the threshold + r, g, b = pixel + return r < threshold and g < threshold and b < threshold + + def detect_top_border(image: Image.Image, black_threshold: int = 0) -> int: """ Detect the top black border by counting rows with only black pixels. - Sums all rgb values for pixel in each row to check for letterboxing - Returns the first row that is not all blacked out from the top. + Checks each RGB channel of each pixel in each row. + Returns the first row that is not all black from the top. """ width, height = image.size for y in range(height): row_pixels = list(image.crop((0, y, width, y + 1)).getdata()) - if all(sum(pixel[:3]) < black_threshold for pixel in row_pixels): + if all(is_pixel_black(pixel, black_threshold) for pixel in row_pixels): continue return y return height @@ -19,13 +25,13 @@ def detect_top_border(image: Image.Image, black_threshold: int = 0) -> int: def detect_bottom_border(image: Image.Image, black_threshold: int = 0) -> int: """ Detect the bottom black border by counting rows with only black pixels from the bottom up. - Sums all rgb values for pixel in each row to check for letterboxing - Returns the first row that is not all blacked out from the bottom. + Checks each RGB channel of each pixel in each row. + Returns the first row that is not all black from the bottom. """ width, height = image.size for y in range(height - 1, -1, -1): row_pixels = list(image.crop((0, y, width, y + 1)).getdata()) - if all(sum(pixel[:3]) < black_threshold for pixel in row_pixels): + if all(is_pixel_black(pixel, black_threshold) for pixel in row_pixels): continue return height - y - 1 return height @@ -34,13 +40,13 @@ def detect_bottom_border(image: Image.Image, black_threshold: int = 0) -> int: def detect_left_border(image: Image.Image, black_threshold: int = 0) -> int: """ Detect the left black border by counting columns with only black pixels. - Sums all rgb values for pixel in each column to check for letterboxing - Returns the first column from the left that is not all blacked out. + Checks each RGB channel of each pixel in each column. + Returns the first column from the left that is not all black. """ width, height = image.size for x in range(width): col_pixels = list(image.crop((x, 0, x + 1, height)).getdata()) - if all(sum(pixel[:3]) < black_threshold for pixel in col_pixels): + if all(is_pixel_black(pixel, black_threshold) for pixel in col_pixels): continue return x return width @@ -49,13 +55,13 @@ def detect_left_border(image: Image.Image, black_threshold: int = 0) -> int: def detect_right_border(image: Image.Image, black_threshold: int = 0) -> int: """ Detect the right black border by counting columns with only black pixels from the right. - Sums all rgb values for pixel in each column to check for letterboxing - Returns the first column from the right that is not all blacked out. + Checks each RGB channel of each pixel in each column. + Returns the first column from the right that is not all black. """ width, height = image.size for x in range(width - 1, -1, -1): col_pixels = list(image.crop((x, 0, x + 1, height)).getdata()) - if all(sum(pixel[:3]) < black_threshold for pixel in col_pixels): + if all(is_pixel_black(pixel, black_threshold) for pixel in col_pixels): continue return width - x - 1 return width diff --git a/python-threatexchange/data/letterboxed_sample-b.jpg b/python-threatexchange/threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg similarity index 100% rename from python-threatexchange/data/letterboxed_sample-b.jpg rename to python-threatexchange/threatexchange/tests/hashing/resources/letterboxed_sample-b.jpg diff --git a/python-threatexchange/threatexchange/tests/hashing/resources/sample-b.jpg b/python-threatexchange/threatexchange/tests/hashing/resources/sample-b.jpg new file mode 100644 index 000000000..66ad092df Binary files /dev/null and b/python-threatexchange/threatexchange/tests/hashing/resources/sample-b.jpg differ diff --git a/python-threatexchange/threatexchange/tests/hashing/test_pdq_letterboxing.py b/python-threatexchange/threatexchange/tests/hashing/test_pdq_letterboxing.py deleted file mode 100644 index d79c851e8..000000000 --- a/python-threatexchange/threatexchange/tests/hashing/test_pdq_letterboxing.py +++ /dev/null @@ -1,55 +0,0 @@ -import unittest -from pathlib import Path -from threatexchange.signal_type.pdq.signal import PdqSignal -from threatexchange.content_type.photo import PhotoContent - - -class TestUnletterboxFunction(unittest.TestCase): - def setUp(self): - # Load the file paths - self.data_dir = (Path(__file__).parent / "../../../data").resolve() - self.letterbox_path = Path(f"{self.data_dir}/letterboxed_sample-b.jpg") - self.clean_path = Path(f"{self.data_dir}/sample-b.jpg") - self.output_path = Path(f"{self.data_dir}/sample-b.png_unletterboxed.jpg") - - def clean(self): - # Removes generated output file if already exists - if self.output_path.exists(): - self.output_path.unlink() - - def test_letterbox_image_without_unletterbox(self): - with self.letterbox_path.open("rb") as f: - letterbox_data = f.read() - - letterbox_hash = PdqSignal.hash_from_bytes(letterbox_data) - - with self.clean_path.open("rb") as f: - clean_data = f.read() - clean_hash = PdqSignal.hash_from_bytes(clean_data) - - # Assert that the hash of the original letterbox image is different from the clean image's hash - self.assertNotEqual( - letterbox_hash, - clean_hash, - "Letterbox image unexpectedly matches the clean image", - ) - - def test_unletterbox_image(self): - unletterboxed_hash = PdqSignal.hash_from_bytes( - PhotoContent.unletterbox(self.letterbox_path,30) - ) - - # Read the clean image data and generate PDQ hash - with self.clean_path.open("rb") as f: - clean_data = f.read() - clean_hash = PdqSignal.hash_from_bytes(clean_data) - - self.assertEqual( - unletterboxed_hash, - clean_hash, - "Unletterboxed image does not match the clean image", - ) - - -if __name__ == "__main__": - unittest.main()