diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index 5bce87b7..bf79a490 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -2,10 +2,14 @@ import sys from unittest import mock +import cryptography +import numpy as np import pytest from litdata import StreamingDataset, merge_datasets, optimize, walk from litdata.processing.functions import _get_input_dir, _resolve_dir from litdata.streaming.cache import Cache +from litdata.utilities.encryption import FernetEncryption, RSAEncryption +from PIL import Image @pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.") @@ -64,6 +68,11 @@ def another_fn(i: int): return i, i**2 +def random_image(index): + fake_img = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)) + return {"image": fake_img, "class": index} + + @pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow") def test_optimize_append_overwrite(tmpdir): output_dir = str(tmpdir / "output_dir") @@ -272,3 +281,114 @@ def test_merge_datasets(tmpdir): assert len(ds) == 20 assert ds[:] == list(range(20)) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows") +def test_optimize_with_fernet_encryption(tmpdir): + output_dir = str(tmpdir / "output_dir") + + # ----------------- sample level ----------------- + fernet = FernetEncryption(password="password", level="sample") + optimize( + fn=compress, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=fernet, + ) + + ds = StreamingDataset(output_dir, encryption=fernet) + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5)] + + # ----------------- chunk level ----------------- + fernet = FernetEncryption(password="password", level="chunk") + optimize( + fn=compress, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=fernet, + mode="overwrite", + ) + + ds = StreamingDataset(output_dir, encryption=fernet) + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5)] + + # ----------------- decrypt with different conf ----------------- + + fernet.level = "sample" + ds = StreamingDataset(output_dir, encryption=fernet) + with pytest.raises(ValueError, match="Encryption level mismatch."): + ds[0] + + fernet = FernetEncryption(password="password", level="chunk") + ds = StreamingDataset(output_dir, encryption=fernet) + with pytest.raises(cryptography.fernet.InvalidToken, match=""): + ds[0] + + # ----------------- test with other alg ----------------- + rsa = RSAEncryption(password="password", level="sample") + ds = StreamingDataset(output_dir, encryption=rsa) + with pytest.raises(ValueError, match="Encryption algorithm mismatch."): + ds[0] + + # ----------------- test with random images ----------------- + + fernet = FernetEncryption(password="password", level="chunk") + optimize( + fn=random_image, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=fernet, + mode="overwrite", + ) + + ds = StreamingDataset(output_dir, encryption=fernet) + + assert len(ds) == 5 + assert ds[0]["class"] == 0 + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows") +def test_optimize_with_rsa_encryption(tmpdir): + output_dir = str(tmpdir / "output_dir") + + # ----------------- sample level ----------------- + rsa = RSAEncryption(password="password", level="sample") + optimize( + fn=compress, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=rsa, + ) + + ds = StreamingDataset(output_dir, encryption=rsa) + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5)] + + # ----------------- chunk level ----------------- + rsa = RSAEncryption(password="password", level="chunk") + optimize( + fn=compress, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=rsa, + mode="overwrite", + ) + + ds = StreamingDataset(output_dir, encryption=rsa) + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5)] + + # ----------------- test with random images ----------------- + # RSA Encryption throws an error when trying to encrypt large data