diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 8f789949..43970b95 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -134,6 +134,11 @@ def set_drop_last(self, drop_last: bool) -> None: for dataset in self._datasets: dataset.set_drop_last(drop_last) + def reset_state_dict(self) -> None: + """Reset the state of the dataset.""" + for dataset in self._datasets: + dataset.reset_state_dict() + def _check_datasets(self, datasets: List[StreamingDataset]) -> None: if any(not isinstance(d, StreamingDataset) for d in datasets): raise RuntimeError("The provided datasets should be instances of the StreamingDataset.") diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 4ad656db..17924d69 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -615,6 +615,7 @@ def __iter__(self) -> Any: self.current_epoch += 1 self._num_samples_yielded_combined = {} self._num_samples_yielded_streaming = 0 + self.dataset.reset_state_dict() self.dataset.set_epoch(self.current_epoch) @@ -700,13 +701,23 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: # Inform we are resuming and disable resetting the StreamingDataLoader state. # This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes. - self.restore = True + # self.restore = True if isinstance(self.dataset, CombinedStreamingDataset): self.dataset._set_use_streaming_dataloader(True) self.dataset.load_state_dict(obj) + + # Inform that the dataloader is resuming. + # TODO: Check if the number of samples yielded is less than the length of the dataset. + # Also, len is not available for CombinedStreamingDataset incase of provided weights. + self.restore = True + elif isinstance(self.dataset, StreamingDataset): self.dataset.load_state_dict(obj["dataset"]) + + # Inform that the dataloader is resuming. + if self._num_samples_yielded_streaming < len(self.dataset): + self.restore = True else: raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.") diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index a2e056fd..5c57cd69 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -342,12 +342,14 @@ def __next__(self) -> Any: # Prevent to create more batch on a given process if self.global_index >= self.stop_length: self.current_epoch += 1 + self.reset_state_dict() raise StopIteration # Lazily re-populate the interval to reduce memory usage. if len(self.current_indexes) == 0: if self.chunk_index == self.num_chunks: self.current_epoch += 1 + self.reset_state_dict() raise StopIteration # reset index @@ -392,7 +394,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int return { "num_samples_yielded": num_samples_yielded, - "num_workers": num_workers, + "num_workers": num_workers or 1, "batch_size": batch_size, "current_epoch": self.current_epoch, "input_dir_path": self.input_dir.path, @@ -411,13 +413,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # the state is restored within the workers self._state_dict = state_dict + def reset_state_dict(self) -> None: + self._state_dict = None + def _validate_state_dict(self) -> None: assert self._state_dict assert self.worker_env assert self.cache state: Dict[str, Any] = self._state_dict - if state["shuffle"] != self.shuffle: raise ValueError( "The provided `shuffle` state doesn't match the current one. " @@ -471,6 +475,12 @@ def _validate_state_dict(self) -> None: f"Found `{self.drop_last}` instead of `{state['drop_last']}`." ) + if state["num_samples_yielded"] > len(self): + raise ValueError( + "The provided `num_samples_yielded` state is greater than the dataset length. " + f"Found `{state['num_samples_yielded']}` instead of `{len(self)}`." + ) + def reset(self) -> None: # undo all the properties associated with original dataset default_properties: Dict[str, Any] = { diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index d6a2b4b8..fb3aa5e4 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -17,6 +17,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset): def _check_datasets(self, datasets) -> None: pass + def reset_state_dict(self): + pass + def test_combined_dataset_num_samples_yield(): dataset = TestCombinedStreamingDataset( diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 768cb4b5..9ffed4d6 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -56,6 +56,9 @@ class TestCombinedStreamingDataset(CombinedStreamingDataset): def _check_datasets(self, datasets) -> None: pass + def reset_state_dict(self): + pass + def test_streaming_dataloader(): dataset = TestCombinedStreamingDataset( @@ -202,3 +205,128 @@ def test_dataloader_no_workers(tmpdir): assert len(dataset) == 1000 assert len(dataloader) == 1000 assert len(dataset) == 1000 + + +@pytest.mark.timeout(120) +def test_dataloader_with_loading_states(tmpdir): + cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB") + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + dataset = StreamingDataset(str(tmpdir), shuffle=True) + + # Test dataloader without explicit num workers + dataloader = StreamingDataLoader(dataset, batch_size=4) + dataloader.load_state_dict(dataloader.state_dict()) + batch = next(iter(dataloader)) + assert len(batch) == 4, "Batch size should be 4" + assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)" + + # Test dataloader with num workers + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)" + + # Verify dataloader state after partial iteration + for batch_idx, batch in enumerate(dataloader): + assert dataloader.current_epoch == 1, "Current epoch should be 1" + if batch_idx == 10: + break + dataloader.load_state_dict(dataloader.state_dict()) + assert dataloader.restore + # Verify remaining batches in the first epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 1, "Current epoch should be 1" + count += 1 + assert count == 15, "There should be atleast 15 batches remaining in the first epoch" + assert not dataloader.restore + + # Verify batches in the second epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 2, "Current epoch should be 2" + count += 1 + assert count >= 25, "There should be at least 25 batches in the second epoch" + + # Verify that the datalaoder can resume after complete last epoch + dataloader.load_state_dict(dataloader.state_dict()) + assert not dataloader.restore + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 3, "Current epoch should be 3" + count += 1 + assert count >= 25, "There should be at least 25 batches in the third epoch" + + +@pytest.mark.timeout(120) +def test_dataloader_states_with_persistent_workers(tmpdir): + cache = Cache(input_dir=str(tmpdir), chunk_bytes="64MB") + for i in range(100): + cache[i] = i + cache.done() + cache.merge() + + dataset = StreamingDataset(str(tmpdir), shuffle=True) + + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + assert len(dataloader) == 25, "Dataloader length should be 25 (100 items / batch size 4)" + + # Verify dataloader state after partial iteration + for batch_idx, batch in enumerate(dataloader): + assert dataloader.current_epoch == 1, "Current epoch should be 1" + if batch_idx == 10: + break + + prev_dataloader_state = dataloader.state_dict() + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2, persistent_workers=True) + dataloader.load_state_dict(prev_dataloader_state) + assert dataloader.restore + + # Verify remaining batches in the first epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 1, "Current epoch should be 1" + count += 1 + assert count == 15, "There should be atleast 15 batches remaining in the first epoch" + assert not dataloader.restore + + # Verify batches in the second epoch + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 2, "Current epoch should be 2" + count += 1 + assert count >= 25, "There should be at least 25 batches in the second epoch" + + # Verify that the datalaoder can resume after complete last epoch + dataloader.load_state_dict(dataloader.state_dict()) + assert not dataloader.restore + count = 0 + for _ in dataloader: + assert dataloader.current_epoch == 3, "Current epoch should be 3" + count += 1 + assert count >= 25, "There should be at least 25 batches in the third epoch" + + +@pytest.mark.timeout(60) +def test_resume_dataloader_with_new_dataset(tmpdir): + dataset_1_path = tmpdir.join("dataset_1") + dataset_2_path = tmpdir.join("dataset_2") + for dataset in [dataset_1_path, dataset_2_path]: + cache = Cache(input_dir=str(dataset), chunk_bytes="64MB") + for i in range(50): + cache[i] = i + cache.done() + cache.merge() + dataset = StreamingDataset(str(dataset_1_path), shuffle=True) + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + for _ in dataloader: + assert dataloader.current_epoch == 1, "Current epoch should be 1" + + dataloader_state = dataloader.state_dict() + dataset = StreamingDataset(str(dataset_2_path), shuffle=True) + dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2) + dataloader.load_state_dict(dataloader_state) + for _ in dataloader: + assert dataloader.current_epoch == 2, "Current epoch should be 2"