Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement optional and fallback dependency injection functionality #443

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/by-examples/050_dependencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,8 @@ Lightbulb will enable dependency injection on a specific subset of your methods

These are listed below:
- {meth}`@lightbulb.invoke <lightbulb.commands.execution.invoke>`
- {meth}`@Client.register <lightbulb.client.Client.register>`
- {meth}`@Client.error_handler <lightbulb.client.Client.error_handler>`
- {meth}`@Client.task <lightbulb.client.Client.task>`
- {meth}`@Loader.command <lightbulb.loaders.Loader.command>` (due to it calling `Client.register` internally)
- {meth}`@Loader.listener <lightbulb.loaders.Loader.listener>`
- {meth}`@Loader.task <lightbulb.loaders.Loader.task>`

Expand Down
3 changes: 3 additions & 0 deletions fragments/443.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Add `__contains__` method to `di.Container` to allow checking if a dependency is registered.

- Allow parameter-injected dependencies to be optional, and have fallbacks if one is not available.
58 changes: 53 additions & 5 deletions lightbulb/di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

__all__ = ["Container"]

import types
import typing as t

import networkx as nx
Expand All @@ -30,14 +31,17 @@
from lightbulb.di import exceptions
from lightbulb.di import registry as registry_
from lightbulb.di import utils as di_utils
from lightbulb.internal import marker

if t.TYPE_CHECKING:
import types
from collections.abc import Callable

from lightbulb.internal import types as lb_types

T = t.TypeVar("T")
D = t.TypeVar("D")

_MISSING = marker.Marker("MISSING")


