Skip to content

Commit

Permalink
[brief] Folds the adjusting of the indices for the samplers into the …
Browse files Browse the repository at this point in the history
…base class.

[detailed]
- It makes more sense than having the same code duplicated everywhere.
  • Loading branch information
marovira committed Mar 28, 2024
1 parent 794c49c commit 6a8472a
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/pyro/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@ class ResumableSampler(tud.Sampler):
always the same and that they will continue from where the previous run left off.
"""

def __init__(self) -> None:
def __init__(self, batch_size: int) -> None:
"""Create the sampler."""
super().__init__()

self._start_iter: int = 0
self._epoch: int = 0
self._batch_size = batch_size

def _adjust_to_start_iter(self, indices: list[int]) -> list[int]:
assert self._batch_size > 0

start_index = self._start_iter * self._batch_size
return indices[start_index:]

@property
def start_iter(self) -> int:
Expand Down Expand Up @@ -53,11 +60,10 @@ def __init__(
self, data_source: typing.Sized, seed: int = 0, batch_size: int = 1
) -> None:
"""Create the sampler."""
super().__init__()
super().__init__(batch_size)

self._data_source = data_source
self._seed = seed
self._batch_size = batch_size
self._num_samples = len(self._data_source)

def __len__(self) -> int:
Expand All @@ -72,10 +78,7 @@ def __iter__(self) -> typing.Iterator[int]:
indices = shuffling.tolist()

assert len(indices) == self._num_samples
assert self._batch_size > 0

start_index = self._start_iter * self._batch_size
indices = indices[start_index:]
indices = self._adjust_to_start_iter(indices)
return iter(indices)


Expand All @@ -94,9 +97,8 @@ class ResumableSequentialSampler(ResumableSampler):

def __init__(self, data_source: typing.Sized, batch_size: int = 1):
"""Create the sampler."""
super().__init__()
super().__init__(batch_size)

self._batch_size = batch_size
self._data_source = data_source
self._num_samples = len(data_source)
self._indices = list(range(self._num_samples))
Expand All @@ -109,8 +111,7 @@ def __iter__(self) -> typing.Iterator[int]:
"""Retrieve the index of the next sample."""
assert self._batch_size > 0
indices = self._indices
start_index = self._start_iter * self._batch_size
indices = indices[start_index:]
indices = self._adjust_to_start_iter(indices)
return iter(indices)


Expand Down

0 comments on commit 6a8472a

Please sign in to comment.