From 3a832aebd4009d16023bf8e57368b75f2beb40d3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Oct 2024 13:39:03 +0100 Subject: [PATCH] update --- src/litdata/streaming/config.py | 3 ++- src/litdata/streaming/item_loader.py | 14 +++++++------- src/litdata/streaming/reader.py | 6 +++--- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index df0ea012..985ddeee 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -226,7 +226,8 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]: begin = self._intervals[index.chunk_index][0] - return local_chunkpath, begin, chunk["chunk_bytes"] + filesize_bytes = (1 + chunk["chunk_size"]) * 4 + chunk["chunk_bytes"] + return local_chunkpath, begin, filesize_bytes def _get_chunk_index_from_filename(self, chunk_filename: str) -> int: """Retrieves the associated chunk_index for a given chunk filename.""" diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 17a2d534..43e981e9 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -88,7 +88,7 @@ def load_item_from_chunk( chunk_index: int, chunk_filepath: str, begin: int, - chunk_bytes: int, + filesize_bytes: int, ) -> Any: """Returns an item loaded from a chunk.""" @@ -132,7 +132,7 @@ def load_item_from_chunk( chunk_index: int, chunk_filepath: str, begin: int, - chunk_bytes: int, + filesize_bytes: int, encryption: Optional[Encryption] = None, ) -> bytes: offset = (1 + (index - begin) if index >= begin else index + 1) * 4 @@ -141,11 +141,11 @@ def load_item_from_chunk( del self._chunk_filepaths[chunk_filepath] if chunk_filepath not in self._chunk_filepaths: - exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= chunk_bytes + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes while not exists: sleep(0.1) - exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= chunk_bytes + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes self._chunk_filepaths[chunk_filepath] = True @@ -329,7 +329,7 @@ def load_item_from_chunk( chunk_index: int, chunk_filepath: str, begin: int, - chunk_bytes: int, + filesize_bytes: int, ) -> torch.Tensor: assert self._block_size @@ -337,11 +337,11 @@ def load_item_from_chunk( del self._chunk_filepaths[chunk_filepath] if chunk_filepath not in self._chunk_filepaths: - exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > filesize_bytes while not exists: sleep(0.1) - exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > filesize_bytes self._chunk_filepaths[chunk_filepath] = True diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 8cb28ef7..078c9fa2 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -278,15 +278,15 @@ def read(self, index: ChunkedIndex) -> Any: self._last_chunk_index = index.chunk_index # Fetch the element - chunk_filepath, begin, chunk_bytes = self.config[index] + chunk_filepath, begin, filesize_bytes = self.config[index] if isinstance(self._item_loader, PyTreeLoader): item = self._item_loader.load_item_from_chunk( - index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes, self._encryption + index.index, index.chunk_index, chunk_filepath, begin, filesize_bytes, self._encryption ) else: item = self._item_loader.load_item_from_chunk( - index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes + index.index, index.chunk_index, chunk_filepath, begin, filesize_bytes ) # We need to request deletion after the latest element has been loaded. # Otherwise, this could trigger segmentation fault error depending on the item loader used.