class Container:
Expand Down Expand Up @@ -78,6 +82,17 @@ def __init__(self, registry: registry_.Registry, *, parent: Container | None = N

self.add_value(Container, self)

def __contains__(self, item: type[t.Any]) -> bool:
dep_id = di_utils.get_dependency_id(item)
if dep_id not in self._graph:
return False

container = self._graph.nodes[dep_id]["container"]
if dep_id in container._instances:
return True

return container._graph.nodes[dep_id].get("factory") is not None

async def __aenter__(self) -> Container:
return self

Expand Down Expand Up @@ -116,9 +131,15 @@ def add_factory(
Returns:
:obj:`None`

Raises:
:obj:`ValueError`: When attempting to add a dependency for ``NoneType``.

See Also:
:meth:`lightbulb.di.registry.Registry.add_factory` for factory and teardown function spec.
"""
if typ is types.NoneType:
raise ValueError("cannot register type 'NoneType' - 'None' is used for optional dependencies")

dependency_id = di_utils.get_dependency_id(typ)

if dependency_id in self._graph:
Expand All @@ -144,24 +165,32 @@ def add_value(
Returns:
:obj:`None`

Raises:
:obj:`ValueError`: When attempting to add a dependency for ``NoneType``.

See Also:
:meth:`lightbulb.di.registry.Registry.add_value` for teardown function spec.
"""
if typ is types.NoneType:
raise ValueError("cannot register type 'NoneType' - 'None' is used for optional dependencies")

dependency_id = di_utils.get_dependency_id(typ)
self._instances[dependency_id] = value

if dependency_id in self._graph:
self._graph.remove_edges_from(list(self._graph.out_edges(dependency_id)))
self._graph.add_node(dependency_id, container=self, teardown=teardown)

async def _get(self, dependency_id: str) -> t.Any:
async def _get(self, dependency_id: str, *, allow_missing: bool = False) -> t.Any:
if self._closed:
raise exceptions.ContainerClosedException

# TODO - look into whether locking is necessary - how likely are we to have race conditions

data = self._graph.nodes.get(dependency_id)
if data is None or data.get("container") is None:
if allow_missing:
return _MISSING
raise exceptions.DependencyNotSatisfiableException

existing_dependency = data["container"]._instances.get(dependency_id)
Expand All @@ -183,6 +212,9 @@ async def _get(self, dependency_id: str) -> t.Any:

for dep_id in creation_order:
if (container := self._graph.nodes[dep_id].get("container")) is None:
if allow_missing:
return _MISSING

raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - not provided by this or a parent container"
)
Expand All @@ -206,7 +238,11 @@ async def _get(self, dependency_id: str) -> t.Any:
sub_dependencies: dict[str, t.Any] = {}
try:
for sub_dependency_id, param_name in node_data["factory_params"].items():
sub_dependencies[param_name] = await node_data["container"]._get(sub_dependency_id)
sub_dependency = await node_data["container"]._get(sub_dependency_id, allow_missing=allow_missing)
if sub_dependency is _MISSING:
return _MISSING

sub_dependencies[param_name] = sub_dependency
except exceptions.DependencyNotSatisfiableException as e:
raise exceptions.DependencyNotSatisfiableException(
f"could not create dependency {dep_id!r} - failed creating sub-dependency"
Expand All @@ -217,12 +253,19 @@ async def _get(self, dependency_id: str) -> t.Any:

return self._graph.nodes[dependency_id]["container"]._instances[dependency_id]

async def get(self, typ: type[T]) -> T:
@t.overload
async def get(self, typ: type[T], /) -> T: ...
@t.overload
async def get(self, typ: type[T], /, *, default: D) -> T | D: ...

async def get(self, typ: type[T], /, *, default: D = _MISSING) -> T | D:
"""
Get a dependency from this container, instantiating it and sub-dependencies if necessary.

Args:
typ: The type used when registering the dependency.
default: The default value to return if the dependency is not satisfiable. If not provided, this will
raise a :obj:`~lightbulb.di.exceptions.DependencyNotSatisfiableException`.

Returns:
The dependency for the given type.
Expand All @@ -235,4 +278,9 @@ async def get(self, typ: type[T]) -> T:
for any other reason.
"""
dependency_id = di_utils.get_dependency_id(typ)
return t.cast(T, await self._get(dependency_id))

dependency = await self._get(dependency_id, allow_missing=default is not _MISSING)
if dependency is _MISSING:
return default

return t.cast(T, dependency)
14 changes: 10 additions & 4 deletions lightbulb/di/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

__all__ = ["Registry"]

import types
import typing as t

import networkx as nx
Expand All @@ -33,7 +34,7 @@
from collections.abc import Callable

from lightbulb.di import container
from lightbulb.internal import types
from lightbulb.internal import types as lb_types

T = t.TypeVar("T")

Expand Down Expand Up @@ -84,7 +85,7 @@ def register_value(
typ: type[T],
value: T,
*,
teardown: Callable[[T], types.MaybeAwaitable[None]] | None = None,
teardown: Callable[[T], lb_types.MaybeAwaitable[None]] | None = None,
) -> None:
"""
Registers a pre-existing value as a dependency.
Expand All @@ -100,15 +101,16 @@ def register_value(

Raises:
:obj:`lightbulb.di.exceptions.RegistryFrozenException`: If the registry is frozen.
:obj:`ValueError`: When attempting to register a dependency for ``NoneType``.
"""
self.register_factory(typ, lambda: value, teardown=teardown)

def register_factory(
self,
typ: type[T],
factory: Callable[..., types.MaybeAwaitable[T]],
factory: Callable[..., lb_types.MaybeAwaitable[T]],
*,
teardown: Callable[[T], types.MaybeAwaitable[None]] | None = None,
teardown: Callable[[T], lb_types.MaybeAwaitable[None]] | None = None,
) -> None:
"""
Registers a factory for creating a dependency.
Expand All @@ -127,10 +129,14 @@ def register_factory(
Raises:
:obj:`lightbulb.di.exceptions.RegistryFrozenException`: If the registry is frozen.
:obj:`lightbulb.di.exceptions.CircularDependencyException`: If the factory requires itself as a dependency.
:obj:`ValueError`: When attempting to register a dependency for ``NoneType``.
"""
if self._active_containers:
raise exceptions.RegistryFrozenException

if typ is types.NoneType:
raise ValueError("cannot register type 'NoneType' - 'None' is used for optional dependencies")

dependency_id = di_utils.get_dependency_id(typ)

# We are overriding a previously defined dependency and want to strip the edges, so we don't have
Expand Down
83 changes: 60 additions & 23 deletions lightbulb/di/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@
import logging
import os
import sys
import types
import typing as t
from collections.abc import AsyncIterator
from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import Sequence

from lightbulb import utils
from lightbulb.di import container
Expand All @@ -54,7 +56,7 @@
from lightbulb.internal import marker

if t.TYPE_CHECKING:
from lightbulb.internal import types
from lightbulb.internal import types as lb_types

P = t.ParamSpec("P")
R = t.TypeVar("R")
Expand Down Expand Up @@ -236,12 +238,24 @@ async def close(self) -> None:
self._default_container = None


CANNOT_INJECT = object()
class ParamInfo(t.NamedTuple):
name: str
types: Sequence[t.Any]
optional: bool
injectable: bool


def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[tuple[str, t.Any]], dict[str, t.Any]]:
positional_or_keyword_params: list[tuple[str, t.Any]] = []
keyword_only_params: dict[str, t.Any] = {}
def _get_requested_types(annotation: t.Any) -> tuple[Sequence[t.Any], bool]:
if t.get_origin(annotation) in (t.Union, types.UnionType):
args = t.get_args(annotation)

return tuple(a for a in args if a is not types.NoneType), types.NoneType in args
return (annotation,), False


def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[ParamInfo], list[ParamInfo]]:
positional_or_keyword_params: list[ParamInfo] = []
keyword_only_params: list[ParamInfo] = []

