Skip to content

Commit

Permalink
feat: implement optional and fallback dependency injection functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
tandemdude committed Aug 22, 2024
1 parent 7a7a7ee commit 5e7ac7f
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 34 deletions.
35 changes: 31 additions & 4 deletions lightbulb/di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lightbulb.di import exceptions
from lightbulb.di import registry as registry_
from lightbulb.di import utils as di_utils
from lightbulb.internal import marker

if t.TYPE_CHECKING:
import types
Expand All @@ -38,6 +39,9 @@
from lightbulb.internal import types as lb_types

T = t.TypeVar("T")
D = t.TypeVar("D")

_MISSING = marker.Marker("MISSING")


class Container:
Expand Down Expand Up @@ -154,14 +158,16 @@ def add_value(
self._graph.remove_edges_from(list(self._graph.out_edges(dependency_id)))
self._graph.add_node(dependency_id, container=self, teardown=teardown)

async def _get(self, dependency_id: str) -> t.Any:
async def _get(self, dependency_id: str, *, allow_missing: bool = False) -> t.Any:
if self._closed:
raise exceptions.ContainerClosedException

# TODO - look into whether locking is necessary - how likely are we to have race conditions

data = self._graph.nodes.get(dependency_id)
if data is None or data.get("container") is None:
if allow_missing:
return _MISSING
raise exceptions.DependencyNotSatisfiableException

existing_dependency = data["container"]._instances.get(dependency_id)
Expand All @@ -183,6 +189,9 @@ async def _get(self, dependency_id: str) -> t.Any:

for dep_id in creation_order:
if (container := self._graph.nodes[dep_id].get("container")) is None:
if allow_missing:
return _MISSING

raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - not provided by this or a parent container"
)
Expand All @@ -206,7 +215,13 @@ async def _get(self, dependency_id: str) -> t.Any:
sub_dependencies: dict[str, t.Any] = {}
try:
for sub_dependency_id, param_name in node_data["factory_params"].items():
sub_dependencies[param_name] = await node_data["container"]._get(sub_dependency_id)
sub_dependency = await node_data["container"]._get(sub_dependency_id, allow_missing=allow_missing)
if sub_dependency is _MISSING:
# I'm not sure that this branch can actually be hit, but I'm going to leave it here
# just in case...
return _MISSING

sub_dependencies[param_name] = sub_dependency
except exceptions.DependencyNotSatisfiableException as e:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - failed creating sub-dependency"
Expand All @@ -217,12 +232,19 @@ async def _get(self, dependency_id: str) -> t.Any:

return self._graph.nodes[dependency_id]["container"]._instances[dependency_id]

async def get(self, typ: type[T]) -> T:
@t.overload
async def get(self, typ: type[T], /) -> T: ...
@t.overload
async def get(self, typ: type[T], /, *, default: D) -> T | D: ...

async def get(self, typ: type[T], /, *, default: D = _MISSING) -> T | D:
"""
Get a dependency from this container, instantiating it and sub-dependencies if necessary.
Args:
typ: The type used when registering the dependency.
default: The default value to return if the dependency is not satisfiable. If not provided, this will
raise a :obj:`~lightbulb.di.exceptions.DependencyNotSatisfiableException`.
Returns:
The dependency for the given type.
Expand All @@ -235,4 +257,9 @@ async def get(self, typ: type[T]) -> T:
for any other reason.
"""
dependency_id = di_utils.get_dependency_id(typ)
return t.cast(T, await self._get(dependency_id))

dependency = await self._get(dependency_id, allow_missing=default is not _MISSING)
if dependency is _MISSING:
return default

return t.cast(T, dependency)
83 changes: 60 additions & 23 deletions lightbulb/di/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@
import logging
import os
import sys
import types
import typing as t
from collections.abc import AsyncIterator
from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import Sequence

from lightbulb import utils
from lightbulb.di import container
Expand All @@ -54,7 +56,7 @@
from lightbulb.internal import marker

if t.TYPE_CHECKING:
from lightbulb.internal import types
from lightbulb.internal import types as lb_types

P = t.ParamSpec("P")
R = t.TypeVar("R")
Expand Down Expand Up @@ -236,12 +238,24 @@ async def close(self) -> None:
self._default_container = None


CANNOT_INJECT = object()
class ParamInfo(t.NamedTuple):
name: str
types: Sequence[t.Any]
optional: bool
injectable: bool


def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[tuple[str, t.Any]], dict[str, t.Any]]:
positional_or_keyword_params: list[tuple[str, t.Any]] = []
keyword_only_params: dict[str, t.Any] = {}
def _get_requested_types(annotation: t.Any) -> tuple[Sequence[t.Any], bool]:
if t.get_origin(annotation) in (t.Union, types.UnionType):
args = t.get_args(annotation)

return tuple(a for a in args if a is not types.NoneType), types.NoneType in args
return (annotation,), False


