diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8e81856..05e14ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: - id: ruff name: ruff # Add --fix, in case you want it to autofix when this hook runs - entry: poetry run ruff check --force-exclude + entry: poetry run ruff check --fix --force-exclude require_serial: true language: system types: [ python ] diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/magic_di/_container.py b/src/magic_di/_container.py index c069d23..e72f314 100644 --- a/src/magic_di/_container.py +++ b/src/magic_di/_container.py @@ -4,7 +4,10 @@ import inspect from dataclasses import dataclass from threading import Lock -from typing import Generic, Iterable, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar, cast + +if TYPE_CHECKING: + from magic_di import ConnectableProtocol T = TypeVar("T") @@ -16,11 +19,11 @@ class Dependency(Generic[T]): class SingletonDependencyContainer: - def __init__(self): - self._deps: dict[type[T], Dependency[T]] = {} + def __init__(self) -> None: + self._deps: dict[type, Dependency[Any]] = {} self._lock: Lock = Lock() - def add(self, obj: type[T], **kwargs) -> type[T]: + def add(self, obj: type[T], **kwargs: Any) -> type[T]: with self._lock: if dep := self._get(obj): return dep @@ -44,9 +47,13 @@ def get(self, obj: type[T]) -> type[T] | None: with self._lock: return self._get(obj) - def iter_instances(self, *, reverse: bool = False) -> Iterable[tuple[type[T], T]]: + def iter_instances( + self, + *, + reverse: bool = False, + ) -> Iterable[tuple[type, ConnectableProtocol]]: with self._lock: - deps_iter: Iterable = list( + deps_iter: Iterable[Dependency[Any]] = list( reversed(self._deps.values()) if reverse else self._deps.values(), ) @@ -59,14 +66,14 @@ def _get(self, obj: type[T]) -> type[T] | None: return dep.object if dep else None -def _wrap(obj: type[T], *args, **kwargs) -> type[T]: +def _wrap(obj: type[T], *args: Any, **kwargs: Any) -> type[T]: if not inspect.isclass(obj): partial = functools.wraps(obj)(functools.partial(obj, *args, **kwargs)) return cast(type[T], partial) - _instance = None + _instance: T | None = None - def new(_): + def new(_: Any) -> T: nonlocal _instance if _instance is not None: diff --git a/src/magic_di/_injector.py b/src/magic_di/_injector.py index 8d089a7..cd5404a 100644 --- a/src/magic_di/_injector.py +++ b/src/magic_di/_injector.py @@ -10,6 +10,8 @@ Any, Callable, Iterable, + Iterator, + Self, TypeVar, cast, get_origin, @@ -65,10 +67,10 @@ def is_forcefully_marked_as_injectable(cls: Any) -> bool: class DependencyInjector: def __init__( self, - bindings: dict | None = None, + bindings: dict[type, type] | None = None, logger: logging.Logger = logger, ): - self.bindings: dict = bindings or {} + self.bindings = bindings or {} self.logger: logging.Logger = logger self._deps = SingletonDependencyContainer() @@ -88,14 +90,14 @@ def inject(self, obj: Callable[..., T]) -> Callable[..., T]: """ obj = self._unwrap_type_hint(obj) # type: ignore[arg-type] - if dep := self._deps.get(obj): # type: ignore[arg-type] + if dep := self._deps.get(obj): return dep signature = self.inspect(obj) clients: dict[str, object] = {} for name, dep in signature.deps.items(): - clients[name] = self.inject(dep)() # type: ignore[misc] + clients[name] = self.inject(dep)() if signature.injector_arg is not None: clients[signature.injector_arg] = self @@ -107,7 +109,7 @@ def inject(self, obj: Callable[..., T]) -> Callable[..., T]: def inspect(self, obj: AnyObject) -> Signature[AnyObject]: try: - hints = get_type_hints(obj) + hints: dict[str, type[Any]] = get_type_hints(obj) hints_with_extras = get_type_hints(obj, include_extras=True) if not hints: @@ -138,7 +140,7 @@ def inspect(self, obj: AnyObject) -> Signature[AnyObject]: return signature - async def connect(self): + async def connect(self) -> None: """ Connect all injected dependencies """ @@ -153,7 +155,7 @@ async def connect(self): self.logger.debug("Connecting %s...", cls.__name__) await instance.__connect__() - async def disconnect(self): + async def disconnect(self) -> None: """ Disconnect all injected dependencies """ @@ -164,11 +166,11 @@ async def disconnect(self): except Exception: self.logger.exception("Failed to disconnect %s", cls.__name__) - async def __aenter__(self): + async def __aenter__(self) -> Self: await self.connect() return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: object, **kwargs: Any) -> None: await self.disconnect() def iter_deps(self) -> Iterable[ConnectableProtocol]: @@ -194,9 +196,9 @@ def lazy_inject(self, obj: Callable[..., T]) -> Callable[..., T]: self._postponed.append(obj) # type: ignore[arg-type] # incompatible type "Callable[..., T]"; expected "Callable[..., T]" - injected = None + injected: T | None = None - def inject(): + def inject() -> T: nonlocal injected if injected is not None: @@ -207,7 +209,7 @@ def inject(): return cast(type[T], inject) - def bind(self, bindings: dict[type, type]): + def bind(self, bindings: dict[type, type]) -> None: """ Bind new bindings to the injector. @@ -234,7 +236,7 @@ def bind(self, bindings: dict[type, type]): self.bindings = self.bindings | bindings @contextmanager - def override(self, bindings: dict[type, type]): + def override(self, bindings: dict[type, type]) -> Iterator[None]: """ Temporarily override the bindings and dependencies of the injector. diff --git a/src/magic_di/_utils.py b/src/magic_di/_utils.py index d0e6cfd..38cd427 100644 --- a/src/magic_di/_utils.py +++ b/src/magic_di/_utils.py @@ -8,7 +8,7 @@ LegacyUnionType = type(Union[object, None]) try: - from types import UnionType # type: ignore[import-error] + from types import UnionType # type: ignore[import-error,unused-ignore] except ImportError: UnionType = LegacyUnionType # type: ignore[misc, assignment] @@ -79,10 +79,10 @@ def is_injectable(cls: Any) -> bool: return safe_is_subclass(cls, ConnectableProtocol) -def get_type_hints(obj: Any, *, include_extras=False) -> dict[str, type]: +def get_type_hints(obj: Any, *, include_extras: bool = False) -> dict[str, type]: try: if is_class(obj): - return _get_type_hints(obj.__init__, include_extras=include_extras) # type: ignore[misc] + return _get_type_hints(obj.__init__, include_extras=include_extras) return _get_type_hints(obj, include_extras=include_extras) except TypeError: diff --git a/src/magic_di/celery/_loader.py b/src/magic_di/celery/_loader.py index 28fab6c..472fedb 100644 --- a/src/magic_di/celery/_loader.py +++ b/src/magic_di/celery/_loader.py @@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable from celery import signals -from celery.loaders.app import AppLoader # type: ignore[import] +from celery.loaders.app import AppLoader # type: ignore[import-untyped] + from magic_di import DependencyInjector from magic_di.celery._async_utils import EventLoop, EventLoopGetter, run_in_event_loop @@ -32,8 +33,8 @@ def get_celery_loader( _injector = injector or DependencyInjector() _event_loop_getter = event_loop_getter or EventLoopGetter() - class CeleryLoader(AppLoader): # type: ignore[no-any-unimported] - def __init__(self, *args, **kwargs): + class CeleryLoader(AppLoader): # type: ignore[no-any-unimported, misc] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.injector = _injector self.loaded = False diff --git a/src/magic_di/celery/_task.py b/src/magic_di/celery/_task.py index 752fbe1..e039e92 100644 --- a/src/magic_di/celery/_task.py +++ b/src/magic_di/celery/_task.py @@ -5,6 +5,7 @@ from typing import Any, Callable, cast, get_type_hints from celery.app.task import Task + from magic_di import Connectable, DependencyInjector from magic_di.celery._async_utils import EventLoop, run_in_event_loop from magic_di.celery._loader import InjectedCeleryLoaderProtocol @@ -14,7 +15,7 @@ class BaseCeleryConnectableDeps(Connectable): ... class InjectableCeleryTaskMetaclass(type): - def __new__(cls, name: str, bases: tuple, dct: dict) -> type: + def __new__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any]) -> type: run = dct.get("run") run_wrapper = dct.get("run_wrapper") @@ -32,7 +33,7 @@ def __new__(cls, name: str, bases: tuple, dct: dict) -> type: return super().__new__(cls, name, bases, dct) -class InjectableCeleryTask(Task, Connectable, metaclass=InjectableCeleryTaskMetaclass): +class InjectableCeleryTask(Task, Connectable, metaclass=InjectableCeleryTaskMetaclass): # type: ignore[type-arg] __annotations__ = {} def __init__( @@ -83,9 +84,10 @@ def load(self) -> None: InjectedCeleryLoaderProtocol, self.app.loader, ) - return loader.on_worker_process_init() + loader.on_worker_process_init() + return - return self.app.loader.on_worker_process_init() # type: ignore[attr-defined] + self.app.loader.on_worker_process_init() # type: ignore[attr-defined] @property def loaded(self) -> bool: @@ -110,9 +112,9 @@ def get_event_loop(self) -> EventLoop | None: return None @staticmethod - def run_wrapper(orig_run: Callable) -> Callable: + def run_wrapper(orig_run: Callable[..., Any]) -> Callable[..., Any]: @wraps(orig_run) - def runner(self: InjectableCeleryTask, *args, **kwargs) -> Any: + def runner(self: InjectableCeleryTask, *args: Any, **kwargs: Any) -> Any: if not self.loaded: self.load() diff --git a/src/magic_di/exceptions.py b/src/magic_di/exceptions.py index 00263f3..aa7d13f 100644 --- a/src/magic_di/exceptions.py +++ b/src/magic_di/exceptions.py @@ -9,7 +9,7 @@ class InjectorError(Exception): ... class InjectionError(InjectorError): - def __init__(self, obj: Callable, signature: Signature): + def __init__(self, obj: Callable[..., Any], signature: Signature[Any]): self.obj = obj self.signature = signature @@ -64,7 +64,7 @@ def _build_error_message(self) -> str: return f"Failed to inspect {self.obj}. \n{object_location}\nSee the exception above" -def _get_object_source_location(obj: Callable) -> str: +def _get_object_source_location(obj: Callable[..., Any]) -> str: try: _, obj_line_number = inspect.getsourcelines(obj) source_file = inspect.getsourcefile(obj) diff --git a/src/magic_di/fastapi/__init__.py b/src/magic_di/fastapi/__init__.py index e39a413..a29c2b5 100644 --- a/src/magic_di/fastapi/__init__.py +++ b/src/magic_di/fastapi/__init__.py @@ -3,6 +3,6 @@ """ from ._app import inject_app -from ._provide import Provide, Provider +from ._provide import Provide, Provider # type: ignore[attr-defined] __all__ = ("inject_app", "Provide", "Provider") diff --git a/src/magic_di/fastapi/_app.py b/src/magic_di/fastapi/_app.py index f81bbc1..4cefd8c 100644 --- a/src/magic_di/fastapi/_app.py +++ b/src/magic_di/fastapi/_app.py @@ -6,6 +6,7 @@ TYPE_CHECKING, Annotated, Any, + AsyncIterator, Callable, Iterator, Protocol, @@ -14,6 +15,7 @@ ) from fastapi.params import Depends + from magic_di._injector import DependencyInjector if TYPE_CHECKING: @@ -67,7 +69,7 @@ def inject_app( """ injector = injector or DependencyInjector() - def collect_deps(): + def collect_deps() -> None: _collect_dependencies(injector, app.router) app.state.dependency_injector = injector @@ -80,12 +82,12 @@ def collect_deps(): return app -def _inject_app_with_lifespan(app: FastAPI, collect_deps_fn: Callable) -> None: +def _inject_app_with_lifespan(app: FastAPI, collect_deps_fn: Callable[[], None]) -> None: app_router: routing.APIRouter = app.router app_lifespan = app_router.lifespan_context @asynccontextmanager - async def injector_lifespan(app: FastAPI): + async def injector_lifespan(app: FastAPI) -> AsyncIterator[None]: collect_deps_fn() injector = app.state.dependency_injector @@ -99,7 +101,7 @@ async def injector_lifespan(app: FastAPI): app_router.lifespan_context = injector_lifespan -def _inject_app_with_events(app: FastAPI, collect_deps_fn: Callable) -> None: +def _inject_app_with_events(app: FastAPI, collect_deps_fn: Callable[[], None]) -> None: app.on_event("startup")(collect_deps_fn) app.on_event("startup")(app.state.dependency_injector.connect) app.on_event("shutdown")(app.state.dependency_injector.disconnect) @@ -141,7 +143,7 @@ def _inspect_and_lazy_inject(obj: object, injector: DependencyInjector) -> None: injector.lazy_inject(dependency) -def _find_fastapi_dependencies(dependency: Callable) -> Iterator[Callable]: +def _find_fastapi_dependencies(dependency: Callable[..., Any]) -> Iterator[Callable[..., Any]]: """ Recursively finds all FastAPI dependencies. It looks for FastAPI's Depends() in default arguments and in type annotations. diff --git a/src/magic_di/fastapi/_provide.py b/src/magic_di/fastapi/_provide.py index d38b39e..8bd5459 100644 --- a/src/magic_di/fastapi/_provide.py +++ b/src/magic_di/fastapi/_provide.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Annotated, TypeVar from fastapi import Depends, Request + from magic_di._injector import Injectable if TYPE_CHECKING: diff --git a/src/magic_di/testing.py b/src/magic_di/testing.py index 3fb271b..180f0b5 100644 --- a/src/magic_di/testing.py +++ b/src/magic_di/testing.py @@ -1,13 +1,16 @@ +from __future__ import annotations + +from typing import Any from unittest.mock import AsyncMock from magic_di import ConnectableProtocol -def get_injectable_mock_cls(return_value): +def get_injectable_mock_cls(return_value: Any) -> type[ConnectableProtocol]: class ClientMetaclassMock(ConnectableProtocol): __annotations__ = {} - def __new__(cls, *_, **__): + def __new__(cls, *_: Any, **__: Any) -> Any: return return_value return ClientMetaclassMock @@ -34,12 +37,12 @@ def test_http_handler(client): """ @property - def mock_cls(self): + def mock_cls(self) -> type[ConnectableProtocol]: return get_injectable_mock_cls(self) - async def __connect__(self): ... + async def __connect__(self) -> None: ... - async def __disconnect__(self): ... + async def __disconnect__(self) -> None: ... - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> InjectableMock: return self.__class__(*args, **kwargs) diff --git a/src/magic_di/utils.py b/src/magic_di/utils.py index 1638592..0c87902 100644 --- a/src/magic_di/utils.py +++ b/src/magic_di/utils.py @@ -2,7 +2,7 @@ import asyncio import inspect -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Awaitable, TypeVar, overload from magic_di import DependencyInjector @@ -10,7 +10,27 @@ from collections.abc import Callable -def inject_and_run(fn: Callable, injector: DependencyInjector | None = None) -> Any: +T = TypeVar("T") + + +@overload +def inject_and_run( + fn: Callable[..., Awaitable[T]], + injector: DependencyInjector | None = None, +) -> T: ... + + +@overload +def inject_and_run( + fn: Callable[..., T], + injector: DependencyInjector | None = None, +) -> T: ... + + +def inject_and_run( + fn: Callable[..., T], + injector: DependencyInjector | None = None, +) -> T: """ This function takes a callable, injects dependencies into it using the provided injector, and then runs the function. If the function is a coroutine, it will be awaited. @@ -33,12 +53,12 @@ def inject_and_run(fn: Callable, injector: DependencyInjector | None = None) -> """ injector = injector or DependencyInjector() - injected: Callable = injector.inject(fn) + injected = injector.inject(fn) - async def run(): + async def run() -> T: async with injector: if inspect.iscoroutinefunction(fn): - return await injected() + return await injected() # type: ignore[misc,no-any-return] return injected() diff --git a/tests/conftest.py b/tests/conftest.py index 7ee44fb..23728ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,10 +11,10 @@ class ConnectableClient(Connectable): connected: bool = False - async def __connect__(self): + async def __connect__(self) -> None: self.connected = True - async def __disconnect__(self): + async def __disconnect__(self) -> None: self.connected = False @@ -25,11 +25,11 @@ class AnotherDatabase(ConnectableClient): ... class Repository(ConnectableClient): - def __init__(self, db: Database, some_params: int = 1): + def __init__(self, db: Database, some_params: int = 1) -> None: self.db = db - self.some_params = 1 + self.some_params = some_params - async def do_something(self): + async def do_something(self) -> bool: await asyncio.sleep(0.1) return self.connected and self.db.connected @@ -42,7 +42,8 @@ class Service(ConnectableClient): repo: Repository workers: AsyncWorkers | None - def is_alive(self): + def is_alive(self) -> bool: + assert self.workers return self.repo.connected and self.workers.connected diff --git a/tests/test_app.py b/tests/test_app.py index 5174474..3bbaa55 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -10,25 +10,25 @@ from tests.conftest import Database, Service -def test_app_injection(injector): +def test_app_injection(injector: DependencyInjector) -> None: app = inject_app(FastAPI(), injector=injector) @app.get(path="/hello-world") - def hello_world(service: Provide[Service], some_query: str) -> dict: + def hello_world(service: Provide[Service], some_query: str) -> dict[str, str | bool]: # noqa: FA102 assert isinstance(service, Service) - return {"query": some_query, "is_alive": service.is_alive()} # type: ignore[attr-defined] + return {"query": some_query, "is_alive": service.is_alive()} with TestClient(app) as client: resp = client.get("/hello-world?some_query=my-query") assert resp.json() == {"query": "my-query", "is_alive": True} -def test_app_injection_with_depends(injector): +def test_app_injection_with_depends(injector: DependencyInjector) -> None: connected_global_dependency = False class GlobalConnect(Database): ... - def global_dependency(dep: Provide[GlobalConnect]): + def global_dependency(dep: Provide[GlobalConnect]) -> None: nonlocal connected_global_dependency connected_global_dependency = dep.connected @@ -55,7 +55,7 @@ def get_creds( mw: Provide[MiddlewareNonConnectable], *, db_connect: bool = Depends(assert_db_connected), - ): + ) -> str: if db_connect: return mw.get_creds() @@ -65,9 +65,9 @@ def get_creds( def hello_world( service: Provide[Service], creds: Annotated[str, Depends(get_creds)], - ) -> dict: + ) -> dict[str, str | bool]: # noqa: FA102 assert isinstance(service, Service) - return {"creds": creds, "is_alive": service.is_alive()} # type: ignore[attr-defined] + return {"creds": creds, "is_alive": service.is_alive()} with TestClient(app) as client: resp = client.get("/hello-world?some_query=my-query") @@ -80,7 +80,7 @@ def test_app_injection_clients_connect( injector: DependencyInjector, *, use_deprecated_events: bool, -): +) -> None: app = inject_app( FastAPI(), injector=injector, @@ -90,12 +90,14 @@ def test_app_injection_clients_connect( router = APIRouter() @router.get(path="/hello-world") - def hello_world(service: Provide[Service]) -> dict: + def hello_world(service: Provide[Service]) -> dict[str, bool]: # noqa: FA102 + assert service.workers + return { - "service_connected": service.connected, # type: ignore[attr-defined] - "workers_connected": service.workers.connected, # type: ignore[union-attr] - "repo_connected": service.repo.connected, # type: ignore[attr-defined] - "db_connected": service.repo.db.connected, # type: ignore[attr-defined] + "service_connected": service.connected, + "workers_connected": service.workers.connected, + "repo_connected": service.repo.connected, + "db_connected": service.repo.db.connected, } app.include_router(router) @@ -118,7 +120,7 @@ def hello_world(service: Provide[Service]) -> dict: } -def test_app_injection_without_registered_injector(injector: DependencyInjector): +def test_app_injection_without_registered_injector(injector: DependencyInjector) -> None: app = FastAPI() @app.get(path="/hello-world") diff --git a/tests/test_celery.py b/tests/test_celery.py index 592ea46..20430b4 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -3,12 +3,12 @@ from contextlib import contextmanager from dataclasses import dataclass from threading import Thread -from typing import Iterator, cast +from typing import Any, Iterator, cast from unittest.mock import MagicMock, call import pytest from celery import Celery -from celery.bin.celery import celery # type: ignore[import] +from celery.bin.celery import celery # type: ignore[import-untyped] from fastapi import FastAPI from magic_di import DependencyInjector from magic_di.celery import ( @@ -53,7 +53,11 @@ def celery_app(injector: DependencyInjector) -> Iterator[Celery]: @pytest.fixture(scope="module") def service_ping_task(celery_app: Celery) -> InjectableCeleryTask: @celery_app.task - async def service_ping(arg1: int, arg2: str, service: Service = PROVIDE) -> tuple: + async def service_ping( + arg1: int, + arg2: str, + service: Service = PROVIDE, + ) -> tuple[int, str, bool]: # noqa: FA102 return arg1, arg2, service.is_alive() return cast(InjectableCeleryTask, service_ping) @@ -62,7 +66,11 @@ async def service_ping(arg1: int, arg2: str, service: Service = PROVIDE) -> tupl @pytest.fixture(scope="module") def service_ping_task_sync(celery_app: Celery) -> InjectableCeleryTask: @celery_app.task - def service_ping_sync(arg1: int, arg2: str, service: Service = PROVIDE) -> tuple: + def service_ping_sync( + arg1: int, + arg2: str, + service: Service = PROVIDE, + ) -> tuple[int, str, bool]: # noqa: FA102 return arg1, arg2, service.is_alive() return cast(InjectableCeleryTask, service_ping_sync) @@ -77,7 +85,7 @@ class Deps(BaseCeleryConnectableDeps): class SyncServicePingTask(InjectableCeleryTask): deps: Deps - async def run(self, arg1: int, arg2: str): + async def run(self, arg1: int, arg2: str) -> tuple[int, str, bool]: # noqa: FA102 return arg1, arg2, self.deps.db.connected return SyncServicePingTask() @@ -92,7 +100,7 @@ class Deps(BaseCeleryConnectableDeps): class ServicePingTask(InjectableCeleryTask): deps: Deps - def run(self, arg1: int, arg2: str): + def run(self, arg1: int, arg2: str) -> tuple[int, str, bool]: # noqa: FA102 return arg1, arg2, self.deps.db.connected return ServicePingTask() @@ -100,12 +108,12 @@ def run(self, arg1: int, arg2: str): @pytest.fixture(scope="module") def run_celery( - celery_app, - service_ping_task, - service_ping_task_sync, - service_ping_class_based_task, - service_ping_class_based_task_sync, -): + celery_app: Celery, + service_ping_task: InjectableCeleryTask, + service_ping_task_sync: InjectableCeleryTask, + service_ping_class_based_task: InjectableCeleryTask, + service_ping_class_based_task_sync: InjectableCeleryTask, +) -> Iterator[Celery]: celery_app.register_task(service_ping_class_based_task) celery_app.register_task(service_ping_class_based_task_sync) @@ -119,7 +127,7 @@ def run_celery( yield celery_app with contextlib.suppress(SystemExit): - (celery.main(args=["control", "shutdown"])) + celery.main(args=["control", "shutdown"]) thread.join() @@ -127,7 +135,7 @@ def run_celery( def test_async_function_based_tasks( run_celery: Celery, service_ping_task: InjectableCeleryTask, -): +) -> None: result = service_ping_task.apply_async(args=(1337, "leet")).get( disable_sync_subtasks=False, ) @@ -148,7 +156,7 @@ def test_async_function_based_tasks( async def test_sync_function_based_tasks( run_celery: Celery, service_ping_task_sync: InjectableCeleryTask, -): +) -> None: result = service_ping_task_sync.apply_async(args=(1337, "leet")).get( disable_sync_subtasks=False, ) @@ -168,7 +176,7 @@ async def test_sync_function_based_tasks( def test_async_class_based_tasks( run_celery: Celery, service_ping_class_based_task: InjectableCeleryTask, -): +) -> None: result = service_ping_class_based_task.apply_async(args=(1337, "leet")).get( disable_sync_subtasks=False, ) @@ -188,7 +196,7 @@ def test_async_class_based_tasks( def test_sync_class_based_tasks( run_celery: Celery, service_ping_class_based_task_sync: InjectableCeleryTask, -): +) -> None: result = service_ping_class_based_task_sync.apply_async(args=(1337, "leet")).get( disable_sync_subtasks=False, ) @@ -205,7 +213,7 @@ def test_sync_class_based_tasks( assert list(result) == [1010, "123", True] -def test_retries_func_based_task(): +def test_retries_func_based_task() -> None: with create_celery(DependencyInjector(), use_broker_and_backend=False) as app: app.conf.update({"task_always_eager": True}) mock = MagicMock() @@ -220,7 +228,7 @@ async def ping_task(service: Service = PROVIDE) -> None: _ = ping_task.apply_async().get(disable_sync_subtasks=False) -def test_retries_class_based_task(): +def test_retries_class_based_task() -> None: with create_celery(DependencyInjector(), use_broker_and_backend=False) as app: app.conf.update({"task_always_eager": True}) @@ -231,7 +239,7 @@ class PingTask(InjectableCeleryTask): autoretry_for = (ValueError,) retry_backoff = 0 - async def run(self, service: Service = PROVIDE): + async def run(self, service: Service = PROVIDE) -> None: assert service.is_alive() mock() raise ValueError(TEST_ERR_MSG) @@ -258,8 +266,8 @@ async def test_async_function_based_tasks_inside_event_loop( *, task_always_eager: bool, use_broker_and_backend: bool, - expected_mock_calls: list, -): + expected_mock_calls: list[Any], # noqa: FA102 +) -> None: injector = DependencyInjector() with create_celery(injector, use_broker_and_backend=use_broker_and_backend) as app: @@ -272,14 +280,14 @@ async def ping_task( arg1: int, arg2: str, service: Service = PROVIDE, - ) -> tuple: + ) -> tuple[int, str, bool]: # noqa: FA102 mock() return arg1, arg2, service.is_alive() fastapi_app = FastAPI() @fastapi_app.get("/") - async def handler(): + async def handler() -> dict[str, bool]: # noqa: FA102 ping_task.apply_async(args=(1337, "leet")) ping_task.apply(args=(1337, "leet-2")) return {"ok": True} diff --git a/tests/test_injector.py b/tests/test_injector.py index 66c03cb..04559ca 100644 --- a/tests/test_injector.py +++ b/tests/test_injector.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Annotated, Generic, TypeVar +from typing import Annotated, Any, Generic, TypeVar import pytest from magic_di import Connectable, DependencyInjector, Injectable @@ -21,7 +21,7 @@ @pytest.mark.asyncio() -async def test_class_injection_success(injector): +async def test_class_injection_success(injector: DependencyInjector) -> None: injected_service = injector.inject(Service)() assert not injected_service.is_alive() @@ -48,8 +48,8 @@ async def test_class_injection_success(injector): @pytest.mark.asyncio() -async def test_function_injection_success(injector): - def run_service(service: Service): +async def test_function_injection_success(injector: DependencyInjector) -> None: + def run_service(service: Service) -> Service: return service injected = injector.inject(run_service) @@ -61,13 +61,13 @@ def run_service(service: Service): assert isinstance(service, Service) -def test_class_injection_missing_class(injector): +def test_class_injection_missing_class(injector: DependencyInjector) -> None: with pytest.raises(InjectionError): injector.inject(BrokenService) @pytest.mark.asyncio() -async def test_class_injection_with_bindings(injector): +async def test_class_injection_with_bindings(injector: DependencyInjector) -> None: injector.bind({RepoInterface: Repository}) injected_service = injector.inject(ServiceWithBindings)() @@ -86,7 +86,7 @@ async def test_class_injection_with_bindings(injector): assert not injected_service.repo.db.connected -def test_lazy_inject(injector): +def test_lazy_inject(injector: DependencyInjector) -> None: get_injected_cls = injector.lazy_inject(Service) injected_service = get_injected_cls() @@ -94,7 +94,7 @@ def test_lazy_inject(injector): assert injected_service is get_injected_cls() -def test_overriden_injection(injector): +def test_overriden_injection(injector: DependencyInjector) -> None: service = injector.inject(Service)() with injector.override({Database: AnotherDatabase}): @@ -110,7 +110,7 @@ def test_overriden_injection(injector): assert service is service_after_overriden_injection -def test_embedded_injection(injector): +def test_embedded_injection(injector: DependencyInjector) -> None: class ClsWithEmbeddedInjection(Connectable): def __init__(self, injected_injector: DependencyInjector): assert injected_injector is injector @@ -120,16 +120,16 @@ def __init__(self, injected_injector: DependencyInjector): assert isinstance(injected.service, Service) -def test_injector_iter_deps(injector): +def test_injector_iter_deps(injector: DependencyInjector) -> None: injector.inject(Service)() deps = [type(dep) for dep in injector.iter_deps()] assert deps == [Database, Repository, AsyncWorkers, Service] -def test_injector_with_metaclass(injector): +def test_injector_with_metaclass(injector: DependencyInjector) -> None: class _ServiceMetaClass(type): - def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict) -> type: + def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> type[Any]: for _ in attrs["__orig_bases__"]: ... @@ -148,7 +148,7 @@ class WrappedService(ServiceGeneric[Service]): injector.inject(WrappedService)() -def test_injector_flag_injectable(injector): +def test_injector_flag_injectable(injector: DependencyInjector) -> None: @dataclass class ServiceWithNonConnectable: db: Annotated[NonConnectableDatabase, Injectable] diff --git a/tests/test_utils.py b/tests/test_utils.py index 756f75a..b445d46 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,10 @@ +from magic_di import DependencyInjector from magic_di.utils import inject_and_run from tests.conftest import Repository -def test_inject_and_run_sync(injector): +def test_inject_and_run_sync(injector: DependencyInjector) -> None: def main(repo: Repository) -> Repository: assert repo.connected assert repo.db.connected @@ -14,7 +15,7 @@ def main(repo: Repository) -> Repository: assert not repo.db.connected -def test_inject_and_run_async(injector): +def test_inject_and_run_async(injector: DependencyInjector) -> None: async def main(repo: Repository) -> Repository: assert await repo.do_something() assert repo.connected