parameters = inspect.signature(func, locals={"lightbulb": sys.modules["lightbulb"]}, eval_str=True).parameters
for parameter in parameters.values():
Expand All @@ -254,15 +268,19 @@ def _parse_injectable_params(func: Callable[..., t.Any]) -> tuple[list[tuple[str
# If it has a default that isn't INJECTED
or ((default := parameter.default) is not inspect.Parameter.empty and default is not INJECTED)
):
# We need to know about ALL pos-or-kw arguments so that we can exclude ones passed in
# when the injection-enabled function is called - this isn't the same for kw-only args
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
positional_or_keyword_params.append((parameter.name, CANNOT_INJECT))
positional_or_keyword_params.append(ParamInfo(parameter.name, (), False, False))
continue

requested_types, optional = _get_requested_types(parameter.annotation)
param_info = ParamInfo(parameter.name, requested_types, optional, True)
if parameter.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
positional_or_keyword_params.append((parameter.name, parameter.annotation))
positional_or_keyword_params.append(param_info)
else:
# It has to be a keyword-only parameter
keyword_only_params[parameter.name] = parameter.annotation
keyword_only_params.append(param_info)

return positional_or_keyword_params, keyword_only_params

Expand All @@ -288,8 +306,8 @@ def __init__(
self,
func: Callable[..., Awaitable[t.Any]],
self_: t.Any = None,
_cached_pos_or_kw_params: list[tuple[str, t.Any]] | None = None,
_cached_kw_only_params: dict[str, t.Any] | None = None,
_cached_pos_or_kw_params: list[ParamInfo] | None = None,
_cached_kw_only_params: list[ParamInfo] | None = None,
) -> None:
self._func = func
self._self: t.Any = self_
Expand Down Expand Up @@ -320,22 +338,41 @@ async def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:

di_container: container.Container | None = DI_CONTAINER.get(None)

injectables = {
name: type
for name, type in self._pos_or_kw_params[len(args) + (self._self is not None) :]
if name not in new_kwargs
}
injectables.update({name: type for name, type in self._kw_only_params.items() if name not in new_kwargs})

for name, type in injectables.items():
# Skip any arguments that we can't inject
if type is CANNOT_INJECT:
maybe_injectables = [*self._pos_or_kw_params[len(args) + (self._self is not None) :], *self._kw_only_params]
for param in maybe_injectables:
# Skip any parameters we already have a value for, or is not valid to be injected
if param.name in new_kwargs or not param.injectable:
continue

if di_container is None:
raise exceptions.DependencyNotSatisfiableException("no DI context is available")

new_kwargs[name] = await di_container.get(type)
# Resolve the dependency, or None if the dependency is unsatisfied and is optional
if len(param.types) == 1:
default_kwarg = {"default": None} if param.optional else {}
new_kwargs[param.name] = await di_container.get(param.types[0], **default_kwarg)
continue

for i, type in enumerate(param.types):
resolved = await di_container.get(type, default=None)

# Check if this is the last type to check, and we couldn't resolve a dependency for it
if resolved is None and i == (len(param.types) - 1):
# If this dependency is optional then set value to 'None'
if param.optional:
new_kwargs[param.name] = None
break
# We can't supply this dependency, so raise an exception
raise exceptions.DependencyNotSatisfiableException(
f"could not satisfy any dependencies for types {param.types}"
)

# We couldn't supply this type, so continue and check the next one
if resolved is None:
continue
# We could supply this type, set the parameter to the dependency and skip to the next parameter
new_kwargs[param.name] = resolved
break

if self._self is not None:
return await utils.maybe_await(self._func(self._self, *args, **new_kwargs))
Expand All @@ -347,10 +384,10 @@ def with_di(func: AsyncFnT) -> AsyncFnT: ...


@t.overload
def with_di(func: Callable[P, types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]: ...
def with_di(func: Callable[P, lb_types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]: ...


def with_di(func: Callable[P, types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]:
def with_di(func: Callable[P, lb_types.MaybeAwaitable[R]]) -> Callable[P, Awaitable[R]]:
"""
Decorator that enables dependency injection on the decorated function. If dependency injection
has been disabled globally then this function does nothing and simply returns the object that was passed in.
Expand Down
Loading