Skip to content

Commit

Permalink
fix: container dependency resolution refactor, bump pyright dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
tandemdude committed Aug 28, 2024
1 parent df39cdd commit 4cb3821
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 53 deletions.
63 changes: 22 additions & 41 deletions lightbulb/di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

__all__ = ["Container"]

import logging
import typing as t

import networkx as nx
Expand All @@ -39,6 +40,7 @@
from lightbulb.internal import types as lb_types

T = t.TypeVar("T")
LOGGER = logging.getLogger(__name__)


class Container:
Expand All @@ -57,29 +59,15 @@ def __init__(
) -> None:
self._registry = registry
self._registry._freeze(self)

self._parent = parent
self._tag = tag

self._closed = False

self._graph: nx.DiGraph[str] = nx.DiGraph(self._parent._graph) if self._parent is not None else nx.DiGraph()
self._graph: nx.DiGraph[str] = nx.DiGraph(self._registry._graph)
self._instances: dict[str, t.Any] = {}

# Add our registry entries to the graphs
for node, node_data in self._registry._graph.nodes.items():
new_node_data = dict(node_data)

# Set the origin container if this is a concrete dependency instead of a transient one
if node_data.get("factory") is not None:
new_node_data["container"] = self

# If we are overriding a previously defined dependency with our own
if node in self._graph and node_data.get("factory") is not None:
self._graph.remove_edges_from(list(self._graph.out_edges(node)))

self._graph.add_node(node, **new_node_data)
self._graph.add_edges_from(self._registry._graph.edges)

self.add_value(Container, self)

async def __aenter__(self) -> Container:
Expand Down Expand Up @@ -156,23 +144,25 @@ def add_value(

if dependency_id in self._graph:
self._graph.remove_edges_from(list(self._graph.out_edges(dependency_id)))
self._graph.add_node(dependency_id, container=self, teardown=teardown)
self._graph.add_node(dependency_id, factory=lambda: None, teardown=teardown)

async def _get(self, dependency_id: str) -> t.Any:
if self._closed:
raise exceptions.ContainerClosedException

# TODO - look into whether locking is necessary - how likely are we to have race conditions

data = self._graph.nodes.get(dependency_id)
if data is None or data.get("container") is None:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dependency_id!r} - not provided by this or a parent container"
)
if (existing := self._instances.get(dependency_id)) is not None:
return existing

existing_dependency = data["container"]._instances.get(dependency_id)
if existing_dependency is not None:
return existing_dependency
if (data := self._graph.nodes.get(dependency_id)) is None or data.get("factory") is None:
if self._parent is None:
raise exceptions.DependencyNotSatisfiableException(
f"cannot create dependency {dependency_id!r} - not provided by this or a parent container"
)

LOGGER.debug("dependency %r not provided by this container - checking parent", dependency_id)
return await self._parent._get(dependency_id)

# TODO - look into caching individual dependency creation order globally
# - may speed up using subsequent containers (i.e. for each command)
Expand All @@ -181,47 +171,38 @@ async def _get(self, dependency_id: str) -> t.Any:
assert isinstance(subgraph, nx.DiGraph)

try:
creation_order = reversed(list(nx.topological_sort(subgraph)))
creation_order = list(reversed(list(nx.topological_sort(subgraph))))
except nx.NetworkXUnfeasible:
raise exceptions.CircularDependencyException(
f"cannot provide {dependency_id!r} - circular dependency found during creation"
)

LOGGER.debug("dependency %r depends on %s", dependency_id, creation_order)
for dep_id in creation_order:
if (container := self._graph.nodes[dep_id].get("container")) is None:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - not provided by this or a parent container"
)

# We already have the dependency we need
if dep_id in container._instances:
if dep_id in self._instances:
continue

node_data = self._graph.nodes[dep_id]
# Check that we actually know how to create the dependency - this should have been caught earlier
# by checking that node["container"] was present - but just in case, we check for the factory
if node_data.get("factory") is None:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - do not know how to instantiate"
)

# Get the dependencies for this dependency from the container this dependency was defined in.
# This prevents 'scope promotion' - a dependency from the parent container requiring one from the
# child container, and hence the lifecycle of the child dependency being extended to
# that of the parent.
sub_dependencies: dict[str, t.Any] = {}
try:
LOGGER.debug("checking sub-dependencies for %r", dep_id)
for sub_dependency_id, param_name in node_data["factory_params"].items():
sub_dependencies[param_name] = await node_data["container"]._get(sub_dependency_id)
sub_dependencies[param_name] = await self._get(sub_dependency_id)
except exceptions.DependencyNotSatisfiableException as e:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - failed creating sub-dependency"
) from e

# Cache the created dependency in the correct container to ensure the correct lifecycle
container._instances[dep_id] = await utils.maybe_await(node_data["factory"](**sub_dependencies))
self._instances[dep_id] = await utils.maybe_await(node_data["factory"](**sub_dependencies))

return self._graph.nodes[dependency_id]["container"]._instances[dependency_id]
return self._instances[dependency_id]

