Skip to content

Commit

Permalink
Feature: Add support for encryption and decryption of data at chunk/s…
Browse files Browse the repository at this point in the history
…ample level (#219)
  • Loading branch information
bhimrazy authored Jul 18, 2024
1 parent 44b8a40 commit 36c1f10
Show file tree
Hide file tree
Showing 14 changed files with 708 additions and 19 deletions.
69 changes: 69 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,75 @@ Explore an example setup of litdata with MinIO in the [LitData with MinIO](https

</details>

<details>
<summary> ✅ Supports encryption and decryption of data at chunk/sample level</summary>
&nbsp;

Secure your data by applying encryption to individual samples or chunks, ensuring sensitive information is protected during storage.

This example demonstrates how to use the `FernetEncryption` class for sample-level encryption with a data optimization function.

```python
from litdata import optimize
from litdata.utilities.encryption import FernetEncryption
import numpy as np
from PIL import Image

# Initialize FernetEncryption with a password for sample-level encryption
fernet = FernetEncryption(password="your_secure_password", level="sample")
data_dir = "s3://my-bucket/optimized_data"

def random_image(index):
"""Generate a random image for demonstration purposes."""
fake_img = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8))
return {"image": fake_img, "class": index}

# Optimize data while applying encryption
optimize(
fn=random_image,
inputs=list(range(5)), # Example inputs: [0, 1, 2, 3, 4]
num_workers=1,
output_dir=data_dir,
chunk_bytes="64MB",
encryption=fernet,
)

# Save the encryption key to a file for later use
fernet.save("fernet.pem")
```

You can load the encrypted data using the `StreamingDataset` class as follows:

```python
from litdata import StreamingDataset
from litdata.utilities.encryption import FernetEncryption

# Load the encryption key
fernet = FernetEncryption(password="your_secure_password", level="sample")
fernet.load("fernet.pem")

# Create a streaming dataset for reading the encrypted samples
ds = StreamingDataset(input_dir=data_dir, encryption=fernet)
```

If you want to implement your own encryption method, you can subclass the `Encryption` class and define the necessary methods:

```python
from litdata.utilities.encryption import Encryption

class CustomEncryption(Encryption):
def encrypt(self, data):
# Implement your custom encryption logic here
return data

def decrypt(self, data):
# Implement your custom decryption logic here
return data
```

With this setup, you can ensure that your data remains secure while maintaining flexibility in how you handle encryption.
</details>

&nbsp;

----
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
coverage ==7.5.3
cryptography==42.0.8
mosaicml-streaming==0.7.6
pytest ==8.2.*
pytest-cov ==5.0.0
Expand Down
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_BOTO3_AVAILABLE = RequirementCache("boto3")
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")
_CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography")
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
_TQDM_AVAILABLE = RequirementCache("tqdm")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
Expand Down
9 changes: 5 additions & 4 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads
from litdata.utilities.broadcast import broadcast_object
from litdata.utilities.dataset_utilities import load_index_file
from litdata.utilities.encryption import Encryption
from litdata.utilities.packing import _pack_greedily

logger = logging.Logger(__name__)
Expand Down Expand Up @@ -519,12 +520,12 @@ def _create_cache(self) -> None:

if isinstance(self.data_recipe, DataTransformRecipe):
return

self.cache = Cache(
self.cache_chunks_dir,
chunk_bytes=self.data_recipe.chunk_bytes,
chunk_size=self.data_recipe.chunk_size,
compression=self.data_recipe.compression,
encryption=self.data_recipe.encryption,
writer_chunk_index=self.writer_starting_chunk_index,
)
self.cache._reader._rank = _get_node_rank() * self.num_workers + self.worker_index
Expand Down Expand Up @@ -714,6 +715,7 @@ class _Result:
num_bytes: Optional[str] = None
data_format: Optional[str] = None
compression: Optional[str] = None
encryption: Optional[Encryption] = None
num_chunks: Optional[int] = None
num_bytes_per_chunk: Optional[List[int]] = None

Expand Down Expand Up @@ -743,6 +745,7 @@ def __init__(
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
compression: Optional[str] = None,
encryption: Optional[Encryption] = None,
):
super().__init__()
if chunk_size is not None and chunk_bytes is not None:
Expand All @@ -751,6 +754,7 @@ def __init__(
self.chunk_size = chunk_size
self.chunk_bytes = 1 << 26 if chunk_size is None and chunk_bytes is None else chunk_bytes
self.compression = compression
self.encryption = encryption

@abstractmethod
def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
Expand Down Expand Up @@ -794,7 +798,6 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
# The platform can't store more than 1024 entries.
# Note: This isn't really used right now, so it is fine to skip if too big.
num_bytes_per_chunk = [c["chunk_size"] for c in config["chunks"]] if num_chunks < 1024 else []

return _Result(
size=size,
num_bytes=num_bytes,
Expand Down Expand Up @@ -944,7 +947,6 @@ def run(self, data_recipe: DataRecipe) -> None:
"""The `DataProcessor.run(...)` method triggers the data recipe processing over your dataset."""
if not isinstance(data_recipe, DataRecipe):
raise ValueError("The provided value should be a data recipe.")

if not self.use_checkpoint and isinstance(data_recipe, DataChunkRecipe):
# clean up checkpoints if not using checkpoints
self._cleanup_checkpoints()
Expand All @@ -959,7 +961,6 @@ def run(self, data_recipe: DataRecipe) -> None:

# Call the setup method of the user
user_items: List[Any] = data_recipe.prepare_structure(self.input_dir.path if self.input_dir else None)

if not isinstance(user_items, (list, StreamingDataLoader)):
raise ValueError("The `prepare_structure` should return a list of item metadata.")

Expand Down
8 changes: 6 additions & 2 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
_resolve_dir,
)
from litdata.utilities._pytree import tree_flatten
from litdata.utilities.encryption import Encryption
from litdata.utilities.format import _get_tqdm_iterator_if_available


Expand Down Expand Up @@ -147,9 +148,10 @@ def __init__(
chunk_size: Optional[int],
chunk_bytes: Optional[Union[int, str]],
compression: Optional[str],
encryption: Optional[Encryption] = None,
existing_index: Optional[Dict[str, Any]] = None,
):
super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption)
self._fn = fn
self._inputs = inputs
self.is_generator = False
Expand Down Expand Up @@ -296,6 +298,7 @@ def optimize(
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
compression: Optional[str] = None,
encryption: Optional[Encryption] = None,
num_workers: Optional[int] = None,
fast_dev_run: bool = False,
num_nodes: Optional[int] = None,
Expand All @@ -321,6 +324,7 @@ def optimize(
chunk_size: The maximum number of elements to hold within a chunk.
chunk_bytes: The maximum number of bytes to hold within a chunk.
compression: The compression algorithm to use over the chunks.
encryption: The encryption algorithm to use over the chunks.
num_workers: The number of workers to use during processing
fast_dev_run: Whether to use process only a sub part of the inputs
num_nodes: When doing remote execution, the number of nodes to use. Only supported on https://lightning.ai/.
Expand Down Expand Up @@ -350,7 +354,6 @@ def optimize(

if len(inputs) == 0:
raise ValueError(f"The provided inputs should be non empty. Found {inputs}.")

if chunk_size is None and chunk_bytes is None:
raise ValueError("Either `chunk_size` or `chunk_bytes` needs to be defined.")

Expand Down Expand Up @@ -433,6 +436,7 @@ def optimize(
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,
encryption=encryption,
existing_index=existing_index_file_content,
)
)
Expand Down
4 changes: 4 additions & 0 deletions src/litdata/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from litdata.streaming.sampler import ChunkedIndex
from litdata.streaming.serializers import Serializer
from litdata.streaming.writer import BinaryWriter
from litdata.utilities.encryption import Encryption
from litdata.utilities.env import _DistributedEnv, _WorkerEnv
from litdata.utilities.format import _convert_bytes_to_int

Expand All @@ -37,6 +38,7 @@ def __init__(
subsampled_files: Optional[List[str]] = None,
region_of_interest: Optional[List[Tuple[int, int]]] = None,
compression: Optional[str] = None,
encryption: Optional[Encryption] = None,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
item_loader: Optional[BaseItemLoader] = None,
Expand Down Expand Up @@ -69,6 +71,7 @@ def __init__(
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,
encryption=encryption,
serializers=serializers,
chunk_index=writer_chunk_index or 0,
)
Expand All @@ -79,6 +82,7 @@ def __init__(
max_cache_size=_convert_bytes_to_int(max_cache_size) if isinstance(max_cache_size, str) else max_cache_size,
remote_input_dir=input_dir.url,
compression=compression,
encryption=encryption,
item_loader=item_loader,
serializers=serializers,
)
Expand Down
5 changes: 5 additions & 0 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from litdata.streaming.serializers import Serializer
from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset
from litdata.utilities.encryption import Encryption
from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv
from litdata.utilities.shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion

Expand All @@ -49,6 +50,7 @@ def __init__(
serializers: Optional[Dict[str, Serializer]] = None,
max_cache_size: Union[int, str] = "100GB",
subsample: float = 1.0,
encryption: Optional[Encryption] = None,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
Expand All @@ -64,6 +66,7 @@ def __init__(
serializers: The serializers used to serialize and deserialize the chunks.
max_cache_size: The maximum cache size used by the StreamingDataset.
subsample: Float representing fraction of the dataset to be randomly sampled (e.g., 0.1 => 10% of dataset).
encryption: The encryption object to use for decrypting the data.
"""
super().__init__()
Expand Down Expand Up @@ -119,6 +122,7 @@ def __init__(
self._state_dict: Optional[Dict[str, Any]] = None
self.num_workers: Optional[int] = None
self.batch_size: Optional[int] = None
self._encryption = encryption

def set_shuffle(self, shuffle: bool) -> None:
self.shuffle = shuffle
Expand Down Expand Up @@ -153,6 +157,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
chunk_bytes=1,
serializers=self.serializers,
max_cache_size=self.max_cache_size,
encryption=self._encryption,
)
cache._reader._try_load_config()

Expand Down
Loading

0 comments on commit 36c1f10

Please sign in to comment.