From 3658d6477bdae9a5b4eaf15845a719ba76466b88 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 9 Aug 2024 12:45:47 +0545 Subject: [PATCH 01/22] chore: Add reset_state_dict method to StreamingDataset --- src/litdata/streaming/dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 2453ace2..af9e194b 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -407,6 +407,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # the state is restored within the workers self._state_dict = state_dict + def reset_state_dict(self) -> None: + self._state_dict = None + def _validate_state_dict(self) -> None: assert self._state_dict assert self.worker_env From 8eb1c7f7f555762dc71890663c5b24ea150d7696 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 9 Aug 2024 12:46:08 +0545 Subject: [PATCH 02/22] chore: Update num_workers fallback value in StreamingDataset --- src/litdata/streaming/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index af9e194b..896a6faa 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -388,7 +388,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int return { "num_samples_yielded": num_samples_yielded, - "num_workers": num_workers, + "num_workers": num_workers if num_workers > 0 else 1, "batch_size": batch_size, "current_epoch": self.current_epoch, "input_dir_path": self.input_dir.path, From 10c10b3056f80ba5f0128abf67e67d64aa13dcdf Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 9 Aug 2024 12:46:35 +0545 Subject: [PATCH 03/22] fix: Reset dataset state after each epoch --- src/litdata/streaming/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 4ad656db..88303174 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -615,6 +615,7 @@ def __iter__(self) -> Any: self.current_epoch += 1 self._num_samples_yielded_combined = {} self._num_samples_yielded_streaming = 0 + self.dataset.reset_state_dict() self.dataset.set_epoch(self.current_epoch) From 391c68b23639a63097a0c843752804a845193f36 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 9 Aug 2024 08:23:48 +0100 Subject: [PATCH 04/22] update --- src/litdata/streaming/combined.py | 5 +++++ src/litdata/streaming/dataset.py | 2 ++ tests/streaming/test_combined.py | 3 +++ tests/streaming/test_dataloader.py | 3 +++ 4 files changed, 13 insertions(+) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 8f789949..43970b95 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -134,6 +134,11 @@ def set_drop_last(self, drop_last: bool) -> None: for dataset in self._datasets: dataset.set_drop_last(drop_last) + def reset_state_dict(self) -> None: + """Reset the state of the dataset.""" + for dataset in self._datasets: + dataset.reset_state_dict() + def _check_datasets(self, datasets: List[StreamingDataset]) -> None: if any(not isinstance(d, StreamingDataset) for d in datasets): raise RuntimeError("The provided datasets should be instances of the StreamingDataset.") diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 896a6faa..96a074ea 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -338,12 +338,14 @@ def __next__(self) -> Any: # Prevent to create more batch on a given process if self.global_index >= self.stop_length: self.current_epoch += 1 + self.reset_state_dict() raise StopIteration # Lazily re-populate the interval to reduce memory usage. if len(self.current_indexes) == 0: if self.chunk_index == self.num_chunks: self.current_epoch += 1 + self.reset_state_dict() raise StopIteration # reset index diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index 7079f564..48966829 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -17,6 +17,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset): def _check_datasets(self, datasets) -> None: pass + def reset_state_dict(self): + pass + def test_combined_dataset_num_samples_yield(): dataset = TestCombinedStreamingDataset( diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 768cb4b5..4366e4fa 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -56,6 +56,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset): def _check_datasets(self, datasets) -> None: pass + def reset_state_dict(self): + pass + def test_streaming_dataloader(): dataset = TestCombinedStreamingDataset( From 5d74ed84442e942ccf5f733b05f6a9299ca68799 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 9 Aug 2024 08:24:10 +0100 Subject: [PATCH 05/22] Update src/litdata/streaming/dataset.py --- src/litdata/streaming/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 96a074ea..81c2fc60 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -390,7 +390,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int return { "num_samples_yielded": num_samples_yielded, - "num_workers": num_workers if num_workers > 0 else 1, + "num_workers": num_workers or 1, "batch_size": batch_size, "current_epoch": self.current_epoch, "input_dir_path": self.input_dir.path, From 7412064243579fb80cc2c43222f4d2a1d9721e2a Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 9 Aug 2024 13:27:37 +0545 Subject: [PATCH 06/22] feat: Add test for dataloader with loading states --- tests/streaming/test_dataloader.py | 43 ++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 4366e4fa..9dc7b687 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -205,3 +205,46 @@ def test_dataloader_no_workers(tmpdir): assert len(dataset) == 1000 assert len(dataloader) == 1000 assert len(dataset) == 1000 + + +@pytest.mark.timeout(120) +def test_dataloader_with_loading_states(tmpdir): + cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB") + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + dataset = StreamingDataset(str(tmpdir), shuffle=True) + + # Test dataloader without explicit num workers + dataloader = StreamingDataLoader(dataset, batch_size=4) + dataloader.load_state_dict(dataloader.state_dict()) + batch = next(iter(dataloader)) + assert len(batch) == 4, "Batch size should be 4" + assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)" + + # Test dataloader with num workers + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)" + + # Verify dataloader state after partial iteration + for batch_idx, batch in enumerate(dataloader): + assert dataloader.current_epoch == 1, "Current epoch should be 1" + if batch_idx == 10: + break + dataloader.load_state_dict(dataloader.state_dict()) + + # Verify remaining batches in the first epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 1, "Current epoch should be 1" + count += 1 + assert count == 15, "There should be atleast 15 batches remaining in the first epoch" + + # Verify batches in the second epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 2, "Current epoch should be 2" + count += 1 + assert count >= 25, "There should be at least 25 batches in the second epoch" From 0290a30b57ab886394bf849a4769006081574c1e Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 9 Aug 2024 14:00:55 +0545 Subject: [PATCH 07/22] chore: Add test for dataloader with loading states with peristent workers --- tests/streaming/test_dataloader.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 9dc7b687..85ab60e6 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -248,3 +248,42 @@ def test_dataloader_with_loading_states(tmpdir): assert dataloader.current_epoch == 2, "Current epoch should be 2" count += 1 assert count >= 25, "There should be at least 25 batches in the second epoch" + + +@pytest.mark.timeout(120) +def test_dataloader_states_with_persistent_workers(tmpdir): + cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB") + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + dataset = StreamingDataset(str(tmpdir), shuffle=True) + + # Test dataloader with persistent workers + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)" + + # Verify dataloader state after partial iteration + for batch_idx, batch in enumerate(dataloader): + assert dataloader.current_epoch == 1, "Current epoch should be 1" + if batch_idx == 10: + break + + prev_dataloader_state = dataloader.state_dict() + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2, persistent_workers=True) + dataloader.load_state_dict(prev_dataloader_state) + + # Verify remaining batches in the first epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 1, "Current epoch should be 1" + count += 1 + assert count == 15, "There should be atleast 15 batches remaining in the first epoch" + + # Verify batches in the second epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 2, "Current epoch should be 2" + count += 1 + assert count >= 25, "There should be at least 25 batches in the second epoch" From 00c2928c5b7bf440d064d58ae395e6538b71e581 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 9 Aug 2024 14:02:52 +0545 Subject: [PATCH 08/22] rm commment --- tests/streaming/test_dataloader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 85ab60e6..ccd80ffa 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -260,7 +260,6 @@ def test_dataloader_states_with_persistent_workers(tmpdir): dataset = StreamingDataset(str(tmpdir), shuffle=True) - # Test dataloader with persistent workers dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)" From 25a87b7bbfe36a9ba0a8cf8f1a31812a30e4f2de Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 11 Aug 2024 20:44:57 +0545 Subject: [PATCH 09/22] =?UTF-8?q?=F0=9F=90=9B=20fix:=20restore=20only=20if?= =?UTF-8?q?=20there=20are=20any=20remaining=20batches/samples=20to=20strea?= =?UTF-8?q?m=20from=20last=20epoch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/litdata/streaming/dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 88303174..290e29cf 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -701,7 +701,8 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: # Inform we are resuming and disable resetting the StreamingDataLoader state. # This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes. - self.restore = True + if obj["num_samples_yielded"] < len(self.dataset): + self.restore = True if isinstance(self.dataset, CombinedStreamingDataset): self.dataset._set_use_streaming_dataloader(True) From 678c3fc2b0d12d91d03e37ae47ab61da4fc605eb Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 11 Aug 2024 20:51:47 +0545 Subject: [PATCH 10/22] added notes to checkout later --- src/litdata/streaming/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 81c2fc60..fe9e77d1 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -543,6 +543,7 @@ def _replay_chunks_sampling( if indexes[worker_idx] >= size: indexes[worker_idx] -= size chunks_index[worker_idx] += 1 + # chunks_index[worker_idx] = (chunks_index[worker_idx] + 1) % len(intervals) else: # We've reached the chunk where resuming needs to take place (for this worker) break From 98669924401b31aeedd37f850d7445f3e3b76f77 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 11 Aug 2024 20:55:07 +0545 Subject: [PATCH 11/22] add note --- src/litdata/streaming/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 9d0b0c06..c8ac1ec7 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -547,6 +547,7 @@ def _replay_chunks_sampling( if indexes[worker_idx] >= size: indexes[worker_idx] -= size chunks_index[worker_idx] += 1 + # TODO: find robust soln, as it only seems to work for 1 worker # chunks_index[worker_idx] = (chunks_index[worker_idx] + 1) % len(intervals) else: # We've reached the chunk where resuming needs to take place (for this worker) From 16bc40fe511ed3f1b8c3fe3998a20b0a8d10053e Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 11 Aug 2024 21:36:15 +0545 Subject: [PATCH 12/22] chore: Add test for dataloader resuming after completing last epoch --- tests/streaming/test_dataloader.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index ccd80ffa..b06d5d3c 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -249,6 +249,14 @@ def test_dataloader_with_loading_states(tmpdir): count += 1 assert count >= 25, "There should be at least 25 batches in the second epoch" + # Verify that the datalaoder can resume after complete last epoch + dataloader.load_state_dict(dataloader.state_dict()) + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 3, "Current epoch should be 3" + count += 1 + assert count >= 25, "There should be at least 25 batches in the third epoch" + @pytest.mark.timeout(120) def test_dataloader_states_with_persistent_workers(tmpdir): @@ -286,3 +294,11 @@ def test_dataloader_states_with_persistent_workers(tmpdir): assert dataloader.current_epoch == 2, "Current epoch should be 2" count += 1 assert count >= 25, "There should be at least 25 batches in the second epoch" + + # Verify that the datalaoder can resume after complete last epoch + dataloader.load_state_dict(dataloader.state_dict()) + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 3, "Current epoch should be 3" + count += 1 + assert count >= 25, "There should be at least 25 batches in the third epoch" From d3f9498ece05ff3722b51fd00d922d4924ff5d88 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 11 Aug 2024 21:37:22 +0545 Subject: [PATCH 13/22] feat: Add test for resuming dataloader with new dataset --- tests/streaming/test_dataloader.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index b06d5d3c..5ea0e160 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -302,3 +302,25 @@ def test_dataloader_states_with_persistent_workers(tmpdir): assert dataloader.current_epoch == 3, "Current epoch should be 3" count += 1 assert count >= 25, "There should be at least 25 batches in the third epoch" + + +def test_resume_dataloader_with_new_dataset(tmpdir): + dataset_1_path = tmpdir.join("dataset_1") + dataset_2_path = tmpdir.join("dataset_2") + for dataset in [dataset_1_path, dataset_2_path]: + cache = Cache(input_dir=str(dataset), chunk_bytes="64MB") + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + dataset = StreamingDataset(str(dataset_1_path), shuffle=True) + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + for _ in dataloader: + assert dataloader.current_epoch == 1, "Current epoch should be 1" + + dataloader_state = dataloader.state_dict() + dataset = StreamingDataset(str(dataset_2_path), shuffle=True) + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + dataloader.load_state_dict(dataloader_state) + for _ in dataloader: + assert dataloader.current_epoch == 2, "Current epoch should be 2" From 6769694593e53b06bd571de442f49f05b1cd6a11 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 11 Aug 2024 21:51:14 +0545 Subject: [PATCH 14/22] adds type ignore --- src/litdata/streaming/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 290e29cf..dd3831b6 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -701,7 +701,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: # Inform we are resuming and disable resetting the StreamingDataLoader state. # This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes. - if obj["num_samples_yielded"] < len(self.dataset): + if obj["num_samples_yielded"] < len(self.dataset): # type: ignore self.restore = True if isinstance(self.dataset, CombinedStreamingDataset): From 81bc5378a339bb587d37d6b3c08bd857efcdff9b Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 11 Aug 2024 21:54:42 +0545 Subject: [PATCH 15/22] update timeout and num of samples --- tests/streaming/test_dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 5ea0e160..f39c95db 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -304,12 +304,13 @@ def test_dataloader_states_with_persistent_workers(tmpdir): assert count >= 25, "There should be at least 25 batches in the third epoch" +@pytest.mark.timeout(60) def test_resume_dataloader_with_new_dataset(tmpdir): dataset_1_path = tmpdir.join("dataset_1") dataset_2_path = tmpdir.join("dataset_2") for dataset in [dataset_1_path, dataset_2_path]: cache = Cache(input_dir=str(dataset), chunk_bytes="64MB") - for i in range(100): + for i in range(50): cache[i] = i cache.done() cache.merge() From 998fe5a0ca36ab897bc8708ab2d098b6d4d2ff9e Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 11 Aug 2024 23:01:44 +0545 Subject: [PATCH 16/22] Add explicit test for resuming dataloader with new dataset --- tests/streaming/test_dataloader.py | 31 ++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index f39c95db..00b1036f 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -325,3 +325,34 @@ def test_resume_dataloader_with_new_dataset(tmpdir): dataloader.load_state_dict(dataloader_state) for _ in dataloader: assert dataloader.current_epoch == 2, "Current epoch should be 2" + + +def test_dataloader_resume_after_epoch_completion(tmpdir): + cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB") + for i in range(50): + cache[i] = i + cache.done() + cache.merge() + + dataset = StreamingDataset(str(tmpdir), shuffle=True) + # Test dataloader without explicit num workers + dataloader = StreamingDataLoader(dataset, batch_size=4) + for _ in dataloader: + pass + assert dataloader.current_epoch == 1 + dataloader.load_state_dict(dataloader.state_dict()) + # force restore + dataloader.restore = True + batch = next(iter(dataloader)) + assert len(batch) == 4, "Batch size should be 4" + + # Test dataloader with num workers > 1 + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + for _ in dataloader: + pass + assert dataloader.current_epoch == 1 + dataloader.load_state_dict(dataloader.state_dict()) + # force restore + dataloader.restore = True + batch = next(iter(dataloader)) + assert len(batch) == 4, "Batch size should be 4" From 61120a4ea76cdc118e5f9321673889f0ae880c10 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Mon, 12 Aug 2024 00:00:25 +0545 Subject: [PATCH 17/22] chore: add validation for num_samples_yielded --- src/litdata/streaming/dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index c8ac1ec7..a2a7d64c 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -422,7 +422,6 @@ def _validate_state_dict(self) -> None: assert self.cache state: Dict[str, Any] = self._state_dict - if state["shuffle"] != self.shuffle: raise ValueError( "The provided `shuffle` state doesn't match the current one. " @@ -476,6 +475,12 @@ def _validate_state_dict(self) -> None: f"Found `{self.drop_last}` instead of `{state['drop_last']}`." ) + if state["num_samples_yielded"] > len(self): + raise ValueError( + "The provided `num_samples_yielded` state is greater than the dataset length. " + f"Found `{state['num_samples_yielded']}` instead of `{len(self)}`." + ) + def reset(self) -> None: # undo all the properties associated with original dataset default_properties: Dict[str, Any] = { From d98681c5008c8e9ad5cb7ad6aaf31a044c7fc77d Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 13 Aug 2024 01:10:33 +0545 Subject: [PATCH 18/22] removed unrequired test, as it was testing for wrong thing, when reset of state was required after complete epoch --- tests/streaming/test_dataloader.py | 31 ------------------------------ 1 file changed, 31 deletions(-) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 00b1036f..f39c95db 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -325,34 +325,3 @@ def test_resume_dataloader_with_new_dataset(tmpdir): dataloader.load_state_dict(dataloader_state) for _ in dataloader: assert dataloader.current_epoch == 2, "Current epoch should be 2" - - -def test_dataloader_resume_after_epoch_completion(tmpdir): - cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB") - for i in range(50): - cache[i] = i - cache.done() - cache.merge() - - dataset = StreamingDataset(str(tmpdir), shuffle=True) - # Test dataloader without explicit num workers - dataloader = StreamingDataLoader(dataset, batch_size=4) - for _ in dataloader: - pass - assert dataloader.current_epoch == 1 - dataloader.load_state_dict(dataloader.state_dict()) - # force restore - dataloader.restore = True - batch = next(iter(dataloader)) - assert len(batch) == 4, "Batch size should be 4" - - # Test dataloader with num workers > 1 - dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) - for _ in dataloader: - pass - assert dataloader.current_epoch == 1 - dataloader.load_state_dict(dataloader.state_dict()) - # force restore - dataloader.restore = True - batch = next(iter(dataloader)) - assert len(batch) == 4, "Batch size should be 4" From 743f0dded3953a2f49539fce4c8f3c452375f909 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 13 Aug 2024 01:11:02 +0545 Subject: [PATCH 19/22] removed the unnecesssary todo --- src/litdata/streaming/dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index a2a7d64c..5c57cd69 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -552,8 +552,6 @@ def _replay_chunks_sampling( if indexes[worker_idx] >= size: indexes[worker_idx] -= size chunks_index[worker_idx] += 1 - # TODO: find robust soln, as it only seems to work for 1 worker - # chunks_index[worker_idx] = (chunks_index[worker_idx] + 1) % len(intervals) else: # We've reached the chunk where resuming needs to take place (for this worker) break From 2db07e02b087b2ee216b5d2a3fcd7f7e7bdc5ca3 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 13 Aug 2024 01:27:56 +0545 Subject: [PATCH 20/22] chore: Add restore flag to dataloader tests --- tests/streaming/test_dataloader.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index f39c95db..9ffed4d6 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -234,13 +234,14 @@ def test_dataloader_with_loading_states(tmpdir): if batch_idx == 10: break dataloader.load_state_dict(dataloader.state_dict()) - + assert dataloader.restore # Verify remaining batches in the first epoch count = 0 for _ in dataloader: assert dataloader.current_epoch == 1, "Current epoch should be 1" count += 1 assert count == 15, "There should be atleast 15 batches remaining in the first epoch" + assert not dataloader.restore # Verify batches in the second epoch count = 0 @@ -251,6 +252,7 @@ def test_dataloader_with_loading_states(tmpdir): # Verify that the datalaoder can resume after complete last epoch dataloader.load_state_dict(dataloader.state_dict()) + assert not dataloader.restore count = 0 for _ in dataloader: assert dataloader.current_epoch == 3, "Current epoch should be 3" @@ -280,6 +282,7 @@ def test_dataloader_states_with_persistent_workers(tmpdir): prev_dataloader_state = dataloader.state_dict() dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2, persistent_workers=True) dataloader.load_state_dict(prev_dataloader_state) + assert dataloader.restore # Verify remaining batches in the first epoch count = 0 @@ -287,6 +290,7 @@ def test_dataloader_states_with_persistent_workers(tmpdir): assert dataloader.current_epoch == 1, "Current epoch should be 1" count += 1 assert count == 15, "There should be atleast 15 batches remaining in the first epoch" + assert not dataloader.restore # Verify batches in the second epoch count = 0 @@ -297,6 +301,7 @@ def test_dataloader_states_with_persistent_workers(tmpdir): # Verify that the datalaoder can resume after complete last epoch dataloader.load_state_dict(dataloader.state_dict()) + assert not dataloader.restore count = 0 for _ in dataloader: assert dataloader.current_epoch == 3, "Current epoch should be 3" From fc3a96065fee7fc3a390f452c9b01c11974bdeec Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 13 Aug 2024 13:10:09 +0545 Subject: [PATCH 21/22] chore: Add restore flag to dataloader for StreamingDataset --- src/litdata/streaming/dataloader.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index dd3831b6..fa9cf15d 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -701,14 +701,24 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: # Inform we are resuming and disable resetting the StreamingDataLoader state. # This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes. - if obj["num_samples_yielded"] < len(self.dataset): # type: ignore - self.restore = True + # if obj["num_samples_yielded"] < len(self.dataset): # type: ignore + # self.restore = True if isinstance(self.dataset, CombinedStreamingDataset): self.dataset._set_use_streaming_dataloader(True) self.dataset.load_state_dict(obj) + + # Inform that the dataloader is resuming. + # TODO: Check if the number of samples yielded is less than the length of the dataset. + # Also, len is not available for CombinedStreamingDataset incase of provided weights. + self.restore = True + elif isinstance(self.dataset, StreamingDataset): self.dataset.load_state_dict(obj["dataset"]) + + # Inform that the dataloader is resuming. + if self._num_samples_yielded_streaming < len(self.dataset): + self.restore = True else: raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.") From 4a50cacfef607bb981641051ddeb6e5f9e26d0cf Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 13 Aug 2024 13:13:31 +0545 Subject: [PATCH 22/22] update --- src/litdata/streaming/dataloader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index fa9cf15d..17924d69 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -701,8 +701,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: # Inform we are resuming and disable resetting the StreamingDataLoader state. # This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes. - # if obj["num_samples_yielded"] < len(self.dataset): # type: ignore - # self.restore = True + # self.restore = True if isinstance(self.dataset, CombinedStreamingDataset): self.dataset._set_use_streaming_dataloader(True)