Skip to content

Commit

Permalink
fix: Updates encryption level handlings with enums
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy committed Jul 18, 2024
1 parent b24e1fe commit bc43c45
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _load_encrypted_data(
self._validate_encryption(encryption)

# chunk-level decryption
if self._config["encryption"]["level"] == EncryptionLevel.CHUNK.value:
if self._config["encryption"]["level"] == EncryptionLevel.CHUNK:
decrypted_data = self._decrypted_chunks.get(chunk_index, None)
if decrypted_data is None:
with open(chunk_filepath, "rb", 0) as fp:
Expand All @@ -181,7 +181,7 @@ def _load_encrypted_data(
data = self._load_data(BytesIO(decrypted_data), offset)

# sample-level decryption
elif self._config["encryption"]["level"] == EncryptionLevel.SAMPLE.value:
elif self._config["encryption"]["level"] == EncryptionLevel.SAMPLE:
with open(chunk_filepath, "rb", 0) as fp:
data = self._load_data(fp, offset)
data = encryption.decrypt(data) # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions src/litdata/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from litdata.streaming.compression import _COMPRESSORS, Compressor
from litdata.streaming.serializers import Serializer, _get_serializers
from litdata.utilities._pytree import PyTree, tree_flatten, treespec_dumps
from litdata.utilities.encryption import Encryption
from litdata.utilities.encryption import Encryption, EncryptionLevel
from litdata.utilities.env import _DistributedEnv, _WorkerEnv
from litdata.utilities.format import _convert_bytes_to_int, _human_readable_bytes

Expand Down Expand Up @@ -244,7 +244,7 @@ def _create_chunk(self, filename: str, on_done: bool = False) -> bytes:
current_chunk_bytes = sum([item.bytes for item in items])

# Whether to encrypt the data at the chunk level
if self._encryption and self._encryption.level == "chunk":
if self._encryption and self._encryption.level == EncryptionLevel.CHUNK:
data = self._encryption.encrypt(data)
current_chunk_bytes = len(data)

Expand Down Expand Up @@ -305,7 +305,7 @@ def add_item(self, index: int, items: Any) -> Optional[str]:
data, dim = self.serialize(items)

# Whether to encrypt the data at the sample level
if self._encryption and self._encryption.level == "sample":
if self._encryption and self._encryption.level == EncryptionLevel.SAMPLE:
data = self._encryption.encrypt(data)

self._serialized_items[index] = Item(
Expand Down
9 changes: 5 additions & 4 deletions src/litdata/utilities/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
from abc import ABC, abstractmethod
from enum import Enum
from dataclasses import dataclass
from typing import Any, Dict, Literal, Tuple, Union, get_args

from litdata.constants import _CRYPTOGRAPHY_AVAILABLE
Expand All @@ -14,7 +14,8 @@
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC


class EncryptionLevel(Enum):
@dataclass
class EncryptionLevel:
SAMPLE = "sample"
CHUNK = "chunk"

Expand Down Expand Up @@ -62,7 +63,7 @@ class FernetEncryption(Encryption):
def __init__(
self,
password: str,
level: EncryptionLevelType = EncryptionLevel.SAMPLE.value,
level: EncryptionLevelType = "sample",
) -> None:
super().__init__()
if not _CRYPTOGRAPHY_AVAILABLE:
Expand Down Expand Up @@ -131,7 +132,7 @@ class RSAEncryption(Encryption):
def __init__(
self,
password: str,
level: EncryptionLevelType = EncryptionLevel.SAMPLE.value,
level: EncryptionLevelType = "sample",
) -> None:
if not _CRYPTOGRAPHY_AVAILABLE:
raise ModuleNotFoundError(str(_CRYPTOGRAPHY_AVAILABLE))
Expand Down

0 comments on commit bc43c45

Please sign in to comment.