From e0a53be1fd2f854872d2da7a6e6d218f3fbd7090 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Wed, 17 Jul 2024 12:25:58 +0545 Subject: [PATCH] refactor: load encrypted data to avoid reading of chunk if already in memeory --- src/litdata/streaming/item_loader.py | 48 +++++++++++++++------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index f13ec765..0a4adc98 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -28,7 +28,7 @@ ) from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import PyTree, tree_unflatten -from litdata.utilities.encryption import Encryption +from litdata.utilities.encryption import Encryption, EncryptionLevel Interval = namedtuple("Interval", ["chunk_start", "roi_start_idx", "roi_end_idx", "chunk_end"]) @@ -149,10 +149,10 @@ def load_item_from_chunk( self._chunk_filepaths[chunk_filepath] = True - with open(chunk_filepath, "rb", 0) as fp: - if self._config["encryption"]: - data = self._load_encrypted_data(fp, chunk_index, offset, encryption) - else: + if self._config["encryption"]: + data = self._load_encrypted_data(chunk_filepath, chunk_index, offset, encryption) + else: + with open(chunk_filepath, "rb", 0) as fp: data = self._load_data(fp, offset) # check for mosaic mds format @@ -161,29 +161,33 @@ def load_item_from_chunk( return self.deserialize(data) def _load_encrypted_data( - self, fp: Union[FileIO, BytesIO], chunk_index: int, offset: int, encryption: Optional[Encryption] + self, chunk_filepath: str, chunk_index: int, offset: int, encryption: Optional[Encryption] ) -> bytes: - """Load and decrypt data from a file pointer based on the encryption configuration.""" + """Load and decrypt data from chunk based on the encryption configuration.""" # Validate the provided encryption object against the expected configuration. self._validate_encryption(encryption) - # If the encryption level is set to 'chunk', decrypt the entire chunk first. - if self._config["encryption"]["level"] == "chunk": - if chunk_index in self._decrypted_chunks: - decrypted_data = self._decrypted_chunks[chunk_index] - else: - encrypted_data = fp.read() - decrypted_data = encryption.decrypt(encrypted_data) # type: ignore - # This would enable us to free the previous chunk from the memory. - self._decrypted_chunks = {chunk_index: decrypted_data} - fp = BytesIO(decrypted_data) - - data = self._load_data(fp, offset) + # chunk-level decryption + if self._config["encryption"]["level"] == EncryptionLevel.CHUNK.value: + decrypted_data = self._decrypted_chunks.get(chunk_index, None) + if decrypted_data is None: + with open(chunk_filepath, "rb", 0) as fp: + encrypted_data = fp.read() + decrypted_data = encryption.decrypt(encrypted_data) # type: ignore + # Store the decrypted chunk to avoid re-decryption, + # also allows to free the previous chunk from the memory + self._decrypted_chunks = {chunk_index: decrypted_data} + data = self._load_data(BytesIO(decrypted_data), offset) + + # sample-level decryption + elif self._config["encryption"]["level"] == EncryptionLevel.SAMPLE.value: + with open(chunk_filepath, "rb", 0) as fp: + data = self._load_data(fp, offset) + data = encryption.decrypt(data) # type: ignore - # If the encryption level is set to 'sample', decrypt each individual sample after loading. - if self._config["encryption"]["level"] == "sample": - data = encryption.decrypt(data) # type: ignore + else: + raise ValueError("Invalid encryption level.") return data