From 12e17702ea207b41b0790fbfd2a5fb0f49b59690 Mon Sep 17 00:00:00 2001 From: EnricoTrizio Date: Wed, 15 Nov 2023 15:25:15 +0100 Subject: [PATCH 1/2] created split_dataset function for sequential and random splitting, deprecated sequential_split function --- mlcolvar/data/datamodule.py | 74 +++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/mlcolvar/data/datamodule.py b/mlcolvar/data/datamodule.py index 87856a2b..f9ea716e 100644 --- a/mlcolvar/data/datamodule.py +++ b/mlcolvar/data/datamodule.py @@ -23,6 +23,7 @@ import numpy as np import lightning from torch.utils.data import random_split, Subset +from torch import default_generator, randperm from torch._utils import _accumulate from mlcolvar.data import DictLoader, DictDataset @@ -123,11 +124,14 @@ def __init__( super().__init__() self.dataset = dataset self.lengths = lengths - self.generator = generator - # Keeping this private for now. Changing it at runtime would # require changing dataset_split and the dataloaders. self._random_split = random_split + + # save generator if given, otherwise set it to torch.default_generator + self.generator = generator if generator is not None else default_generator + if self.generator is not None and not self._random_split: + warnings.warn("A torch.generator was provided but it is not used with random_split=False") # Make sure batch_size and shuffle are lists. if isinstance(batch_size, int): @@ -215,15 +219,11 @@ def __repr__(self) -> str: def _split(self, dataset): """Perform the random or sequential spliting of a single dataset. - + Returns a list of Subset[DictDataset] objects. """ - if self._random_split: - dataset_split = random_split( - dataset, self.lengths, generator=self.generator - ) - else: - dataset_split = sequential_split(dataset, self.lengths) + + dataset_split = split_dataset(dataset, self.lengths, self._random_split, self.generator) return dataset_split def _check_setup(self): @@ -234,12 +234,14 @@ def _check_setup(self): "outside a Lightning trainer please call .setup() first." ) - -def sequential_split(dataset, lengths: Sequence) -> list: +def split_dataset(dataset, + lengths: Sequence, + random_split : bool, + generator : Optional[torch.Generator] = default_generator) -> list: """ - Sequentially split a dataset into non-overlapping new datasets of given lengths. + Sequentially or randomly split a dataset into non-overlapping new datasets of given lengths. - The behavior is the same as torch.utils.data.dataset.random_split. + If random_split=True the behavior is the same as torch.utils.data.dataset.random_split. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as @@ -248,6 +250,8 @@ def sequential_split(dataset, lengths: Sequence) -> list: After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left. + + Optionally fix the generator for reproducible results. """ if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: @@ -267,24 +271,40 @@ def sequential_split(dataset, lengths: Sequence) -> list: lengths = subset_lengths for i, length in enumerate(lengths): if length == 0: - warnings.warn( - f"Length of split at index {i} is 0. " - f"This might result in an empty dataset." - ) - - # Cannot verify that dataset is Sized - if sum(lengths) != len(dataset): # type: ignore[arg-type] - raise ValueError( - "Sum of input lengths does not equal the length of the input dataset!" - ) - - # LB change: do sequential rather then random splitting + warnings.warn(f"Length of split at index {i} is 0. " + f"This might result in an empty dataset.") + + # Cannot verify that dataset is Sized + if sum(lengths) != len(dataset): # type: ignore[arg-type] + raise ValueError("Sum of input lengths does not equal the length of the input dataset!") + if random_split: + indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload] + return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)] + else: return [ Subset(dataset, np.arange(offset - length, offset)) for offset, length in zip(_accumulate(lengths), lengths) ] - else: - raise NotImplementedError("The lengths must sum to 1.") + + +def sequential_split(dataset, lengths: Sequence) -> list: + """ + Sequentially split a dataset into non-overlapping new datasets of given lengths. + + The behavior is the same as torch.utils.data.dataset.random_split. + + If a list of fractions that sum up to 1 is given, + the lengths will be computed automatically as + floor(frac * len(dataset)) for each fraction provided. + + After computing the lengths, if there are any remainders, 1 count will be + distributed in round-robin fashion to the lengths + until there are no remainders left. + """ + + warnings.warn("The function sequential_split is deprecated, use split_dataset(.., .., random_split=False, ..)", DeprecationWarning) + + return split_dataset(dataset=dataset, lengths=lengths, random_split=False) if __name__ == "__main__": From c9f6138106b9d6f3b14b6b96a1c9577914dc605c Mon Sep 17 00:00:00 2001 From: EnricoTrizio Date: Thu, 16 Nov 2023 11:24:26 +0100 Subject: [PATCH 2/2] Changed deprecation wanring to future warning in sequential_split --- mlcolvar/data/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlcolvar/data/datamodule.py b/mlcolvar/data/datamodule.py index f9ea716e..a8bea2fa 100644 --- a/mlcolvar/data/datamodule.py +++ b/mlcolvar/data/datamodule.py @@ -302,7 +302,7 @@ def sequential_split(dataset, lengths: Sequence) -> list: until there are no remainders left. """ - warnings.warn("The function sequential_split is deprecated, use split_dataset(.., .., random_split=False, ..)", DeprecationWarning) + warnings.warn("The function sequential_split is deprecated, use split_dataset(.., .., random_split=False, ..)", FutureWarning, stacklevel=2) return split_dataset(dataset=dataset, lengths=lengths, random_split=False)