From 4cb38218be5d0621603e5b4bc9eec1578160d0b1 Mon Sep 17 00:00:00 2001 From: tandemdude <43570299+tandemdude@users.noreply.github.com> Date: Wed, 28 Aug 2024 23:34:51 +0100 Subject: [PATCH] fix: container dependency resolution refactor, bump pyright dependency --- lightbulb/di/container.py | 63 +++++++++++++------------------------ lightbulb/di/solver.py | 4 +++ noxfile.py | 12 +++---- pyproject.toml | 2 +- tests/di/test_container.py | 2 +- tests/prefab/test_checks.py | 8 ++--- 6 files changed, 38 insertions(+), 53 deletions(-) diff --git a/lightbulb/di/container.py b/lightbulb/di/container.py index 179cd425..199c438d 100644 --- a/lightbulb/di/container.py +++ b/lightbulb/di/container.py @@ -22,6 +22,7 @@ __all__ = ["Container"] +import logging import typing as t import networkx as nx @@ -39,6 +40,7 @@ from lightbulb.internal import types as lb_types T = t.TypeVar("T") +LOGGER = logging.getLogger(__name__) class Container: @@ -57,29 +59,15 @@ def __init__( ) -> None: self._registry = registry self._registry._freeze(self) + self._parent = parent self._tag = tag self._closed = False - self._graph: nx.DiGraph[str] = nx.DiGraph(self._parent._graph) if self._parent is not None else nx.DiGraph() + self._graph: nx.DiGraph[str] = nx.DiGraph(self._registry._graph) self._instances: dict[str, t.Any] = {} - # Add our registry entries to the graphs - for node, node_data in self._registry._graph.nodes.items(): - new_node_data = dict(node_data) - - # Set the origin container if this is a concrete dependency instead of a transient one - if node_data.get("factory") is not None: - new_node_data["container"] = self - - # If we are overriding a previously defined dependency with our own - if node in self._graph and node_data.get("factory") is not None: - self._graph.remove_edges_from(list(self._graph.out_edges(node))) - - self._graph.add_node(node, **new_node_data) - self._graph.add_edges_from(self._registry._graph.edges) - self.add_value(Container, self) async def __aenter__(self) -> Container: @@ -156,7 +144,7 @@ def add_value( if dependency_id in self._graph: self._graph.remove_edges_from(list(self._graph.out_edges(dependency_id))) - self._graph.add_node(dependency_id, container=self, teardown=teardown) + self._graph.add_node(dependency_id, factory=lambda: None, teardown=teardown) async def _get(self, dependency_id: str) -> t.Any: if self._closed: @@ -164,15 +152,17 @@ async def _get(self, dependency_id: str) -> t.Any: # TODO - look into whether locking is necessary - how likely are we to have race conditions - data = self._graph.nodes.get(dependency_id) - if data is None or data.get("container") is None: - raise exceptions.DependencyNotSatisfiableException( - f"could not create dependency {dependency_id!r} - not provided by this or a parent container" - ) + if (existing := self._instances.get(dependency_id)) is not None: + return existing - existing_dependency = data["container"]._instances.get(dependency_id) - if existing_dependency is not None: - return existing_dependency + if (data := self._graph.nodes.get(dependency_id)) is None or data.get("factory") is None: + if self._parent is None: + raise exceptions.DependencyNotSatisfiableException( + f"cannot create dependency {dependency_id!r} - not provided by this or a parent container" + ) + + LOGGER.debug("dependency %r not provided by this container - checking parent", dependency_id) + return await self._parent._get(dependency_id) # TODO - look into caching individual dependency creation order globally # - may speed up using subsequent containers (i.e. for each command) @@ -181,47 +171,38 @@ async def _get(self, dependency_id: str) -> t.Any: assert isinstance(subgraph, nx.DiGraph) try: - creation_order = reversed(list(nx.topological_sort(subgraph))) + creation_order = list(reversed(list(nx.topological_sort(subgraph)))) except nx.NetworkXUnfeasible: raise exceptions.CircularDependencyException( f"cannot provide {dependency_id!r} - circular dependency found during creation" ) + LOGGER.debug("dependency %r depends on %s", dependency_id, creation_order) for dep_id in creation_order: - if (container := self._graph.nodes[dep_id].get("container")) is None: - raise exceptions.DependencyNotSatisfiableException( - f"could not create dependency {dep_id!r} - not provided by this or a parent container" - ) - # We already have the dependency we need - if dep_id in container._instances: + if dep_id in self._instances: continue node_data = self._graph.nodes[dep_id] - # Check that we actually know how to create the dependency - this should have been caught earlier - # by checking that node["container"] was present - but just in case, we check for the factory if node_data.get("factory") is None: raise exceptions.DependencyNotSatisfiableException( f"could not create dependency {dep_id!r} - do not know how to instantiate" ) - # Get the dependencies for this dependency from the container this dependency was defined in. - # This prevents 'scope promotion' - a dependency from the parent container requiring one from the - # child container, and hence the lifecycle of the child dependency being extended to - # that of the parent. sub_dependencies: dict[str, t.Any] = {} try: + LOGGER.debug("checking sub-dependencies for %r", dep_id) for sub_dependency_id, param_name in node_data["factory_params"].items(): - sub_dependencies[param_name] = await node_data["container"]._get(sub_dependency_id) + sub_dependencies[param_name] = await self._get(sub_dependency_id) except exceptions.DependencyNotSatisfiableException as e: raise exceptions.DependencyNotSatisfiableException( f"could not create dependency {dep_id!r} - failed creating sub-dependency" ) from e # Cache the created dependency in the correct container to ensure the correct lifecycle - container._instances[dep_id] = await utils.maybe_await(node_data["factory"](**sub_dependencies)) + self._instances[dep_id] = await utils.maybe_await(node_data["factory"](**sub_dependencies)) - return self._graph.nodes[dependency_id]["container"]._instances[dependency_id] + return self._instances[dependency_id] async def get(self, typ: type[T]) -> T: """ diff --git a/lightbulb/di/solver.py b/lightbulb/di/solver.py index 0ed2e9ad..935a0ea1 100644 --- a/lightbulb/di/solver.py +++ b/lightbulb/di/solver.py @@ -389,6 +389,10 @@ async def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: LOGGER.debug("requesting dependency for type %r", type) new_kwargs[name] = await di_container.get(type) + if len(new_kwargs) > len(kwargs): + func_name = ((self._self.__class__.__name__ + ".") if self._self else "") + self._func.__name__ + LOGGER.debug("calling function %r with resolved dependencies", func_name) + if self._self is not None: return await utils.maybe_await(self._func(self._self, *args, **new_kwargs)) return await utils.maybe_await(self._func(*args, **new_kwargs)) diff --git a/noxfile.py b/noxfile.py index 7ea98bb1..29ac6717 100644 --- a/noxfile.py +++ b/noxfile.py @@ -50,33 +50,33 @@ def inner(func: Callable[[nox.Session], None]) -> Callable[[nox.Session], None]: @nox_session() def format_fix(session: nox.Session) -> None: - session.install(".[localization,crontrigger,dev.format]") + session.install("-U", ".[localization,crontrigger,dev.format]") session.run("python", "-m", "ruff", "format", *SCRIPT_PATHS) session.run("python", "-m", "ruff", "check", "--fix", *SCRIPT_PATHS) @nox_session() def format_check(session: nox.Session) -> None: - session.install(".[localization,crontrigger,dev.format]") + session.install("-U", ".[localization,crontrigger,dev.format]") session.run("python", "-m", "ruff", "format", *SCRIPT_PATHS, "--check") session.run("python", "-m", "ruff", "check", "--output-format", "github", *SCRIPT_PATHS) @nox_session() def typecheck(session: nox.Session) -> None: - session.install(".[localization,crontrigger,dev.typecheck,dev.test]") + session.install("-U", ".[localization,crontrigger,dev.typecheck,dev.test]") session.run("python", "-m", "pyright") @nox_session() def slotscheck(session: nox.Session) -> None: - session.install(".[localization,crontrigger,dev.slotscheck]") + session.install("-U", ".[localization,crontrigger,dev.slotscheck]") session.run("python", "-m", "slotscheck", "-m", "lightbulb") @nox_session() def test(session: nox.Session) -> None: - session.install(".[localization,crontrigger,dev.test]") + session.install("-U", ".[localization,crontrigger,dev.test]") args = ["python", "-m", "pytest"] if session.posargs: @@ -88,6 +88,6 @@ def test(session: nox.Session) -> None: @nox_session() def sphinx(session: nox.Session) -> None: - session.install(".[localization,crontrigger,dev.docs]") + session.install("-U", ".[localization,crontrigger,dev.docs]") session.run("python", "./scripts/docs/api_reference_generator.py") session.run("python", "-m", "sphinx.cmd.build", "docs/source", "docs/build", "-b", "html") diff --git a/pyproject.toml b/pyproject.toml index 3fae6a48..131528cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dev = ["nox==2024.4.15"] ] "dev.format" = ["ruff==0.6.2"] "dev.typecheck" = [ - "pyright==1.1.377", + "pyright==1.1.378", "typing-extensions>=4.12.2, <5", "types-networkx>=3.2.1.20240703, <4", "types-polib>=1.2.0.20240327, <2", diff --git a/tests/di/test_container.py b/tests/di/test_container.py index 4b5e3e4e..30a97e58 100644 --- a/tests/di/test_container.py +++ b/tests/di/test_container.py @@ -191,7 +191,7 @@ def f3(_: A, __: B) -> object: return object() async with di.Container(registry) as container: await container.get(C) - @pytest.fixture + @pytest.fixture(scope="function") def complicated_registry(self) -> di.Registry: # fmt: off def f_a() -> object: return object() diff --git a/tests/prefab/test_checks.py b/tests/prefab/test_checks.py index 7d424e20..1919e5ac 100644 --- a/tests/prefab/test_checks.py +++ b/tests/prefab/test_checks.py @@ -27,7 +27,7 @@ class TestOwnerOnly: - @pytest.fixture + @pytest.fixture(scope="function") def application(self) -> hikari.Application: app = mock.Mock(spec=hikari.Application) app.owner.id = 123 @@ -70,7 +70,7 @@ async def test_gets_correct_owner_ids_from_application(self, application: hikari class TestHasPermissions: - @pytest.fixture + @pytest.fixture(scope="function") def context(self) -> lightbulb.Context: ctx = mock.Mock(spec=lightbulb.Context) ctx.member = mock.Mock(permissions=hikari.Permissions.all_permissions()) @@ -104,7 +104,7 @@ async def test_passes_when_not_in_guild_fail_flag_disabled(self, context: lightb class TestBotHasPermissions: - @pytest.fixture + @pytest.fixture(scope="function") def context(self) -> lightbulb.Context: ctx = mock.Mock(spec=lightbulb.Context) ctx.interaction = mock.Mock(app_permissions=hikari.Permissions.all_permissions()) @@ -140,7 +140,7 @@ async def test_passes_when_not_in_guild_fail_flag_disabled(self, context: lightb class TestHasRoles: - @pytest.fixture + @pytest.fixture(scope="function") def context(self) -> lightbulb.Context: ctx = mock.Mock(spec=lightbulb.Context) ctx.member = mock.Mock(role_ids=[123, 456, 789])