Skip to content

Commit

Permalink
Add DependenciesHealthcheck dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
RB387 committed Sep 23, 2024
1 parent 532f86d commit 399516e
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @RB387 @jerry-git
* @RB387 @jerry-git @erhosen
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- Added `magic_di.healthcheck.DependenciesHealthcheck` class to make health checks of injected dependencies that implement `magic_di.healthcheck.PingableProtocol` interface
### Fixed
- Inject dependencies inside of event loop in `magic_di.utils.inject_and_run` to prevent wrong event loop usage inside of the injected dependencies
### Changed
- Cruft update to get changes from the cookiecutter template
- Renamed LICENCE -> LICENSE, now it's automatically included in the wheel created by poetry
Expand Down
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Dependency Injector with minimal boilerplate code, built-in support for FastAPI
* [Custom integrations](#custom-integrations)
* [Manual injection](#manual-injection)
* [Forced injections](#forced-injections)
* [Healthcheck](#healthcheck)
* [Testing](#testing)
* [Default simple mock](#default-simple-mock)
* [Custom mocks](#custom-mocks)
Expand Down Expand Up @@ -330,6 +331,38 @@ class Service(Connectable):
dependency: Annotated[NonConnectableDependency, Injectable]
```

## Healthchecks
You can implement `Pingable` protocol to define healthchecks for your clients. The `DependenciesHealthcheck` will call the `__ping__` method on all injected clients that implement this protocol.

```python
from magic_di.healthcheck import DependenciesHealthcheck


class Service(Connectable):
def __init__(self, db: Database):
self.db = db

def is_connected(self):
return self.db.connected

async def __ping__(self) -> None:
if not self.is_connected():
raise Exception("Service is not connected")


@app.get(path="/hello-world")
def hello_world(service: Provide[Service]) -> dict:
return {
"is_connected": service.is_connected()
}


@app.get(path="/healthcheck")
async def healthcheck_handler(healthcheck: Provide[DependenciesHealthcheck]) -> dict:
await healthcheck.ping_dependencies()
return {"alive": True}
```

## Testing
If you need to mock a dependency in tests, you can easily do so by using the `injector.override` context manager and still use this dependency injector.

Expand Down
9 changes: 4 additions & 5 deletions src/magic_di/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import inspect
from dataclasses import dataclass
from threading import Lock
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, TypeVar

if TYPE_CHECKING:
from collections.abc import Iterable

from magic_di import ConnectableProtocol

T = TypeVar("T")

Expand Down Expand Up @@ -53,7 +52,7 @@ def iter_instances(
self,
*,
reverse: bool = False,
) -> Iterable[tuple[type, ConnectableProtocol]]:
) -> Iterable[tuple[type, object]]:
with self._lock:
deps_iter: Iterable[Dependency[Any]] = list(
reversed(self._deps.values()) if reverse else self._deps.values(),
Expand All @@ -70,8 +69,8 @@ def _get(self, obj: type[T]) -> type[T] | None:

def _wrap(obj: type[T], *args: Any, **kwargs: Any) -> type[T]:
if not inspect.isclass(obj):
partial = functools.wraps(obj)(functools.partial(obj, *args, **kwargs)) # type: ignore[var-annotated]
return cast(type[T], partial)
partial: type[T] = functools.wraps(obj)(functools.partial(obj, *args, **kwargs)) # type: ignore[assignment]
return partial

_instance: T | None = None

Expand Down
33 changes: 22 additions & 11 deletions src/magic_di/_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from magic_di._utils import (
get_cls_from_optional,
get_type_hints,
is_injectable,
is_connectable,
safe_is_instance,
safe_is_subclass,
)
from magic_di.exceptions import InjectionError, InspectionError

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator

from magic_di._connectable import ConnectableProtocol

# flag to use in typing.Annotated
# to forcefully mark dependency as injectable
Expand Down Expand Up @@ -111,22 +111,22 @@ def inspect(self, obj: AnyObject) -> Signature[AnyObject]:
hints_with_extras = get_type_hints(obj, include_extras=True)

if not hints:
return Signature(obj, is_injectable=is_injectable(obj))
return Signature(obj, is_injectable=bool(is_connectable(obj)))

if inspect.ismethod(obj):
hints.pop("self", None)

hints.pop("return", None)

signature = Signature(obj, is_injectable=is_injectable(obj))
signature = Signature(obj, is_injectable=bool(is_connectable(obj)))

for name, hint_ in hints.items():
hint = self._unwrap_type_hint(hint_)
hint_with_extra = hints_with_extras[name]

if is_injector(hint):
signature.injector_arg = name
elif not is_injectable(hint) and not is_forcefully_marked_as_injectable(
elif not is_connectable(hint) and not is_forcefully_marked_as_injectable(
hint_with_extra,
):
signature.kwargs[name] = hint
Expand All @@ -149,30 +149,41 @@ async def connect(self) -> None:
self.inject(postponed)

for cls, instance in self._deps.iter_instances():
if is_injectable(instance):
if connectable_instance := is_connectable(instance):
self.logger.debug("Connecting %s...", cls.__name__)
await instance.__connect__()
await connectable_instance.__connect__()

async def disconnect(self) -> None:
"""
Disconnect all injected dependencies
"""
for cls, instance in self._deps.iter_instances(reverse=True):
if is_injectable(instance):
if connectable_instance := is_connectable(instance):
try:
await instance.__disconnect__()
await connectable_instance.__disconnect__()
except Exception:
self.logger.exception("Failed to disconnect %s", cls.__name__)

def get_dependencies_by_interface(
self,
interface: Callable[..., AnyObject],
) -> Iterable[AnyObject]:
"""
Get all injected dependencies that implement a particular interface.
"""
for _, instance in self._deps.iter_instances():
if safe_is_instance(instance, interface): # type: ignore[arg-type]
yield instance # type: ignore[misc]

async def __aenter__(self) -> DependencyInjector: # noqa: PYI034
await self.connect()
return self

async def __aexit__(self, *args: object, **kwargs: Any) -> None:
await self.disconnect()

def iter_deps(self) -> Iterable[ConnectableProtocol]:
instance: ConnectableProtocol
def iter_deps(self) -> Iterable[object]:
instance: object

for _, instance in self._deps.iter_instances():
yield instance
Expand Down
13 changes: 10 additions & 3 deletions src/magic_di/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_cls_from_optional(cls: T) -> T:
Extract the actual class from a union that includes None.
If it is not a union type hint, it returns the same type hint.
Example:
```python
>>> get_cls_from_optional(Union[str, None])
str
>>> get_cls_from_optional(str | None)
Expand All @@ -35,6 +36,7 @@ def get_cls_from_optional(cls: T) -> T:
str
>>> get_cls_from_optional(int)
int
```
Args:
cls (T): Type hint for class
Returns:
Expand Down Expand Up @@ -73,17 +75,22 @@ def safe_is_instance(sub_cls: Any, cls: type) -> bool:
return False


def is_injectable(cls: Any) -> bool:
def is_connectable(cls: Any) -> ConnectableProtocol | None:
"""
Check if a class is a subclass of ConnectableProtocol.
Args:
cls (Any): The class to check.
Returns:
bool: True if the class is a subclass of ConnectableProtocol, False otherwise.
ConnectableProtocol | None: return instance if the class
is a subclass of ConnectableProtocol, None otherwise.
"""
return safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance(cls, ConnectableProtocol)
connectable = safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance(
cls,
ConnectableProtocol,
)
return cls if connectable else None


def get_type_hints(obj: Any, *, include_extras: bool = False) -> dict[str, type]:
Expand Down
59 changes: 59 additions & 0 deletions src/magic_di/healthcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import asyncio
from asyncio import Future
from contextlib import suppress
from dataclasses import dataclass
from typing import Any, Protocol

from magic_di import Connectable, ConnectableProtocol, DependencyInjector


class PingableProtocol(ConnectableProtocol, Protocol):
async def __ping__(self) -> None: ...


@dataclass
class DependenciesHealthcheck(Connectable):
"""
Injectable Healthcheck component that pings all injected dependencies
that implement the PingableProtocol
Example usage:
```python
from app.components.services.health import DependenciesHealthcheck
async def main(redis: Redis, deps_healthcheck: DependenciesHealthcheck) -> None:
await deps_healthcheck.ping_dependencies() # redis will be pinged if it has method __ping__
inject_and_run(main)
```
"""

injector: DependencyInjector

async def ping_dependencies(self, max_concurrency: int = 1) -> None:
"""
Ping all dependencies that implement the PingableProtocol
:param max_concurrency: Maximum number of concurrent pings
"""
tasks: set[Future[Any]] = set()

try:
for dependency in self.injector.get_dependencies_by_interface(PingableProtocol):
tasks.add(asyncio.ensure_future(dependency.__ping__()))

if len(tasks) >= max_concurrency:
tasks, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

if tasks:
await asyncio.gather(*tasks)
tasks = set()

finally:
for task in tasks:
task.cancel()

if tasks:
with suppress(asyncio.CancelledError):
await asyncio.gather(*tasks)
20 changes: 11 additions & 9 deletions src/magic_di/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,20 @@ class InjectableMock(AsyncMock):
and use AsyncMock instead of a real class instance
Example:
@pytest.fixture()
def client():
injector = DependencyInjector()
```python
@pytest.fixture()
def client():
injector = DependencyInjector()
with injector.override({Service: InjectableMock().mock_cls}):
with TestClient(app) as client:
yield client
with injector.override({Service: InjectableMock().mock_cls}):
with TestClient(app) as client:
yield client
def test_http_handler(client):
resp = client.post('/hello-world')
def test_http_handler(client):
resp = client.post('/hello-world')
assert resp.status_code == 200
assert resp.status_code == 200
```
"""

@property
Expand Down
4 changes: 2 additions & 2 deletions src/magic_di/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def inject_and_run(
"""
injector = injector or DependencyInjector()

injected = injector.inject(fn)

async def run() -> T:
injected = injector.inject(fn)

async with injector:
if inspect.iscoroutinefunction(fn):
return await injected() # type: ignore[misc,no-any-return]
Expand Down
56 changes: 56 additions & 0 deletions tests/test_healthcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from dataclasses import dataclass

import pytest
from magic_di import Connectable, DependencyInjector
from magic_di.healthcheck import DependenciesHealthcheck


@dataclass
class PingableDatabase(Connectable):
ping_count: int = 0

async def __ping__(self) -> None:
self.ping_count += 1


@dataclass
class Service(Connectable):
ping_count: int = 0


@dataclass
class PingableService(Connectable):
db: PingableDatabase
ping_count: int = 0

async def __ping__(self) -> None:
self.ping_count += 1


@pytest.mark.asyncio()
async def test_healthcheck(injector: DependencyInjector) -> None:
async def main(_: PingableService) -> None: ...

await injector.inject(main)()

injected_db = injector.inject(PingableDatabase)()
injected_srv = injector.inject(PingableService)()
injected_srv_not_pingable = injector.inject(Service)()

assert injected_db.ping_count == 0
assert injected_srv.ping_count == 0
assert injected_srv_not_pingable.ping_count == 0

healthcheck = injector.inject(DependenciesHealthcheck)()

await healthcheck.ping_dependencies(max_concurrency=1)

assert injected_db.ping_count == 1
assert injected_srv.ping_count == 1
assert injected_srv_not_pingable.ping_count == 0

await healthcheck.ping_dependencies(max_concurrency=3)

assert injected_db.ping_count == 2 # noqa: PLR2004
assert injected_srv.ping_count == 2.0 # noqa: PLR2004
assert injected_srv_not_pingable.ping_count == 0
Loading

0 comments on commit 399516e

Please sign in to comment.