Skip to content

Commit

Permalink
Merge branch 'main' into fix/race-condition-worker-failed
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Oct 9, 2024
2 parents 53b1099 + 6ceea8a commit a18e45b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,14 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None:
del self._mmaps[chunk_index]
os.remove(chunk_filepath)

def close(self, chunk_index: int) -> None:
"""Release the memory-mapped file for a specific chunk index."""
if chunk_index in self._mmaps:
self._mmaps[chunk_index]._mmap.close()
del self._mmaps[chunk_index]
if chunk_index in self._buffers:
del self._buffers[chunk_index]

@classmethod
def encode_data(cls, data: List[bytes], _: List[int], flattened: List[Any]) -> Tuple[bytes, Optional[int]]:
return data[0], flattened[0].shape[0]
8 changes: 6 additions & 2 deletions src/litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from litdata.streaming.config import ChunksConfig, Interval
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
from litdata.streaming.sampler import ChunkedIndex
from litdata.streaming.serializers import Serializer, _get_serializers
from litdata.utilities.encryption import Encryption
Expand Down Expand Up @@ -288,7 +288,6 @@ def read(self, index: ChunkedIndex) -> Any:
item = self._item_loader.load_item_from_chunk(
index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes
)

# We need to request deletion after the latest element has been loaded.
# Otherwise, this could trigger segmentation fault error depending on the item loader used.
if (
Expand All @@ -302,6 +301,11 @@ def read(self, index: ChunkedIndex) -> Any:
# inform the chunk has been completely consumed
self._prepare_thread.delete([self._last_chunk_index])

if index.chunk_index != self._last_chunk_index:
# Close the memory-mapped file for the last chunk index
if isinstance(self._item_loader, TokensLoader) and self._last_chunk_index is not None:
self._item_loader.close(self._last_chunk_index)

# track the new chunk index as the latest one
self._last_chunk_index = index.chunk_index

Expand Down
22 changes: 22 additions & 0 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,28 @@ def test_dataset_for_text_tokens(tmpdir):
break


@pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported")
def test_dataset_for_text_tokens_with_large_num_chunks(tmpdir):
import resource

resource.setrlimit(resource.RLIMIT_NOFILE, (1024, 1024))

block_size = 1024
cache = Cache(input_dir=str(tmpdir), chunk_bytes="10KB", item_loader=TokensLoader(block_size))

for i in range(10000):
text_ids = torch.randint(0, 10001, (torch.randint(100, 1001, (1,)).item(),)).numpy()
cache._add_item(i, text_ids)

cache.done()
cache.merge()

dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)

for _ in dataset:
pass


def test_dataset_with_1d_array(tmpdir):
seed_everything(42)

Expand Down

0 comments on commit a18e45b

Please sign in to comment.