diff --git a/.github/build.yml b/.github/build.yml deleted file mode 100644 index 31241667..00000000 --- a/.github/build.yml +++ /dev/null @@ -1,39 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - -name: Python application - -on: - push: - branches: [ "master" ] - pull_request: - branches: [ "master" ] - -permissions: - contents: read - -jobs: - build: - - runs-on: self-hosted - - steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.12 - uses: actions/setup-python@v4 - with: - python-version: '3.12' - - run: pip install pre-commit - - run: pre-commit --version - - run: pre-commit install - - run: pre-commit run --all-files - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e . - - name: Test with pytest (very fast) - run: pytest -v --shorter-than=1.0 - - name: Test with pytest (fast) - run: pytest -v - - name: Test with pytest (with slow tests) - run: pytest -v -m slow --slow diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000..9718018c --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,110 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python application + +on: + pull_request: + +permissions: + contents: read + +# https://stackoverflow.com/a/72408109/6388696 +# https://docs.github.com/en/actions/using-jobs/using-concurrency#example-using-concurrency-to-cancel-any-in-progress-job-or-run +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + linting: + name: Run linting/pre-commit checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: '3.12' + - run: pip install pre-commit + - run: pre-commit --version + - run: pre-commit install + - run: pre-commit run --all-files + + unit_tests: + needs: [linting] + runs-on: ${{ matrix.platform }} + strategy: + max-parallel: 4 + matrix: + platform: [ubuntu-latest] + python-version: ['3.12'] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - run: pip install pdm + - name: Install dependencies + run: pdm install + - name: Test with pytest (very fast) + run: pdm run pytest -v --shorter-than=1.0 --cov=project --cov-report=xml --cov-append + - name: Test with pytest (fast) + run: pdm run pytest -v --cov=project --cov-report=xml --cov-append + + - name: Store coverage report as an artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-reports-unit-tests-${{ matrix.platform }}-${{ matrix.python-version }} + path: ./coverage.xml + + integration_tests: + needs: [unit_tests] + runs-on: self-hosted + strategy: + max-parallel: 1 + matrix: + python-version: ['3.12'] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - run: pip install pdm + - name: Install dependencies + run: pdm install + + - name: Test with pytest + run: pdm run pytest -v --cov=project --cov-report=xml --cov-append + + - name: Test with pytest (only slow tests) + run: pdm run pytest -v -m slow --slow --cov=project --cov-report=xml --cov-append + + - name: Store coverage report as an artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-reports-integration-tests-${{ matrix.python-version }} + path: ./coverage.xml + + # https://about.codecov.io/blog/uploading-code-coverage-in-a-separate-job-on-github-actions/ + upload-coverage-codecov: + needs: [integration_tests] + runs-on: ubuntu-latest + name: Upload coverage reports to Codecov + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + pattern: coverage-reports-* + merge-multiple: false + # download all the artifacts in this directory (each .coverage.xml will be in a subdirectory) + # Next step if this doesn't work would be to give the coverage files a unique name and use merge-multiple: true + path: coverage_reports + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + directory: coverage_reports + fail_ci_if_error: true diff --git a/README.md b/README.md index cc996363..a71fa43b 100644 --- a/README.md +++ b/README.md @@ -1 +1,4 @@ # research_template + +![Build](https://github.com/mila-iqia/ResearchTemplate/workflows/build.yml/badge.svg) +[![codecov](https://codecov.io/gh/mila-iqia/ResearchTemplate/graph/badge.svg?token=I2DYLK8NTD)](https://codecov.io/gh/mila-iqia/ResearchTemplate) diff --git a/pdm.lock b/pdm.lock index 81b60c1c..1a03ae58 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:23ae16466ebc9fa2ce51513a1bc26cc634d47afc591ace665dbccd82d34117c7" +content_hash = "sha256:1e5611b1f430e5820256e84761ec57ca9de8a29a13612e23f80653c080095c5b" [[package]] name = "absl-py" @@ -388,6 +388,52 @@ files = [ {file = "contourpy-1.2.0.tar.gz", hash = "sha256:171f311cb758de7da13fc53af221ae47a5877be5a0843a9fe150818c51ed276a"}, ] +[[package]] +name = "coverage" +version = "7.5.3" +requires_python = ">=3.8" +summary = "Code coverage measurement for Python" +groups = ["dev"] +files = [ + {file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"}, + {file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"}, + {file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"}, + {file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"}, + {file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"}, + {file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"}, +] + +[[package]] +name = "coverage" +version = "7.5.3" +extras = ["toml"] +requires_python = ">=3.8" +summary = "Code coverage measurement for Python" +groups = ["dev"] +dependencies = [ + "coverage==7.5.3", +] +files = [ + {file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"}, + {file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"}, + {file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"}, + {file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"}, + {file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"}, + {file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"}, + {file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"}, + {file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"}, +] + [[package]] name = "croniter" version = "1.3.15" @@ -1138,7 +1184,7 @@ files = [ [[package]] name = "jaxlib" -version = "0.4.28+cuda12.cudnn89" +version = "0.4.28" requires_python = ">=3.9" summary = "XLA library for JAX" groups = ["default"] @@ -2120,6 +2166,21 @@ files = [ {file = "pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6"}, ] +[[package]] +name = "pytest-cov" +version = "5.0.0" +requires_python = ">=3.8" +summary = "Pytest plugin for measuring coverage." +groups = ["dev"] +dependencies = [ + "coverage[toml]>=5.2.1", + "pytest>=4.6", +] +files = [ + {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, + {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, +] + [[package]] name = "pytest-datadir" version = "1.5.0" diff --git a/project/conftest.py b/project/conftest.py index 0cd0f5dd..e55d490b 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -5,7 +5,7 @@ import sys import typing import warnings -from collections.abc import Callable, Generator +from collections.abc import Generator from contextlib import contextmanager from logging import getLogger as get_logger from pathlib import Path @@ -27,7 +27,7 @@ from project.datamodules.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.vision.base import VisionDataModule, num_cpus_on_node +from project.datamodules.vision.base import VisionDataModule from project.experiment import ( instantiate_algorithm, instantiate_datamodule, @@ -183,9 +183,6 @@ def accelerator(request: pytest.FixtureRequest): return accelerator -_cuda_available = torch.cuda.is_available() - - @pytest.fixture( scope="session", params=None, @@ -198,24 +195,7 @@ def num_devices_to_use(accelerator: str, request: pytest.FixtureRequest) -> int: return num_gpus # Use only one GPU by default. else: assert accelerator == "cpu" - return request.param - - -def run_with_multiple_devices(test_fn: Callable) -> pytest.MarkDecorator: - if torch.cuda.is_available(): - gpus = torch.cuda.device_count() - return pytest.mark.parametrize( - num_devices_to_use.__name__, - list(range(1, gpus + 1)), - indirect=True, - ids=[f"gpus={i}" for i in range(1, gpus + 1)], - ) - return pytest.mark.parametrize( - num_devices_to_use.__name__, - [num_cpus_on_node()], - indirect=True, - ids=[""], - )(test_fn) + return getattr(request, "param", 1) @pytest.fixture(scope="session") diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index 28a073cf..f647fa8d 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -70,7 +70,7 @@ def test_first_batch( fig.suptitle(f"First batch of datamodule {type(datamodule).__name__}") figure_path, _ = get_test_source_and_temp_file_paths( - ".png", request=request, original_datadir=original_datadir, datadir=datadir + extension=".png", request=request, original_datadir=original_datadir, datadir=datadir ) figure_path.parent.mkdir(exist_ok=True, parents=True) fig.savefig(figure_path) diff --git a/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml b/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml index 1f7a549c..e027265f 100644 --- a/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml +++ b/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml @@ -1,20 +1,20 @@ '0': device: cpu hash: 1082905456378942323 - max: 2.125603675842285 - mean: -0.007423439994454384 - min: -1.9888888597488403 + max: 2.126 + mean: -0.007 + min: -1.989 shape: - 128 - 3 - 32 - 32 - sum: -2919.015380859375 + sum: -2919.015 '1': device: cpu hash: 3692171093056153318 max: 9 - mean: 4.5546875 + mean: 4.555 min: 0 shape: - 128 diff --git a/project/datamodules/datamodules_test/test_first_batch/fashion_mnist.yaml b/project/datamodules/datamodules_test/test_first_batch/fashion_mnist.yaml index df3fc261..0c4df8a7 100644 --- a/project/datamodules/datamodules_test/test_first_batch/fashion_mnist.yaml +++ b/project/datamodules/datamodules_test/test_first_batch/fashion_mnist.yaml @@ -1,20 +1,20 @@ '0': device: cpu hash: -3706536913713083016 - max: 2.821486711502075 - mean: 0.47488248348236084 - min: -0.4242129623889923 + max: 2.821 + mean: 0.475 + min: -0.424 shape: - 128 - 1 - 28 - 28 - sum: 47655.40625 + sum: 47655.406 '1': device: cpu hash: -4023601292826392021 max: 9 - mean: 4.5546875 + mean: 4.555 min: 0 shape: - 128 diff --git a/project/datamodules/datamodules_test/test_first_batch/mnist.yaml b/project/datamodules/datamodules_test/test_first_batch/mnist.yaml index 08ec89f3..6b988420 100644 --- a/project/datamodules/datamodules_test/test_first_batch/mnist.yaml +++ b/project/datamodules/datamodules_test/test_first_batch/mnist.yaml @@ -1,20 +1,20 @@ '0': device: cpu hash: 4338584025941619046 - max: 2.821486711502075 - mean: 0.014241953380405903 - min: -0.4242129623889923 + max: 2.821 + mean: 0.014 + min: -0.424 shape: - 128 - 1 - 28 - 28 - sum: 1429.20849609375 + sum: 1429.208 '1': device: cpu hash: 1596942422053415325 max: 9 - mean: 4.2421875 + mean: 4.242 min: 0 shape: - 128 diff --git a/project/main_test.py b/project/main_test.py index 39cdbfac..12b415d0 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -11,7 +11,7 @@ from project.configs.datamodule import CIFAR10DataModuleConfig from project.conftest import setup_hydra_for_tests_and_compose, use_overrides from project.datamodules.image_classification.cifar10 import CIFAR10DataModule -from project.networks.fcnet import FcNet +from project.networks import FcNetConfig from project.utils.hydra_utils import resolve_dictconfig if typing.TYPE_CHECKING: @@ -77,7 +77,7 @@ def test_setting_algorithm( @pytest.mark.parametrize( ("overrides", "expected_type"), [ - (["algorithm=example_algo", "network=fcnet"], FcNet.HParams), + (["algorithm=example_algo", "network=fcnet"], FcNetConfig), ], ids=_ids, ) diff --git a/project/utils/tensor_regression.py b/project/utils/tensor_regression.py index 601abb59..677c2929 100644 --- a/project/utils/tensor_regression.py +++ b/project/utils/tensor_regression.py @@ -20,6 +20,14 @@ logger = get_logger(__name__) +PRECISION = 3 +"""Number of decimals used when rounding the simple stats of Tensor / ndarray in the pre-check. + +Full precision is used in the actual regression check, but this is just for the simple attributes +(min, max, mean, etc.) which seem to be slightly different on the GitHub CI than on a local +machine. +""" + @functools.singledispatch def to_ndarray(v: Any) -> np.ndarray | None: @@ -85,6 +93,7 @@ def __init__( ndarrays_regression: NDArraysRegressionFixture, data_regression: DataRegressionFixture, monkeypatch: pytest.MonkeyPatch, + simple_attributes_precision: int = PRECISION, ) -> None: self.request = request self.datadir = datadir @@ -93,6 +102,7 @@ def __init__( self.ndarrays_regression = ndarrays_regression self.data_regression = data_regression self.monkeypatch = monkeypatch + self.simple_attributes_precision = simple_attributes_precision self.generate_missing_files: bool | None = self.request.config.getoption( "--gen-missing", default=None, # type: ignore @@ -238,7 +248,9 @@ def check( ) def pre_check(self, data_dict: dict[str, Any], simple_attributes_source_file: Path) -> None: - version_controlled_simple_attributes = get_version_controlled_attributes(data_dict) + version_controlled_simple_attributes = get_version_controlled_attributes( + data_dict, precision=self.simple_attributes_precision + ) # Run the regression check with the hashes (and don't fail if they don't exist) __tracebackhide__ = True # TODO: Figure out how to include/use the names of the GPUs: @@ -325,25 +337,25 @@ def get_test_source_and_temp_file_paths( @functools.singledispatch -def get_simple_attributes(value: Any) -> Any: +def get_simple_attributes(value: Any, precision: int) -> Any: raise NotImplementedError( f"get_simple_attributes doesn't have a registered handler for values of type {type(value)}" ) @get_simple_attributes.register(type(None)) -def _get_none_attributes(value: None): +def _get_none_attributes(value: None, precision: int): return {"type": "None"} @get_simple_attributes.register(bool) @get_simple_attributes.register(int | float | str) -def _get_bool_attributes(value: Any): +def _get_bool_attributes(value: Any, precision: int): return {"value": value, "type": type(value).__name__} @get_simple_attributes.register(list) -def list_simple_attributes(some_list: list[Any]): +def list_simple_attributes(some_list: list[Any], precision: int): return { "length": len(some_list), "item_types": sorted(set(type(item).__name__ for item in some_list)), @@ -351,24 +363,24 @@ def list_simple_attributes(some_list: list[Any]): @get_simple_attributes.register(dict) -def dict_simple_attributes(some_dict: dict[str, Any]): - return {k: get_simple_attributes(v) for k, v in some_dict.items()} +def dict_simple_attributes(some_dict: dict[str, Any], precision: int): + return {k: get_simple_attributes(v, precision=precision) for k, v in some_dict.items()} @get_simple_attributes.register(np.ndarray) -def ndarray_simple_attributes(array: np.ndarray) -> dict: +def ndarray_simple_attributes(array: np.ndarray, precision: int) -> dict: return { "shape": tuple(array.shape), "hash": _hash(array), - "min": array.min().item(), - "max": array.max().item(), - "sum": array.sum().item(), - "mean": array.mean(), + "min": round(array.min().item(), precision), + "max": round(array.max().item(), precision), + "sum": round(array.sum().item(), precision), + "mean": round(array.mean(), precision), } @get_simple_attributes.register(Tensor) -def tensor_simple_attributes(tensor: Tensor) -> dict: +def tensor_simple_attributes(tensor: Tensor, precision: int) -> dict: if tensor.is_nested: # assert not [tensor_i.any() for tensor_i in tensor.unbind()], tensor # TODO: It might be a good idea to make a distinction here between '0' as the default, and @@ -378,10 +390,10 @@ def tensor_simple_attributes(tensor: Tensor) -> dict: return { "shape": tuple(tensor.shape) if not tensor.is_nested else get_shape_ish(tensor), "hash": _hash(tensor), - "min": tensor.min().item(), - "max": tensor.max().item(), - "sum": tensor.sum().item(), - "mean": tensor.float().mean().item(), + "min": round(tensor.min().item(), precision), + "max": round(tensor.max().item(), precision), + "sum": round(tensor.sum().item(), precision), + "mean": round(tensor.float().mean().item(), precision), "device": ( "cpu" if tensor.device.type == "cpu" else f"{tensor.device.type}:{tensor.device.index}" ), @@ -399,8 +411,10 @@ def get_gpu_names(data_dict: dict[str, Any]) -> list[str]: ) -def get_version_controlled_attributes(data_dict: dict[str, Any]) -> dict[str, Any]: - return {key: get_simple_attributes(value) for key, value in data_dict.items()} +def get_version_controlled_attributes(data_dict: dict[str, Any], precision: int) -> dict[str, Any]: + return { + key: get_simple_attributes(value, precision=precision) for key, value in data_dict.items() + } class FilesDidntExist(Failed): diff --git a/project/utils/utils.py b/project/utils/utils.py index c8914661..8f02c6b8 100644 --- a/project/utils/utils.py +++ b/project/utils/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Mapping, Sequence from dataclasses import field from logging import getLogger as get_logger from pathlib import Path @@ -22,7 +22,7 @@ ) from project.utils.types.protocols import DataModule, Module -from .types import NestedDict +from .types import NestedDict, NestedMapping logger = get_logger(__name__) @@ -221,7 +221,7 @@ def print_config( # rich.print(tree, file=file) -def flatten[K, V](nested: NestedDict[K, V]) -> dict[tuple[K, ...], V]: +def flatten[K, V](nested: NestedMapping[K, V]) -> dict[tuple[K, ...], V]: """Flatten a dictionary of dictionaries. The returned dictionary's keys are tuples, one entry per layer. @@ -230,7 +230,7 @@ def flatten[K, V](nested: NestedDict[K, V]) -> dict[tuple[K, ...], V]: """ flattened: dict[tuple[K, ...], V] = {} for k, v in nested.items(): - if isinstance(v, dict): + if isinstance(v, Mapping): for subkeys, subv in flatten(v).items(): collision_key = (k, *subkeys) assert collision_key not in flattened @@ -257,7 +257,7 @@ def unflatten[K, V](flattened: dict[tuple[K, ...], V]) -> NestedDict[K, V]: return nested -def flatten_dict[V](nested: NestedDict[str, V], sep: str = ".") -> dict[str, V]: +def flatten_dict[V](nested: NestedMapping[str, V], sep: str = ".") -> dict[str, V]: """Flatten a dictionary of dictionaries. Joins different nesting levels with `sep` as separator. diff --git a/pyproject.toml b/pyproject.toml index 21d81221..e07e050a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dev = [ "pytest-xdist>=3.5.0", "ruff>=0.3.3", "pytest-benchmark>=4.0.0", + "pytest-cov>=5.0.0", ] [[tool.pdm.source]]