-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: updates the training sampling strategy to complete the last batch
Fixes #438
- Loading branch information
Showing
2 changed files
with
80 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import Iterator, Sized | ||
|
||
import torch | ||
from torch.utils.data.sampler import Sampler, SequentialSampler | ||
|
||
|
||
class BatchOversampler(Sampler[list[int]]): | ||
r"""Samples elements sequentially, always in the same order. Completes the last incomplete batch with random samples from other batches. | ||
Args: | ||
data_source (Dataset): dataset to sample from | ||
batch_size (int): number of items in a batch | ||
""" | ||
|
||
def __init__(self, data_source: Sized, batch_size: int) -> None: | ||
if ( | ||
not isinstance(batch_size, int) | ||
or isinstance(batch_size, bool) | ||
or batch_size <= 0 | ||
): | ||
raise ValueError( | ||
f"batch_size should be a positive integer value, but got batch_size={batch_size}" | ||
) | ||
self.batch_size = batch_size | ||
self.n = len(data_source) | ||
self.n_full_batches = self.n // self.batch_size | ||
self.remaining_samples = self.n % self.batch_size | ||
self.sampler = SequentialSampler(data_source) | ||
|
||
def __iter__(self) -> Iterator[list[int]]: | ||
batch = [0] * self.batch_size | ||
idx_in_batch = 0 | ||
for idx in self.sampler: | ||
batch[idx_in_batch] = idx | ||
idx_in_batch += 1 | ||
if idx_in_batch == self.batch_size: | ||
yield batch | ||
idx_in_batch = 0 | ||
batch = [0] * self.batch_size | ||
if idx_in_batch > 0: | ||
seed = int(torch.empty((), dtype=torch.int64).random_().item()) | ||
generator = torch.Generator() | ||
generator.manual_seed(seed) | ||
oversampler = map( | ||
int, | ||
torch.randperm( | ||
self.n_full_batches * self.batch_size, generator=generator | ||
)[: self.batch_size - self.remaining_samples].numpy(), | ||
) | ||
for idx in oversampler: | ||
batch[idx_in_batch] = idx | ||
idx_in_batch += 1 | ||
yield batch | ||
|
||
def __len__(self) -> int: | ||
if self.remaining_samples: | ||
return self.n_full_batches + 1 | ||
else: | ||
return self.n_full_batches | ||