Skip to content

Commit

Permalink
fix: updates the training sampling strategy to complete the last batch
Browse files Browse the repository at this point in the history
Fixes #438
  • Loading branch information
wiitt committed Aug 30, 2024
1 parent 8fc4099 commit ad8cd7f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 14 deletions.
35 changes: 21 additions & 14 deletions everyvoice/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.data import DataLoader

from everyvoice.dataloader.imbalanced_sampler import ImbalancedDatasetSampler
from everyvoice.dataloader.oversampler import BatchOversampler
from everyvoice.model.aligner.config import AlignerConfig
from everyvoice.model.e2e.config import EveryVoiceConfig
from everyvoice.model.feature_prediction.config import FeaturePredictionConfig
Expand Down Expand Up @@ -49,20 +50,26 @@ def setup(self, stage: Optional[str] = None):
self.predict_dataset = torch.load(self.predict_path)

def train_dataloader(self):
sampler = (
ImbalancedDatasetSampler(self.train_dataset)
if self.use_weighted_sampler
else None
)
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.config.training.train_data_workers,
pin_memory=False,
drop_last=True,
collate_fn=self.collate_fn,
sampler=sampler,
)
if self.use_weighted_sampler:
sampler = ImbalancedDatasetSampler(self.train_dataset)
return DataLoader(

Check warning on line 55 in everyvoice/dataloader/__init__.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/__init__.py#L54-L55

Added lines #L54 - L55 were not covered by tests
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.config.training.train_data_workers,
pin_memory=False,
drop_last=True,
collate_fn=self.collate_fn,
sampler=sampler,
)
else:
batch_sampler = BatchOversampler(self.train_dataset, self.batch_size)
return DataLoader(

Check warning on line 66 in everyvoice/dataloader/__init__.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/__init__.py#L65-L66

Added lines #L65 - L66 were not covered by tests
self.train_dataset,
batch_sampler=batch_sampler,
num_workers=self.config.training.train_data_workers,
pin_memory=False,
collate_fn=self.collate_fn,
)

def predict_dataloader(self):
return DataLoader(
Expand Down
59 changes: 59 additions & 0 deletions everyvoice/dataloader/oversampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Iterator, Sized

import torch
from torch.utils.data.sampler import Sampler, SequentialSampler


class BatchOversampler(Sampler[list[int]]):
r"""Samples elements sequentially, always in the same order. Completes the last incomplete batch with random samples from other batches.
Args:
data_source (Dataset): dataset to sample from
batch_size (int): number of items in a batch
"""

def __init__(self, data_source: Sized, batch_size: int) -> None:
if (
not isinstance(batch_size, int)
or isinstance(batch_size, bool)
or batch_size <= 0
):
raise ValueError(

Check warning on line 21 in everyvoice/dataloader/oversampler.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/oversampler.py#L21

Added line #L21 was not covered by tests
f"batch_size should be a positive integer value, but got batch_size={batch_size}"
)
self.batch_size = batch_size
self.n = len(data_source)
self.n_full_batches = self.n // self.batch_size
self.remaining_samples = self.n % self.batch_size
self.sampler = SequentialSampler(data_source)

Check warning on line 28 in everyvoice/dataloader/oversampler.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/oversampler.py#L24-L28

Added lines #L24 - L28 were not covered by tests

def __iter__(self) -> Iterator[list[int]]:
batch = [0] * self.batch_size
idx_in_batch = 0

Check warning on line 32 in everyvoice/dataloader/oversampler.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/oversampler.py#L31-L32

Added lines #L31 - L32 were not covered by tests
for idx in self.sampler:
batch[idx_in_batch] = idx
idx_in_batch += 1

Check warning on line 35 in everyvoice/dataloader/oversampler.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/oversampler.py#L34-L35

Added lines #L34 - L35 were not covered by tests
if idx_in_batch == self.batch_size:
yield batch
idx_in_batch = 0
batch = [0] * self.batch_size

Check warning on line 39 in everyvoice/dataloader/oversampler.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/oversampler.py#L37-L39

Added lines #L37 - L39 were not covered by tests
if idx_in_batch > 0:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
oversampler = map(

Check warning on line 44 in everyvoice/dataloader/oversampler.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/oversampler.py#L41-L44

Added lines #L41 - L44 were not covered by tests
int,
torch.randperm(
self.n_full_batches * self.batch_size, generator=generator
)[: self.batch_size - self.remaining_samples].numpy(),
)
for idx in oversampler:
batch[idx_in_batch] = idx
idx_in_batch += 1
yield batch

Check warning on line 53 in everyvoice/dataloader/oversampler.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/oversampler.py#L51-L53

Added lines #L51 - L53 were not covered by tests

def __len__(self) -> int:
if self.remaining_samples:
return self.n_full_batches + 1

Check warning on line 57 in everyvoice/dataloader/oversampler.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/oversampler.py#L57

Added line #L57 was not covered by tests
else:
return self.n_full_batches

Check warning on line 59 in everyvoice/dataloader/oversampler.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/dataloader/oversampler.py#L59

Added line #L59 was not covered by tests

0 comments on commit ad8cd7f

Please sign in to comment.