Skip to content

Commit

Permalink
Add DependenciesHealthcheck dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
RB387 committed Sep 20, 2024
1 parent 532f86d commit 7508c83
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 31 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- Added `magic_di.healthcheck.DependenciesHealthcheck` class to make health checks of injected dependencies that implement `magic_di.healthcheck.PingableProtocol` interface
### Fixed
- Inject dependencies inside of event loop in `magic_di.utils.inject_and_run` to prevent wrong event loop usage inside of the injected dependencies
### Changed
- Cruft update to get changes from the cookiecutter template
- Renamed LICENCE -> LICENSE, now it's automatically included in the wheel created by poetry
Expand Down
7 changes: 3 additions & 4 deletions src/magic_di/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
if TYPE_CHECKING:
from collections.abc import Iterable

from magic_di import ConnectableProtocol

T = TypeVar("T")

Expand Down Expand Up @@ -53,7 +52,7 @@ def iter_instances(
self,
*,
reverse: bool = False,
) -> Iterable[tuple[type, ConnectableProtocol]]:
) -> Iterable[tuple[type, object]]:
with self._lock:
deps_iter: Iterable[Dependency[Any]] = list(
reversed(self._deps.values()) if reverse else self._deps.values(),
Expand All @@ -70,7 +69,7 @@ def _get(self, obj: type[T]) -> type[T] | None:

def _wrap(obj: type[T], *args: Any, **kwargs: Any) -> type[T]:
if not inspect.isclass(obj):
partial = functools.wraps(obj)(functools.partial(obj, *args, **kwargs)) # type: ignore[var-annotated]
partial = functools.wraps(obj)(functools.partial(obj, *args, **kwargs))
return cast(type[T], partial)

_instance: T | None = None
Expand Down Expand Up @@ -99,7 +98,7 @@ def new(_: Any) -> T:
#
# Here we manually create a new singleton class factory using the `type` metaclass
# Since the original class was not modified, it will use its own metaclass.
return functools.wraps( # type: ignore[return-value]
return functools.wraps(
obj,
updated=(),
)(
Expand Down
35 changes: 23 additions & 12 deletions src/magic_di/_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from magic_di._utils import (
get_cls_from_optional,
get_type_hints,
is_injectable,
is_connectable,
safe_is_instance,
safe_is_subclass,
)
from magic_di.exceptions import InjectionError, InspectionError

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator

from magic_di._connectable import ConnectableProtocol

# flag to use in typing.Annotated
# to forcefully mark dependency as injectable
Expand Down Expand Up @@ -111,29 +111,29 @@ def inspect(self, obj: AnyObject) -> Signature[AnyObject]:
hints_with_extras = get_type_hints(obj, include_extras=True)

if not hints:
return Signature(obj, is_injectable=is_injectable(obj))
return Signature(obj, is_injectable=bool(is_connectable(obj)))

if inspect.ismethod(obj):
hints.pop("self", None)

hints.pop("return", None)

signature = Signature(obj, is_injectable=is_injectable(obj))
signature = Signature(obj, is_injectable=bool(is_connectable(obj)))

for name, hint_ in hints.items():
hint = self._unwrap_type_hint(hint_)
hint_with_extra = hints_with_extras[name]

if is_injector(hint):
signature.injector_arg = name
elif not is_injectable(hint) and not is_forcefully_marked_as_injectable(
elif not is_connectable(hint) and not is_forcefully_marked_as_injectable(
hint_with_extra,
):
signature.kwargs[name] = hint
else:
signature.deps[name] = hint

except Exception as exc:
except Exception as exc: # noqa: BLE001
raise InspectionError(obj) from exc

return signature
Expand All @@ -149,30 +149,41 @@ async def connect(self) -> None:
self.inject(postponed)

for cls, instance in self._deps.iter_instances():
if is_injectable(instance):
if connectable_instance := is_connectable(instance):
self.logger.debug("Connecting %s...", cls.__name__)
await instance.__connect__()
await connectable_instance.__connect__()

async def disconnect(self) -> None:
"""
Disconnect all injected dependencies
"""
for cls, instance in self._deps.iter_instances(reverse=True):
if is_injectable(instance):
if connectable_instance := is_connectable(instance):
try:
await instance.__disconnect__()
await connectable_instance.__disconnect__()
except Exception:
self.logger.exception("Failed to disconnect %s", cls.__name__)

def get_dependencies_by_interface(
self,
interface: Callable[..., AnyObject],
) -> Iterable[AnyObject]:
"""
Get all injected dependencies that implement a particular interface.
"""
for _, instance in self._deps.iter_instances():
if safe_is_instance(instance, interface): # type: ignore[arg-type]
yield instance # type: ignore[misc]

async def __aenter__(self) -> DependencyInjector: # noqa: PYI034
await self.connect()
return self

async def __aexit__(self, *args: object, **kwargs: Any) -> None:
await self.disconnect()

def iter_deps(self) -> Iterable[ConnectableProtocol]:
instance: ConnectableProtocol
def iter_deps(self) -> Iterable[object]:
instance: object

for _, instance in self._deps.iter_instances():
yield instance
Expand Down
13 changes: 10 additions & 3 deletions src/magic_di/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_cls_from_optional(cls: T) -> T:
Extract the actual class from a union that includes None.
If it is not a union type hint, it returns the same type hint.
Example:
```python
>>> get_cls_from_optional(Union[str, None])
str
>>> get_cls_from_optional(str | None)
Expand All @@ -35,6 +36,7 @@ def get_cls_from_optional(cls: T) -> T:
str
>>> get_cls_from_optional(int)
int
```
Args:
cls (T): Type hint for class
Returns:
Expand Down Expand Up @@ -73,17 +75,22 @@ def safe_is_instance(sub_cls: Any, cls: type) -> bool:
return False


def is_injectable(cls: Any) -> bool:
def is_connectable(cls: Any) -> ConnectableProtocol | None:
"""
Check if a class is a subclass of ConnectableProtocol.
Args:
cls (Any): The class to check.
Returns:
bool: True if the class is a subclass of ConnectableProtocol, False otherwise.
ConnectableProtocol | None: return instance if the class
is a subclass of ConnectableProtocol, None otherwise.
"""
return safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance(cls, ConnectableProtocol)
connectable = safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance(
cls,
ConnectableProtocol,
)
return cls if connectable else None


def get_type_hints(obj: Any, *, include_extras: bool = False) -> dict[str, type]:
Expand Down
47 changes: 47 additions & 0 deletions src/magic_di/healthcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import asyncio
from asyncio import Future
from dataclasses import dataclass
from typing import Protocol, Any

from magic_di import DependencyInjector, ConnectableProtocol, Connectable


class PingableProtocol(ConnectableProtocol, Protocol):
async def __ping__(self) -> None:
...


@dataclass
class DependenciesHealthcheck(Connectable):
"""
Injectable Healthcheck component that pings all dependencies that implement the PingableProtocol
Example usage:
```python
from app.components.services.health import DependenciesHealthcheck
async def main(redis: Redis, deps_healthcheck: DependenciesHealthcheck) -> None:
await deps_healthcheck.ping_dependencies() # redis will be pinged if it has method __ping__
inject_and_run(main)
```
"""
injector: DependencyInjector

async def ping_dependencies(self, max_concurrency: int = 1) -> None:
"""
Ping all dependencies that implement the PingableProtocol
:param max_concurrency: Maximum number of concurrent pings
"""
tasks: set[Future[Any]] = set()

for dependency in self.injector.get_dependencies_by_interface(PingableProtocol):
tasks.add(asyncio.ensure_future(dependency.__ping__()))

if len(tasks) >= max_concurrency:
tasks, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

if tasks:
await asyncio.gather(*tasks)
20 changes: 11 additions & 9 deletions src/magic_di/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,20 @@ class InjectableMock(AsyncMock):
and use AsyncMock instead of a real class instance
Example:
@pytest.fixture()
def client():
injector = DependencyInjector()
```python
@pytest.fixture()
def client():
injector = DependencyInjector()
with injector.override({Service: InjectableMock().mock_cls}):
with TestClient(app) as client:
yield client
with injector.override({Service: InjectableMock().mock_cls}):
with TestClient(app) as client:
yield client
def test_http_handler(client):
resp = client.post('/hello-world')
def test_http_handler(client):
resp = client.post('/hello-world')
assert resp.status_code == 200
assert resp.status_code == 200
```
"""

@property
Expand Down
4 changes: 2 additions & 2 deletions src/magic_di/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def inject_and_run(
"""
injector = injector or DependencyInjector()

injected = injector.inject(fn)

async def run() -> T:
injected = injector.inject(fn)

async with injector:
if inspect.iscoroutinefunction(fn):
return await injected() # type: ignore[misc,no-any-return]
Expand Down
58 changes: 58 additions & 0 deletions tests/test_healthcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from dataclasses import dataclass

import pytest

from magic_di import Connectable, DependencyInjector
from magic_di.healthcheck import DependenciesHealthcheck


@dataclass
class PingableDatabase(Connectable):
ping_count: int = 0

async def __ping__(self) -> None:
self.ping_count += 1


@dataclass
class Service(Connectable):
ping_count: int = 0


@dataclass
class PingableService(Connectable):
db: PingableDatabase
ping_count: int = 0

async def __ping__(self) -> None:
self.ping_count += 1


@pytest.mark.asyncio()
async def test_healthcheck(injector: DependencyInjector) -> None:
async def main(_: PingableService) -> None:
...

await injector.inject(main)()

injected_db = injector.inject(PingableDatabase)()
injected_srv = injector.inject(PingableService)()
injected_srv_not_pingable = injector.inject(Service)()

assert injected_db.ping_count == 0
assert injected_srv.ping_count == 0
assert injected_srv_not_pingable.ping_count == 0

healthcheck = injector.inject(DependenciesHealthcheck)()

await healthcheck.ping_dependencies(max_concurrency=1)

assert injected_db.ping_count == 1
assert injected_srv.ping_count == 1
assert injected_srv_not_pingable.ping_count == 0

await healthcheck.ping_dependencies(max_concurrency=3)

assert injected_db.ping_count == 2
assert injected_srv.ping_count == 2
assert injected_srv_not_pingable.ping_count == 0
16 changes: 15 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from magic_di import DependencyInjector
import asyncio

from magic_di import ConnectableProtocol, DependencyInjector
from magic_di.utils import inject_and_run

from tests.conftest import Repository
Expand All @@ -25,3 +27,15 @@ async def main(repo: Repository) -> Repository:
repo = inject_and_run(main, injector=injector)
assert not repo.connected
assert not repo.db.connected


def test_inject_and_run_async_proper_event_loop(injector: DependencyInjector) -> None:
class DependencyWithEventLoop(ConnectableProtocol):
def __init__(self) -> None:
self.event_loop = asyncio.get_event_loop()

async def main(dependency: DependencyWithEventLoop) -> None:
event_loop = asyncio.get_running_loop()
assert event_loop is dependency.event_loop

inject_and_run(main, injector=injector)

0 comments on commit 7508c83

Please sign in to comment.