Skip to content

Commit

Permalink
feat: Adds test cases for optimizing and reading dataset with encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy committed Jul 14, 2024
1 parent 184b46c commit cfb08dc
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions tests/processing/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
import sys
from unittest import mock

import cryptography
import numpy as np
import pytest
from litdata import StreamingDataset, merge_datasets, optimize, walk
from litdata.processing.functions import _get_input_dir, _resolve_dir
from litdata.streaming.cache import Cache
from litdata.utilities.encryption import FernetEncryption, RSAEncryption
from PIL import Image


@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
Expand Down Expand Up @@ -64,6 +68,11 @@ def another_fn(i: int):
return i, i**2


def random_image(index):
fake_img = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8))
return {"image": fake_img, "class": index}


@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow")
def test_optimize_append_overwrite(tmpdir):
output_dir = str(tmpdir / "output_dir")
Expand Down Expand Up @@ -272,3 +281,114 @@ def test_merge_datasets(tmpdir):

assert len(ds) == 20
assert ds[:] == list(range(20))


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
def test_optimize_with_fernet_encryption(tmpdir):
output_dir = str(tmpdir / "output_dir")

# ----------------- sample level -----------------
fernet = FernetEncryption(password="password", level="sample")
optimize(
fn=compress,
inputs=list(range(5)),
num_workers=1,
output_dir=output_dir,
chunk_bytes="64MB",
encryption=fernet,
)

ds = StreamingDataset(output_dir, encryption=fernet)
assert len(ds) == 5
assert ds[:] == [(i, i**2) for i in range(5)]

# ----------------- chunk level -----------------
fernet = FernetEncryption(password="password", level="chunk")
optimize(
fn=compress,
inputs=list(range(5)),
num_workers=1,
output_dir=output_dir,
chunk_bytes="64MB",
encryption=fernet,
mode="overwrite",
)

ds = StreamingDataset(output_dir, encryption=fernet)
assert len(ds) == 5
assert ds[:] == [(i, i**2) for i in range(5)]

# ----------------- decrypt with different conf -----------------

fernet.level = "sample"
ds = StreamingDataset(output_dir, encryption=fernet)
with pytest.raises(ValueError, match="Encryption level mismatch."):
ds[0]

fernet = FernetEncryption(password="password", level="chunk")
ds = StreamingDataset(output_dir, encryption=fernet)
with pytest.raises(cryptography.fernet.InvalidToken, match=""):
ds[0]

# ----------------- test with other alg -----------------
rsa = RSAEncryption(password="password", level="sample")
ds = StreamingDataset(output_dir, encryption=rsa)
with pytest.raises(ValueError, match="Encryption algorithm mismatch."):
ds[0]

# ----------------- test with random images -----------------

fernet = FernetEncryption(password="password", level="chunk")
optimize(
fn=random_image,
inputs=list(range(5)),
num_workers=1,
output_dir=output_dir,
chunk_bytes="64MB",
encryption=fernet,
mode="overwrite",
)

ds = StreamingDataset(output_dir, encryption=fernet)

assert len(ds) == 5
assert ds[0]["class"] == 0


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
def test_optimize_with_rsa_encryption(tmpdir):
output_dir = str(tmpdir / "output_dir")

# ----------------- sample level -----------------
rsa = RSAEncryption(password="password", level="sample")
optimize(
fn=compress,
inputs=list(range(5)),
num_workers=1,
output_dir=output_dir,
chunk_bytes="64MB",
encryption=rsa,
)

ds = StreamingDataset(output_dir, encryption=rsa)
assert len(ds) == 5
assert ds[:] == [(i, i**2) for i in range(5)]

# ----------------- chunk level -----------------
rsa = RSAEncryption(password="password", level="chunk")
optimize(
fn=compress,
inputs=list(range(5)),
num_workers=1,
output_dir=output_dir,
chunk_bytes="64MB",
encryption=rsa,
mode="overwrite",
)

ds = StreamingDataset(output_dir, encryption=rsa)
assert len(ds) == 5
assert ds[:] == [(i, i**2) for i in range(5)]

# ----------------- test with random images -----------------
# RSA Encryption throws an error when trying to encrypt large data

0 comments on commit cfb08dc

Please sign in to comment.