diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 736e04b9..05286ff1 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -690,7 +690,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: self._num_samples_yielded_combined = obj["num_samples_yielded"] # Used to restart on the next DataLoader worker from the previous run. - self._latest_worker_idx = (obj["latest_worker_idx"] + 1) % (self.num_workers if self.num_workers > 0 else 1) + self._latest_worker_idx = obj["latest_worker_idx"] + 1 self._worker_idx_iter = iter(self._worker_idx) for _ in range(self._latest_worker_idx): next(self._worker_idx_iter) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 3083b5ee..699ddc29 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -370,8 +370,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if state_dict: # the state is restored within the workers self._state_dict = state_dict - self.subsampled_files = state_dict["subsampled_files"] - self.region_of_interest = state_dict["region_of_interest"] def _validate_state_dict(self) -> None: assert self._state_dict