From 37fd403d44042cefe9d5e8ac25b4c46dd2ed54e0 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:07:06 +0100 Subject: [PATCH] FIX: use `itertools.pairwise()` correctly --- src/compwa_policy/utilities/pyproject/setters.py | 12 ++++++------ tests/utilities/pyproject/test_setters.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/compwa_policy/utilities/pyproject/setters.py b/src/compwa_policy/utilities/pyproject/setters.py index 361b79ca..2d8d6249 100644 --- a/src/compwa_policy/utilities/pyproject/setters.py +++ b/src/compwa_policy/utilities/pyproject/setters.py @@ -90,16 +90,16 @@ def _add_to_optional_dependencies( ) return True if isinstance(optional_key, abc.Sequence): - if len(optional_key) < 2: # noqa: PLR2004 - msg = "Need at least two keys to define nested optional dependencies" + if len(optional_key) == 0: + msg = "Need at least one key to define nested optional dependencies" raise ValueError(msg) this_package = get_package_name(pyproject, raise_on_missing=True) updated = False + updated &= add_dependency(pyproject, package, optional_key=optional_key[0]) for previous, key in itertools.pairwise(optional_key): - if previous is None: - updated &= add_dependency(pyproject, package, key) - else: - updated &= add_dependency(pyproject, f"{this_package}[{previous}]", key) + updated &= add_dependency( + pyproject, f"{this_package}[{previous}]", optional_key=key + ) return updated msg = f"Unsupported type for optional_key: {type(optional_key)}" raise NotImplementedError(msg) diff --git a/tests/utilities/pyproject/test_setters.py b/tests/utilities/pyproject/test_setters.py index 349d116d..bb9a33c8 100644 --- a/tests/utilities/pyproject/test_setters.py +++ b/tests/utilities/pyproject/test_setters.py @@ -52,7 +52,6 @@ def test_add_dependency_nested(): """) pyproject = load_pyproject_toml(src, modifiable=True) add_dependency(pyproject, "ruff", optional_key=["lint", "style", "dev"]) - new_content = tomlkit.dumps(pyproject) expected = dedent(""" [project] @@ -65,6 +64,18 @@ def test_add_dependency_nested(): """) assert new_content == expected + pyproject = load_pyproject_toml(src, modifiable=True) + add_dependency(pyproject, "ruff", optional_key=["lint"]) + new_content = tomlkit.dumps(pyproject) + expected = dedent(""" + [project] + name = "my-package" + + [project.optional-dependencies] + lint = ["ruff"] + """) + assert new_content == expected + def test_add_dependency_optional(): src = dedent("""