diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 92e65511..e0c82087 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -525,7 +525,7 @@ def _replay_chunks_sampling( indexes[worker_idx] -= size chunks_index[worker_idx] += 1 else: - # We've reached the chunk where resuming needs to take place + # We've reached the chunk where resuming needs to take place (for this worker) break return chunks_index, indexes diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 29b5a058..910fbdbb 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -807,7 +807,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir): return StreamingDataLoader(dataset, batch_size=2, num_workers=1) -@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="Not tested on windows and MacOs") +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have @@ -824,7 +824,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): chunk_size=190, num_workers=4, ) - assert len(os.listdir(tmpdir / "optimized")) > 1 + assert len(os.listdir(tmpdir / "optimized")) == 9 # 8 chunks + 1 index file os.mkdir(s3_cache_dir) train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)