def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[ParamInfo], list[ParamInfo]]:
positional_or_keyword_params: list[ParamInfo] = []
keyword_only_params: list[ParamInfo] = []

parameters = inspect.signature(func, locals={"lightbulb": sys.modules["lightbulb"]}, eval_str=True).parameters
for parameter in parameters.values():
Expand All @@ -254,15 +268,19 @@ def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[tuple[str
# If it has a default that isn't INJECTED
or ((default := parameter.default) is not inspect.Parameter.empty and default is not INJECTED)
):
# We need to know about ALL pos-or-kw arguments so that we can exclude ones passed in
# when the injection-enabled function is called - this isn't the same for kw-only args
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
positional_or_keyword_params.append((parameter.name, CANNOT_INJECT))
positional_or_keyword_params.append(ParamInfo(parameter.name, (), False, False))
continue

requested_types, optional = _get_requested_types(parameter.annotation)
param_info = ParamInfo(parameter.name, requested_types, optional, True)
if parameter.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
positional_or_keyword_params.append((parameter.name, parameter.annotation))
positional_or_keyword_params.append(param_info)
else:
# It has to be a keyword-only parameter
keyword_only_params[parameter.name] = parameter.annotation
keyword_only_params.append(param_info)

return positional_or_keyword_params, keyword_only_params

Expand All @@ -288,8 +306,8 @@ def __init__(
self,
func: Callable[..., Awaitable[t.Any]],
self_: t.Any = None,
_cached_pos_or_kw_params: list[tuple[str, t.Any]] | None = None,
_cached_kw_only_params: dict[str, t.Any] | None = None,
_cached_pos_or_kw_params: list[ParamInfo] | None = None,
_cached_kw_only_params: list[ParamInfo] | None = None,
) -> None:
self._func = func
self._self: t.Any = self_
Expand Down Expand Up @@ -320,22 +338,41 @@ async def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:

di_container: container.Container | None = DI_CONTAINER.get(None)

injectables = {
name: type
for name, type in self._pos_or_kw_params[len(args) + (self._self is not None) :]
if name not in new_kwargs
}
injectables.update({name: type for name, type in self._kw_only_params.items() if name not in new_kwargs})

for name, type in injectables.items():
# Skip any arguments that we can't inject
if type is CANNOT_INJECT:
maybe_injectables = [*self._pos_or_kw_params[len(args) + (self._self is not None) :], *self._kw_only_params]
for param in maybe_injectables:
# Skip any parameters we already have a value for, or is not valid to be injected
if param.name in new_kwargs or not param.injectable:
continue

if di_container is None:
raise exceptions.DependencyNotSatisfiableException("no DI context is available")

new_kwargs[name] = await di_container.get(type)
# Resolve the dependency, or None if the dependency is unsatisfied and is optional
if len(param.types) == 1:
default_kwarg = {"default": None} if param.optional else {}
new_kwargs[param.name] = await di_container.get(param.types[0], **default_kwarg)
continue

for i, type in enumerate(param.types):
resolved = await di_container.get(type, default=None)

# Check if this is the last type to check, and we couldn't resolve a dependency for it
if resolved is None and i == (len(param.types) - 1):
# If this dependency is optional then set value to 'None'
if param.optional:
new_kwargs[param.name] = None
break
# We can't supply this dependency, so raise an exception
raise exceptions.DependencyNotSatisfiableException(
f"could not satisfy any dependencies for types {param.types}"
)

# We couldn't supply this type, so continue and check the next one
if resolved is None:
continue
# We could supply this type, set the parameter to the dependency and skip to the next parameter
new_kwargs[param.name] = resolved
break

if self._self is not None:
return await utils.maybe_await(self._func(self._self, *args, **new_kwargs))
Expand All @@ -347,10 +384,10 @@ def with_di(func: AsyncFnT) -> AsyncFnT: ...


@t.overload
def with_di(func: Callable[P, types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]: ...
def with_di(func: Callable[P, lb_types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]: ...


def with_di(func: Callable[P, types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]:
def with_di(func: Callable[P, lb_types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]:
"""
Decorator that enables dependency injection on the decorated function. If dependency injection
has been disabled globally then this function does nothing and simply returns the object that was passed in.
Expand Down
19 changes: 19 additions & 0 deletions tests/di/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,22 @@ async def test_get_from_closed_container_raises_exception(self) -> None:
async with di.Container(registry) as c:
pass
await c.get(object)

@pytest.mark.asyncio
async def test_get_with_default_when_dependency_not_available_returns_default(self) -> None:
registry = di.Registry()

async with di.Container(registry) as c:
assert await c.get(object, default=None) is None

@pytest.mark.asyncio
async def test_get_with_default_when_sub_dependency_not_available_returns_default(self) -> None:
registry = di.Registry()

def f1(_: str) -> object:
return object()

registry.register_factory(object, f1)

async with di.Container(registry) as c:
assert await c.get(object, default=None) is None
Loading

0 comments on commit 5e7ac7f

Please sign in to comment.