Skip to content

Commit

Permalink
refactor: load encrypted data to avoid reading of chunk if already in…
Browse files Browse the repository at this point in the history
… memeory
  • Loading branch information
bhimrazy committed Jul 17, 2024
1 parent 2fce2cc commit e0a53be
Showing 1 changed file with 26 additions and 22 deletions.
48 changes: 26 additions & 22 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import PyTree, tree_unflatten
from litdata.utilities.encryption import Encryption
from litdata.utilities.encryption import Encryption, EncryptionLevel

Interval = namedtuple("Interval", ["chunk_start", "roi_start_idx", "roi_end_idx", "chunk_end"])

Expand Down Expand Up @@ -149,10 +149,10 @@ def load_item_from_chunk(

self._chunk_filepaths[chunk_filepath] = True

with open(chunk_filepath, "rb", 0) as fp:
if self._config["encryption"]:
data = self._load_encrypted_data(fp, chunk_index, offset, encryption)
else:
if self._config["encryption"]:
data = self._load_encrypted_data(chunk_filepath, chunk_index, offset, encryption)
else:
with open(chunk_filepath, "rb", 0) as fp:
data = self._load_data(fp, offset)

# check for mosaic mds format
Expand All @@ -161,29 +161,33 @@ def load_item_from_chunk(
return self.deserialize(data)

def _load_encrypted_data(
self, fp: Union[FileIO, BytesIO], chunk_index: int, offset: int, encryption: Optional[Encryption]
self, chunk_filepath: str, chunk_index: int, offset: int, encryption: Optional[Encryption]
) -> bytes:
"""Load and decrypt data from a file pointer based on the encryption configuration."""
"""Load and decrypt data from chunk based on the encryption configuration."""

# Validate the provided encryption object against the expected configuration.
self._validate_encryption(encryption)

# If the encryption level is set to 'chunk', decrypt the entire chunk first.
if self._config["encryption"]["level"] == "chunk":
if chunk_index in self._decrypted_chunks:
decrypted_data = self._decrypted_chunks[chunk_index]
else:
encrypted_data = fp.read()
decrypted_data = encryption.decrypt(encrypted_data) # type: ignore
# This would enable us to free the previous chunk from the memory.
self._decrypted_chunks = {chunk_index: decrypted_data}
fp = BytesIO(decrypted_data)

data = self._load_data(fp, offset)
# chunk-level decryption
if self._config["encryption"]["level"] == EncryptionLevel.CHUNK.value:
decrypted_data = self._decrypted_chunks.get(chunk_index, None)
if decrypted_data is None:
with open(chunk_filepath, "rb", 0) as fp:
encrypted_data = fp.read()
decrypted_data = encryption.decrypt(encrypted_data) # type: ignore
# Store the decrypted chunk to avoid re-decryption,
# also allows to free the previous chunk from the memory
self._decrypted_chunks = {chunk_index: decrypted_data}
data = self._load_data(BytesIO(decrypted_data), offset)

# sample-level decryption
elif self._config["encryption"]["level"] == EncryptionLevel.SAMPLE.value:
with open(chunk_filepath, "rb", 0) as fp:
data = self._load_data(fp, offset)
data = encryption.decrypt(data) # type: ignore

# If the encryption level is set to 'sample', decrypt each individual sample after loading.
if self._config["encryption"]["level"] == "sample":
data = encryption.decrypt(data) # type: ignore
else:
raise ValueError("Invalid encryption level.")

return data

Expand Down

0 comments on commit e0a53be

Please sign in to comment.