Skip to content

Commit

Permalink
Merge pull request #100 from luigibonati/fix_split_function
Browse files Browse the repository at this point in the history
general split_dataset functions
  • Loading branch information
EnricoTrizio authored Nov 16, 2023
2 parents 6e378a6 + c9f6138 commit d14cf54
Showing 1 changed file with 47 additions and 27 deletions.
74 changes: 47 additions & 27 deletions mlcolvar/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import lightning
from torch.utils.data import random_split, Subset
from torch import default_generator, randperm
from torch._utils import _accumulate

from mlcolvar.data import DictLoader, DictDataset
Expand Down Expand Up @@ -123,11 +124,14 @@ def __init__(
super().__init__()
self.dataset = dataset
self.lengths = lengths
self.generator = generator

# Keeping this private for now. Changing it at runtime would
# require changing dataset_split and the dataloaders.
self._random_split = random_split

# save generator if given, otherwise set it to torch.default_generator
self.generator = generator if generator is not None else default_generator
if self.generator is not None and not self._random_split:
warnings.warn("A torch.generator was provided but it is not used with random_split=False")

# Make sure batch_size and shuffle are lists.
if isinstance(batch_size, int):
Expand Down Expand Up @@ -215,15 +219,11 @@ def __repr__(self) -> str:

def _split(self, dataset):
"""Perform the random or sequential spliting of a single dataset.
Returns a list of Subset[DictDataset] objects.
"""
if self._random_split:
dataset_split = random_split(
dataset, self.lengths, generator=self.generator
)
else:
dataset_split = sequential_split(dataset, self.lengths)

dataset_split = split_dataset(dataset, self.lengths, self._random_split, self.generator)
return dataset_split

def _check_setup(self):
Expand All @@ -234,12 +234,14 @@ def _check_setup(self):
"outside a Lightning trainer please call .setup() first."
)


def sequential_split(dataset, lengths: Sequence) -> list:
def split_dataset(dataset,
lengths: Sequence,
random_split : bool,
generator : Optional[torch.Generator] = default_generator) -> list:
"""
Sequentially split a dataset into non-overlapping new datasets of given lengths.
Sequentially or randomly split a dataset into non-overlapping new datasets of given lengths.
The behavior is the same as torch.utils.data.dataset.random_split.
If random_split=True the behavior is the same as torch.utils.data.dataset.random_split.
If a list of fractions that sum up to 1 is given,
the lengths will be computed automatically as
Expand All @@ -248,6 +250,8 @@ def sequential_split(dataset, lengths: Sequence) -> list:
After computing the lengths, if there are any remainders, 1 count will be
distributed in round-robin fashion to the lengths
until there are no remainders left.
Optionally fix the generator for reproducible results.
"""

if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
Expand All @@ -267,24 +271,40 @@ def sequential_split(dataset, lengths: Sequence) -> list:
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(
f"Length of split at index {i} is 0. "
f"This might result in an empty dataset."
)

# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
)

# LB change: do sequential rather then random splitting
warnings.warn(f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.")

# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
if random_split:
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
else:
return [
Subset(dataset, np.arange(offset - length, offset))
for offset, length in zip(_accumulate(lengths), lengths)
]
else:
raise NotImplementedError("The lengths must sum to 1.")


def sequential_split(dataset, lengths: Sequence) -> list:
"""
Sequentially split a dataset into non-overlapping new datasets of given lengths.
The behavior is the same as torch.utils.data.dataset.random_split.
If a list of fractions that sum up to 1 is given,
the lengths will be computed automatically as
floor(frac * len(dataset)) for each fraction provided.
After computing the lengths, if there are any remainders, 1 count will be
distributed in round-robin fashion to the lengths
until there are no remainders left.
"""

warnings.warn("The function sequential_split is deprecated, use split_dataset(.., .., random_split=False, ..)", FutureWarning, stacklevel=2)

return split_dataset(dataset=dataset, lengths=lengths, random_split=False)


if __name__ == "__main__":
Expand Down

0 comments on commit d14cf54

Please sign in to comment.