diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 699ddc29..e0c82087 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -331,7 +331,7 @@ def __next__(self) -> Any: index=index, chunk_index=self.worker_chunks[self.chunk_index - 1], # We provide the chunks indexes only one the first - chunk_indexes=None if self.has_triggered_download else self.worker_chunks, + chunk_indexes=None if self.has_triggered_download else self.worker_chunks[self.chunk_index - 1 :], is_last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1, ) ) @@ -520,9 +520,12 @@ def _replay_chunks_sampling( for worker_idx, intervals in workers_intervals.items(): for interval in intervals: - size = interval[-1] - interval[0] + size = interval[2] - interval[1] if indexes[worker_idx] >= size: indexes[worker_idx] -= size chunks_index[worker_idx] += 1 + else: + # We've reached the chunk where resuming needs to take place (for this worker) + break return chunks_index, indexes diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 91de1a95..88176ec2 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -14,6 +14,7 @@ import json import os import random +import shutil import sys from time import sleep from unittest import mock @@ -21,7 +22,7 @@ import numpy as np import pytest import torch -from litdata import train_test_split +from litdata import optimize, train_test_split from litdata.constants import _ZSTD_AVAILABLE from litdata.processing import functions from litdata.streaming import Cache @@ -793,6 +794,57 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): assert not torch.equal(batch_1, batch_2) +def _simple_preprocess(_): + for _ in range(10): + yield torch.randint(0, 100, size=(10,), dtype=torch.int64) + + +def _get_simulated_s3_dataloader(cache_dir, data_dir): + dataset = EmulateS3StreamingDataset( + input_dir=Dir(cache_dir, data_dir), + item_loader=TokensLoader(block_size=10), + ) + return StreamingDataLoader(dataset, batch_size=2, num_workers=1) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") +@mock.patch.dict(os.environ, {}, clear=True) +def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): + """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have + the same size.""" + s3_cache_dir = str(tmpdir / "s3cache") + optimize_cache_dir = str(tmpdir / "optimize_cache") + data_dir = str(tmpdir / "optimized") + monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", optimize_cache_dir) + + optimize( + fn=_simple_preprocess, + inputs=list(range(8)), + output_dir=str(tmpdir / "optimized"), + chunk_size=190, + num_workers=4, + ) + assert len(os.listdir(tmpdir / "optimized")) > 0 + + os.mkdir(s3_cache_dir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) + batches_to_fetch = 16 + batch_to_resume_from = None + for i, batch in enumerate(train_dataloader): + if i == batches_to_fetch: + dataloader_state = train_dataloader.state_dict() + if i == batches_to_fetch + 1: + batch_to_resume_from = batch + break + + shutil.rmtree(s3_cache_dir) + os.mkdir(s3_cache_dir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) + train_dataloader.load_state_dict(dataloader_state) + # The next batch after resuming must match what we should have gotten next in the initial loop + assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from) + + @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") def test_dataset_valid_state(tmpdir, monkeypatch): seed_everything(42) @@ -915,19 +967,27 @@ def test_replay_sampling(): def test_replay_chunks_sampling(): chunks_replica = range(10) - intervals_replica = [(i, i + 5) for i in range(0, 50, 5)] + intervals_replica = [(i, i, i + 5, i + 5) for i in range(0, 50, 5)] workers_chunks, workers_intervals = _associate_chunks_to_workers( _WorkerEnv(2, 0), chunks_replica, intervals_replica ) assert workers_chunks == {0: [0, 2, 4, 6, 8], 1: [1, 3, 5, 7, 9]} assert workers_intervals == { - 0: [(0, 5), (10, 15), (20, 25), (30, 35), (40, 45)], - 1: [(5, 10), (15, 20), (25, 30), (35, 40), (45, 50)], + 0: [(0, 0, 5, 5), (10, 10, 15, 15), (20, 20, 25, 25), (30, 30, 35, 35), (40, 40, 45, 45)], + 1: [(5, 5, 10, 10), (15, 15, 20, 20), (25, 25, 30, 30), (35, 35, 40, 40), (45, 45, 50, 50)], } assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1}) assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3}) assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2}) + # Test that replay stops at the right chunk + workers_intervals = {0: [(0, 0, 10, 10), (10, 10, 20, 20), (20, 20, 21, 21), (21, 21, 30, 30)]} + indexes = {0: 15} + # Replay should stop at chunk index 1, because 15 - 10 = 5, which fits into with chunk idx 1 + chunk_indexes, indexes = _replay_chunks_sampling(workers_intervals, indexes) + assert chunk_indexes == {0: 1} + assert indexes == {0: 5} + @pytest.mark.parametrize( "compression",