diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e5d8170..d838cf5 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @RB387 @jerry-git +* @RB387 @jerry-git @erhosen diff --git a/CHANGELOG.md b/CHANGELOG.md index f15c23f..b228a7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- Added `magic_di.healthcheck.DependenciesHealthcheck` class to make health checks of injected dependencies that implement `magic_di.healthcheck.PingableProtocol` interface +### Fixed +- Inject dependencies inside of event loop in `magic_di.utils.inject_and_run` to prevent wrong event loop usage inside of the injected dependencies ### Changed - Cruft update to get changes from the cookiecutter template - Renamed LICENCE -> LICENSE, now it's automatically included in the wheel created by poetry diff --git a/README.md b/README.md index 04d83aa..96ee5c1 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Dependency Injector with minimal boilerplate code, built-in support for FastAPI * [Custom integrations](#custom-integrations) * [Manual injection](#manual-injection) * [Forced injections](#forced-injections) +* [Healthcheck](#healthcheck) * [Testing](#testing) * [Default simple mock](#default-simple-mock) * [Custom mocks](#custom-mocks) @@ -330,6 +331,38 @@ class Service(Connectable): dependency: Annotated[NonConnectableDependency, Injectable] ``` +## Healthchecks +You can implement `Pingable` protocol to define healthchecks for your clients. The `DependenciesHealthcheck` will call the `__ping__` method on all injected clients that implement this protocol. + +```python +from magic_di.healthcheck import DependenciesHealthcheck + + +class Service(Connectable): + def __init__(self, db: Database): + self.db = db + + def is_connected(self): + return self.db.connected + + async def __ping__(self) -> None: + if not self.is_connected(): + raise Exception("Service is not connected") + + +@app.get(path="/hello-world") +def hello_world(service: Provide[Service]) -> dict: + return { + "is_connected": service.is_connected() + } + + +@app.get(path="/healthcheck") +async def healthcheck_handler(healthcheck: Provide[DependenciesHealthcheck]) -> dict: + await healthcheck.ping_dependencies() + return {"alive": True} +``` + ## Testing If you need to mock a dependency in tests, you can easily do so by using the `injector.override` context manager and still use this dependency injector. diff --git a/src/magic_di/_container.py b/src/magic_di/_container.py index db36a96..bc9c60b 100644 --- a/src/magic_di/_container.py +++ b/src/magic_di/_container.py @@ -4,12 +4,11 @@ import inspect from dataclasses import dataclass from threading import Lock -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar if TYPE_CHECKING: from collections.abc import Iterable - from magic_di import ConnectableProtocol T = TypeVar("T") @@ -53,7 +52,7 @@ def iter_instances( self, *, reverse: bool = False, - ) -> Iterable[tuple[type, ConnectableProtocol]]: + ) -> Iterable[tuple[type, object]]: with self._lock: deps_iter: Iterable[Dependency[Any]] = list( reversed(self._deps.values()) if reverse else self._deps.values(), @@ -70,8 +69,8 @@ def _get(self, obj: type[T]) -> type[T] | None: def _wrap(obj: type[T], *args: Any, **kwargs: Any) -> type[T]: if not inspect.isclass(obj): - partial = functools.wraps(obj)(functools.partial(obj, *args, **kwargs)) # type: ignore[var-annotated] - return cast(type[T], partial) + partial: type[T] = functools.wraps(obj)(functools.partial(obj, *args, **kwargs)) # type: ignore[assignment] + return partial _instance: T | None = None diff --git a/src/magic_di/_injector.py b/src/magic_di/_injector.py index 3d8f2c9..61aca0c 100644 --- a/src/magic_di/_injector.py +++ b/src/magic_di/_injector.py @@ -18,7 +18,8 @@ from magic_di._utils import ( get_cls_from_optional, get_type_hints, - is_injectable, + is_connectable, + safe_is_instance, safe_is_subclass, ) from magic_di.exceptions import InjectionError, InspectionError @@ -26,7 +27,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator - from magic_di._connectable import ConnectableProtocol # flag to use in typing.Annotated # to forcefully mark dependency as injectable @@ -111,14 +111,14 @@ def inspect(self, obj: AnyObject) -> Signature[AnyObject]: hints_with_extras = get_type_hints(obj, include_extras=True) if not hints: - return Signature(obj, is_injectable=is_injectable(obj)) + return Signature(obj, is_injectable=bool(is_connectable(obj))) if inspect.ismethod(obj): hints.pop("self", None) hints.pop("return", None) - signature = Signature(obj, is_injectable=is_injectable(obj)) + signature = Signature(obj, is_injectable=bool(is_connectable(obj))) for name, hint_ in hints.items(): hint = self._unwrap_type_hint(hint_) @@ -126,7 +126,7 @@ def inspect(self, obj: AnyObject) -> Signature[AnyObject]: if is_injector(hint): signature.injector_arg = name - elif not is_injectable(hint) and not is_forcefully_marked_as_injectable( + elif not is_connectable(hint) and not is_forcefully_marked_as_injectable( hint_with_extra, ): signature.kwargs[name] = hint @@ -149,21 +149,32 @@ async def connect(self) -> None: self.inject(postponed) for cls, instance in self._deps.iter_instances(): - if is_injectable(instance): + if connectable_instance := is_connectable(instance): self.logger.debug("Connecting %s...", cls.__name__) - await instance.__connect__() + await connectable_instance.__connect__() async def disconnect(self) -> None: """ Disconnect all injected dependencies """ for cls, instance in self._deps.iter_instances(reverse=True): - if is_injectable(instance): + if connectable_instance := is_connectable(instance): try: - await instance.__disconnect__() + await connectable_instance.__disconnect__() except Exception: self.logger.exception("Failed to disconnect %s", cls.__name__) + def get_dependencies_by_interface( + self, + interface: Callable[..., AnyObject], + ) -> Iterable[AnyObject]: + """ + Get all injected dependencies that implement a particular interface. + """ + for _, instance in self._deps.iter_instances(): + if safe_is_instance(instance, interface): # type: ignore[arg-type] + yield instance # type: ignore[misc] + async def __aenter__(self) -> DependencyInjector: # noqa: PYI034 await self.connect() return self @@ -171,8 +182,8 @@ async def __aenter__(self) -> DependencyInjector: # noqa: PYI034 async def __aexit__(self, *args: object, **kwargs: Any) -> None: await self.disconnect() - def iter_deps(self) -> Iterable[ConnectableProtocol]: - instance: ConnectableProtocol + def iter_deps(self) -> Iterable[object]: + instance: object for _, instance in self._deps.iter_instances(): yield instance diff --git a/src/magic_di/_utils.py b/src/magic_di/_utils.py index 63825f5..f3589ce 100644 --- a/src/magic_di/_utils.py +++ b/src/magic_di/_utils.py @@ -27,6 +27,7 @@ def get_cls_from_optional(cls: T) -> T: Extract the actual class from a union that includes None. If it is not a union type hint, it returns the same type hint. Example: + ```python >>> get_cls_from_optional(Union[str, None]) str >>> get_cls_from_optional(str | None) @@ -35,6 +36,7 @@ def get_cls_from_optional(cls: T) -> T: str >>> get_cls_from_optional(int) int + ``` Args: cls (T): Type hint for class Returns: @@ -73,7 +75,7 @@ def safe_is_instance(sub_cls: Any, cls: type) -> bool: return False -def is_injectable(cls: Any) -> bool: +def is_connectable(cls: Any) -> ConnectableProtocol | None: """ Check if a class is a subclass of ConnectableProtocol. @@ -81,9 +83,14 @@ def is_injectable(cls: Any) -> bool: cls (Any): The class to check. Returns: - bool: True if the class is a subclass of ConnectableProtocol, False otherwise. + ConnectableProtocol | None: return instance if the class + is a subclass of ConnectableProtocol, None otherwise. """ - return safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance(cls, ConnectableProtocol) + connectable = safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance( + cls, + ConnectableProtocol, + ) + return cls if connectable else None def get_type_hints(obj: Any, *, include_extras: bool = False) -> dict[str, type]: diff --git a/src/magic_di/healthcheck.py b/src/magic_di/healthcheck.py new file mode 100644 index 0000000..786ace7 --- /dev/null +++ b/src/magic_di/healthcheck.py @@ -0,0 +1,59 @@ +import asyncio +from asyncio import Future +from contextlib import suppress +from dataclasses import dataclass +from typing import Any, Protocol + +from magic_di import Connectable, ConnectableProtocol, DependencyInjector + + +class PingableProtocol(ConnectableProtocol, Protocol): + async def __ping__(self) -> None: ... + + +@dataclass +class DependenciesHealthcheck(Connectable): + """ + Injectable Healthcheck component that pings all injected dependencies + that implement the PingableProtocol + + Example usage: + + ```python + from app.components.services.health import DependenciesHealthcheck + + async def main(redis: Redis, deps_healthcheck: DependenciesHealthcheck) -> None: + await deps_healthcheck.ping_dependencies() # redis will be pinged if it has method __ping__ + + inject_and_run(main) + ``` + """ + + injector: DependencyInjector + + async def ping_dependencies(self, max_concurrency: int = 1) -> None: + """ + Ping all dependencies that implement the PingableProtocol + + :param max_concurrency: Maximum number of concurrent pings + """ + tasks: set[Future[Any]] = set() + + try: + for dependency in self.injector.get_dependencies_by_interface(PingableProtocol): + tasks.add(asyncio.ensure_future(dependency.__ping__())) + + if len(tasks) >= max_concurrency: + tasks, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if tasks: + await asyncio.gather(*tasks) + tasks = set() + + finally: + for task in tasks: + task.cancel() + + if tasks: + with suppress(asyncio.CancelledError): + await asyncio.gather(*tasks) diff --git a/src/magic_di/testing.py b/src/magic_di/testing.py index 180f0b5..a22135d 100644 --- a/src/magic_di/testing.py +++ b/src/magic_di/testing.py @@ -22,18 +22,20 @@ class InjectableMock(AsyncMock): and use AsyncMock instead of a real class instance Example: - @pytest.fixture() - def client(): - injector = DependencyInjector() + ```python + @pytest.fixture() + def client(): + injector = DependencyInjector() - with injector.override({Service: InjectableMock().mock_cls}): - with TestClient(app) as client: - yield client + with injector.override({Service: InjectableMock().mock_cls}): + with TestClient(app) as client: + yield client - def test_http_handler(client): - resp = client.post('/hello-world') + def test_http_handler(client): + resp = client.post('/hello-world') - assert resp.status_code == 200 + assert resp.status_code == 200 + ``` """ @property diff --git a/src/magic_di/utils.py b/src/magic_di/utils.py index 72c481f..bf33bc4 100644 --- a/src/magic_di/utils.py +++ b/src/magic_di/utils.py @@ -53,9 +53,9 @@ def inject_and_run( """ injector = injector or DependencyInjector() - injected = injector.inject(fn) - async def run() -> T: + injected = injector.inject(fn) + async with injector: if inspect.iscoroutinefunction(fn): return await injected() # type: ignore[misc,no-any-return] diff --git a/tests/test_healthcheck.py b/tests/test_healthcheck.py new file mode 100644 index 0000000..a50d2cd --- /dev/null +++ b/tests/test_healthcheck.py @@ -0,0 +1,56 @@ +from dataclasses import dataclass + +import pytest +from magic_di import Connectable, DependencyInjector +from magic_di.healthcheck import DependenciesHealthcheck + + +@dataclass +class PingableDatabase(Connectable): + ping_count: int = 0 + + async def __ping__(self) -> None: + self.ping_count += 1 + + +@dataclass +class Service(Connectable): + ping_count: int = 0 + + +@dataclass +class PingableService(Connectable): + db: PingableDatabase + ping_count: int = 0 + + async def __ping__(self) -> None: + self.ping_count += 1 + + +@pytest.mark.asyncio() +async def test_healthcheck(injector: DependencyInjector) -> None: + async def main(_: PingableService) -> None: ... + + await injector.inject(main)() + + injected_db = injector.inject(PingableDatabase)() + injected_srv = injector.inject(PingableService)() + injected_srv_not_pingable = injector.inject(Service)() + + assert injected_db.ping_count == 0 + assert injected_srv.ping_count == 0 + assert injected_srv_not_pingable.ping_count == 0 + + healthcheck = injector.inject(DependenciesHealthcheck)() + + await healthcheck.ping_dependencies(max_concurrency=1) + + assert injected_db.ping_count == 1 + assert injected_srv.ping_count == 1 + assert injected_srv_not_pingable.ping_count == 0 + + await healthcheck.ping_dependencies(max_concurrency=3) + + assert injected_db.ping_count == 2 # noqa: PLR2004 + assert injected_srv.ping_count == 2.0 # noqa: PLR2004 + assert injected_srv_not_pingable.ping_count == 0 diff --git a/tests/test_utils.py b/tests/test_utils.py index b445d46..070d5fe 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,6 @@ -from magic_di import DependencyInjector +import asyncio + +from magic_di import ConnectableProtocol, DependencyInjector from magic_di.utils import inject_and_run from tests.conftest import Repository @@ -25,3 +27,15 @@ async def main(repo: Repository) -> Repository: repo = inject_and_run(main, injector=injector) assert not repo.connected assert not repo.db.connected + + +def test_inject_and_run_async_proper_event_loop(injector: DependencyInjector) -> None: + class DependencyWithEventLoop(ConnectableProtocol): + def __init__(self) -> None: + self.event_loop = asyncio.get_event_loop() + + async def main(dependency: DependencyWithEventLoop) -> None: + event_loop = asyncio.get_running_loop() + assert event_loop is dependency.event_loop + + inject_and_run(main, injector=injector)