From af6f92f60dabe62f980dbb74f192185b0b35a38f Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 4 Dec 2024 21:36:12 +0300 Subject: [PATCH] lint: fix mypy a bit --- faststream/_internal/_compat.py | 1 + faststream/_internal/cli/main.py | 28 ++++--- faststream/_internal/fastapi/router.py | 2 +- faststream/_internal/proto.py | 28 ++++++- faststream/_internal/publisher/proto.py | 34 ++++---- faststream/_internal/publisher/specified.py | 10 +-- faststream/_internal/publisher/usecase.py | 21 +++-- .../_internal/state/logger/logger_proxy.py | 4 + faststream/_internal/subscriber/call_item.py | 6 +- .../{call_wrapper/call.py => call_wrapper.py} | 7 +- .../subscriber/call_wrapper/__init__.py | 0 .../subscriber/call_wrapper/proto.py | 79 ------------------- faststream/_internal/subscriber/mixins.py | 2 +- faststream/_internal/subscriber/proto.py | 11 +-- faststream/_internal/subscriber/specified.py | 8 +- faststream/_internal/subscriber/usecase.py | 79 ++++++++++++++----- faststream/_internal/subscriber/utils.py | 10 ++- faststream/_internal/types.py | 9 ++- faststream/_internal/utils/data.py | 5 +- faststream/_internal/utils/functions.py | 20 ++--- faststream/asgi/app.py | 2 + faststream/confluent/client.py | 22 ++++-- faststream/confluent/subscriber/usecase.py | 17 ++-- faststream/confluent/testing.py | 2 +- faststream/kafka/testing.py | 2 +- faststream/redis/subscriber/usecase.py | 22 ++---- faststream/specification/proto/endpoint.py | 7 +- tests/cli/test_run_asgi.py | 15 +++- 28 files changed, 217 insertions(+), 236 deletions(-) rename faststream/_internal/subscriber/{call_wrapper/call.py => call_wrapper.py} (97%) delete mode 100644 faststream/_internal/subscriber/call_wrapper/__init__.py delete mode 100644 faststream/_internal/subscriber/call_wrapper/proto.py diff --git a/faststream/_internal/_compat.py b/faststream/_internal/_compat.py index ba38326ac9..b445e336b7 100644 --- a/faststream/_internal/_compat.py +++ b/faststream/_internal/_compat.py @@ -33,6 +33,7 @@ __all__ = ( "HAS_TYPER", "PYDANTIC_V2", + "BaseModel", "CoreSchema", "EmailStr", "GetJsonSchemaHandler", diff --git a/faststream/_internal/cli/main.py b/faststream/_internal/cli/main.py index 95fb8037c6..17ccbd61f7 100644 --- a/faststream/_internal/cli/main.py +++ b/faststream/_internal/cli/main.py @@ -2,13 +2,14 @@ import sys import warnings from contextlib import suppress -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, cast import anyio import typer from faststream import FastStream from faststream.__about__ import __version__ +from faststream._internal._compat import json_loads from faststream._internal.application import Application from faststream._internal.cli.docs import asyncapi_app from faststream._internal.cli.utils.imports import import_from_string @@ -122,6 +123,7 @@ def run( # Should be imported after sys.path changes module_path, app_obj = import_from_string(app, is_factory=is_factory) + app_obj = cast(Application, app_obj) args = (app, extra, is_factory, casted_log_level) @@ -160,7 +162,7 @@ def run( workers=workers, ).run() else: - args[1]["workers"] = workers + args[1]["workers"] = str(workers) _run(*args) else: @@ -181,6 +183,7 @@ def _run( ) -> None: """Runs the specified application.""" _, app_obj = import_from_string(app, is_factory=is_factory) + app_obj = cast(Application, app_obj) _run_imported_app( app_obj, extra_options=extra_options, @@ -235,7 +238,7 @@ def publish( ), message: str = typer.Argument( ..., - help="Message to be published.", + help="JSON Message string to publish.", ), rpc: bool = typer.Option( False, @@ -255,9 +258,9 @@ def publish( """ app, extra = parse_cli_args(app, *ctx.args) - extra["message"] = message - if "timeout" in extra: - extra["timeout"] = float(extra["timeout"]) + publish_extra: AnyDict = extra.copy() + if "timeout" in publish_extra: + publish_extra["timeout"] = float(publish_extra["timeout"]) try: _, app_obj = import_from_string(app, is_factory=is_factory) @@ -269,7 +272,7 @@ def publish( raise ValueError(msg) app_obj._setup() - result = anyio.run(publish_message, app_obj.broker, rpc, extra) + result = anyio.run(publish_message, app_obj.broker, rpc, message, publish_extra) if rpc: typer.echo(result) @@ -282,13 +285,18 @@ def publish( async def publish_message( broker: "BrokerUsecase[Any, Any]", rpc: bool, + message: str, extra: "AnyDict", ) -> Any: + with suppress(Exception): + message = json_loads(message) + try: async with broker: if rpc: - return await broker.request(**extra) - return await broker.publish(**extra) + return await broker.request(message, **extra) # type: ignore[call-arg] + return await broker.publish(message, **extra) # type: ignore[call-arg] + except Exception as e: - typer.echo(f"Error when broker was publishing: {e}") + typer.echo(f"Error when broker was publishing: {e!r}") sys.exit(1) diff --git a/faststream/_internal/fastapi/router.py b/faststream/_internal/fastapi/router.py index e0bb4b2d9b..7d99dae680 100644 --- a/faststream/_internal/fastapi/router.py +++ b/faststream/_internal/fastapi/router.py @@ -51,7 +51,7 @@ from faststream._internal.broker.broker import BrokerUsecase from faststream._internal.proto import NameRequired from faststream._internal.publisher.proto import PublisherProto - from faststream._internal.subscriber.call_wrapper.call import HandlerCallWrapper + from faststream._internal.subscriber.call_wrapper import HandlerCallWrapper from faststream._internal.types import BrokerMiddleware from faststream.message import StreamMessage from faststream.specification.base.specification import Specification diff --git a/faststream/_internal/proto.py b/faststream/_internal/proto.py index 615dec872b..9eb8ed33aa 100644 --- a/faststream/_internal/proto.py +++ b/faststream/_internal/proto.py @@ -1,8 +1,32 @@ from abc import abstractmethod -from typing import Any, Optional, Protocol, TypeVar, Union, overload +from typing import Any, Callable, Optional, Protocol, TypeVar, Union, overload +from faststream._internal.subscriber.call_wrapper import ( + HandlerCallWrapper, + ensure_call_wrapper, +) +from faststream._internal.types import ( + MsgType, + P_HandlerParams, + T_HandlerReturn, +) -class Endpoint(Protocol): + +class EndpointWrapper(Protocol[MsgType]): + def __call__( + self, + func: Union[ + Callable[P_HandlerParams, T_HandlerReturn], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ], + ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: + handler: HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn] = ( + ensure_call_wrapper(func) + ) + return handler + + +class Endpoint(EndpointWrapper[MsgType]): @abstractmethod def add_prefix(self, prefix: str) -> None: ... diff --git a/faststream/_internal/publisher/proto.py b/faststream/_internal/publisher/proto.py index 93e83efe94..f83be6188c 100644 --- a/faststream/_internal/publisher/proto.py +++ b/faststream/_internal/publisher/proto.py @@ -1,11 +1,16 @@ from abc import abstractmethod from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Protocol - -from typing_extensions import override +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Protocol, +) from faststream._internal.proto import Endpoint -from faststream._internal.types import MsgType +from faststream._internal.types import ( + MsgType, +) from faststream.response.response import PublishCommand if TYPE_CHECKING: @@ -14,9 +19,7 @@ from faststream._internal.types import ( AsyncCallable, BrokerMiddleware, - P_HandlerParams, PublisherMiddleware, - T_HandlerReturn, ) from faststream.response.response import PublishCommand @@ -88,28 +91,23 @@ async def request( class PublisherProto( - Endpoint, + Endpoint[MsgType], BasePublisherProto, - Generic[MsgType], ): _broker_middlewares: Sequence["BrokerMiddleware[MsgType]"] _middlewares: Sequence["PublisherMiddleware"] - _producer: Optional["ProducerProto"] + + @property + @abstractmethod + def _producer(self) -> "ProducerProto": ... @abstractmethod def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: ... - @override @abstractmethod - def _setup( # type: ignore[override] + def _setup( self, *, - producer: Optional["ProducerProto"], state: "Pointer[BrokerState]", + producer: "ProducerProto", ) -> None: ... - - @abstractmethod - def __call__( - self, - func: "Callable[P_HandlerParams, T_HandlerReturn]", - ) -> "Callable[P_HandlerParams, T_HandlerReturn]": ... diff --git a/faststream/_internal/publisher/specified.py b/faststream/_internal/publisher/specified.py index a6e34a163b..db9c64a974 100644 --- a/faststream/_internal/publisher/specified.py +++ b/faststream/_internal/publisher/specified.py @@ -17,10 +17,10 @@ if TYPE_CHECKING: from faststream._internal.basic_types import AnyCallable, AnyDict from faststream._internal.state import BrokerState, Pointer - from faststream._internal.subscriber.call_wrapper.call import HandlerCallWrapper + from faststream._internal.subscriber.call_wrapper import HandlerCallWrapper -class SpecificationPublisher(EndpointSpecification[PublisherSpec]): +class SpecificationPublisher(EndpointSpecification[MsgType, PublisherSpec]): """A base class for publishers in an asynchronous API.""" _state: "Pointer[BrokerState]" # should be set in next parent @@ -44,9 +44,9 @@ def __call__( "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", ], ) -> "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]": - func = super().__call__(func) - self.calls.append(func._original_call) - return func + handler = super().__call__(func) + self.calls.append(handler._original_call) + return handler def get_payloads(self) -> list[tuple["AnyDict", str]]: payloads: list[tuple[AnyDict, str]] = [] diff --git a/faststream/_internal/publisher/usecase.py b/faststream/_internal/publisher/usecase.py index 08094ca56f..662430fbbb 100644 --- a/faststream/_internal/publisher/usecase.py +++ b/faststream/_internal/publisher/usecase.py @@ -1,4 +1,4 @@ -from collections.abc import Awaitable, Iterable +from collections.abc import Awaitable, Iterable, Sequence from functools import partial from itertools import chain from typing import ( @@ -15,9 +15,8 @@ from faststream._internal.publisher.proto import PublisherProto from faststream._internal.state import BrokerState, EmptyBrokerState, Pointer from faststream._internal.state.producer import ProducerUnset -from faststream._internal.subscriber.call_wrapper.call import ( +from faststream._internal.subscriber.call_wrapper import ( HandlerCallWrapper, - ensure_call_wrapper, ) from faststream._internal.subscriber.utils import process_msg from faststream._internal.types import ( @@ -42,8 +41,8 @@ class PublisherUsecase(PublisherProto[MsgType]): def __init__( self, *, - broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], - middlewares: Iterable["PublisherMiddleware"], + broker_middlewares: Sequence["BrokerMiddleware[MsgType]"], + middlewares: Sequence["PublisherMiddleware"], ) -> None: self.middlewares = middlewares self._broker_middlewares = broker_middlewares @@ -65,7 +64,7 @@ def _producer(self) -> "ProducerProto": return self.__producer or self._state.get().producer @override - def _setup( # type: ignore[override] + def _setup( self, *, state: "Pointer[BrokerState]", @@ -97,9 +96,7 @@ def __call__( ], ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: """Decorate user's function by current publisher.""" - handler: HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn] = ( - ensure_call_wrapper(func) - ) + handler = super().__call__(func) handler._publishers.append(self) return handler @@ -125,7 +122,7 @@ async def _basic_publish( ): pub = partial(pub_m, pub) - await pub(cmd) + return await pub(cmd) async def _basic_request( self, @@ -163,7 +160,7 @@ async def _basic_publish_batch( cmd: "PublishCommand", *, _extra_middlewares: Iterable["PublisherMiddleware"], - ) -> Optional[Any]: + ) -> Any: pub = self._producer.publish_batch context = self._state.get().di_state.context @@ -180,4 +177,4 @@ async def _basic_publish_batch( ): pub = partial(pub_m, pub) - await pub(cmd) + return await pub(cmd) diff --git a/faststream/_internal/state/logger/logger_proxy.py b/faststream/_internal/state/logger/logger_proxy.py index 690a42c6dd..0693fd184e 100644 --- a/faststream/_internal/state/logger/logger_proxy.py +++ b/faststream/_internal/state/logger/logger_proxy.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from collections.abc import Mapping from typing import Any, Optional @@ -8,6 +9,7 @@ class LoggerObject(LoggerProto): logger: Optional["LoggerProto"] + @abstractmethod def __bool__(self) -> bool: ... @@ -73,6 +75,8 @@ class RealLoggerObject(LoggerObject): or in default logger case (.params_storage.DefaultLoggerStorage). """ + logger: "LoggerProto" + def __init__(self, logger: "LoggerProto") -> None: self.logger = logger diff --git a/faststream/_internal/subscriber/call_item.py b/faststream/_internal/subscriber/call_item.py index 3ab28c58c7..bd8c7a4e49 100644 --- a/faststream/_internal/subscriber/call_item.py +++ b/faststream/_internal/subscriber/call_item.py @@ -21,7 +21,7 @@ from faststream._internal.basic_types import AsyncFuncAny, Decorator from faststream._internal.state import BrokerState, Pointer - from faststream._internal.subscriber.call_wrapper.call import HandlerCallWrapper + from faststream._internal.subscriber.call_wrapper import HandlerCallWrapper from faststream._internal.types import ( AsyncCallable, AsyncFilter, @@ -128,8 +128,8 @@ async def is_suitable( if not (parser := cast(Optional["AsyncCallable"], self.item_parser)) or not ( decoder := cast(Optional["AsyncCallable"], self.item_decoder) ): - msg = "You should setup `HandlerItem` at first." - raise SetupError(msg) + error_msg = "You should setup `HandlerItem` at first." + raise SetupError(error_msg) message = cache[parser] = cast( "StreamMessage[MsgType]", diff --git a/faststream/_internal/subscriber/call_wrapper/call.py b/faststream/_internal/subscriber/call_wrapper.py similarity index 97% rename from faststream/_internal/subscriber/call_wrapper/call.py rename to faststream/_internal/subscriber/call_wrapper.py index 14d081b52f..dfe1b45dad 100644 --- a/faststream/_internal/subscriber/call_wrapper/call.py +++ b/faststream/_internal/subscriber/call_wrapper.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: from fast_depends.dependencies import Dependant - from fast_depends.use import InjectWrapper from faststream._internal.basic_types import Decorator from faststream._internal.publisher.proto import PublisherProto @@ -88,7 +87,7 @@ def __call__( async def call_wrapped( self, message: "StreamMessage[MsgType]", - ) -> Awaitable[Any]: + ) -> Any: """Calls the wrapped function with the given message.""" assert self._wrapped_call, "You should use `set_wrapped` first" # nosec B101 if self.is_test: @@ -145,7 +144,7 @@ def refresh(self, with_mock: bool = False) -> None: def set_wrapped( self, *, - dependencies: Iterable["Dependant"], + dependencies: Sequence["Dependant"], _call_decorators: Iterable["Decorator"], state: "DIState", ) -> Optional["CallModel"]: @@ -166,7 +165,7 @@ def set_wrapped( ) if state.use_fastdepends: - wrapper: InjectWrapper[Any, Any] = inject( + wrapper = inject( func=None, context__=state.context, ) diff --git a/faststream/_internal/subscriber/call_wrapper/__init__.py b/faststream/_internal/subscriber/call_wrapper/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/faststream/_internal/subscriber/call_wrapper/proto.py b/faststream/_internal/subscriber/call_wrapper/proto.py deleted file mode 100644 index fdaf8eb812..0000000000 --- a/faststream/_internal/subscriber/call_wrapper/proto.py +++ /dev/null @@ -1,79 +0,0 @@ -from collections.abc import Iterable, Sequence -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Optional, - Protocol, - Union, - overload, -) - -from faststream._internal.types import ( - CustomCallable, - Filter, - MsgType, - P_HandlerParams, - SubscriberMiddleware, - T_HandlerReturn, -) - -if TYPE_CHECKING: - from fast_depends.dependencies import Dependant - - from .call import HandlerCallWrapper - - -class WrapperProto(Protocol[MsgType]): - """Annotation class to represent @subscriber return type.""" - - @overload - def __call__( - self, - func: None = None, - *, - filter: Optional["Filter[Any]"] = None, - parser: Optional["CustomCallable"] = None, - decoder: Optional["CustomCallable"] = None, - middlewares: Sequence["SubscriberMiddleware[Any]"] = (), - dependencies: Iterable["Dependant"] = (), - ) -> Callable[ - [Callable[P_HandlerParams, T_HandlerReturn]], - "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", - ]: ... - - @overload - def __call__( - self, - func: Union[ - Callable[P_HandlerParams, T_HandlerReturn], - "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", - ], - *, - filter: Optional["Filter[Any]"] = None, - parser: Optional["CustomCallable"] = None, - decoder: Optional["CustomCallable"] = None, - middlewares: Sequence["SubscriberMiddleware[Any]"] = (), - dependencies: Iterable["Dependant"] = (), - ) -> "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]": ... - - def __call__( - self, - func: Union[ - Callable[P_HandlerParams, T_HandlerReturn], - "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", - None, - ] = None, - *, - filter: Optional["Filter[Any]"] = None, - parser: Optional["CustomCallable"] = None, - decoder: Optional["CustomCallable"] = None, - middlewares: Sequence["SubscriberMiddleware[Any]"] = (), - dependencies: Iterable["Dependant"] = (), - ) -> Union[ - "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", - Callable[ - [Callable[P_HandlerParams, T_HandlerReturn]], - "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", - ], - ]: ... diff --git a/faststream/_internal/subscriber/mixins.py b/faststream/_internal/subscriber/mixins.py index c76887b757..af96a87a41 100644 --- a/faststream/_internal/subscriber/mixins.py +++ b/faststream/_internal/subscriber/mixins.py @@ -27,7 +27,7 @@ async def close(self) -> None: if not task.done(): task.cancel() - self.tasks = [] + self.tasks.clear() class ConcurrentMixin(TasksMixin): diff --git a/faststream/_internal/subscriber/proto.py b/faststream/_internal/subscriber/proto.py index 1e8a7ce988..cb24b32295 100644 --- a/faststream/_internal/subscriber/proto.py +++ b/faststream/_internal/subscriber/proto.py @@ -2,10 +2,9 @@ from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any, Optional -from typing_extensions import Self, override +from typing_extensions import Self from faststream._internal.proto import Endpoint -from faststream._internal.subscriber.call_wrapper.proto import WrapperProto from faststream._internal.types import MsgType if TYPE_CHECKING: @@ -28,10 +27,7 @@ from .call_item import HandlerItem -class SubscriberProto( - Endpoint, - WrapperProto[MsgType], -): +class SubscriberProto(Endpoint[MsgType]): calls: list["HandlerItem[MsgType]"] running: bool @@ -49,9 +45,8 @@ def get_log_context( /, ) -> dict[str, str]: ... - @override @abstractmethod - def _setup( # type: ignore[override] + def _setup( self, *, extra_context: "AnyDict", diff --git a/faststream/_internal/subscriber/specified.py b/faststream/_internal/subscriber/specified.py index 3af87b590c..50c36efbf6 100644 --- a/faststream/_internal/subscriber/specified.py +++ b/faststream/_internal/subscriber/specified.py @@ -4,6 +4,7 @@ Optional, ) +from faststream._internal.types import MsgType from faststream.exceptions import SetupError from faststream.specification.asyncapi.message import parse_handler_params from faststream.specification.asyncapi.utils import to_camelcase @@ -12,16 +13,11 @@ if TYPE_CHECKING: from faststream._internal.basic_types import AnyDict - from faststream._internal.types import ( - MsgType, - ) from .call_item import HandlerItem -class SpecificationSubscriber( - EndpointSpecification[SubscriberSpec], -): +class SpecificationSubscriber(EndpointSpecification[MsgType, SubscriberSpec]): calls: list["HandlerItem[MsgType]"] def __init__( diff --git a/faststream/_internal/subscriber/usecase.py b/faststream/_internal/subscriber/usecase.py index ba7d28e6c9..26ba99000f 100644 --- a/faststream/_internal/subscriber/usecase.py +++ b/faststream/_internal/subscriber/usecase.py @@ -7,15 +7,12 @@ Any, Callable, Optional, + Union, ) -from typing_extensions import Self, override +from typing_extensions import Self, overload, override from faststream._internal.subscriber.call_item import HandlerItem -from faststream._internal.subscriber.call_wrapper.call import ( - HandlerCallWrapper, - ensure_call_wrapper, -) from faststream._internal.subscriber.proto import SubscriberProto from faststream._internal.subscriber.utils import ( MultiLock, @@ -42,8 +39,10 @@ BasePublisherProto, ) from faststream._internal.state import BrokerState, Pointer + from faststream._internal.subscriber.call_wrapper import HandlerCallWrapper from faststream._internal.types import ( AsyncCallable, + AsyncFilter, BrokerMiddleware, CustomCallable, Filter, @@ -124,7 +123,7 @@ def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: self._broker_middlewares = (*self._broker_middlewares, middleware) @override - def _setup( # type: ignore[override] + def _setup( self, *, extra_context: "AnyDict", @@ -196,35 +195,75 @@ def add_call( ) return self + @overload + def __call__( + self, + func: None = None, + *, + filter: "Filter[StreamMessage[MsgType]]" = default_filter, + parser: Optional["CustomCallable"] = None, + decoder: Optional["CustomCallable"] = None, + middlewares: Sequence["SubscriberMiddleware[Any]"] = (), + dependencies: Iterable["Dependant"] = (), + ) -> Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", + ]: ... + + @overload + def __call__( + self, + func: Union[ + Callable[P_HandlerParams, T_HandlerReturn], + "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", + ], + *, + filter: "Filter[StreamMessage[MsgType]]" = default_filter, + parser: Optional["CustomCallable"] = None, + decoder: Optional["CustomCallable"] = None, + middlewares: Sequence["SubscriberMiddleware[Any]"] = (), + dependencies: Iterable["Dependant"] = (), + ) -> "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]": ... + @override def __call__( self, - func: Optional[Callable[P_HandlerParams, T_HandlerReturn]] = None, + func: Union[ + Callable[P_HandlerParams, T_HandlerReturn], + "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", + None, + ] = None, *, - filter: "Filter[Any]" = default_filter, + filter: "Filter[StreamMessage[MsgType]]" = default_filter, parser: Optional["CustomCallable"] = None, decoder: Optional["CustomCallable"] = None, middlewares: Sequence["SubscriberMiddleware[Any]"] = (), dependencies: Iterable["Dependant"] = (), - ) -> Any: + ) -> Union[ + "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", + Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", + ], + ]: if (options := self._call_options) is None: msg = ( "You can't create subscriber directly. Please, use `add_call` at first." ) - raise SetupError( - msg, - ) + raise SetupError(msg) total_deps = (*options.dependencies, *dependencies) total_middlewares = (*options.middlewares, *middlewares) - async_filter = to_async(filter) + async_filter: AsyncFilter[StreamMessage[MsgType]] = to_async(filter) def real_wrapper( - func: Callable[P_HandlerParams, T_HandlerReturn], + func: Union[ + Callable[P_HandlerParams, T_HandlerReturn], + "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", + ], ) -> "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]": - handler: HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn] = ( - ensure_call_wrapper(func) - ) + handler = super(SubscriberUsecase, self).__call__(func) + self.calls.append( HandlerItem[MsgType]( handler=handler, @@ -342,13 +381,13 @@ async def process_message(self, msg: MsgType) -> "Response": if parsing_error: raise parsing_error - msg = f"There is no suitable handler for {msg=}" - raise SubscriberNotFound(msg) + error_msg = f"There is no suitable handler for {msg=}" + raise SubscriberNotFound(error_msg) # An error was raised and processed by some middleware return ensure_response(None) - def __build__middlewares_stack(self) -> tuple["BaseMiddleware", ...]: + def __build__middlewares_stack(self) -> tuple["BrokerMiddleware[MsgType]", ...]: logger_state = self._state.get().logger_state if self.ack_policy is AckPolicy.DO_NOTHING: diff --git a/faststream/_internal/subscriber/utils.py b/faststream/_internal/subscriber/utils.py index 31f2c5358d..9195087e2f 100644 --- a/faststream/_internal/subscriber/utils.py +++ b/faststream/_internal/subscriber/utils.py @@ -56,8 +56,8 @@ async def process_msg( parsed_msg.set_decoder(decoder) return await return_msg(parsed_msg) - msg = "unreachable" - raise AssertionError(msg) + error_msg = "unreachable" + raise AssertionError(error_msg) async def default_filter(msg: "StreamMessage[Any]") -> bool: @@ -66,7 +66,11 @@ async def default_filter(msg: "StreamMessage[Any]") -> bool: class MultiLock: - """A class representing a multi lock.""" + """A class representing a multi lock. + + This lock can be acquired multiple times. + `wait_release` method waits for all locks will be released. + """ def __init__(self) -> None: """Initialize a new instance of the class.""" diff --git a/faststream/_internal/types.py b/faststream/_internal/types.py index ea1ecd3dbf..b90aad9cd1 100644 --- a/faststream/_internal/types.py +++ b/faststream/_internal/types.py @@ -17,6 +17,7 @@ from faststream.response.response import PublishCommand MsgType = TypeVar("MsgType") +Msg_contra = TypeVar("Msg_contra", contravariant=True) StreamMsg = TypeVar("StreamMsg", bound=StreamMessage[Any]) ConnectionType = TypeVar("ConnectionType") @@ -63,12 +64,12 @@ ] -class BrokerMiddleware(Protocol[MsgType]): +class BrokerMiddleware(Protocol[Msg_contra]): """Middleware builder interface.""" def __call__( self, - msg: Optional[MsgType], + msg: Optional[Msg_contra], /, *, context: ContextRepo, @@ -86,6 +87,6 @@ class PublisherMiddleware(Protocol): def __call__( self, - call_next: Callable[[PublishCommand], Awaitable[PublishCommand]], - msg: PublishCommand, + call_next: Callable[[PublishCommand], Awaitable[Any]], + cmd: PublishCommand, ) -> Any: ... diff --git a/faststream/_internal/utils/data.py b/faststream/_internal/utils/data.py index 98e3729fac..8f8a133636 100644 --- a/faststream/_internal/utils/data.py +++ b/faststream/_internal/utils/data.py @@ -20,4 +20,7 @@ def filter_by_dict( else: extra_data[k] = v - return typed_dict(out_data), extra_data + return ( + typed_dict(out_data), # type: ignore[call-arg] + extra_data, + ) diff --git a/faststream/_internal/utils/functions.py b/faststream/_internal/utils/functions.py index be90e6f0a2..d81201fcf0 100644 --- a/faststream/_internal/utils/functions.py +++ b/faststream/_internal/utils/functions.py @@ -1,16 +1,15 @@ from collections.abc import AsyncIterator, Awaitable, Iterator -from contextlib import AbstractContextManager, asynccontextmanager, contextmanager +from contextlib import asynccontextmanager, contextmanager from functools import wraps from typing import ( Any, Callable, - Optional, TypeVar, Union, + cast, overload, ) -import anyio from fast_depends.core import CallModel from fast_depends.utils import ( is_coroutine_callable, @@ -25,7 +24,6 @@ "call_or_await", "drop_response_type", "fake_context", - "timeout_scope", "to_async", ) @@ -53,7 +51,9 @@ def to_async( ) -> Callable[F_Spec, Awaitable[F_Return]]: """Converts a synchronous function to an asynchronous function.""" if is_coroutine_callable(func): - return func + return cast(Callable[F_Spec, Awaitable[F_Return]], func) + + func = cast(Callable[F_Spec, F_Return], func) @wraps(func) async def to_async_wrapper(*args: F_Spec.args, **kwargs: F_Spec.kwargs) -> F_Return: @@ -63,16 +63,6 @@ async def to_async_wrapper(*args: F_Spec.args, **kwargs: F_Spec.kwargs) -> F_Ret return to_async_wrapper -def timeout_scope( - timeout: Optional[float] = 30, - raise_timeout: bool = False, -) -> AbstractContextManager[anyio.CancelScope]: - scope: Callable[[Optional[float]], AbstractContextManager[anyio.CancelScope]] - scope = anyio.fail_after if raise_timeout else anyio.move_on_after - - return scope(timeout) - - @asynccontextmanager async def fake_context(*args: Any, **kwargs: Any) -> AsyncIterator[None]: yield None diff --git a/faststream/asgi/app.py b/faststream/asgi/app.py index 02fe0cfb78..c815454401 100644 --- a/faststream/asgi/app.py +++ b/faststream/asgi/app.py @@ -183,6 +183,8 @@ def load(self) -> "ASGIApp": elif port is not None: bindings.append(f"127.0.0.1:{port}") + run_extra_options["workers"] = int(run_extra_options.pop("workers", 1)) + bind = run_extra_options.get("bind") if isinstance(bind, list): bindings.extend(bind) diff --git a/faststream/confluent/client.py b/faststream/confluent/client.py index bd1736ddb5..3c1cc2b8cf 100644 --- a/faststream/confluent/client.py +++ b/faststream/confluent/client.py @@ -311,6 +311,9 @@ def __init__( self.config = final_config self.consumer = Consumer(final_config) + # We can't read and close consumer concurrently + self._lock = asyncio.Lock() + @property def topics_to_create(self) -> list[str]: return list({*self.topics, *(p.topic for p in self.partitions)}) @@ -368,12 +371,14 @@ async def stop(self) -> None: exc_info=e, ) - # Wrap calls to async to make method cancelable by timeout - await call_or_await(self.consumer.close) + async with self._lock: + # Wrap calls to async to make method cancelable by timeout + await call_or_await(self.consumer.close) async def getone(self, timeout: float = 0.1) -> Optional[Message]: """Consumes a single message from Kafka.""" - msg = await call_or_await(self.consumer.poll, timeout) + async with self._lock: + msg = await call_or_await(self.consumer.poll, timeout) return check_msg_error(msg) async def getmany( @@ -382,11 +387,12 @@ async def getmany( max_records: Optional[int] = 10, ) -> tuple[Message, ...]: """Consumes a batch of messages from Kafka and groups them by topic and partition.""" - raw_messages: list[Optional[Message]] = await call_or_await( - self.consumer.consume, # type: ignore[arg-type] - num_messages=max_records or 10, - timeout=timeout, - ) + async with self._lock: + raw_messages: list[Optional[Message]] = await call_or_await( + self.consumer.consume, # type: ignore[arg-type] + num_messages=max_records or 10, + timeout=timeout, + ) return tuple(x for x in map(check_msg_error, raw_messages) if x is not None) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 1c46aec7fb..5d44941a22 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -1,5 +1,4 @@ -import asyncio -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Iterable, Sequence from typing import ( TYPE_CHECKING, @@ -12,6 +11,7 @@ from confluent_kafka import KafkaException, Message from typing_extensions import override +from faststream._internal.subscriber.mixins import TasksMixin from faststream._internal.subscriber.usecase import SubscriberUsecase from faststream._internal.subscriber.utils import process_msg from faststream._internal.types import MsgType @@ -35,7 +35,7 @@ from faststream.message import StreamMessage -class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): +class LogicSubscriber(TasksMixin, SubscriberUsecase[MsgType]): """A class to handle logic for consuming messages from Kafka.""" topics: Sequence[str] @@ -45,7 +45,6 @@ class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): consumer: Optional["AsyncConfluentConsumer"] parser: AsyncConfluentParser - task: Optional["asyncio.Task[None]"] client_id: Optional[str] def __init__( @@ -81,7 +80,6 @@ def __init__( self.partitions = partitions self.consumer = None - self.task = None self.polling_interval = polling_interval # Setup it later @@ -130,19 +128,14 @@ async def start(self) -> None: await super().start() if self.calls: - self.task = asyncio.create_task(self._consume()) + self.add_task(self._consume()) async def close(self) -> None: - await super().close() - if self.consumer is not None: await self.consumer.stop() self.consumer = None - if self.task is not None and not self.task.done(): - self.task.cancel() - - self.task = None + await super().close() @override async def get_one( diff --git a/faststream/confluent/testing.py b/faststream/confluent/testing.py index 92676d6e7a..5f4d8711a8 100644 --- a/faststream/confluent/testing.py +++ b/faststream/confluent/testing.py @@ -54,7 +54,7 @@ async def _fake_connect( # type: ignore[override] @staticmethod def create_publisher_fake_subscriber( broker: KafkaBroker, - publisher: "SpecificationPublisher[Any]", + publisher: "SpecificationPublisher[Any, Any]", ) -> tuple["LogicSubscriber[Any]", bool]: sub: Optional[LogicSubscriber[Any]] = None for handler in broker._subscribers: diff --git a/faststream/kafka/testing.py b/faststream/kafka/testing.py index 96e0614183..fc442f9ee6 100755 --- a/faststream/kafka/testing.py +++ b/faststream/kafka/testing.py @@ -56,7 +56,7 @@ async def _fake_connect( # type: ignore[override] @staticmethod def create_publisher_fake_subscriber( broker: KafkaBroker, - publisher: "SpecificationPublisher[Any]", + publisher: "SpecificationPublisher[Any, Any]", ) -> tuple["LogicSubscriber[Any]", bool]: sub: Optional[LogicSubscriber[Any]] = None for handler in broker._subscribers: diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index c54559782a..f1c9bb880f 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -1,4 +1,3 @@ -import asyncio import math from abc import abstractmethod from collections.abc import Awaitable, Iterable, Sequence @@ -19,6 +18,7 @@ from redis.exceptions import ResponseError from typing_extensions import TypeAlias, override +from faststream._internal.subscriber.mixins import TasksMixin from faststream._internal.subscriber.usecase import SubscriberUsecase from faststream._internal.subscriber.utils import process_msg from faststream.middlewares import AckPolicy @@ -61,7 +61,7 @@ Offset: TypeAlias = bytes -class LogicSubscriber(SubscriberUsecase[UnifyRedisDict]): +class LogicSubscriber(TasksMixin, SubscriberUsecase[UnifyRedisDict]): """A class to represent a Redis handler.""" _client: Optional["Redis[bytes]"] @@ -88,7 +88,6 @@ def __init__( ) self._client = None - self.task: Optional[asyncio.Task[None]] = None @override def _setup( # type: ignore[override] @@ -128,7 +127,7 @@ async def start( self, *args: Any, ) -> None: - if self.task: + if self.tasks: return await super().start() @@ -136,9 +135,7 @@ async def start( start_signal = anyio.Event() if self.calls: - self.task = asyncio.create_task( - self._consume(*args, start_signal=start_signal), - ) + self.add_task(self._consume(*args, start_signal=start_signal)) with anyio.fail_after(3.0): await start_signal.wait() @@ -171,13 +168,6 @@ async def _consume(self, *args: Any, start_signal: anyio.Event) -> None: async def _get_msgs(self, *args: Any) -> None: raise NotImplementedError - async def close(self) -> None: - await super().close() - - if self.task is not None and not self.task.done(): - self.task.cancel() - self.task = None - @staticmethod def build_log_context( message: Optional["BrokerStreamMessage[Any]"], @@ -351,7 +341,7 @@ async def _consume( # type: ignore[override] @override async def start(self) -> None: - if self.task: + if self.tasks: return assert self._client, "You should setup subscriber at first." # nosec B101 @@ -523,7 +513,7 @@ def get_log_context( @override async def start(self) -> None: - if self.task: + if self.tasks: return assert self._client, "You should setup subscriber at first." # nosec B101 diff --git a/faststream/specification/proto/endpoint.py b/faststream/specification/proto/endpoint.py index 380acb1071..b0991d43f8 100644 --- a/faststream/specification/proto/endpoint.py +++ b/faststream/specification/proto/endpoint.py @@ -1,10 +1,13 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Generic, Optional, TypeVar +from faststream._internal.proto import EndpointWrapper +from faststream._internal.types import MsgType + T = TypeVar("T") -class EndpointSpecification(ABC, Generic[T]): +class EndpointSpecification(EndpointWrapper[MsgType], Generic[MsgType, T]): """A class representing an asynchronous API operation: Pub or Sub.""" title_: Optional[str] diff --git a/tests/cli/test_run_asgi.py b/tests/cli/test_run_asgi.py index 49825f932b..5920e74d1c 100644 --- a/tests/cli/test_run_asgi.py +++ b/tests/cli/test_run_asgi.py @@ -34,8 +34,15 @@ def test_run_as_asgi(runner: CliRunner) -> None: assert result.exit_code == 0 -@pytest.mark.parametrize("workers", (pytest.param(1), pytest.param(2), pytest.param(5))) -def test_run_as_asgi_with_workers(runner: CliRunner, workers: int) -> None: +@pytest.mark.parametrize( + "workers", + ( + pytest.param("1"), + pytest.param("2"), + pytest.param("5"), + ), +) +def test_run_as_asgi_with_workers(runner: CliRunner, workers: str) -> None: app = AsgiFastStream(AsyncMock()) app.run = AsyncMock() @@ -53,10 +60,10 @@ def test_run_as_asgi_with_workers(runner: CliRunner, workers: int) -> None: "--port", "8000", "-w", - str(workers), + workers, ], ) - extra = {"workers": workers} if workers > 1 else {} + extra = {"workers": workers} if int(workers) > 1 else {} app.run.assert_awaited_once_with( logging.INFO,