Skip to content

Commit

Permalink
feat: union dependency support within factories
Browse files Browse the repository at this point in the history
  • Loading branch information
tandemdude committed Sep 12, 2024
1 parent 0d737f6 commit 531a8f4
Show file tree
Hide file tree
Showing 9 changed files with 447 additions and 236 deletions.
54 changes: 53 additions & 1 deletion lightbulb/di/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# SOFTWARE.
from __future__ import annotations

__all__ = ["If", "Try"]
__all__ = ["DependencyExpression", "If", "Try"]

import abc
import types
Expand All @@ -30,6 +30,8 @@
from lightbulb.di import utils as di_utils

if t.TYPE_CHECKING:
from collections.abc import Sequence

from lightbulb.di import container as container_

T = t.TypeVar("T")
Expand Down Expand Up @@ -110,6 +112,56 @@ async def _get_from(self, container: container_.Container) -> tuple[bool, t.Any]
return False, None


class DependencyExpression(t.Generic[T]):
__slots__ = ("_order", "_required")

def __init__(self, order: Sequence[BaseCondition], required: bool) -> None:
self._order = order
self._required = required

def __repr__(self) -> str:
return f"DependencyExpression({self._order}, required={self._required})"

async def resolve(self, container: container_.Container, /) -> T | None:
if len(self._order) == 1 and self._required:
return await container._get(self._order[0].inner_id)

for dependency in self._order:
succeeded, found = await dependency._get_from(container)
if succeeded:
return found

if not self._required:
return None

raise exceptions.DependencyNotSatisfiableException("no dependencies can satisfy the requested type")

# TODO - TypeExpr
@classmethod
def create(cls, expr: t.Any, /) -> DependencyExpression[t.Any]:
requested_dependencies: list[BaseCondition] = []
required: bool = True

args: Sequence[t.Any] = (expr,)
if isinstance(expr, types.UnionType):
args = t.get_args(expr)
elif isinstance(expr, BaseCondition):
args = expr.order

for arg in args:
if arg is types.NoneType or arg is None:
required = False
continue

if not isinstance(arg, BaseCondition):
# a concrete type T implicitly means If[T]
arg = If(arg)

requested_dependencies.append(arg)

return cls(requested_dependencies, required)


if t.TYPE_CHECKING:
If = t.Annotated[T, None] # type: ignore[reportAssignmentType]
Try = t.Annotated[T, None] # type: ignore[reportAssignmentType]
139 changes: 34 additions & 105 deletions lightbulb/di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,22 @@
# SOFTWARE.
from __future__ import annotations

__all__ = ["Container", "DependencyExpression"]
__all__ = ["Container"]

import logging
import types
import typing as t

import networkx as nx

from lightbulb import utils
from lightbulb.di import conditions
from lightbulb.di import exceptions
from lightbulb.di import graph
from lightbulb.di import registry as registry_
from lightbulb.di import utils as di_utils
from lightbulb.di.graph import DependencyData

if t.TYPE_CHECKING:
import types
from collections.abc import Callable
from collections.abc import Sequence

from lightbulb.di import solver
from lightbulb.internal import types as lb_types
Expand All @@ -46,56 +45,6 @@
LOGGER = logging.getLogger(__name__)


class DependencyExpression(t.Generic[T]):
__slots__ = ("_order", "_required")

def __init__(self, order: Sequence[conditions.BaseCondition], required: bool) -> None:
self._order = order
self._required = required

def __repr__(self) -> str:
return f"DependencyExpression({self._order}, required={self._required})"

async def resolve(self, container: Container, /) -> T | None:
if len(self._order) == 1 and self._required:
return await container._get(self._order[0].inner_id)

for dependency in self._order:
succeeded, found = await dependency._get_from(container)
if succeeded:
return found

if not self._required:
return None

raise exceptions.DependencyNotSatisfiableException("no dependencies can satisfy the requested type")

# TODO - TypeExpr
@classmethod
def create(cls, expr: t.Any, /) -> DependencyExpression[t.Any]:
requested_dependencies: list[conditions.BaseCondition] = []
required: bool = True

args: Sequence[t.Any] = (expr,)
if isinstance(expr, types.UnionType):
args = t.get_args(expr)
elif isinstance(expr, conditions.BaseCondition):
args = expr.order

for arg in args:
if arg is types.NoneType or arg is None:
required = False
continue

if not isinstance(arg, conditions.BaseCondition):
# a concrete type T implicitly means If[T]
arg = conditions.If(arg)

requested_dependencies.append(arg)

return cls(requested_dependencies, required)


