diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 12bb2af0..605d1489 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -23,7 +23,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [macOS-latest, ubuntu-latest, windows-latest] + # os: [macOS-latest, ubuntu-latest, windows-latest] # TODO use this when macOS-latest becomes stable again + os: [macOS-13, ubuntu-latest, windows-latest] python-version: [3.8, 3.9, "3.10"] steps: diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 1cc60519..831b4dbd 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -1,8 +1,8 @@ name: test channels: - - conda-forge - pytorch + - conda-forge - defaults dependencies: diff --git a/mlcolvar/data/datamodule.py b/mlcolvar/data/datamodule.py index 613d600a..b0446d55 100644 --- a/mlcolvar/data/datamodule.py +++ b/mlcolvar/data/datamodule.py @@ -24,7 +24,6 @@ import lightning from torch.utils.data import random_split, Subset from torch import default_generator, randperm -from torch._utils import _accumulate from mlcolvar.data import DictLoader, DictDataset @@ -324,6 +323,20 @@ def sequential_split(dataset, lengths: Sequence) -> list: return split_dataset(dataset=dataset, lengths=lengths, random_split=False) +# Taken from python 3.5 docs, removed from PyTorch 2.3 onward +def _accumulate(iterable, fn=lambda x, y: x + y): + "Return running totals" + # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 + # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 + it = iter(iterable) + try: + total = next(it) + except StopIteration: + return + yield total + for element in it: + total = fn(total, element) + yield total if __name__ == "__main__": import doctest