Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in building DictModule with old version of PyTorch. #91

Closed
jintuzhang opened this issue Oct 16, 2023 · 0 comments · Fixed by #100
Closed

Bug in building DictModule with old version of PyTorch. #91

jintuzhang opened this issue Oct 16, 2023 · 0 comments · Fixed by #100
Labels
bug Something isn't working

Comments

@jintuzhang
Copy link

When using old versions of PyTorch (e.g., 1.10), building a mlcolvars.data.DictModule may cause the following error:

raise ValueError("Sum of input lengths does not equal the length of the input dataset!").

And this is caused by the change of the torch.utils.data.random_split method:

random_split In PyTorch 2.1:

def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
                 generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    If a list of fractions that sum up to 1 is given,
    the lengths will be computed automatically as
    floor(frac * len(dataset)) for each fraction provided.

    After computing the lengths, if there are any remainders, 1 count will be
    distributed in round-robin fashion to the lengths
    until there are no remainders left.

    Optionally fix the generator for reproducible results, e.g.:

    Example:
        >>> # xdoctest: +SKIP
        >>> generator1 = torch.Generator().manual_seed(42)
        >>> generator2 = torch.Generator().manual_seed(42)
        >>> random_split(range(10), [3, 7], generator=generator1)
        >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)

    Args:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths or fractions of splits to be produced
        generator (Generator): Generator used for the random permutation.
    """

random_split In PyTorch 1.10:

def random_split(dataset: Dataset[T], lengths: Sequence[int],
                 generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.
    Optionally fix the generator for reproducible results, e.g.:

    >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))

    Args:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
        generator (Generator): Generator used for the random permutation.
    """

This method is invoked by

def _split(self, dataset):

Apparently, the _split method passes dataset length fractions to the random_split method, but the old random_split method only accepts explicit dataset lengths as parameters. Thus, it may be reasonable to modify the code to pass actual data lengths.

@EnricoTrizio EnricoTrizio mentioned this issue Oct 24, 2023
1 task
@EnricoTrizio EnricoTrizio reopened this Oct 25, 2023
@EnricoTrizio EnricoTrizio added the bug Something isn't working label Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants