Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
RB387 committed Mar 25, 2024
1 parent 92b3828 commit 2bb8df8
Show file tree
Hide file tree
Showing 18 changed files with 164 additions and 114 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
- id: ruff
name: ruff
# Add --fix, in case you want it to autofix when this hook runs
entry: poetry run ruff check --force-exclude
entry: poetry run ruff check --fix --force-exclude
require_serial: true
language: system
types: [ python ]
Expand Down
Empty file removed src/__init__.py
Empty file.
25 changes: 16 additions & 9 deletions src/magic_di/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import inspect
from dataclasses import dataclass
from threading import Lock
from typing import Generic, Iterable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar, cast

if TYPE_CHECKING:
from magic_di import ConnectableProtocol

T = TypeVar("T")

Expand All @@ -16,11 +19,11 @@ class Dependency(Generic[T]):


class SingletonDependencyContainer:
def __init__(self):
self._deps: dict[type[T], Dependency[T]] = {}
def __init__(self) -> None:
self._deps: dict[type, Dependency[Any]] = {}
self._lock: Lock = Lock()

def add(self, obj: type[T], **kwargs) -> type[T]:
def add(self, obj: type[T], **kwargs: Any) -> type[T]:
with self._lock:
if dep := self._get(obj):
return dep
Expand All @@ -44,9 +47,13 @@ def get(self, obj: type[T]) -> type[T] | None:
with self._lock:
return self._get(obj)

def iter_instances(self, *, reverse: bool = False) -> Iterable[tuple[type[T], T]]:
def iter_instances(
self,
*,
reverse: bool = False,
) -> Iterable[tuple[type, ConnectableProtocol]]:
with self._lock:
deps_iter: Iterable = list(
deps_iter: Iterable[Dependency[Any]] = list(
reversed(self._deps.values()) if reverse else self._deps.values(),
)

Expand All @@ -59,14 +66,14 @@ def _get(self, obj: type[T]) -> type[T] | None:
return dep.object if dep else None


def _wrap(obj: type[T], *args, **kwargs) -> type[T]:
def _wrap(obj: type[T], *args: Any, **kwargs: Any) -> type[T]:
if not inspect.isclass(obj):
partial = functools.wraps(obj)(functools.partial(obj, *args, **kwargs))
return cast(type[T], partial)

_instance = None
_instance: T | None = None

def new(_):
def new(_: Any) -> T:
nonlocal _instance

if _instance is not None:
Expand Down
28 changes: 15 additions & 13 deletions src/magic_di/_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Any,
Callable,
Iterable,
Iterator,
Self,
TypeVar,
cast,
get_origin,
Expand Down Expand Up @@ -65,10 +67,10 @@ def is_forcefully_marked_as_injectable(cls: Any) -> bool:
class DependencyInjector:
def __init__(
self,
bindings: dict | None = None,
bindings: dict[type, type] | None = None,
logger: logging.Logger = logger,
):
self.bindings: dict = bindings or {}
self.bindings = bindings or {}
self.logger: logging.Logger = logger

self._deps = SingletonDependencyContainer()
Expand All @@ -88,14 +90,14 @@ def inject(self, obj: Callable[..., T]) -> Callable[..., T]:
"""
obj = self._unwrap_type_hint(obj) # type: ignore[arg-type]

if dep := self._deps.get(obj): # type: ignore[arg-type]
if dep := self._deps.get(obj):
return dep

signature = self.inspect(obj)

clients: dict[str, object] = {}
for name, dep in signature.deps.items():
clients[name] = self.inject(dep)() # type: ignore[misc]
clients[name] = self.inject(dep)()

if signature.injector_arg is not None:
clients[signature.injector_arg] = self
Expand All @@ -107,7 +109,7 @@ def inject(self, obj: Callable[..., T]) -> Callable[..., T]:

def inspect(self, obj: AnyObject) -> Signature[AnyObject]:
try:
hints = get_type_hints(obj)
hints: dict[str, type[Any]] = get_type_hints(obj)
hints_with_extras = get_type_hints(obj, include_extras=True)

if not hints:
Expand Down Expand Up @@ -138,7 +140,7 @@ def inspect(self, obj: AnyObject) -> Signature[AnyObject]:

return signature

async def connect(self):
async def connect(self) -> None:
"""
Connect all injected dependencies
"""
Expand All @@ -153,7 +155,7 @@ async def connect(self):
self.logger.debug("Connecting %s...", cls.__name__)
await instance.__connect__()

async def disconnect(self):
async def disconnect(self) -> None:
"""
Disconnect all injected dependencies
"""
Expand All @@ -164,11 +166,11 @@ async def disconnect(self):
except Exception:
self.logger.exception("Failed to disconnect %s", cls.__name__)

async def __aenter__(self):
async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *args, **kwargs):
async def __aexit__(self, *args: object, **kwargs: Any) -> None:
await self.disconnect()

def iter_deps(self) -> Iterable[ConnectableProtocol]:
Expand All @@ -194,9 +196,9 @@ def lazy_inject(self, obj: Callable[..., T]) -> Callable[..., T]:
self._postponed.append(obj) # type: ignore[arg-type]
# incompatible type "Callable[..., T]"; expected "Callable[..., T]"

injected = None
injected: T | None = None

def inject():
def inject() -> T:
nonlocal injected

if injected is not None:
Expand All @@ -207,7 +209,7 @@ def inject():

return cast(type[T], inject)

def bind(self, bindings: dict[type, type]):
def bind(self, bindings: dict[type, type]) -> None:
"""
Bind new bindings to the injector.
Expand All @@ -234,7 +236,7 @@ def bind(self, bindings: dict[type, type]):
self.bindings = self.bindings | bindings

@contextmanager
def override(self, bindings: dict[type, type]):
def override(self, bindings: dict[type, type]) -> Iterator[None]:
"""
Temporarily override the bindings and dependencies of the injector.
Expand Down
6 changes: 3 additions & 3 deletions src/magic_di/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
LegacyUnionType = type(Union[object, None])

try:
from types import UnionType # type: ignore[import-error]
from types import UnionType # type: ignore[import-error,unused-ignore]
except ImportError:
UnionType = LegacyUnionType # type: ignore[misc, assignment]

Expand Down Expand Up @@ -79,10 +79,10 @@ def is_injectable(cls: Any) -> bool:
return safe_is_subclass(cls, ConnectableProtocol)


def get_type_hints(obj: Any, *, include_extras=False) -> dict[str, type]:
def get_type_hints(obj: Any, *, include_extras: bool = False) -> dict[str, type]:
try:
if is_class(obj):
return _get_type_hints(obj.__init__, include_extras=include_extras) # type: ignore[misc]
return _get_type_hints(obj.__init__, include_extras=include_extras)

return _get_type_hints(obj, include_extras=include_extras)
except TypeError:
Expand Down
7 changes: 4 additions & 3 deletions src/magic_di/celery/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable

from celery import signals
from celery.loaders.app import AppLoader # type: ignore[import]
from celery.loaders.app import AppLoader # type: ignore[import-untyped]

from magic_di import DependencyInjector
from magic_di.celery._async_utils import EventLoop, EventLoopGetter, run_in_event_loop

Expand All @@ -32,8 +33,8 @@ def get_celery_loader(
_injector = injector or DependencyInjector()
_event_loop_getter = event_loop_getter or EventLoopGetter()

class CeleryLoader(AppLoader): # type: ignore[no-any-unimported]
def __init__(self, *args, **kwargs):
class CeleryLoader(AppLoader): # type: ignore[no-any-unimported, misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.injector = _injector
self.loaded = False
Expand Down
14 changes: 8 additions & 6 deletions src/magic_di/celery/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, cast, get_type_hints

from celery.app.task import Task

from magic_di import Connectable, DependencyInjector
from magic_di.celery._async_utils import EventLoop, run_in_event_loop
from magic_di.celery._loader import InjectedCeleryLoaderProtocol
Expand All @@ -14,7 +15,7 @@ class BaseCeleryConnectableDeps(Connectable): ...


class InjectableCeleryTaskMetaclass(type):
def __new__(cls, name: str, bases: tuple, dct: dict) -> type:
def __new__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any]) -> type:
run = dct.get("run")
run_wrapper = dct.get("run_wrapper")

Expand All @@ -32,7 +33,7 @@ def __new__(cls, name: str, bases: tuple, dct: dict) -> type:
return super().__new__(cls, name, bases, dct)


class InjectableCeleryTask(Task, Connectable, metaclass=InjectableCeleryTaskMetaclass):
class InjectableCeleryTask(Task, Connectable, metaclass=InjectableCeleryTaskMetaclass): # type: ignore[type-arg]
__annotations__ = {}

def __init__(
Expand Down Expand Up @@ -83,9 +84,10 @@ def load(self) -> None:
InjectedCeleryLoaderProtocol,
self.app.loader,
)
return loader.on_worker_process_init()
loader.on_worker_process_init()
return

return self.app.loader.on_worker_process_init() # type: ignore[attr-defined]
self.app.loader.on_worker_process_init() # type: ignore[attr-defined]

@property
def loaded(self) -> bool:
Expand All @@ -110,9 +112,9 @@ def get_event_loop(self) -> EventLoop | None:
return None

@staticmethod
def run_wrapper(orig_run: Callable) -> Callable:
def run_wrapper(orig_run: Callable[..., Any]) -> Callable[..., Any]:
@wraps(orig_run)
def runner(self: InjectableCeleryTask, *args, **kwargs) -> Any:
def runner(self: InjectableCeleryTask, *args: Any, **kwargs: Any) -> Any:
if not self.loaded:
self.load()

Expand Down
4 changes: 2 additions & 2 deletions src/magic_di/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class InjectorError(Exception): ...


class InjectionError(InjectorError):
def __init__(self, obj: Callable, signature: Signature):
def __init__(self, obj: Callable[..., Any], signature: Signature[Any]):
self.obj = obj
self.signature = signature

Expand Down Expand Up @@ -64,7 +64,7 @@ def _build_error_message(self) -> str:
return f"Failed to inspect {self.obj}. \n{object_location}\nSee the exception above"


def _get_object_source_location(obj: Callable) -> str:
def _get_object_source_location(obj: Callable[..., Any]) -> str:
try:
_, obj_line_number = inspect.getsourcelines(obj)
source_file = inspect.getsourcefile(obj)
Expand Down
2 changes: 1 addition & 1 deletion src/magic_di/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
"""

from ._app import inject_app
from ._provide import Provide, Provider
from ._provide import Provide, Provider # type: ignore[attr-defined]

__all__ = ("inject_app", "Provide", "Provider")
12 changes: 7 additions & 5 deletions src/magic_di/fastapi/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TYPE_CHECKING,
Annotated,
Any,
AsyncIterator,
Callable,
Iterator,
Protocol,
Expand All @@ -14,6 +15,7 @@
)

from fastapi.params import Depends

from magic_di._injector import DependencyInjector

if TYPE_CHECKING:
Expand Down Expand Up @@ -67,7 +69,7 @@ def inject_app(
"""
injector = injector or DependencyInjector()

def collect_deps():
def collect_deps() -> None:
_collect_dependencies(injector, app.router)

app.state.dependency_injector = injector
Expand All @@ -80,12 +82,12 @@ def collect_deps():
return app


def _inject_app_with_lifespan(app: FastAPI, collect_deps_fn: Callable) -> None:
def _inject_app_with_lifespan(app: FastAPI, collect_deps_fn: Callable[[], None]) -> None:
app_router: routing.APIRouter = app.router
app_lifespan = app_router.lifespan_context

@asynccontextmanager
async def injector_lifespan(app: FastAPI):
async def injector_lifespan(app: FastAPI) -> AsyncIterator[None]:
collect_deps_fn()

injector = app.state.dependency_injector
Expand All @@ -99,7 +101,7 @@ async def injector_lifespan(app: FastAPI):
app_router.lifespan_context = injector_lifespan


def _inject_app_with_events(app: FastAPI, collect_deps_fn: Callable) -> None:
def _inject_app_with_events(app: FastAPI, collect_deps_fn: Callable[[], None]) -> None:
app.on_event("startup")(collect_deps_fn)
app.on_event("startup")(app.state.dependency_injector.connect)
app.on_event("shutdown")(app.state.dependency_injector.disconnect)
Expand Down Expand Up @@ -141,7 +143,7 @@ def _inspect_and_lazy_inject(obj: object, injector: DependencyInjector) -> None:
injector.lazy_inject(dependency)


def _find_fastapi_dependencies(dependency: Callable) -> Iterator[Callable]:
def _find_fastapi_dependencies(dependency: Callable[..., Any]) -> Iterator[Callable[..., Any]]:
"""
Recursively finds all FastAPI dependencies.
It looks for FastAPI's Depends() in default arguments and in type annotations.
Expand Down
1 change: 1 addition & 0 deletions src/magic_di/fastapi/_provide.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Annotated, TypeVar

from fastapi import Depends, Request

from magic_di._injector import Injectable

if TYPE_CHECKING:
Expand Down
Loading

0 comments on commit 2bb8df8

Please sign in to comment.