diff --git a/.cspell.json b/.cspell.json index 65c8fb42..34d65eac 100644 --- a/.cspell.json +++ b/.cspell.json @@ -53,6 +53,7 @@ "PyPA", "pytest", "PYTHONHASHSEED", + "QRules", "rtoml", "sympy", "toctree", diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8feda57c..6d8079bd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ on: jobs: doc: - uses: ComPWA/actions/.github/workflows/ci-docs.yml@v2 + uses: ComPWA/actions/.github/workflows/ci-docs.yml@v2.1 permissions: pages: write id-token: write @@ -36,7 +36,7 @@ jobs: gh-pages: true specific-pip-packages: ${{ inputs.specific-pip-packages }} pytest: - uses: ComPWA/actions/.github/workflows/pytest.yml@v2 + uses: ComPWA/actions/.github/workflows/pytest.yml@v2.1 with: coverage-target: compwa_policy macos-python-version: "3.9" @@ -45,4 +45,4 @@ jobs: if: inputs.specific-pip-packages == '' secrets: token: ${{ secrets.PAT }} - uses: ComPWA/actions/.github/workflows/pre-commit.yml@v2 + uses: ComPWA/actions/.github/workflows/pre-commit.yml@v2.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aeaef5f3..c942a86f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: check-useless-excludes - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.1 hooks: - id: ruff args: [--fix] @@ -78,7 +78,7 @@ repos: exclude: (?x)^(labels/.*\.toml)$ - repo: https://github.com/streetsidesoftware/cspell-cli - rev: v8.15.1 + rev: v8.15.2 hooks: - id: cspell @@ -104,6 +104,6 @@ repos: - python - repo: https://github.com/ComPWA/pyright-pre-commit - rev: v1.1.384 + rev: v1.1.386 hooks: - id: pyright diff --git a/pyproject.toml b/pyproject.toml index a46369dd..d976275b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,28 +29,43 @@ maintainers = [{email = "compwa-admin@ep1.rub.de"}] name = "compwa-policy" requires-python = ">=3.9" -[project.optional-dependencies] +[project.readme] +content-type = "text/markdown" +file = "README.md" + +[project.scripts] +check-dev-files = "compwa_policy.check_dev_files:main" +colab-toc-visible = "compwa_policy.colab_toc_visible:main" +fix-nbformat-version = "compwa_policy.fix_nbformat_version:main" +remove-empty-tags = "compwa_policy.remove_empty_tags:main" +self-check = "compwa_policy.self_check:main" +set-nb-cells = "compwa_policy.set_nb_cells:main" + +[project.urls] +Source = "https://github.com/ComPWA/policy" +Tracker = "https://github.com/ComPWA/policy/issues" + +[dependency-groups] dev = [ - "compwa-policy[doc]", - "compwa-policy[sty]", - "compwa-policy[test]", "labels", + "ruff", "sphinx-autobuild", + {include-group = "doc"}, + {include-group = "style"}, + {include-group = "test"}, ] doc = [ - "Sphinx", "myst-parser", - "sphinx-api-relink >=0.0.4", + "sphinx", "sphinx-api-relink", "sphinx-argparse", "sphinx-book-theme", "sphinx-codeautolink", "sphinx-copybutton", ] -sty = [ - "compwa-policy[types]", +style = [ "mypy", - "ruff", + {include-group = "types"}, ] test = [ "pytest", @@ -59,27 +74,11 @@ test = [ ] types = [ "pytest", - "sphinx-api-relink >=0.0.3", + "sphinx-api-relink", "types-PyYAML", "types-toml", ] -[project.readme] -content-type = "text/markdown" -file = "README.md" - -[project.scripts] -check-dev-files = "compwa_policy.check_dev_files:main" -colab-toc-visible = "compwa_policy.colab_toc_visible:main" -fix-nbformat-version = "compwa_policy.fix_nbformat_version:main" -remove-empty-tags = "compwa_policy.remove_empty_tags:main" -self-check = "compwa_policy.self_check:main" -set-nb-cells = "compwa_policy.set_nb_cells:main" - -[project.urls] -Source = "https://github.com/ComPWA/policy" -Tracker = "https://github.com/ComPWA/policy/issues" - [tool.setuptools] include-package-data = false license-files = ["LICENSE"] @@ -169,6 +168,8 @@ reportUnusedFunction = true reportUnusedImport = true reportUnusedVariable = true typeCheckingMode = "strict" +venv = ".venv" +venvPath = "." [tool.pytest.ini_options] addopts = [ @@ -358,8 +359,6 @@ setenv = allowlist_externals = pre-commit commands = - pre-commit run {posargs} --all-files + pre-commit run --all-files {posargs} description = Perform all linting, formatting, and spelling checks -setenv = - SKIP = pyright """ diff --git a/src/compwa_policy/.github/workflows/ci.yml b/src/compwa_policy/.github/workflows/ci.yml index ae13957f..6840cf0f 100644 --- a/src/compwa_policy/.github/workflows/ci.yml +++ b/src/compwa_policy/.github/workflows/ci.yml @@ -28,18 +28,18 @@ on: jobs: doc: - uses: ComPWA/actions/.github/workflows/ci-docs.yml@v2 + uses: ComPWA/actions/.github/workflows/ci-docs.yml@v2.1 permissions: pages: write id-token: write with: specific-pip-packages: ${{ inputs.specific-pip-packages }} pytest: - uses: ComPWA/actions/.github/workflows/pytest.yml@v2 + uses: ComPWA/actions/.github/workflows/pytest.yml@v2.1 with: specific-pip-packages: ${{ inputs.specific-pip-packages }} style: if: inputs.specific-pip-packages == '' secrets: token: ${{ secrets.PAT }} - uses: ComPWA/actions/.github/workflows/pre-commit.yml@v2 + uses: ComPWA/actions/.github/workflows/pre-commit.yml@v2.1 diff --git a/src/compwa_policy/check_dev_files/binder.py b/src/compwa_policy/check_dev_files/binder.py index 6ce23101..2999b735 100644 --- a/src/compwa_policy/check_dev_files/binder.py +++ b/src/compwa_policy/check_dev_files/binder.py @@ -8,15 +8,15 @@ import os from dataclasses import dataclass from textwrap import dedent -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from compwa_policy.errors import PrecommitError from compwa_policy.utilities import CONFIG_PATH from compwa_policy.utilities.executor import Executor -from compwa_policy.utilities.match import git_ls_files from compwa_policy.utilities.pyproject import Pyproject if TYPE_CHECKING: + from collections.abc import Mapping from pathlib import Path from compwa_policy.check_dev_files.conda import PackageManagerChoice @@ -84,28 +84,16 @@ def __get_post_builder_for_pixi_with_uv() -> str: for script in activation.scripts: expected_content += "\nbash " + script expected_content += "\npixi clean cache --yes\n" - notebook_extras = __get_notebook_extras() - if "uv.lock" in set(git_ls_files(untracked=True)): - expected_content += "\nuv export \\" - for extra in notebook_extras: - expected_content += f"\n --extra {extra} \\" - expected_content += dedent(R""" - > requirements.txt - uv pip install \ - --requirement requirements.txt \ - --system - uv cache clean - """) - else: - package = "." - if notebook_extras: - package = f"'.[{','.join(notebook_extras)}]'" - expected_content += dedent(Rf""" - uv pip install \ - --editable {package} \ - --no-cache \ - --system - """) + expected_content += "\nuv export \\" + for groups in __get_notebook_groups(): + expected_content += f"\n --group {groups} \\" + expected_content += dedent(R""" + > requirements.txt + uv pip install \ + --requirement requirements.txt \ + --system + uv cache clean + """) return expected_content @@ -135,41 +123,33 @@ def __get_post_builder_for_uv() -> str: curl -LsSf https://astral.sh/uv/install.sh | sh source $HOME/.cargo/env """).strip() - notebook_extras = __get_notebook_extras() - if "uv.lock" in set(git_ls_files(untracked=True)): - expected_content += "\nuv export \\" - for extra in notebook_extras: - expected_content += f"\n --extra {extra} \\" - expected_content += dedent(R""" - > requirements.txt - uv pip install \ - --requirement requirements.txt \ - --system - uv cache clean - """) - else: - package = "." - if notebook_extras: - package = f"'.[{','.join(notebook_extras)}]'" - expected_content += dedent(Rf""" - uv pip install \ - --editable {package} \ - --no-cache \ - --system - """) + expected_content += "\nuv export \\" + for group in __get_notebook_groups(): + expected_content += f"\n --group {group} \\" + expected_content += dedent(R""" + > requirements.txt + uv pip install \ + --requirement requirements.txt \ + --system + rm requirements.txt + uv cache clean + """) return expected_content -def __get_notebook_extras() -> list[str]: +def __get_notebook_groups() -> list[str]: + dependency_groups = ___safe_get_table("dependency-groups") + allowed_groups = {"jupyter", "notebooks"} + return sorted(allowed_groups & set(dependency_groups)) + + +def ___safe_get_table(dotted_header: str) -> Mapping[str, Any]: if not CONFIG_PATH.pyproject.exists(): - return [] + return {} pyproject = Pyproject.load() - table_key = "project.optional-dependencies" - if not pyproject.has_table(table_key): - return [] - optional_dependencies = pyproject.get_table(table_key) - allowed_sections = {"jupyter", "notebooks"} - return sorted(allowed_sections & set(optional_dependencies)) + if not pyproject.has_table(dotted_header): + return {} + return pyproject.get_table(dotted_header) def _make_executable(path: Path) -> None: diff --git a/src/compwa_policy/check_dev_files/jupyter.py b/src/compwa_policy/check_dev_files/jupyter.py index 201712f3..3de6c15d 100644 --- a/src/compwa_policy/check_dev_files/jupyter.py +++ b/src/compwa_policy/check_dev_files/jupyter.py @@ -35,4 +35,4 @@ def _update_dev_requirements(no_ruff: bool) -> None: } packages.update(ruff_packages) for package in sorted(packages): - pyproject.add_dependency(package, optional_key=["jupyter", "dev"]) + pyproject.add_dependency(package, dependency_group=["jupyter", "dev"]) diff --git a/src/compwa_policy/check_dev_files/pyproject.py b/src/compwa_policy/check_dev_files/pyproject.py index 6baba1b6..6749707e 100644 --- a/src/compwa_policy/check_dev_files/pyproject.py +++ b/src/compwa_policy/check_dev_files/pyproject.py @@ -2,6 +2,8 @@ from __future__ import annotations +import re + from compwa_policy.utilities import CONFIG_PATH from compwa_policy.utilities.pyproject import ModifiablePyproject from compwa_policy.utilities.pyproject.getters import ( @@ -15,10 +17,105 @@ def main(excluded_python_versions: set[str], no_pypi: bool) -> None: if not CONFIG_PATH.pyproject.exists(): return with ModifiablePyproject.load() as pyproject: + _convert_to_dependency_groups(pyproject) + _rename_sty_to_style(pyproject) _update_requires_python(pyproject) _update_python_version_classifiers(pyproject, excluded_python_versions, no_pypi) +def _convert_to_dependency_groups(pyproject: ModifiablePyproject) -> None: + table_key = "project.optional-dependencies" + if not pyproject.has_table(table_key): + return + optional_dependencies = pyproject.get_table(table_key) + dependency_groups = pyproject.get_table("dependency-groups", create=True) + dev_groups = { + "dev", + "doc", + "jupyter", + "lint", + "mypy", + "notebooks", + "sty", + "style", + "test", + "types", + } + package_name = pyproject.get_package_name() + updated = False + for group, dependencies in dict(optional_dependencies).items(): + if group not in dev_groups: + continue + dependencies = __convert_to_dependency_group( + dependencies, package_name, dev_groups + ) + dependency_groups[group] = to_toml_array(dependencies) + optional_dependencies.pop(group) + updated = True + if len(optional_dependencies) == 0: + del pyproject.get_table("project")["optional-dependencies"] + if updated: + msg = "Converted optional-dependencies to dependency-groups" + pyproject.changelog.append(msg) + + +def _rename_sty_to_style(pyproject: ModifiablePyproject) -> None: + dependency_groups = pyproject.get_table("dependency-groups", create=True) + if "sty" not in dependency_groups: + return + dependency_groups["style"] = to_toml_array(dependency_groups["sty"]) + del dependency_groups["sty"] + for dependencies in dependency_groups.values(): + for dependency in dependencies: + if not isinstance(dependency, dict): + continue + include_group = dependency.get("include-group") + if include_group == "sty": + dependency["include-group"] = "style" + pyproject.changelog.append("Renamed 'sty' dependency group to 'style'") + + +def __convert_to_dependency_group( + dependencies: list[str], package_name: str | None, dev_dependencies: set[str] +) -> list[str | dict]: + """Convert a list of optional dependencies to a dependency group. + + >>> __convert_to_dependency_group( + ... ["qrules[dev]", "qrules[viz]", "mypy"], + ... package_name="qrules", + ... dev_dependencies={"dev"}, + ... ) + [{'include-group': 'dev'}, 'qrules[viz]', 'mypy'] + """ + new_dependencies = [] + for dependency in dependencies: + converted = __convert_to_include(dependency, package_name, dev_dependencies) + if converted is not None: + new_dependencies.append(converted) + return new_dependencies + + +def __convert_to_include( + dependency: str, package_name: str | None, dev_dependencies: set[str] +) -> str | dict | None: + """Convert a recursive optional dependency to an include group entry. + + >>> __convert_to_include("compwa-policy[dev]", "compwa-policy", {"dev"}) + {'include-group': 'dev'} + >>> __convert_to_include("ruff", "compwa-policy", {"dev"}) + 'ruff' + >>> __convert_to_include("qrules[viz]", "qrules", {"dev"}) + 'qrules[viz]' + """ + if package_name is not None: + matches = re.match(rf"{package_name}\[(.+)\]", dependency) + if matches is not None: + include_name = matches.group(1) + if include_name in dev_dependencies: + return {"include-group": include_name} + return dependency + + def _update_requires_python(pyproject: ModifiablePyproject) -> None: if not pyproject.has_table("project"): return diff --git a/src/compwa_policy/check_dev_files/readthedocs.py b/src/compwa_policy/check_dev_files/readthedocs.py index d31e5706..af142c54 100644 --- a/src/compwa_policy/check_dev_files/readthedocs.py +++ b/src/compwa_policy/check_dev_files/readthedocs.py @@ -109,7 +109,7 @@ def __get_pixi_packages(cmd: str) -> list[str] | None: def _install_pixi(config: ReadTheDocs, packages: set[str]) -> None: pixi_cmd = __get_pixi_install_statement() if packages: - pixi_cmd += "\n" f"pixi global install {' '.join(sorted(packages))}" + pixi_cmd += f"\npixi global install {' '.join(sorted(packages))}" commands = __get_commands(config) idx: int | None = None for i, cmd in enumerate(commands): @@ -178,7 +178,7 @@ def _update_build_step_for_pixi(config: ReadTheDocs) -> None: export UV_LINK_MODE=copy pixi run \ uv run \ - --extra doc \ + --group doc \ --locked \ --with tox \ tox -e doc @@ -196,12 +196,12 @@ def _update_build_step_for_uv(config: ReadTheDocs) -> None: new_command = "export UV_LINK_MODE=copy" if "uv.lock" in set(git_ls_files(untracked=True)): new_command += dedent(R""" - uv run --extra doc --locked --with tox \ + uv run --group doc --locked --with tox \ tox -e doc """) else: new_command += dedent(R""" - uv run --extra doc --with tox \ + uv run --group doc --with tox \ tox -e doc """) new_command += dedent(R""" @@ -271,7 +271,7 @@ def __get_install_steps( pip_install = "python -m uv pip install" constraints_file = get_constraints_file(python_version) if package_manager == "uv": - install_statement = "python -m uv sync --extra=doc" + install_statement = "python -m uv sync --group=doc" elif constraints_file is None: install_statement = f"{pip_install} -e .[doc]" else: diff --git a/src/compwa_policy/check_dev_files/ruff.py b/src/compwa_policy/check_dev_files/ruff.py index 8e60e574..ff5c2c4f 100644 --- a/src/compwa_policy/check_dev_files/ruff.py +++ b/src/compwa_policy/check_dev_files/ruff.py @@ -260,7 +260,7 @@ def ___get_target_version(pyproject: Pyproject) -> str: 'py39' """ supported_python_versions = pyproject.get_supported_python_versions() - versions = {f'py{v.replace(".", "")}' for v in supported_python_versions} + versions = {f"py{v.replace('.', '')}" for v in supported_python_versions} versions &= {"py37", "py38", "py39", "py310", "py311", "py312"} if not versions: return "py37" @@ -631,7 +631,8 @@ def _update_lint_dependencies(pyproject: ModifiablePyproject) -> None: ruff = 'ruff; python_version >="3.7.0"' else: ruff = "ruff" - pyproject.add_dependency(ruff, optional_key=["sty", "dev"]) + pyproject.add_dependency(ruff, dependency_group="dev") + pyproject.remove_dependency(ruff, ignored_sections=["dev"]) def _update_vscode_settings() -> None: diff --git a/src/compwa_policy/check_dev_files/update_lock.py b/src/compwa_policy/check_dev_files/update_lock.py index 707e18dd..a39a79ac 100644 --- a/src/compwa_policy/check_dev_files/update_lock.py +++ b/src/compwa_policy/check_dev_files/update_lock.py @@ -70,7 +70,7 @@ def overwrite_workflow(workflow_file: str) -> None: if not existing_paths: msg = ( "No paths defined for pull_request trigger. Expecting any of " - ", ".join(original_paths) + + ", ".join(original_paths) ) raise ValueError(msg) expected_data["on"]["pull_request"]["paths"] = existing_paths diff --git a/src/compwa_policy/utilities/pyproject/__init__.py b/src/compwa_policy/utilities/pyproject/__init__.py index fcc6fc94..0b4d6c9b 100644 --- a/src/compwa_policy/utilities/pyproject/__init__.py +++ b/src/compwa_policy/utilities/pyproject/__init__.py @@ -191,10 +191,15 @@ def get_table( return super().get_table(dotted_header) # type:ignore[return-value] def add_dependency( - self, package: str, optional_key: str | Sequence[str] | None = None + self, + package: str, + dependency_group: str | Sequence[str] | None = None, + optional_key: str | Sequence[str] | None = None, ) -> None: self.__assert_is_in_context() - updated = add_dependency(self._document, package, optional_key) + updated = add_dependency( + self._document, package, dependency_group, optional_key + ) if updated: msg = f"Listed {package} as a dependency" self._changelog.append(msg) diff --git a/src/compwa_policy/utilities/pyproject/_struct.py b/src/compwa_policy/utilities/pyproject/_struct.py index 7adc7984..33f79138 100644 --- a/src/compwa_policy/utilities/pyproject/_struct.py +++ b/src/compwa_policy/utilities/pyproject/_struct.py @@ -12,11 +12,13 @@ else: from typing import NotRequired +IncludeGroup = TypedDict("IncludeGroup", {"include-group": str}) PyprojectTOML = TypedDict( "PyprojectTOML", { "build-system": NotRequired["BuildSystem"], "project": "Project", + "dependency-groups": NotRequired[dict[str, list[str | IncludeGroup]]], "tool": NotRequired[dict[str, dict[str, str]]], }, ) diff --git a/src/compwa_policy/utilities/pyproject/setters.py b/src/compwa_policy/utilities/pyproject/setters.py index c4ac5bbc..2d8d6249 100644 --- a/src/compwa_policy/utilities/pyproject/setters.py +++ b/src/compwa_policy/utilities/pyproject/setters.py @@ -2,6 +2,7 @@ from __future__ import annotations +import itertools import re from collections import abc from collections.abc import Iterable, Mapping, MutableMapping, Sequence @@ -18,22 +19,64 @@ if TYPE_CHECKING: from tomlkit.items import Table - from compwa_policy.utilities.pyproject._struct import PyprojectTOML + from compwa_policy.utilities.pyproject._struct import IncludeGroup, PyprojectTOML def add_dependency( pyproject: PyprojectTOML, package: str, + dependency_group: str | Sequence[str] | None = None, optional_key: str | Sequence[str] | None = None, ) -> bool: - if optional_key is None: - project = get_sub_table(pyproject, "project") - existing_dependencies: set[str] = set(project.get("dependencies", [])) - if package in existing_dependencies: + if optional_key is None and dependency_group is None: + return _add_direct_dependency(pyproject, package) + if dependency_group is not None: + return _add_to_dependency_group(pyproject, package, dependency_group) + if optional_key is not None: + return _add_to_optional_dependencies(pyproject, package, optional_key) + return False + + +def _add_direct_dependency(pyproject: PyprojectTOML, package: str) -> bool: + project = get_sub_table(pyproject, "project") + existing_dependencies = set(project.get("dependencies", [])) + if package in existing_dependencies: + return False + existing_dependencies.add(package) + project["dependencies"] = to_toml_array(_sort_taplo(existing_dependencies)) + return True + + +def _add_to_dependency_group( + pyproject: PyprojectTOML, package: str, dependency_group: str | Sequence[str] +) -> bool: + if "dependency-groups" not in pyproject: + pyproject["dependency-groups"] = tomlkit.table(is_super_table=False) + dependency_groups = pyproject["dependency-groups"] + if isinstance(dependency_group, str): + dependencies = dependency_groups.get(dependency_group, []) + if package in dependencies: return False - existing_dependencies.add(package) - project["dependencies"] = to_toml_array(_sort_taplo(existing_dependencies)) + dependencies.append(package) + dependency_groups[dependency_group] = to_toml_array(dependencies) return True + if isinstance(dependency_group, abc.Sequence) and len(dependency_group): + updated = add_dependency(pyproject, package, dependency_group[0]) + for previous, current in itertools.pairwise(dependency_group): + dependencies = dependency_groups.get(current, []) + expected: IncludeGroup = {"include-group": previous} + if expected in dependencies: + continue + updated &= True + dependencies.append(expected) + return updated + msg = f"Unsupported type for dependency group: {type(dependency_group)}" + raise NotImplementedError(msg) + + +def _add_to_optional_dependencies( + pyproject: PyprojectTOML, package: str, optional_key: str | Sequence[str] +) -> bool: if isinstance(optional_key, str): table_key = "project.optional-dependencies" optional_dependencies = get_sub_table(pyproject, table_key) @@ -43,20 +86,20 @@ def add_dependency( existing_dependencies.add(package) existing_dependencies = set(existing_dependencies) optional_dependencies[optional_key] = to_toml_array( - _sort_taplo(existing_dependencies) + _sort_taplo(existing_dependencies) # type:ignore[arg-type] ) return True if isinstance(optional_key, abc.Sequence): - if len(optional_key) < 2: # noqa: PLR2004 - msg = "Need at least two keys to define nested optional dependencies" + if len(optional_key) == 0: + msg = "Need at least one key to define nested optional dependencies" raise ValueError(msg) this_package = get_package_name(pyproject, raise_on_missing=True) updated = False - for key, previous in zip(optional_key, [None, *optional_key]): - if previous is None: - updated &= add_dependency(pyproject, package, key) - else: - updated &= add_dependency(pyproject, f"{this_package}[{previous}]", key) + updated &= add_dependency(pyproject, package, optional_key=optional_key[0]) + for previous, key in itertools.pairwise(optional_key): + updated &= add_dependency( + pyproject, f"{this_package}[{previous}]", optional_key=key + ) return updated msg = f"Unsupported type for optional_key: {type(optional_key)}" raise NotImplementedError(msg) @@ -86,7 +129,7 @@ def get_sub_table( return cast(MutableMapping[str, Any], table) -def remove_dependency( # noqa: C901 +def remove_dependency( # noqa: C901, PLR0912 pyproject: PyprojectTOML, package: str, ignored_sections: Iterable[str] | None = None, @@ -102,12 +145,12 @@ def remove_dependency( # noqa: C901 idx = package_names.index(package) dependencies.pop(idx) updated = True + if ignored_sections is None: + ignored_sections = set() + else: + ignored_sections = set(ignored_sections) optional_dependencies = project.get("optional-dependencies") if optional_dependencies is not None: - if ignored_sections is None: - ignored_sections = set() - else: - ignored_sections = set(ignored_sections) for section, dependencies in optional_dependencies.items(): if section in ignored_sections: continue @@ -120,6 +163,27 @@ def remove_dependency( # noqa: C901 empty_sections = [k for k, v in optional_dependencies.items() if not v] for section in empty_sections: del optional_dependencies[section] + if not optional_dependencies: + del project["optional-dependencies"] + dependency_groups = pyproject.get("dependency-groups") + if dependency_groups is not None: + for section, dependencies in dependency_groups.items(): # type:ignore[assignment] + if section in ignored_sections: + continue + package_names = [ + split_dependency_definition(p)[0] if isinstance(p, str) else p + for p in dependencies # type:ignore[union-attr] + ] + if package in package_names: + idx = package_names.index(package) + dependencies.pop(idx) # type:ignore[union-attr] + updated = True + if updated: + empty_sections = [k for k, v in dependency_groups.items() if not v] + for section in empty_sections: + del dependency_groups[section] + if not dependency_groups: + del pyproject["dependency-groups"] return updated diff --git a/tests/utilities/pyproject/test_setters.py b/tests/utilities/pyproject/test_setters.py index 50ab2d6c..bb9a33c8 100644 --- a/tests/utilities/pyproject/test_setters.py +++ b/tests/utilities/pyproject/test_setters.py @@ -51,8 +51,21 @@ def test_add_dependency_nested(): name = "my-package" """) pyproject = load_pyproject_toml(src, modifiable=True) - add_dependency(pyproject, "ruff", optional_key=["lint", "sty", "dev"]) + add_dependency(pyproject, "ruff", optional_key=["lint", "style", "dev"]) + new_content = tomlkit.dumps(pyproject) + expected = dedent(""" + [project] + name = "my-package" + + [project.optional-dependencies] + lint = ["ruff"] + style = ["my-package[lint]"] + dev = ["my-package[style]"] + """) + assert new_content == expected + pyproject = load_pyproject_toml(src, modifiable=True) + add_dependency(pyproject, "ruff", optional_key=["lint"]) new_content = tomlkit.dumps(pyproject) expected = dedent(""" [project] @@ -60,8 +73,6 @@ def test_add_dependency_nested(): [project.optional-dependencies] lint = ["ruff"] - sty = ["my-package[lint]"] - dev = ["my-package[sty]"] """) assert new_content == expected @@ -97,7 +108,7 @@ def pyproject_example() -> PyprojectTOML: "mypy", "ruff", ] - sty = ["ruff"] + style = ["ruff"] """) return load_pyproject_toml(src, modifiable=True) @@ -114,14 +125,14 @@ def test_remove_dependency(pyproject_example: PyprojectTOML): "mypy", "ruff", ] - sty = ["ruff"] + style = ["ruff"] """) new_content = tomlkit.dumps(pyproject_example) assert new_content == expected def test_remove_dependency_nested(pyproject_example: PyprojectTOML): - remove_dependency(pyproject_example, "ruff", ignored_sections=["sty"]) + remove_dependency(pyproject_example, "ruff", ignored_sections=["sty", "style"]) new_content = tomlkit.dumps(pyproject_example) expected = dedent(""" [project] @@ -132,7 +143,7 @@ def test_remove_dependency_nested(pyproject_example: PyprojectTOML): lint = [ "mypy", ] - sty = ["ruff"] + style = ["ruff"] """) assert new_content == expected