Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

general split_dataset functions #100

Merged
merged 2 commits into from
Nov 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 47 additions & 27 deletions mlcolvar/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import lightning
from torch.utils.data import random_split, Subset
from torch import default_generator, randperm
from torch._utils import _accumulate

from mlcolvar.data import DictLoader, DictDataset
Expand Down Expand Up @@ -123,11 +124,14 @@ def __init__(
super().__init__()
self.dataset = dataset
self.lengths = lengths
self.generator = generator

# Keeping this private for now. Changing it at runtime would
# require changing dataset_split and the dataloaders.
self._random_split = random_split

# save generator if given, otherwise set it to torch.default_generator
self.generator = generator if generator is not None else default_generator
if self.generator is not None and not self._random_split:
warnings.warn("A torch.generator was provided but it is not used with random_split=False")

# Make sure batch_size and shuffle are lists.
if isinstance(batch_size, int):
Expand Down Expand Up @@ -215,15 +219,11 @@ def __repr__(self) -> str:

def _split(self, dataset):
"""Perform the random or sequential spliting of a single dataset.

Returns a list of Subset[DictDataset] objects.
"""
if self._random_split:
dataset_split = random_split(
dataset, self.lengths, generator=self.generator
)
else:
dataset_split = sequential_split(dataset, self.lengths)

dataset_split = split_dataset(dataset, self.lengths, self._random_split, self.generator)
return dataset_split

def _check_setup(self):
Expand All @@ -234,12 +234,14 @@ def _check_setup(self):
"outside a Lightning trainer please call .setup() first."
)


def sequential_split(dataset, lengths: Sequence) -> list:
def split_dataset(dataset,
lengths: Sequence,
random_split : bool,
generator : Optional[torch.Generator] = default_generator) -> list:
"""
Sequentially split a dataset into non-overlapping new datasets of given lengths.
Sequentially or randomly split a dataset into non-overlapping new datasets of given lengths.

The behavior is the same as torch.utils.data.dataset.random_split.
If random_split=True the behavior is the same as torch.utils.data.dataset.random_split.

If a list of fractions that sum up to 1 is given,
the lengths will be computed automatically as
Expand All @@ -248,6 +250,8 @@ def sequential_split(dataset, lengths: Sequence) -> list:
After computing the lengths, if there are any remainders, 1 count will be
distributed in round-robin fashion to the lengths
until there are no remainders left.

Optionally fix the generator for reproducible results.
"""

if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
Expand All @@ -267,24 +271,40 @@ def sequential_split(dataset, lengths: Sequence) -> list:
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(
f"Length of split at index {i} is 0. "
f"This might result in an empty dataset."
)

# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
)

# LB change: do sequential rather then random splitting
warnings.warn(f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.")

# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
if random_split:
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
else:
return [
Subset(dataset, np.arange(offset - length, offset))
for offset, length in zip(_accumulate(lengths), lengths)
]
else:
raise NotImplementedError("The lengths must sum to 1.")


def sequential_split(dataset, lengths: Sequence) -> list:
"""
Sequentially split a dataset into non-overlapping new datasets of given lengths.

The behavior is the same as torch.utils.data.dataset.random_split.

If a list of fractions that sum up to 1 is given,
the lengths will be computed automatically as
floor(frac * len(dataset)) for each fraction provided.

After computing the lengths, if there are any remainders, 1 count will be
distributed in round-robin fashion to the lengths
until there are no remainders left.
"""

warnings.warn("The function sequential_split is deprecated, use split_dataset(.., .., random_split=False, ..)", FutureWarning, stacklevel=2)

return split_dataset(dataset=dataset, lengths=lengths, random_split=False)


if __name__ == "__main__":
Expand Down