diff --git a/README.md b/README.md index dea406f9..6ae91bea 100644 --- a/README.md +++ b/README.md @@ -654,6 +654,75 @@ Explore an example setup of litdata with MinIO in the [LitData with MinIO](https +
+ ✅ Supports encryption and decryption of data at chunk/sample level +  + +Secure your data by applying encryption to individual samples or chunks, ensuring sensitive information is protected during storage. + +This example demonstrates how to use the `FernetEncryption` class for sample-level encryption with a data optimization function. + +```python +from litdata import optimize +from litdata.utilities.encryption import FernetEncryption +import numpy as np +from PIL import Image + +# Initialize FernetEncryption with a password for sample-level encryption +fernet = FernetEncryption(password="your_secure_password", level="sample") +data_dir = "s3://my-bucket/optimized_data" + +def random_image(index): + """Generate a random image for demonstration purposes.""" + fake_img = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)) + return {"image": fake_img, "class": index} + +# Optimize data while applying encryption +optimize( + fn=random_image, + inputs=list(range(5)), # Example inputs: [0, 1, 2, 3, 4] + num_workers=1, + output_dir=data_dir, + chunk_bytes="64MB", + encryption=fernet, +) + +# Save the encryption key to a file for later use +fernet.save("fernet.pem") +``` + +You can load the encrypted data using the `StreamingDataset` class as follows: + +```python +from litdata import StreamingDataset +from litdata.utilities.encryption import FernetEncryption + +# Load the encryption key +fernet = FernetEncryption(password="your_secure_password", level="sample") +fernet.load("fernet.pem") + +# Create a streaming dataset for reading the encrypted samples +ds = StreamingDataset(input_dir=data_dir, encryption=fernet) +``` + +If you want to implement your own encryption method, you can subclass the `Encryption` class and define the necessary methods: + +```python +from litdata.utilities.encryption import Encryption + +class CustomEncryption(Encryption): + def encrypt(self, data): + # Implement your custom encryption logic here + return data + + def decrypt(self, data): + # Implement your custom decryption logic here + return data +``` + +With this setup, you can ensure that your data remains secure while maintaining flexibility in how you handle encryption. +
+   ---- diff --git a/requirements/test.txt b/requirements/test.txt index 71de5307..4b337154 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,4 +1,5 @@ coverage ==7.5.3 +cryptography==42.0.8 mosaicml-streaming==0.7.6 pytest ==8.2.* pytest-cov ==5.0.0 diff --git a/src/litdata/constants.py b/src/litdata/constants.py index b001507b..0cb759fa 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -32,6 +32,7 @@ _BOTO3_AVAILABLE = RequirementCache("boto3") _TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio") _ZSTD_AVAILABLE = RequirementCache("zstd") +_CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography") _GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage") _TQDM_AVAILABLE = RequirementCache("tqdm") _LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index f7d2ce89..56480a9f 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -52,6 +52,7 @@ from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads from litdata.utilities.broadcast import broadcast_object from litdata.utilities.dataset_utilities import load_index_file +from litdata.utilities.encryption import Encryption from litdata.utilities.packing import _pack_greedily logger = logging.Logger(__name__) @@ -519,12 +520,12 @@ def _create_cache(self) -> None: if isinstance(self.data_recipe, DataTransformRecipe): return - self.cache = Cache( self.cache_chunks_dir, chunk_bytes=self.data_recipe.chunk_bytes, chunk_size=self.data_recipe.chunk_size, compression=self.data_recipe.compression, + encryption=self.data_recipe.encryption, writer_chunk_index=self.writer_starting_chunk_index, ) self.cache._reader._rank = _get_node_rank() * self.num_workers + self.worker_index @@ -714,6 +715,7 @@ class _Result: num_bytes: Optional[str] = None data_format: Optional[str] = None compression: Optional[str] = None + encryption: Optional[Encryption] = None num_chunks: Optional[int] = None num_bytes_per_chunk: Optional[List[int]] = None @@ -743,6 +745,7 @@ def __init__( chunk_size: Optional[int] = None, chunk_bytes: Optional[Union[int, str]] = None, compression: Optional[str] = None, + encryption: Optional[Encryption] = None, ): super().__init__() if chunk_size is not None and chunk_bytes is not None: @@ -751,6 +754,7 @@ def __init__( self.chunk_size = chunk_size self.chunk_bytes = 1 << 26 if chunk_size is None and chunk_bytes is None else chunk_bytes self.compression = compression + self.encryption = encryption @abstractmethod def prepare_structure(self, input_dir: Optional[str]) -> List[T]: @@ -794,7 +798,6 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul # The platform can't store more than 1024 entries. # Note: This isn't really used right now, so it is fine to skip if too big. num_bytes_per_chunk = [c["chunk_size"] for c in config["chunks"]] if num_chunks < 1024 else [] - return _Result( size=size, num_bytes=num_bytes, @@ -944,7 +947,6 @@ def run(self, data_recipe: DataRecipe) -> None: """The `DataProcessor.run(...)` method triggers the data recipe processing over your dataset.""" if not isinstance(data_recipe, DataRecipe): raise ValueError("The provided value should be a data recipe.") - if not self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe): # clean up checkpoints if not using checkpoints self._cleanup_checkpoints() @@ -959,7 +961,6 @@ def run(self, data_recipe: DataRecipe) -> None: # Call the setup method of the user user_items: List[Any] = data_recipe.prepare_structure(self.input_dir.path if self.input_dir else None) - if not isinstance(user_items, (list, StreamingDataLoader)): raise ValueError("The `prepare_structure` should return a list of item metadata.") diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 2734d5f0..86c2f3f4 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -46,6 +46,7 @@ _resolve_dir, ) from litdata.utilities._pytree import tree_flatten +from litdata.utilities.encryption import Encryption from litdata.utilities.format import _get_tqdm_iterator_if_available @@ -147,9 +148,10 @@ def __init__( chunk_size: Optional[int], chunk_bytes: Optional[Union[int, str]], compression: Optional[str], + encryption: Optional[Encryption] = None, existing_index: Optional[Dict[str, Any]] = None, ): - super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression) + super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption) self._fn = fn self._inputs = inputs self.is_generator = False @@ -296,6 +298,7 @@ def optimize( chunk_size: Optional[int] = None, chunk_bytes: Optional[Union[int, str]] = None, compression: Optional[str] = None, + encryption: Optional[Encryption] = None, num_workers: Optional[int] = None, fast_dev_run: bool = False, num_nodes: Optional[int] = None, @@ -321,6 +324,7 @@ def optimize( chunk_size: The maximum number of elements to hold within a chunk. chunk_bytes: The maximum number of bytes to hold within a chunk. compression: The compression algorithm to use over the chunks. + encryption: The encryption algorithm to use over the chunks. num_workers: The number of workers to use during processing fast_dev_run: Whether to use process only a sub part of the inputs num_nodes: When doing remote execution, the number of nodes to use. Only supported on https://lightning.ai/. @@ -350,7 +354,6 @@ def optimize( if len(inputs) == 0: raise ValueError(f"The provided inputs should be non empty. Found {inputs}.") - if chunk_size is None and chunk_bytes is None: raise ValueError("Either `chunk_size` or `chunk_bytes` needs to be defined.") @@ -433,6 +436,7 @@ def optimize( chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, + encryption=encryption, existing_index=existing_index_file_content, ) ) diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index d303066e..17cf5534 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -24,6 +24,7 @@ from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer from litdata.streaming.writer import BinaryWriter +from litdata.utilities.encryption import Encryption from litdata.utilities.env import _DistributedEnv, _WorkerEnv from litdata.utilities.format import _convert_bytes_to_int @@ -37,6 +38,7 @@ def __init__( subsampled_files: Optional[List[str]] = None, region_of_interest: Optional[List[Tuple[int, int]]] = None, compression: Optional[str] = None, + encryption: Optional[Encryption] = None, chunk_size: Optional[int] = None, chunk_bytes: Optional[Union[int, str]] = None, item_loader: Optional[BaseItemLoader] = None, @@ -69,6 +71,7 @@ def __init__( chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, + encryption=encryption, serializers=serializers, chunk_index=writer_chunk_index or 0, ) @@ -79,6 +82,7 @@ def __init__( max_cache_size=_convert_bytes_to_int(max_cache_size) if isinstance(max_cache_size, str) else max_cache_size, remote_input_dir=input_dir.url, compression=compression, + encryption=encryption, item_loader=item_loader, serializers=serializers, ) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index e0c82087..579ccf6b 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -30,6 +30,7 @@ from litdata.streaming.serializers import Serializer from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset +from litdata.utilities.encryption import Encryption from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv from litdata.utilities.shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion @@ -49,6 +50,7 @@ def __init__( serializers: Optional[Dict[str, Serializer]] = None, max_cache_size: Union[int, str] = "100GB", subsample: float = 1.0, + encryption: Optional[Encryption] = None, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -64,6 +66,7 @@ def __init__( serializers: The serializers used to serialize and deserialize the chunks. max_cache_size: The maximum cache size used by the StreamingDataset. subsample: Float representing fraction of the dataset to be randomly sampled (e.g., 0.1 => 10% of dataset). + encryption: The encryption object to use for decrypting the data. """ super().__init__() @@ -119,6 +122,7 @@ def __init__( self._state_dict: Optional[Dict[str, Any]] = None self.num_workers: Optional[int] = None self.batch_size: Optional[int] = None + self._encryption = encryption def set_shuffle(self, shuffle: bool) -> None: self.shuffle = shuffle @@ -153,6 +157,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache: chunk_bytes=1, serializers=self.serializers, max_cache_size=self.max_cache_size, + encryption=self._encryption, ) cache._reader._try_load_config() diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 2213aaca..75bd3ef6 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -16,8 +16,9 @@ from abc import ABC, abstractmethod from collections import namedtuple from copy import deepcopy +from io import BytesIO, FileIO from time import sleep -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -27,6 +28,7 @@ ) from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import PyTree, tree_unflatten +from litdata.utilities.encryption import Encryption, EncryptionLevel Interval = namedtuple("Interval", ["chunk_start", "roi_start_idx", "roi_end_idx", "chunk_end"]) @@ -83,7 +85,12 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: @abstractmethod def load_item_from_chunk( - self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int + self, + index: int, + chunk_index: int, + chunk_filepath: str, + begin: int, + chunk_bytes: int, ) -> Any: """Returns an item loaded from a chunk.""" pass @@ -99,6 +106,7 @@ class PyTreeLoader(BaseItemLoader): def __init__(self) -> None: self._chunk_filepaths: Dict[str, bool] = {} + self._decrypted_chunks: Dict[int, bytes] = {} def generate_intervals(self) -> List[Interval]: intervals = [] @@ -119,7 +127,13 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: pass def load_item_from_chunk( - self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int + self, + index: int, + chunk_index: int, + chunk_filepath: str, + begin: int, + chunk_bytes: int, + encryption: Optional[Encryption] = None, ) -> bytes: offset = (1 + (index - begin) if index >= begin else index + 1) * 4 @@ -135,18 +149,56 @@ def load_item_from_chunk( self._chunk_filepaths[chunk_filepath] = True - with open(chunk_filepath, "rb", 0) as fp: - fp.seek(offset) - pair = fp.read(8) - begin, end = np.frombuffer(pair, np.uint32) - fp.seek(begin) - data = fp.read(end - begin) + if self._config["encryption"]: + data = self._load_encrypted_data(chunk_filepath, chunk_index, offset, encryption) + else: + with open(chunk_filepath, "rb", 0) as fp: + data = self._load_data(fp, offset) # check for mosaic mds format if "format" in self._config and self._config["format"] == "mds": return self.mds_deserialize(data, chunk_index) return self.deserialize(data) + def _load_encrypted_data( + self, chunk_filepath: str, chunk_index: int, offset: int, encryption: Optional[Encryption] + ) -> bytes: + """Load and decrypt data from chunk based on the encryption configuration.""" + + # Validate the provided encryption object against the expected configuration. + self._validate_encryption(encryption) + + # chunk-level decryption + if self._config["encryption"]["level"] == EncryptionLevel.CHUNK: + decrypted_data = self._decrypted_chunks.get(chunk_index, None) + if decrypted_data is None: + with open(chunk_filepath, "rb", 0) as fp: + encrypted_data = fp.read() + decrypted_data = encryption.decrypt(encrypted_data) # type: ignore + # Store the decrypted chunk to avoid re-decryption, + # also allows to free the previous chunk from the memory + self._decrypted_chunks = {chunk_index: decrypted_data} + data = self._load_data(BytesIO(decrypted_data), offset) + + # sample-level decryption + elif self._config["encryption"]["level"] == EncryptionLevel.SAMPLE: + with open(chunk_filepath, "rb", 0) as fp: + data = self._load_data(fp, offset) + data = encryption.decrypt(data) # type: ignore + + else: + raise ValueError("Invalid encryption level.") + + return data + + def _load_data(self, fp: Union[FileIO, BytesIO], offset: int) -> bytes: + """Load the data from the file pointer.""" + fp.seek(offset) + pair = fp.read(8) + begin, end = np.frombuffer(pair, np.uint32) + fp.seek(begin) + return fp.read(end - begin) + def mds_deserialize(self, raw_item_data: bytes, chunk_index: int) -> "PyTree": """Deserialize the mds raw bytes into their python equivalent.""" idx = 0 @@ -184,6 +236,15 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None: if os.path.exists(chunk_filepath): os.remove(chunk_filepath) + def _validate_encryption(self, encryption: Optional[Encryption]) -> None: + """Validate the encryption object.""" + if not encryption: + raise ValueError("Data is encrypted but no encryption object was provided.") + if encryption.algorithm != self._config["encryption"]["algorithm"]: + raise ValueError("Encryption algorithm mismatch.") + if encryption.level != self._config["encryption"]["level"]: + raise ValueError("Encryption level mismatch.") + class TokensLoader(BaseItemLoader): def __init__(self, block_size: int): @@ -256,7 +317,12 @@ def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None: self._load_chunk(chunk_index, chunk_filepath) def load_item_from_chunk( - self, index: int, chunk_index: int, chunk_filepath: str, begin: int, chunk_bytes: int + self, + index: int, + chunk_index: int, + chunk_filepath: str, + begin: int, + chunk_bytes: int, ) -> torch.Tensor: if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath): del self._chunk_filepaths[chunk_filepath] diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 7ec85a94..0361f573 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -24,6 +24,7 @@ from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer, _get_serializers +from litdata.utilities.encryption import Encryption from litdata.utilities.env import _DistributedEnv, _WorkerEnv warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*") @@ -165,6 +166,7 @@ def __init__( max_cache_size: Optional[Union[int, str]] = None, remote_input_dir: Optional[str] = None, compression: Optional[str] = None, + encryption: Optional[Encryption] = None, item_loader: Optional[BaseItemLoader] = None, serializers: Optional[Dict[str, Serializer]] = None, ) -> None: @@ -177,6 +179,7 @@ def __init__( remote_input_dir: The path to a remote folder where the data are located. The scheme needs to be added to the path. compression: The algorithm to decompress the chunks. + encryption: The algorithm to decrypt the chunks or samples. item_loader: The chunk sampler to create sub arrays from a chunk. max_cache_size: The maximum cache size used by the reader when fetching the chunks. serializers: Provide your own serializers. @@ -192,6 +195,7 @@ def __init__( raise FileNotFoundError(f"The provided cache_dir `{self._cache_dir}` doesn't exist.") self._compression = compression + self._encryption = encryption self._intervals: Optional[List[str]] = None self.subsampled_files = subsampled_files self.region_of_interest = region_of_interest @@ -272,9 +276,15 @@ def read(self, index: ChunkedIndex) -> Any: # Fetch the element chunk_filepath, begin, chunk_bytes = self.config[index] - item = self._item_loader.load_item_from_chunk( - index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes - ) + + if isinstance(self._item_loader, PyTreeLoader): + item = self._item_loader.load_item_from_chunk( + index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes, self._encryption + ) + else: + item = self._item_loader.load_item_from_chunk( + index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes + ) # We need to request deletion after the latest element has been loaded. # Otherwise, this could trigger segmentation fault error depending on the item loader used. diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 56891047..577fbbe3 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -27,6 +27,7 @@ from litdata.streaming.compression import _COMPRESSORS, Compressor from litdata.streaming.serializers import Serializer, _get_serializers from litdata.utilities._pytree import PyTree, tree_flatten, treespec_dumps +from litdata.utilities.encryption import Encryption, EncryptionLevel from litdata.utilities.env import _DistributedEnv, _WorkerEnv from litdata.utilities.format import _convert_bytes_to_int, _human_readable_bytes @@ -49,6 +50,7 @@ def __init__( chunk_size: Optional[int] = None, chunk_bytes: Optional[Union[int, str]] = None, compression: Optional[str] = None, + encryption: Optional[Encryption] = None, follow_tensor_dimension: bool = True, serializers: Optional[Dict[str, Serializer]] = None, chunk_index: Optional[int] = None, @@ -60,6 +62,7 @@ def __init__( chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. compression: The compression algorithm to use. + encryption: The encryption algorithm to use. serializers: Provide your own serializers. chunk_index: The index of the chunk to start from. @@ -79,6 +82,7 @@ def __init__( self._chunk_size = chunk_size self._chunk_bytes = _convert_bytes_to_int(chunk_bytes) if isinstance(chunk_bytes, str) else chunk_bytes self._compression = compression + self._encryption = encryption self._data_format: Optional[List[str]] = None self._data_spec: Optional[PyTree] = None @@ -143,6 +147,7 @@ def get_config(self) -> Dict[str, Any]: "chunk_bytes": self._chunk_bytes, "data_format": self._data_format, "data_spec": treespec_dumps(self._data_spec) if self._data_spec else None, + "encryption": self._encryption.state_dict() if self._encryption else None, } def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]: @@ -238,6 +243,11 @@ def _create_chunk(self, filename: str, on_done: bool = False) -> bytes: current_chunk_bytes = sum([item.bytes for item in items]) + # Whether to encrypt the data at the chunk level + if self._encryption and self._encryption.level == EncryptionLevel.CHUNK: + data = self._encryption.encrypt(data) + current_chunk_bytes = len(data) + if self._chunk_bytes and current_chunk_bytes > self._chunk_bytes: warnings.warn( f"An item was larger than the target chunk size ({_human_readable_bytes(self._chunk_bytes)})." @@ -293,6 +303,11 @@ def add_item(self, index: int, items: Any) -> Optional[str]: raise ValueError(f"The provided index {index} already exists in the cache.") data, dim = self.serialize(items) + + # Whether to encrypt the data at the sample level + if self._encryption and self._encryption.level == EncryptionLevel.SAMPLE: + data = self._encryption.encrypt(data) + self._serialized_items[index] = Item( index=index, data=data, diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 7e0c4838..a45b2f51 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -185,5 +185,6 @@ def adapt_mds_shards_to_chunks(data: Dict[str, Any]) -> Dict[str, Any]: "data_format": shards[0]["column_encodings"], "format": shards[0]["format"], "data_spec": json.dumps(data_spec), + "encryption": None, } return data diff --git a/src/litdata/utilities/encryption.py b/src/litdata/utilities/encryption.py new file mode 100644 index 00000000..ea93e728 --- /dev/null +++ b/src/litdata/utilities/encryption.py @@ -0,0 +1,268 @@ +import base64 +import json +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Literal, Tuple, Union, get_args + +from litdata.constants import _CRYPTOGRAPHY_AVAILABLE + +if _CRYPTOGRAPHY_AVAILABLE: + from cryptography.fernet import Fernet + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import padding, rsa + from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +@dataclass +class EncryptionLevel: + SAMPLE = "sample" + CHUNK = "chunk" + + +EncryptionLevelType = Literal["sample", "chunk"] + + +class Encryption(ABC): + """Base class for encryption algorithm.""" + + @property + @abstractmethod + def algorithm(self) -> str: + pass + + @abstractmethod + def encrypt(self, data: bytes) -> bytes: + pass + + @abstractmethod + def decrypt(self, data: bytes) -> bytes: + pass + + @abstractmethod + def state_dict(self) -> dict: + pass + + @abstractmethod + def save(self, file_path: str) -> None: + pass + + @classmethod + @abstractmethod + def load(cls, file_path: str, password: str) -> Any: + pass + + +class FernetEncryption(Encryption): + """Encryption for the Fernet package. + + Adapted from: https://cryptography.io/en/latest/fernet/ + + """ + + def __init__( + self, + password: str, + level: EncryptionLevelType = "sample", + ) -> None: + super().__init__() + if not _CRYPTOGRAPHY_AVAILABLE: + raise ModuleNotFoundError(str(_CRYPTOGRAPHY_AVAILABLE)) + + if level not in get_args(EncryptionLevelType): + raise ValueError("The provided `level` should be either `sample` or `chunk`") + + self.password = password + self.level = level + self.salt = os.urandom(16) + self.key = self._derive_key(password, self.salt) + self.fernet = Fernet(self.key) + + @property + def algorithm(self) -> str: + return "fernet" + + def _derive_key(self, password: str, salt: bytes) -> bytes: + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=480000, + ) + return base64.urlsafe_b64encode(kdf.derive(password.encode())) + + def encrypt(self, data: bytes) -> bytes: + return self.fernet.encrypt(data) + + def decrypt(self, data: bytes) -> bytes: + return self.fernet.decrypt(data) + + def state_dict(self) -> Dict[str, Any]: + return { + "algorithm": self.algorithm, + "level": self.level, + } + + def save(self, file_path: str) -> None: + state = self.state_dict() + state["salt"] = base64.urlsafe_b64encode(self.salt).decode("utf-8") + with open(file_path, "wb") as file: + file.write(json.dumps(state).encode("utf-8")) + + @classmethod + def load(cls, file_path: str, password: str) -> "FernetEncryption": + with open(file_path, "rb") as file: + state = json.load(file) + + salt = base64.urlsafe_b64decode(state["salt"]) + instance = cls(password=password, level=state["level"]) + instance.salt = salt + instance.key = instance._derive_key(password, salt) + instance.fernet = Fernet(instance.key) + return instance + + +class RSAEncryption(Encryption): + """Encryption for the RSA package. + + Adapted from: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/rsa/ + + """ + + def __init__( + self, + password: str, + level: EncryptionLevelType = "sample", + ) -> None: + if not _CRYPTOGRAPHY_AVAILABLE: + raise ModuleNotFoundError(str(_CRYPTOGRAPHY_AVAILABLE)) + if level not in get_args(EncryptionLevelType): + raise ValueError("The provided `level` should be either `sample` or `chunk`") + + self.password = password + self.level = level + self.private_key, self.public_key = self._generate_keys() + + @property + def algorithm(self) -> str: + return "rsa" + + def _generate_keys(self) -> Tuple[Any, Any]: + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + public_key = private_key.public_key() + return private_key, public_key + + def encrypt(self, data: bytes) -> bytes: + if not self.public_key: + raise AttributeError("Public key not found.") + return self.public_key.encrypt( + data, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + + def decrypt(self, data: bytes) -> bytes: + if not self.private_key: + raise AttributeError("Private key not found.") + return self.private_key.decrypt( + data, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + + def state_dict(self) -> Dict[str, Union[str, None]]: + return { + "algorithm": self.algorithm, + "level": self.level, + } + + def __getstate__(self) -> Dict[str, Union[str, None]]: + encryption_algorithm = ( + serialization.BestAvailableEncryption(self.password.encode()) + if self.password + else serialization.NoEncryption() + ) + return { + "private_key": self.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=encryption_algorithm, + ).decode("utf-8") + if self.private_key + else None, + "public_key": self.public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode("utf-8") + if self.public_key + else None, + "password": self.password, + "level": self.level, + } + + def __setstate__(self, state: Dict[str, Union[str, None]]) -> None: + # Restore the state from the serialized data + self.password = state["password"] if state["password"] else "" + self.level = state["level"] # type: ignore + + if state["private_key"]: + self.private_key = serialization.load_pem_private_key( + state["private_key"].encode("utf-8"), + password=self.password.encode() if self.password else None, + ) + else: + self.private_key = None + + if state["public_key"]: + self.public_key = serialization.load_pem_public_key( + state["public_key"].encode("utf-8"), + ) + else: + self.public_key = None + + def _load_private_key(self, key_path: str, password: str) -> Any: + with open(key_path, "rb") as key_file: + return serialization.load_pem_private_key( + key_file.read(), + password=password.encode(), + ) + + def _load_public_key(self, key_path: str) -> Any: + with open(key_path, "rb") as key_file: + return serialization.load_pem_public_key(key_file.read()) + + def save(self, file_path: str) -> None: + with open(file_path, "wb") as file: + file.write(json.dumps(self.__getstate__()).encode("utf-8")) + + @classmethod + def load(cls, file_path: str, password: str) -> "RSAEncryption": + with open(file_path, "rb") as file: + state = json.load(file) + + instance = cls(password=password, level=state["level"]) + instance.__setstate__(state) + return instance + + def save_keys(self, private_key_path: str, public_key_path: str) -> None: + state = self.__getstate__() + if not state["private_key"] or not state["public_key"]: + raise AttributeError("Keys not found.") + with open(private_key_path, "wb") as key_file: + key_file.write(state["private_key"].encode("utf-8")) + + with open(public_key_path, "wb") as key_file: + key_file.write(state["public_key"].encode("utf-8")) + + def load_keys(self, private_key_path: str, public_key_path: str, password: str) -> None: + self.private_key = self._load_private_key(private_key_path, password) + self.public_key = self._load_public_key(public_key_path) diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index 5bce87b7..a750ad5f 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -2,10 +2,14 @@ import sys from unittest import mock +import cryptography +import numpy as np import pytest from litdata import StreamingDataset, merge_datasets, optimize, walk from litdata.processing.functions import _get_input_dir, _resolve_dir from litdata.streaming.cache import Cache +from litdata.utilities.encryption import FernetEncryption, RSAEncryption +from PIL import Image @pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.") @@ -64,6 +68,11 @@ def another_fn(i: int): return i, i**2 +def random_image(index): + fake_img = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)) + return {"image": fake_img, "class": index} + + @pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow") def test_optimize_append_overwrite(tmpdir): output_dir = str(tmpdir / "output_dir") @@ -272,3 +281,159 @@ def test_merge_datasets(tmpdir): assert len(ds) == 20 assert ds[:] == list(range(20)) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows") +def test_optimize_with_fernet_encryption(tmpdir): + output_dir = str(tmpdir / "output_dir") + + # ----------------- sample level ----------------- + fernet = FernetEncryption(password="password", level="sample") + optimize( + fn=compress, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=fernet, + ) + + ds = StreamingDataset(output_dir, encryption=fernet) + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5)] + + # ----------------- chunk level ----------------- + fernet = FernetEncryption(password="password", level="chunk") + optimize( + fn=compress, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=fernet, + mode="overwrite", + ) + + ds = StreamingDataset(output_dir, encryption=fernet) + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5)] + + # ----------------- test with appending more ----------------- + optimize( + fn=compress, + inputs=list(range(5, 10)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=fernet, + mode="append", + ) + ds = StreamingDataset(output_dir, encryption=fernet) + assert len(ds) == 10 + assert ds[:] == [(i, i**2) for i in range(10)] + + # ----------------- decrypt with different conf ----------------- + ds = StreamingDataset(output_dir) + with pytest.raises(ValueError, match="Data is encrypted but no encryption object was provided."): + ds[0] + + fernet.level = "sample" + ds = StreamingDataset(output_dir, encryption=fernet) + with pytest.raises(ValueError, match="Encryption level mismatch."): + ds[0] + + fernet = FernetEncryption(password="password", level="chunk") + ds = StreamingDataset(output_dir, encryption=fernet) + with pytest.raises(cryptography.fernet.InvalidToken, match=""): + ds[0] + + # ----------------- test with other alg ----------------- + rsa = RSAEncryption(password="password", level="sample") + ds = StreamingDataset(output_dir, encryption=rsa) + with pytest.raises(ValueError, match="Encryption algorithm mismatch."): + ds[0] + + # ----------------- test with random images ----------------- + + fernet = FernetEncryption(password="password", level="chunk") + optimize( + fn=random_image, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=fernet, + mode="overwrite", + ) + + ds = StreamingDataset(output_dir, encryption=fernet) + + assert len(ds) == 5 + assert ds[0]["class"] == 0 + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows") +def test_optimize_with_rsa_encryption(tmpdir): + output_dir = str(tmpdir / "output_dir") + + # ----------------- sample level ----------------- + rsa = RSAEncryption(password="password", level="sample") + optimize( + fn=compress, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=rsa, + ) + + ds = StreamingDataset(output_dir, encryption=rsa) + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5)] + + # ----------------- chunk level ----------------- + rsa = RSAEncryption(password="password", level="chunk") + optimize( + fn=compress, + inputs=list(range(5)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=rsa, + mode="overwrite", + ) + + ds = StreamingDataset(output_dir, encryption=rsa) + assert len(ds) == 5 + assert ds[:] == [(i, i**2) for i in range(5)] + + # ----------------- test with appending more ----------------- + optimize( + fn=compress, + inputs=list(range(5, 10)), + num_workers=1, + output_dir=output_dir, + chunk_bytes="64MB", + encryption=rsa, + mode="append", + ) + ds = StreamingDataset(output_dir, encryption=rsa) + assert len(ds) == 10 + assert ds[:] == [(i, i**2) for i in range(10)] + + # ----------------- decrypt with different conf ----------------- + ds = StreamingDataset(output_dir) + with pytest.raises(ValueError, match="Data is encrypted but no encryption object was provided."): + ds[0] + + # ----------------- test with random images ----------------- + # RSA Encryption throws an error: ValueError: Encryption failed, when trying to encrypt large data + # optimize( + # fn=random_image, + # inputs=list(range(5)), + # num_workers=1, + # output_dir=output_dir, + # chunk_bytes="64MB", + # encryption=rsa, + # mode="overwrite", + # ) diff --git a/tests/utilities/test_encryption.py b/tests/utilities/test_encryption.py new file mode 100644 index 00000000..4cbe44be --- /dev/null +++ b/tests/utilities/test_encryption.py @@ -0,0 +1,79 @@ +import os + +import pytest +from litdata.utilities.encryption import FernetEncryption, RSAEncryption + + +def test_fernet_encryption(tmpdir): + password = "password" + data = b"test data" + fernet = FernetEncryption(password) + encrypted_data = fernet.encrypt(data) + decrypted_data = fernet.decrypt(encrypted_data) + assert data == decrypted_data + assert data != encrypted_data + assert decrypted_data != encrypted_data + assert isinstance(encrypted_data, bytes) + assert isinstance(decrypted_data, bytes) + assert isinstance(fernet.algorithm, str) + assert fernet.algorithm == "fernet" + assert fernet.password == password + assert fernet.key == fernet._derive_key(password, fernet.salt) + assert isinstance(fernet._derive_key(password, os.urandom(16)), bytes) + + # ------ Test for ValueError ------ + with pytest.raises(ValueError, match="The provided `level` should be either `sample` or `chunk`"): + fernet = FernetEncryption(password, level="test") + + # ------ Test for saving and loading fernet instance------ + file_path = tmpdir.join("fernet.txt") + fernet.save(file_path) + fernet_loaded = FernetEncryption.load(file_path, password) + assert fernet_loaded.password == password + assert fernet_loaded.level == fernet.level + assert fernet_loaded.salt == fernet.salt + assert fernet_loaded.key == fernet.key + + decrypted_data_loaded = fernet_loaded.decrypt(encrypted_data) + assert data == decrypted_data_loaded + + +def test_rsa_encryption(tmpdir): + password = "password" + data = b"test data" + rsa = RSAEncryption(password) + encrypted_data = rsa.encrypt(data) + decrypted_data = rsa.decrypt(encrypted_data) + assert data == decrypted_data + assert data != encrypted_data + assert decrypted_data != encrypted_data + assert isinstance(encrypted_data, bytes) + assert isinstance(decrypted_data, bytes) + assert isinstance(rsa.algorithm, str) + assert rsa.algorithm == "rsa" + + # ------ Test for ValueError ------ + with pytest.raises(ValueError, match="The provided `level` should be either `sample` or `chunk`"): + rsa = RSAEncryption(password, level="test") + + # ------ Test for saving and loading rsa instance------ + file_path = tmpdir.join("rsa.txt") + rsa.save(file_path) + + rsa_loaded = RSAEncryption.load(file_path, password) + assert rsa_loaded.level == rsa.level + assert rsa_loaded.password == rsa.password + + decrypted_data_loaded = rsa_loaded.decrypt(encrypted_data) + assert data == decrypted_data_loaded + + # ------ Test for saving and loading rsa instance with password------ + private_key_path = tmpdir.join("rsa_private.pem") + public_key_path = tmpdir.join("rsa_public.pem") + rsa.save_keys(private_key_path, public_key_path) + + rsa_keys_loaded = RSAEncryption(password) + rsa_keys_loaded.load_keys(private_key_path, public_key_path, password) + + decrypted_data_loaded = rsa_keys_loaded.decrypt(encrypted_data) + assert data == decrypted_data_loaded