async def get(self, typ: type[T]) -> T:
"""
Expand Down
4 changes: 4 additions & 0 deletions lightbulb/di/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ async def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
LOGGER.debug("requesting dependency for type %r", type)
new_kwargs[name] = await di_container.get(type)

if len(new_kwargs) > len(kwargs):
func_name = ((self._self.__class__.__name__ + ".") if self._self else "") + self._func.__name__
LOGGER.debug("calling function %r with resolved dependencies", func_name)

if self._self is not None:
return await utils.maybe_await(self._func(self._self, *args, **new_kwargs))
return await utils.maybe_await(self._func(*args, **new_kwargs))
Expand Down
12 changes: 6 additions & 6 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,33 +50,33 @@ def inner(func: Callable[[nox.Session], None]) -> Callable[[nox.Session], None]:

@nox_session()
def format_fix(session: nox.Session) -> None:
session.install(".[localization,crontrigger,dev.format]")
session.install("-U", ".[localization,crontrigger,dev.format]")
session.run("python", "-m", "ruff", "format", *SCRIPT_PATHS)
session.run("python", "-m", "ruff", "check", "--fix", *SCRIPT_PATHS)


@nox_session()
def format_check(session: nox.Session) -> None:
session.install(".[localization,crontrigger,dev.format]")
session.install("-U", ".[localization,crontrigger,dev.format]")
session.run("python", "-m", "ruff", "format", *SCRIPT_PATHS, "--check")
session.run("python", "-m", "ruff", "check", "--output-format", "github", *SCRIPT_PATHS)


@nox_session()
def typecheck(session: nox.Session) -> None:
session.install(".[localization,crontrigger,dev.typecheck,dev.test]")
session.install("-U", ".[localization,crontrigger,dev.typecheck,dev.test]")
session.run("python", "-m", "pyright")


@nox_session()
def slotscheck(session: nox.Session) -> None:
session.install(".[localization,crontrigger,dev.slotscheck]")
session.install("-U", ".[localization,crontrigger,dev.slotscheck]")
session.run("python", "-m", "slotscheck", "-m", "lightbulb")


@nox_session()
def test(session: nox.Session) -> None:
session.install(".[localization,crontrigger,dev.test]")
session.install("-U", ".[localization,crontrigger,dev.test]")

args = ["python", "-m", "pytest"]
if session.posargs:
Expand All @@ -88,6 +88,6 @@ def test(session: nox.Session) -> None:

@nox_session()
def sphinx(session: nox.Session) -> None:
session.install(".[localization,crontrigger,dev.docs]")
session.install("-U", ".[localization,crontrigger,dev.docs]")
session.run("python", "./scripts/docs/api_reference_generator.py")
session.run("python", "-m", "sphinx.cmd.build", "docs/source", "docs/build", "-b", "html")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dev = ["nox==2024.4.15"]
]
"dev.format" = ["ruff==0.6.2"]
"dev.typecheck" = [
"pyright==1.1.377",
"pyright==1.1.378",
"typing-extensions>=4.12.2, <5",
"types-networkx>=3.2.1.20240703, <4",
"types-polib>=1.2.0.20240327, <2",
Expand Down
2 changes: 1 addition & 1 deletion tests/di/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def f3(_: A, __: B) -> object: return object()
async with di.Container(registry) as container:
await container.get(C)

@pytest.fixture
@pytest.fixture(scope="function")
def complicated_registry(self) -> di.Registry:
# fmt: off
def f_a() -> object: return object()
Expand Down
8 changes: 4 additions & 4 deletions tests/prefab/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


class TestOwnerOnly:
@pytest.fixture
@pytest.fixture(scope="function")
def application(self) -> hikari.Application:
app = mock.Mock(spec=hikari.Application)
app.owner.id = 123
Expand Down Expand Up @@ -70,7 +70,7 @@ async def test_gets_correct_owner_ids_from_application(self, application: hikari


class TestHasPermissions:
@pytest.fixture
@pytest.fixture(scope="function")
def context(self) -> lightbulb.Context:
ctx = mock.Mock(spec=lightbulb.Context)
ctx.member = mock.Mock(permissions=hikari.Permissions.all_permissions())
Expand Down Expand Up @@ -104,7 +104,7 @@ async def test_passes_when_not_in_guild_fail_flag_disabled(self, context: lightb


class TestBotHasPermissions:
@pytest.fixture
@pytest.fixture(scope="function")
def context(self) -> lightbulb.Context:
ctx = mock.Mock(spec=lightbulb.Context)
ctx.interaction = mock.Mock(app_permissions=hikari.Permissions.all_permissions())
Expand Down Expand Up @@ -140,7 +140,7 @@ async def test_passes_when_not_in_guild_fail_flag_disabled(self, context: lightb


class TestHasRoles:
@pytest.fixture
@pytest.fixture(scope="function")
def context(self) -> lightbulb.Context:
ctx = mock.Mock(spec=lightbulb.Context)
ctx.member = mock.Mock(role_ids=[123, 456, 789])
Expand Down

0 comments on commit 4cb3821

Please sign in to comment.