class Container:
"""
A container for managing and supplying dependencies.
Expand All @@ -118,7 +67,7 @@ def __init__(

self._closed = False

self._graph: nx.DiGraph[str] = nx.DiGraph(self._registry._graph)
self._graph: graph.DiGraph = graph.DiGraph(registry._graph)
self._instances: dict[str, t.Any] = {}

self.add_value(Container, self)
Expand All @@ -134,7 +83,7 @@ def __contains__(self, item: t.Any) -> bool:
return True

node = self._graph.nodes.get(item)
if node is not None and node.get("factory") is not None:
if node is not None:
return True

if self._parent is None:
Expand All @@ -153,7 +102,10 @@ async def __aexit__(
async def close(self) -> None:
"""Closes the container, running teardown procedures for each created dependency belonging to this container."""
for dependency_id, instance in self._instances.items():
if (td := self._graph.nodes[dependency_id]["teardown"]) is None:
if (node := self._graph.nodes.get(dependency_id)) is None:
continue

if (td := node.teardown_method) is None:
continue

await utils.maybe_await(td(instance))
Expand Down Expand Up @@ -186,8 +138,10 @@ def add_factory(
dependency_id = di_utils.get_dependency_id(typ)

if dependency_id in self._graph:
self._graph.remove_edges_from(list(self._graph.out_edges(dependency_id)))
di_utils.populate_graph_for_dependency(self._graph, dependency_id, factory, teardown, container=self)
for edge in self._graph.out_edges(dependency_id):
self._graph.remove_edge(dependency_id, edge)

graph.populate_graph_for_dependency(self._graph, dependency_id, factory, teardown)

def add_value(
self,
Expand Down Expand Up @@ -215,8 +169,10 @@ def add_value(
self._instances[dependency_id] = value

if dependency_id in self._graph:
self._graph.remove_edges_from(list(self._graph.out_edges(dependency_id)))
self._graph.add_node(dependency_id, factory=lambda: None, teardown=teardown)
for edge in self._graph.out_edges(dependency_id):
self._graph.remove_edge(dependency_id, edge)

self._graph.add_node(dependency_id, DependencyData(lambda: None, {}, teardown))

async def _get(self, dependency_id: str) -> t.Any:
if self._closed:
Expand All @@ -226,7 +182,7 @@ async def _get(self, dependency_id: str) -> t.Any:
if (existing := self._instances.get(dependency_id)) is not None:
return existing

if (data := self._graph.nodes.get(dependency_id)) is None or data.get("factory") is None:
if (data := self._graph.nodes.get(dependency_id)) is None:
if self._parent is None:
raise exceptions.DependencyNotSatisfiableException(
f"cannot create dependency {dependency_id!r} - not provided by this or a parent container"
Expand All @@ -235,53 +191,26 @@ async def _get(self, dependency_id: str) -> t.Any:
LOGGER.debug("dependency %r not provided by this container - checking parent", dependency_id)
return await self._parent._get(dependency_id)

# TODO - look into caching individual dependency creation order globally
# - may speed up using subsequent containers (i.e. for each command)
# - would need to consider how to handle invalidating the cache
subgraph = self._graph.subgraph(nx.descendants(self._graph, dependency_id) | {dependency_id})
assert isinstance(subgraph, nx.DiGraph)

try:
creation_order = list(reversed(list(nx.topological_sort(subgraph))))
except nx.NetworkXUnfeasible:
children = self._graph.children(dependency_id)
if dependency_id in children:
raise exceptions.CircularDependencyException(
f"cannot provide {dependency_id!r} - circular dependency found during creation"
f"cannot provide {dependency_id!r} - circular dependency found"
)

LOGGER.debug("dependency %r depends on %s", dependency_id, creation_order[:-1])
for dep_id in creation_order:
# We already have the dependency we need
if dep_id in self._instances:
continue

node_data = self._graph.nodes[dep_id]
if node_data.get("factory") is None:
if self._parent is None:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - do not know how to instantiate"
)
# Ensure that the dependency is available from the parent container
await self._parent._get(dep_id)
continue

sub_dependencies: dict[str, t.Any] = {}
injectable_params: dict[str, t.Any] = {}
for param_name, expr in data.factory_params.items():
LOGGER.debug("evaluating expression %r for factory parameter %r", expr, param_name)
try:
LOGGER.debug("checking sub-dependencies for %r", dep_id)
for sub_dependency_id, param_name in node_data["factory_params"].items():
sub_dependencies[param_name] = await self._get(sub_dependency_id)
injectable_params[param_name] = await expr.resolve(self)
except exceptions.DependencyNotSatisfiableException as e:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - failed creating sub-dependency"
) from e
raise exceptions.DependencyNotSatisfiableException("failed evaluating sub-dependency expression") from e

# Cache the created dependency in the correct container to ensure the correct lifecycle
try:
self._instances[dep_id] = await utils.maybe_await(node_data["factory"](**sub_dependencies))
except Exception as e:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - factory raised exception"
) from e
LOGGER.debug("instantiated dependency %r", dep_id)
try:
self._instances[dependency_id] = await utils.maybe_await(data.factory_method(**injectable_params))
except Exception as e:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dependency_id!r} - factory raised exception"
) from e

return self._instances[dependency_id]

Expand Down Expand Up @@ -310,5 +239,5 @@ async def get(self, type_: t.Any, /) -> t.Any:
if self._closed:
raise exceptions.ContainerClosedException("the container is closed")

expr = DependencyExpression.create(type_)
expr = conditions.DependencyExpression.create(type_)
return await expr.resolve(self)
Loading

0 comments on commit 531a8f4

Please sign in to comment.