diff --git a/src/magic_di/_container.py b/src/magic_di/_container.py index 8506f45..ddc8843 100644 --- a/src/magic_di/_container.py +++ b/src/magic_di/_container.py @@ -2,12 +2,13 @@ import functools import inspect -from collections.abc import Iterable from dataclasses import dataclass from threading import Lock from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast if TYPE_CHECKING: + from collections.abc import Iterable + from magic_di import ConnectableProtocol T = TypeVar("T") diff --git a/src/magic_di/_injector.py b/src/magic_di/_injector.py index 322e0d0..59878e0 100644 --- a/src/magic_di/_injector.py +++ b/src/magic_di/_injector.py @@ -2,7 +2,6 @@ import inspect import logging -from collections.abc import Callable, Iterable, Iterator from contextlib import contextmanager from threading import Lock from typing import ( @@ -25,6 +24,8 @@ from magic_di.exceptions import InjectionError, InspectionError if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator + from magic_di._connectable import ConnectableProtocol # flag to use in typing.Annotated diff --git a/src/magic_di/_utils.py b/src/magic_di/_utils.py index a4b8e56..63825f5 100644 --- a/src/magic_di/_utils.py +++ b/src/magic_di/_utils.py @@ -1,16 +1,16 @@ from __future__ import annotations -from typing import Any, TypeVar, Union, cast, get_args +from typing import Any, TypeVar, cast, get_args from typing import get_type_hints as _get_type_hints from magic_di import ConnectableProtocol -LegacyUnionType = type(Union[object, None]) +LegacyUnionType = type(object | None) try: from types import UnionType # type: ignore[import-error,unused-ignore] except ImportError: - UnionType = LegacyUnionType # type: ignore[misc, assignment] + UnionType = LegacyUnionType # type: ignore[misc] T = TypeVar("T") @@ -40,7 +40,7 @@ def get_cls_from_optional(cls: T) -> T: Returns: T: Extracted class """ - if not isinstance(cls, (UnionType, LegacyUnionType)): + if not isinstance(cls, UnionType | LegacyUnionType): return cls args = get_args(cls) diff --git a/src/magic_di/celery/_async_utils.py b/src/magic_di/celery/_async_utils.py index d835c1e..1a7fc22 100644 --- a/src/magic_di/celery/_async_utils.py +++ b/src/magic_di/celery/_async_utils.py @@ -2,9 +2,11 @@ import asyncio import threading -from collections.abc import Coroutine from dataclasses import dataclass -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + from collections.abc import Coroutine R = TypeVar("R") diff --git a/src/magic_di/celery/_loader.py b/src/magic_di/celery/_loader.py index 5257002..68e95a4 100644 --- a/src/magic_di/celery/_loader.py +++ b/src/magic_di/celery/_loader.py @@ -3,7 +3,6 @@ import asyncio import os import threading -from collections.abc import Callable from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from celery import signals @@ -13,6 +12,8 @@ from magic_di.celery._async_utils import EventLoop, EventLoopGetter, run_in_event_loop if TYPE_CHECKING: + from collections.abc import Callable + from celery.loaders.base import BaseLoader diff --git a/src/magic_di/celery/_task.py b/src/magic_di/celery/_task.py index 62c7389..1ec1d3f 100644 --- a/src/magic_di/celery/_task.py +++ b/src/magic_di/celery/_task.py @@ -1,9 +1,8 @@ from __future__ import annotations import inspect -from collections.abc import Callable from functools import wraps -from typing import Any, cast, get_type_hints +from typing import TYPE_CHECKING, Any, cast, get_type_hints from celery.app.task import Task @@ -11,6 +10,9 @@ from magic_di.celery._async_utils import EventLoop, run_in_event_loop from magic_di.celery._loader import InjectedCeleryLoaderProtocol +if TYPE_CHECKING: + from collections.abc import Callable + class BaseCeleryConnectableDeps(Connectable): ... diff --git a/src/magic_di/fastapi/_app.py b/src/magic_di/fastapi/_app.py index a8ff81e..cd20d92 100644 --- a/src/magic_di/fastapi/_app.py +++ b/src/magic_di/fastapi/_app.py @@ -1,7 +1,6 @@ from __future__ import annotations import inspect -from collections.abc import AsyncIterator, Callable, Iterator from contextlib import asynccontextmanager from typing import ( TYPE_CHECKING, @@ -17,6 +16,8 @@ from magic_di._injector import DependencyInjector if TYPE_CHECKING: + from collections.abc import AsyncIterator, Callable, Iterator + from fastapi import FastAPI, routing diff --git a/src/magic_di/utils.py b/src/magic_di/utils.py index f893efa..72c481f 100644 --- a/src/magic_di/utils.py +++ b/src/magic_di/utils.py @@ -2,13 +2,12 @@ import asyncio import inspect -from collections.abc import Awaitable from typing import TYPE_CHECKING, TypeVar, overload from magic_di import DependencyInjector if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Awaitable, Callable T = TypeVar("T")