diff --git a/src/pyro/data/samplers.py b/src/pyro/data/samplers.py index 6dff0cc..eb74998 100644 --- a/src/pyro/data/samplers.py +++ b/src/pyro/data/samplers.py @@ -14,12 +14,19 @@ class ResumableSampler(tud.Sampler): always the same and that they will continue from where the previous run left off. """ - def __init__(self) -> None: + def __init__(self, batch_size: int) -> None: """Create the sampler.""" super().__init__() self._start_iter: int = 0 self._epoch: int = 0 + self._batch_size = batch_size + + def _adjust_to_start_iter(self, indices: list[int]) -> list[int]: + assert self._batch_size > 0 + + start_index = self._start_iter * self._batch_size + return indices[start_index:] @property def start_iter(self) -> int: @@ -53,11 +60,10 @@ def __init__( self, data_source: typing.Sized, seed: int = 0, batch_size: int = 1 ) -> None: """Create the sampler.""" - super().__init__() + super().__init__(batch_size) self._data_source = data_source self._seed = seed - self._batch_size = batch_size self._num_samples = len(self._data_source) def __len__(self) -> int: @@ -72,10 +78,7 @@ def __iter__(self) -> typing.Iterator[int]: indices = shuffling.tolist() assert len(indices) == self._num_samples - assert self._batch_size > 0 - - start_index = self._start_iter * self._batch_size - indices = indices[start_index:] + indices = self._adjust_to_start_iter(indices) return iter(indices) @@ -94,9 +97,8 @@ class ResumableSequentialSampler(ResumableSampler): def __init__(self, data_source: typing.Sized, batch_size: int = 1): """Create the sampler.""" - super().__init__() + super().__init__(batch_size) - self._batch_size = batch_size self._data_source = data_source self._num_samples = len(data_source) self._indices = list(range(self._num_samples)) @@ -109,8 +111,7 @@ def __iter__(self) -> typing.Iterator[int]: """Retrieve the index of the next sample.""" assert self._batch_size > 0 indices = self._indices - start_index = self._start_iter * self._batch_size - indices = indices[start_index:] + indices = self._adjust_to_start_iter(indices) return iter(indices)