diff --git a/everyvoice/dataloader/__init__.py b/everyvoice/dataloader/__init__.py index 606a446e..5148e391 100644 --- a/everyvoice/dataloader/__init__.py +++ b/everyvoice/dataloader/__init__.py @@ -7,6 +7,7 @@ from torch.utils.data import DataLoader from everyvoice.dataloader.imbalanced_sampler import ImbalancedDatasetSampler +from everyvoice.dataloader.oversampler import BatchOversampler from everyvoice.model.aligner.config import AlignerConfig from everyvoice.model.e2e.config import EveryVoiceConfig from everyvoice.model.feature_prediction.config import FeaturePredictionConfig @@ -49,20 +50,26 @@ def setup(self, stage: Optional[str] = None): self.predict_dataset = torch.load(self.predict_path) def train_dataloader(self): - sampler = ( - ImbalancedDatasetSampler(self.train_dataset) - if self.use_weighted_sampler - else None - ) - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.config.training.train_data_workers, - pin_memory=False, - drop_last=True, - collate_fn=self.collate_fn, - sampler=sampler, - ) + if self.use_weighted_sampler: + sampler = ImbalancedDatasetSampler(self.train_dataset) + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.config.training.train_data_workers, + pin_memory=False, + drop_last=True, + collate_fn=self.collate_fn, + sampler=sampler, + ) + else: + batch_sampler = BatchOversampler(self.train_dataset, self.batch_size) + return DataLoader( + self.train_dataset, + batch_sampler=batch_sampler, + num_workers=self.config.training.train_data_workers, + pin_memory=False, + collate_fn=self.collate_fn, + ) def predict_dataloader(self): return DataLoader( diff --git a/everyvoice/dataloader/oversampler.py b/everyvoice/dataloader/oversampler.py new file mode 100644 index 00000000..2b4afb75 --- /dev/null +++ b/everyvoice/dataloader/oversampler.py @@ -0,0 +1,59 @@ +from typing import Iterator, Sized + +import torch +from torch.utils.data.sampler import Sampler, SequentialSampler + + +class BatchOversampler(Sampler[list[int]]): + r"""Samples elements sequentially, always in the same order. Completes the last incomplete batch with random samples from other batches. + + Args: + data_source (Dataset): dataset to sample from + batch_size (int): number of items in a batch + """ + + def __init__(self, data_source: Sized, batch_size: int) -> None: + if ( + not isinstance(batch_size, int) + or isinstance(batch_size, bool) + or batch_size <= 0 + ): + raise ValueError( + f"batch_size should be a positive integer value, but got batch_size={batch_size}" + ) + self.batch_size = batch_size + self.n = len(data_source) + self.n_full_batches = self.n // self.batch_size + self.remaining_samples = self.n % self.batch_size + self.sampler = SequentialSampler(data_source) + + def __iter__(self) -> Iterator[list[int]]: + batch = [0] * self.batch_size + idx_in_batch = 0 + for idx in self.sampler: + batch[idx_in_batch] = idx + idx_in_batch += 1 + if idx_in_batch == self.batch_size: + yield batch + idx_in_batch = 0 + batch = [0] * self.batch_size + if idx_in_batch > 0: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + oversampler = map( + int, + torch.randperm( + self.n_full_batches * self.batch_size, generator=generator + )[: self.batch_size - self.remaining_samples].numpy(), + ) + for idx in oversampler: + batch[idx_in_batch] = idx + idx_in_batch += 1 + yield batch + + def __len__(self) -> int: + if self.remaining_samples: + return self.n_full_batches + 1 + else: + return self.n_full_batches