From e0cf2ba3faae440684cb249daa2da2b007f0127a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 9 Jul 2024 16:06:38 +0000 Subject: [PATCH 01/16] fix chunk_indexes --- src/litdata/streaming/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 699ddc29..91be8862 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -331,7 +331,7 @@ def __next__(self) -> Any: index=index, chunk_index=self.worker_chunks[self.chunk_index - 1], # We provide the chunks indexes only one the first - chunk_indexes=None if self.has_triggered_download else self.worker_chunks, + chunk_indexes=None if self.has_triggered_download else self.worker_chunks[self.chunk_index - 1:], is_last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1, ) ) From aaab88ff7cb1bebe9f8bce0c425c56e0935f828d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 14:54:28 +0000 Subject: [PATCH 02/16] fix2 --- src/litdata/streaming/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 91be8862..b96e8f8f 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -520,9 +520,11 @@ def _replay_chunks_sampling( for worker_idx, intervals in workers_intervals.items(): for interval in intervals: - size = interval[-1] - interval[0] + size = (interval[2] - interval[1]) if indexes[worker_idx] >= size: indexes[worker_idx] -= size chunks_index[worker_idx] += 1 + else: + break return chunks_index, indexes From dcdc7adf7cb3a427f5144ef7dba3f9831e6f95c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:37:23 +0000 Subject: [PATCH 03/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index b96e8f8f..556bd4d6 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -331,7 +331,7 @@ def __next__(self) -> Any: index=index, chunk_index=self.worker_chunks[self.chunk_index - 1], # We provide the chunks indexes only one the first - chunk_indexes=None if self.has_triggered_download else self.worker_chunks[self.chunk_index - 1:], + chunk_indexes=None if self.has_triggered_download else self.worker_chunks[self.chunk_index - 1 :], is_last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1, ) ) @@ -520,7 +520,7 @@ def _replay_chunks_sampling( for worker_idx, intervals in workers_intervals.items(): for interval in intervals: - size = (interval[2] - interval[1]) + size = interval[2] - interval[1] if indexes[worker_idx] >= size: indexes[worker_idx] -= size chunks_index[worker_idx] += 1 From 877f7411cecaea228f3c548763c4dc48fd0c2df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 15:59:28 +0000 Subject: [PATCH 04/16] add e2e test --- tests/streaming/test_dataset.py | 48 ++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 91de1a95..3abdcc7c 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -15,13 +15,14 @@ import os import random import sys +import shutil from time import sleep from unittest import mock import numpy as np import pytest import torch -from litdata import train_test_split +from litdata import train_test_split, optimize from litdata.constants import _ZSTD_AVAILABLE from litdata.processing import functions from litdata.streaming import Cache @@ -793,6 +794,51 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): assert not torch.equal(batch_1, batch_2) +def _simple_preprocess(_): + for _ in range(10): + yield torch.randint(0, 100, size=(10,), dtype=torch.int64) + + +def _get_simulated_s3_dataloader(tmpdir): + dataset = EmulateS3StreamingDataset( + input_dir=Dir(str(tmpdir / "s3cache"), str(tmpdir / "optimized")), + item_loader=TokensLoader(block_size=10), + ) + return StreamingDataLoader(dataset, batch_size=2, num_workers=1) + + +def test_dataset_resume_on_future_chunks(tmpdir): + optimize( + fn=_simple_preprocess, + inputs=list(range(8)), + output_dir=str(tmpdir / "optimized"), + chunk_size=190, + num_workers=4, + ) + assert len(os.listdir(tmpdir / "optimized")) == 9 # 8 chunks + 1 index file + + os.mkdir(tmpdir / "s3cache") + shutil.rmtree("/cache/chunks", ignore_errors=True) # TODO + + train_dataloader = _get_simulated_s3_dataloader(tmpdir) + batches_to_fetch = 16 + batch_to_resume_from = None + for i, batch in enumerate(train_dataloader): + if i == batches_to_fetch: + dataloader_state = train_dataloader.state_dict() + if i == batches_to_fetch + 1: + batch_to_resume_from = batch + break + assert i == batches_to_fetch + 1 + + shutil.rmtree(tmpdir / "s3cache") + os.mkdir(tmpdir / "s3cache") + shutil.rmtree("/cache/chunks", ignore_errors=True) + train_dataloader = _get_simulated_s3_dataloader(tmpdir) + train_dataloader.load_state_dict(dataloader_state) + assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from) + + @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") def test_dataset_valid_state(tmpdir, monkeypatch): seed_everything(42) From 7d368f9f81dfd550b54cc3aec5863a675ae61f9c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:59:45 +0000 Subject: [PATCH 05/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataset.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 3abdcc7c..e22a244e 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -14,15 +14,15 @@ import json import os import random -import sys import shutil +import sys from time import sleep from unittest import mock import numpy as np import pytest import torch -from litdata import train_test_split, optimize +from litdata import optimize, train_test_split from litdata.constants import _ZSTD_AVAILABLE from litdata.processing import functions from litdata.streaming import Cache @@ -809,17 +809,17 @@ def _get_simulated_s3_dataloader(tmpdir): def test_dataset_resume_on_future_chunks(tmpdir): optimize( - fn=_simple_preprocess, - inputs=list(range(8)), - output_dir=str(tmpdir / "optimized"), - chunk_size=190, - num_workers=4, + fn=_simple_preprocess, + inputs=list(range(8)), + output_dir=str(tmpdir / "optimized"), + chunk_size=190, + num_workers=4, ) - assert len(os.listdir(tmpdir / "optimized")) == 9 # 8 chunks + 1 index file + assert len(os.listdir(tmpdir / "optimized")) == 9 # 8 chunks + 1 index file os.mkdir(tmpdir / "s3cache") shutil.rmtree("/cache/chunks", ignore_errors=True) # TODO - + train_dataloader = _get_simulated_s3_dataloader(tmpdir) batches_to_fetch = 16 batch_to_resume_from = None From 6273c8d9037bbe69f0b6c4de6952effc82406abd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 16:04:14 +0000 Subject: [PATCH 06/16] update --- src/litdata/streaming/dataset.py | 1 + tests/streaming/test_dataset.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 556bd4d6..92e65511 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -525,6 +525,7 @@ def _replay_chunks_sampling( indexes[worker_idx] -= size chunks_index[worker_idx] += 1 else: + # We've reached the chunk where resuming needs to take place break return chunks_index, indexes diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index e22a244e..c18e1ed0 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -829,7 +829,6 @@ def test_dataset_resume_on_future_chunks(tmpdir): if i == batches_to_fetch + 1: batch_to_resume_from = batch break - assert i == batches_to_fetch + 1 shutil.rmtree(tmpdir / "s3cache") os.mkdir(tmpdir / "s3cache") From ea41a03d16fdbe5bf40c250593ef39a1a75841e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 16:24:37 +0000 Subject: [PATCH 07/16] debug --- tests/streaming/test_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index c18e1ed0..f868a453 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -807,6 +807,7 @@ def _get_simulated_s3_dataloader(tmpdir): return StreamingDataLoader(dataset, batch_size=2, num_workers=1) +@mock.patch.dict(os.environ, {}, clear=True) def test_dataset_resume_on_future_chunks(tmpdir): optimize( fn=_simple_preprocess, From 14db351a045c37811c3b8742180fc23574ba8ca3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 21:33:44 +0000 Subject: [PATCH 08/16] update test --- tests/streaming/test_dataset.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index f868a453..aafb509f 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -799,16 +799,22 @@ def _simple_preprocess(_): yield torch.randint(0, 100, size=(10,), dtype=torch.int64) -def _get_simulated_s3_dataloader(tmpdir): +def _get_simulated_s3_dataloader(cache_dir, data_dir): dataset = EmulateS3StreamingDataset( - input_dir=Dir(str(tmpdir / "s3cache"), str(tmpdir / "optimized")), + input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size=10), ) return StreamingDataLoader(dataset, batch_size=2, num_workers=1) +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) def test_dataset_resume_on_future_chunks(tmpdir): + """This test is constructed to test resuming from a chunk past the first chunk, when + subsequent chunks don't have the same size.""" + s3_cache_dir = str(tmpdir / "s3cache") + data_dir = str(tmpdir / "optimized") + optimize( fn=_simple_preprocess, inputs=list(range(8)), @@ -816,12 +822,12 @@ def test_dataset_resume_on_future_chunks(tmpdir): chunk_size=190, num_workers=4, ) - assert len(os.listdir(tmpdir / "optimized")) == 9 # 8 chunks + 1 index file + assert len(os.listdir(tmpdir / "optimized")) > 1 - os.mkdir(tmpdir / "s3cache") - shutil.rmtree("/cache/chunks", ignore_errors=True) # TODO + os.mkdir(s3_cache_dir) + shutil.rmtree("/cache/chunks", ignore_errors=True) - train_dataloader = _get_simulated_s3_dataloader(tmpdir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) batches_to_fetch = 16 batch_to_resume_from = None for i, batch in enumerate(train_dataloader): @@ -831,11 +837,12 @@ def test_dataset_resume_on_future_chunks(tmpdir): batch_to_resume_from = batch break - shutil.rmtree(tmpdir / "s3cache") - os.mkdir(tmpdir / "s3cache") + shutil.rmtree(s3_cache_dir) + os.mkdir(s3_cache_dir) shutil.rmtree("/cache/chunks", ignore_errors=True) - train_dataloader = _get_simulated_s3_dataloader(tmpdir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) train_dataloader.load_state_dict(dataloader_state) + # The next batch after resuming must match what we should have gotten next in the initial loop assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from) From 7ee3a2de7b8f0a771e40e87cd740a2c953d6bd42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 21:51:42 +0000 Subject: [PATCH 09/16] fix test --- tests/streaming/test_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index aafb509f..7bbbe522 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -968,14 +968,14 @@ def test_replay_sampling(): def test_replay_chunks_sampling(): chunks_replica = range(10) - intervals_replica = [(i, i + 5) for i in range(0, 50, 5)] + intervals_replica = [(i, i, i + 5, i + 5) for i in range(0, 50, 5)] workers_chunks, workers_intervals = _associate_chunks_to_workers( _WorkerEnv(2, 0), chunks_replica, intervals_replica ) assert workers_chunks == {0: [0, 2, 4, 6, 8], 1: [1, 3, 5, 7, 9]} assert workers_intervals == { - 0: [(0, 5), (10, 15), (20, 25), (30, 35), (40, 45)], - 1: [(5, 10), (15, 20), (25, 30), (35, 40), (45, 50)], + 0: [(0, 0, 5, 5), (10, 10, 15, 15), (20, 20, 25, 25), (30, 30, 35, 35), (40, 40, 45, 45)], + 1: [(5, 5, 10, 10), (15, 15, 20, 20), (25, 25, 30, 30), (35, 35, 40, 40), (45, 45, 50, 50)], } assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1}) assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3}) From cc6be066effb62083cc73a5dd35e8d5a98ef7ec9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jul 2024 21:53:05 +0000 Subject: [PATCH 10/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 7bbbe522..7f278819 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -810,8 +810,8 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) def test_dataset_resume_on_future_chunks(tmpdir): - """This test is constructed to test resuming from a chunk past the first chunk, when - subsequent chunks don't have the same size.""" + """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have + the same size.""" s3_cache_dir = str(tmpdir / "s3cache") data_dir = str(tmpdir / "optimized") From 7733468a2cbe59bf1a62b38bd5ec65b2aa133f04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 22:28:15 +0000 Subject: [PATCH 11/16] extend test --- tests/streaming/test_dataset.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 7f278819..35d509a3 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -981,6 +981,14 @@ def test_replay_chunks_sampling(): assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3}) assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2}) + # Test that replay stops at the right chunk + workers_intervals={0: [(0, 0, 10, 10), (10, 10, 20, 20), (20, 20, 21, 21), (21, 21, 30, 30)]} + indexes={0: 15} + # Replay should stop at chunk index 1, because 15 - 10 = 5, which fits into with chunk idx 1 + chunk_indexes, indexes = _replay_chunks_sampling(workers_intervals, indexes) + assert chunk_indexes == {0: 1} + assert indexes == {0: 5} + @pytest.mark.parametrize( "compression", From 22c5dcaadb1fc3134402a330580d73e2e8adc7ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 22:30:56 +0000 Subject: [PATCH 12/16] skip on macos --- tests/streaming/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 35d509a3..c1c14592 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -807,7 +807,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir): return StreamingDataLoader(dataset, batch_size=2, num_workers=1) -@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") +@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) def test_dataset_resume_on_future_chunks(tmpdir): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have From 7f921f86b1f153e03b32171f6ca267b3a35a95ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jul 2024 22:31:10 +0000 Subject: [PATCH 13/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index c1c14592..4ec132d4 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -982,8 +982,8 @@ def test_replay_chunks_sampling(): assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2}) # Test that replay stops at the right chunk - workers_intervals={0: [(0, 0, 10, 10), (10, 10, 20, 20), (20, 20, 21, 21), (21, 21, 30, 30)]} - indexes={0: 15} + workers_intervals = {0: [(0, 0, 10, 10), (10, 10, 20, 20), (20, 20, 21, 21), (21, 21, 30, 30)]} + indexes = {0: 15} # Replay should stop at chunk index 1, because 15 - 10 = 5, which fits into with chunk idx 1 chunk_indexes, indexes = _replay_chunks_sampling(workers_intervals, indexes) assert chunk_indexes == {0: 1} From 499bb9dcf473e01ac6c0748fec5539527e1be127 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 22:57:52 +0000 Subject: [PATCH 14/16] debug --- tests/streaming/test_dataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 4ec132d4..29b5a058 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -809,11 +809,13 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir): @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) -def test_dataset_resume_on_future_chunks(tmpdir): +def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" s3_cache_dir = str(tmpdir / "s3cache") + optimize_cache_dir = str(tmpdir / "optimize_cache") data_dir = str(tmpdir / "optimized") + monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", optimize_cache_dir) optimize( fn=_simple_preprocess, @@ -825,8 +827,6 @@ def test_dataset_resume_on_future_chunks(tmpdir): assert len(os.listdir(tmpdir / "optimized")) > 1 os.mkdir(s3_cache_dir) - shutil.rmtree("/cache/chunks", ignore_errors=True) - train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) batches_to_fetch = 16 batch_to_resume_from = None @@ -839,7 +839,6 @@ def test_dataset_resume_on_future_chunks(tmpdir): shutil.rmtree(s3_cache_dir) os.mkdir(s3_cache_dir) - shutil.rmtree("/cache/chunks", ignore_errors=True) train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) train_dataloader.load_state_dict(dataloader_state) # The next batch after resuming must match what we should have gotten next in the initial loop From 62775f30fd19a2b7e1748b723a54c34d91de4833 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 23:10:28 +0000 Subject: [PATCH 15/16] update --- src/litdata/streaming/dataset.py | 2 +- tests/streaming/test_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 92e65511..e0c82087 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -525,7 +525,7 @@ def _replay_chunks_sampling( indexes[worker_idx] -= size chunks_index[worker_idx] += 1 else: - # We've reached the chunk where resuming needs to take place + # We've reached the chunk where resuming needs to take place (for this worker) break return chunks_index, indexes diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 29b5a058..910fbdbb 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -807,7 +807,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir): return StreamingDataLoader(dataset, batch_size=2, num_workers=1) -@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="Not tested on windows and MacOs") +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have @@ -824,7 +824,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): chunk_size=190, num_workers=4, ) - assert len(os.listdir(tmpdir / "optimized")) > 1 + assert len(os.listdir(tmpdir / "optimized")) == 9 # 8 chunks + 1 index file os.mkdir(s3_cache_dir) train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) From d4594fb030ac7e1af3feedfe2c1df93d9c08ea36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Jul 2024 23:41:26 +0000 Subject: [PATCH 16/16] update --- tests/streaming/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 910fbdbb..88176ec2 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -824,7 +824,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): chunk_size=190, num_workers=4, ) - assert len(os.listdir(tmpdir / "optimized")) == 9 # 8 chunks + 1 index file + assert len(os.listdir(tmpdir / "optimized")) > 0 os.mkdir(s3_cache_dir) train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)