From 90777a56b0bbd25e6c348276bd99277cd94bed74 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 10 Jan 2024 20:53:45 +0300 Subject: [PATCH 01/87] refactor: new subscriber logic in NATS --- faststream/app.py | 215 ++------ faststream/broker/core/abc.py | 75 +-- faststream/broker/core/asynchronous.py | 57 +-- faststream/broker/fastapi/route.py | 7 - faststream/broker/fastapi/router.py | 32 -- faststream/broker/handler.py | 479 +++++++++++------- faststream/broker/middlewares.py | 5 +- faststream/broker/parsers.py | 24 +- faststream/broker/publisher.py | 3 +- faststream/broker/router.py | 4 - faststream/broker/test.py | 20 +- faststream/broker/types.py | 1 + faststream/broker/utils.py | 2 - faststream/broker/wrapper.py | 13 - faststream/kafka/broker.py | 12 - faststream/kafka/handler.py | 4 +- faststream/nats/broker.py | 91 +--- faststream/nats/broker.pyi | 8 +- faststream/nats/fastapi.pyi | 1 + faststream/nats/handler.py | 96 +++- faststream/nats/parser.py | 20 +- faststream/nats/router.pyi | 1 + faststream/nats/shared/router.pyi | 2 + faststream/rabbit/broker.py | 12 - faststream/rabbit/handler.py | 4 +- faststream/redis/handler.py | 4 +- faststream/utils/classes.py | 10 +- faststream/utils/context/repository.py | 4 +- pyproject.toml | 2 +- tests/asyncapi/base/arguments.py | 2 +- tests/brokers/base/consume.py | 4 +- tests/brokers/base/rpc.py | 4 +- tests/cli/test_app.py | 2 +- .../getting_started/routers/test_delay.py | 8 +- .../router/test_delay_registration.py | 2 +- tests/marks.py | 2 +- 36 files changed, 550 insertions(+), 682 deletions(-) diff --git a/faststream/app.py b/faststream/app.py index c52aee2ef7..950379acdc 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -1,5 +1,4 @@ import logging -from abc import ABC from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union import anyio @@ -20,7 +19,7 @@ from faststream.broker.core.asynchronous import BrokerAsyncUsecase from faststream.cli.supervisors.utils import HANDLED_SIGNALS from faststream.log import logger -from faststream.types import AnyCallable, AnyDict, AsyncFunc, Lifespan, SettingField +from faststream.types import AnyDict, AsyncFunc, Lifespan, SettingField from faststream.utils import apply_types, context from faststream.utils.functions import drop_response_type, fake_context, to_async @@ -28,149 +27,7 @@ T_HookReturn = TypeVar("T_HookReturn") -class ABCApp(ABC): - """A class representing an ABC App. - - Attributes: - _on_startup_calling : List of callable functions to be called on startup - _after_startup_calling : List of callable functions to be called after startup - _on_shutdown_calling : List of callable functions to be called on shutdown - _after_shutdown_calling : List of callable functions to be called after shutdown - broker : Optional broker object - logger : Optional logger object - title : Title of the app - version : Version of the app - description : Description of the app - terms_of_service : Optional terms of service URL - license : Optional license information - contact : Optional contact information - identifier : Optional identifier - asyncapi_tags : Optional list of tags - external_docs : Optional external documentation - - Methods: - set_broker : Set the broker object - on_startup : Add a hook to be run before the broker is connected - on_shutdown : Add a hook to be run before the broker is disconnected - after_startup : Add a hook to be run after the broker is connected - after_shutdown : Add a hook to be run after the broker is disconnected - _log : Log a message at a specified - """ - - _on_startup_calling: List[AnyCallable] - _after_startup_calling: List[AnyCallable] - _on_shutdown_calling: List[AnyCallable] - _after_shutdown_calling: List[AnyCallable] - - def __init__( - self, - broker: Optional[BrokerAsyncUsecase[Any, Any]] = None, - logger: Optional[logging.Logger] = logger, - # AsyncAPI information - title: str = "FastStream", - version: str = "0.1.0", - description: str = "", - terms_of_service: Optional[AnyHttpUrl] = None, - license: Optional[Union[License, LicenseDict, AnyDict]] = None, - contact: Optional[Union[Contact, ContactDict, AnyDict]] = None, - identifier: Optional[str] = None, - tags: Optional[Sequence[Union[Tag, TagDict, AnyDict]]] = None, - external_docs: Optional[Union[ExternalDocs, ExternalDocsDict, AnyDict]] = None, - ) -> None: - """Initialize an instance of the class. - - Args: - broker: An optional instance of the BrokerAsyncUsecase class. - logger: An optional instance of the logging.Logger class. - title: A string representing the title of the AsyncAPI. - version: A string representing the version of the AsyncAPI. - description: A string representing the description of the AsyncAPI. - terms_of_service: An optional URL representing the terms of service of the AsyncAPI. - license: An optional instance of the License class. - contact: An optional instance of the Contact class. - identifier: An optional string representing the identifier of the AsyncAPI. - tags: An optional sequence of Tag instances. - external_docs: An optional instance of the ExternalDocs class. - """ - self.broker = broker - self.logger = logger - self.context = context - context.set_global("app", self) - - self._on_startup_calling = [] - self._after_startup_calling = [] - self._on_shutdown_calling = [] - self._after_shutdown_calling = [] - - # AsyncAPI information - self.title = title - self.version = version - self.description = description - self.terms_of_service = terms_of_service - self.license = license - self.contact = contact - self.identifier = identifier - self.asyncapi_tags = tags - self.external_docs = external_docs - - def set_broker(self, broker: BrokerAsyncUsecase[Any, Any]) -> None: - """Set already existed App object broker. - - Useful then you create/init broker in `on_startup` hook. - """ - self.broker = broker - - def on_startup( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running BEFORE broker connected. - - This hook also takes an extra CLI options as a kwargs. - """ - self._on_startup_calling.append(apply_types(func)) - return func - - def on_shutdown( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running BEFORE broker disconnected.""" - self._on_shutdown_calling.append(apply_types(func)) - return func - - def after_startup( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running AFTER broker connected.""" - self._after_startup_calling.append(apply_types(func)) - return func - - def after_shutdown( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running AFTER broker disconnected.""" - self._after_shutdown_calling.append(apply_types(func)) - return func - - def _log(self, level: int, message: str) -> None: - """Logs a message with the specified log level. - - Args: - level (int): The log level. - message (str): The message to be logged. - - Returns: - None - - """ - if self.logger is not None: - self.logger.log(level, message) - - -class FastStream(ABCApp): +class FastStream: """A class representing a FastStream application. Attributes: @@ -232,19 +89,15 @@ def __init__( tags: application tags - for AsyncAPI docs external_docs: application external docs - for AsyncAPI docs """ - super().__init__( - broker=broker, - logger=logger, - title=title, - version=version, - description=description, - terms_of_service=terms_of_service, - license=license, - contact=contact, - identifier=identifier, - tags=tags, - external_docs=external_docs, - ) + self.broker = broker + self.logger = logger + self.context = context + context.set_global("app", self) + + self._on_startup_calling = [] + self._after_startup_calling = [] + self._on_shutdown_calling = [] + self._after_shutdown_calling = [] self.lifespan_context = ( apply_types( @@ -255,6 +108,24 @@ def __init__( else fake_context ) + # AsyncAPI information + self.title = title + self.version = version + self.description = description + self.terms_of_service = terms_of_service + self.license = license + self.contact = contact + self.identifier = identifier + self.asyncapi_tags = tags + self.external_docs = external_docs + + def set_broker(self, broker: BrokerAsyncUsecase[Any, Any]) -> None: + """Set already existed App object broker. + + Useful then you create/init broker in `on_startup` hook. + """ + self.broker = broker + def on_startup( self, func: Callable[P_HookParams, T_HookReturn], @@ -269,7 +140,7 @@ def on_startup( Returns: Async version of the func argument """ - super().on_startup(to_async(func)) + self._on_startup_calling.append(apply_types(to_async(func))) return func def on_shutdown( @@ -284,7 +155,7 @@ def on_shutdown( Returns: Async version of the func argument """ - super().on_shutdown(to_async(func)) + self._on_shutdown_calling.append(apply_types(to_async(func))) return func def after_startup( @@ -299,7 +170,7 @@ def after_startup( Returns: Async version of the func argument """ - super().after_startup(to_async(func)) + self._after_startup_calling.append(apply_types(to_async(func))) return func def after_shutdown( @@ -314,7 +185,7 @@ def after_shutdown( Returns: Async version of the func argument """ - super().after_shutdown(to_async(func)) + self._after_shutdown_calling.append(apply_types(to_async(func))) return func async def run( @@ -366,6 +237,8 @@ async def _start( async def _stop(self, log_level: int = logging.INFO) -> None: """Stop the application gracefully. + Blocking method (waits for SIGINT/SIGTERM). + Args: log_level (int): log level for logging messages (default: logging.INFO) @@ -398,6 +271,11 @@ async def _startup(self, **run_extra_options: SettingField) -> None: await func() async def _shutdown(self) -> None: + """Executes shutdown tasks. + + Returns: + None + """ for func in self._on_shutdown_calling: await func() @@ -406,3 +284,16 @@ async def _shutdown(self) -> None: for func in self._after_shutdown_calling: await func() + + def _log(self, level: int, message: str) -> None: + """Logs a message with the specified log level. + + Args: + level (int): The log level. + message (str): The message to be logged. + + Returns: + None + """ + if self.logger is not None: + self.logger.log(level, message) diff --git a/faststream/broker/core/abc.py b/faststream/broker/core/abc.py index fabe32b3cd..97439d586a 100644 --- a/faststream/broker/core/abc.py +++ b/faststream/broker/core/abc.py @@ -6,7 +6,6 @@ from types import TracebackType from typing import ( Any, - AsyncContextManager, Awaitable, Callable, Generic, @@ -43,7 +42,6 @@ MsgType, P_HandlerParams, T_HandlerReturn, - WrappedReturn, ) from faststream.broker.utils import ( change_logger_handlers, @@ -172,7 +170,7 @@ def __init__( midd_args: Sequence[Callable[[MsgType], BaseMiddleware]] = ( middlewares or empty_middleware ) - self.middlewares = [CriticalLogMiddleware(logger, log_level), *midd_args] + self.middlewares = (CriticalLogMiddleware(logger, log_level), *midd_args) self.dependencies = dependencies self._connection_args = (url, *args) @@ -280,15 +278,25 @@ def _wrap_handler( NotImplementedError: If silent animals are not supported. """ + final_extra_deps = tuple(chain(extra_dependencies, self.dependencies)) + build_dep = cast( - Callable[[Callable[F_Spec, F_Return]], CallModel[F_Spec, F_Return]], - _get_dependant or partial(build_call_model, cast=self._is_validate), + Callable[ + [Callable[F_Spec, F_Return]], + CallModel[F_Spec, F_Return], + ], + _get_dependant + or partial( + build_call_model, + cast=self._is_validate, + ), ) if isinstance(func, HandlerCallWrapper): handler_call, func = func, func._original_call if handler_call._wrapped_call is not None: return handler_call, build_dep(func) + else: handler_call = HandlerCallWrapper(func) @@ -296,27 +304,28 @@ def _wrap_handler( dependant = build_dep(f) - extra = [ - build_dep(d.dependency) - for d in chain(extra_dependencies, self.dependencies) - ] - + extra = [build_dep(d.dependency) for d in final_extra_deps] extend_dependencies(extra, dependant) if getattr(dependant, "flat_params", None) is None: # handle FastAPI Dependant dependant = _patch_fastapi_dependant(dependant) + params = () + + else: + params = set( + chain( + dependant.flat_params.keys(), + *(d.flat_params.keys() for d in extra), + ) + ) if self._is_apply_types and not _raw: - f = apply_types(None, cast=self._is_validate)(f, dependant) # type: ignore[arg-type,assignment] + f = apply_types(None)(f, dependant) # type: ignore[arg-type,assignment] decode_f = self._wrap_decode_message( func=f, _raw=_raw, - params=set( - chain( - dependant.flat_params.keys(), *(d.flat_params.keys() for d in extra) - ) - ), + params=params, ) process_f = self._process_message( @@ -338,9 +347,6 @@ def _abc_start(self) -> None: if not self.started: self.started = True - for h in self.handlers.values(): - h.global_middlewares = (*self.middlewares, *h.global_middlewares) - if self.logger is not None: change_logger_handlers(self.logger, self.fmt) @@ -359,7 +365,6 @@ def _abc_close( Returns: None - """ self.started = False @@ -381,34 +386,9 @@ def _abc__close( Note: This is an abstract method and must be implemented by subclasses. - """ self._connection = None - @abstractmethod - def _process_message( - self, - func: Callable[[StreamMessage[MsgType]], Awaitable[T_HandlerReturn]], - watcher: Callable[..., AsyncContextManager[None]], - **kwargs: Any, - ) -> Callable[[StreamMessage[MsgType]], Awaitable[WrappedReturn[T_HandlerReturn]]]: - """Processes a message using a given function and watcher. - - Args: - func: A callable that takes a StreamMessage of type MsgType and returns an Awaitable of type T_HandlerReturn. - watcher: An instance of BaseWatcher. - disable_watcher: Whether to use watcher context. - kwargs: Additional keyword arguments. - - Returns: - A callable that takes a StreamMessage of type MsgType and returns an Awaitable of type WrappedReturn[T_HandlerReturn]. - - Raises: - NotImplementedError: If the method is not implemented. - - """ - raise NotImplementedError() - @abstractmethod def subscriber( # type: ignore[return] self, @@ -457,7 +437,6 @@ def subscriber( # type: ignore[return] Raises: RuntimeWarning: If the broker is already running. - """ if self.started and not is_test_env(): # pragma: no cover warnings.warn( @@ -484,7 +463,6 @@ def publisher( Raises: NotImplementedError: If the method is not implemented. - """ self._publishers = {**self._publishers, key: publisher} return publisher @@ -508,7 +486,6 @@ def _wrap_decode_message( Raises: NotImplementedError: If the method is not implemented. - """ raise NotImplementedError() @@ -524,7 +501,6 @@ def extend_dependencies( Returns: The updated function or FastAPI dependency. - """ if isinstance(dependant, CallModel): dependant.extra_dependencies = (*dependant.extra_dependencies, *extra) @@ -543,7 +519,6 @@ def _patch_fastapi_dependant( Returns: The patched dependant. - """ params = dependant.query_params + dependant.body_params # type: ignore[attr-defined] diff --git a/faststream/broker/core/asynchronous.py b/faststream/broker/core/asynchronous.py index ab1519be7a..708e93ccdc 100644 --- a/faststream/broker/core/asynchronous.py +++ b/faststream/broker/core/asynchronous.py @@ -4,7 +4,6 @@ from types import TracebackType from typing import ( Any, - AsyncContextManager, Awaitable, Callable, Mapping, @@ -22,7 +21,7 @@ from typing_extensions import Self, override from faststream.broker.core.abc import BrokerUsecase -from faststream.broker.handler import AsyncHandler +from faststream.broker.handler import BaseHandler from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( @@ -35,7 +34,6 @@ MsgType, P_HandlerParams, T_HandlerReturn, - WrappedReturn, ) from faststream.broker.wrapper import HandlerCallWrapper from faststream.log import access_logger @@ -72,10 +70,9 @@ class BrokerAsyncUsecase(BrokerUsecase[MsgType, ConnectionType]): close(exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exec_tb: Optional[TracebackType] = None) : Close the connection to the broker. _process_message(func: Callable[[StreamMessage[MsgType]], Awaitable[T_HandlerReturn]], watcher: BaseWatcher) : Abstract method to process a message. publish(message: SendableMessage, *args: Any, reply_to: str = "", rpc: bool = False, rpc_timeout: Optional[float] - """ - handlers: Mapping[Any, AsyncHandler[MsgType]] + handlers: Mapping[Any, BaseHandler[MsgType]] middlewares: Sequence[Callable[[MsgType], BaseMiddleware]] _global_parser: Optional[AsyncCustomParser[MsgType, StreamMessage[MsgType]]] _global_decoder: Optional[AsyncCustomDecoder[StreamMessage[MsgType]]] @@ -85,8 +82,8 @@ async def start(self) -> None: """Start the broker async use case.""" super()._abc_start() for h in self.handlers.values(): - for f, _, _, _, _, _ in h.calls: - f.refresh(with_mock=False) + for f in h.calls: + f.handler.refresh(with_mock=False) await self.connect() @abstractmethod @@ -101,7 +98,6 @@ async def _connect(self, **kwargs: Any) -> ConnectionType: Raises: NotImplementedError: If the method is not implemented. - """ raise NotImplementedError() @@ -121,7 +117,6 @@ async def _close( Returns: None - """ super()._abc__close(exc_type, exc_val, exec_tb) @@ -143,7 +138,6 @@ async def close( Raises: NotImplementedError: If the method is not implemented. - """ super()._abc_close(exc_type, exc_val, exec_tb) @@ -153,34 +147,6 @@ async def close( if self._connection is not None: await self._close(exc_type, exc_val, exec_tb) - @override - @abstractmethod - def _process_message( - self, - func: Callable[[StreamMessage[MsgType]], Awaitable[T_HandlerReturn]], - watcher: Callable[..., AsyncContextManager[None]], - **kwargs: Any, - ) -> Callable[ - [StreamMessage[MsgType]], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: - """Process a message. - - Args: - func: A callable function that takes a StreamMessage and returns an Awaitable. - watcher: An instance of BaseWatcher. - disable_watcher: Whether to use watcher context. - kwargs: Additional keyword arguments. - - Returns: - A callable function that takes a StreamMessage and returns an Awaitable. - - Raises: - NotImplementedError: If the method is not implemented. - - """ - raise NotImplementedError() - @abstractmethod async def publish( self, @@ -208,7 +174,6 @@ async def publish( Raises: NotImplementedError: If the method is not implemented. - """ raise NotImplementedError() @@ -254,7 +219,6 @@ def subscriber( # type: ignore[override,return] Raises: NotImplementedError: If silent animals are not supported. - """ super().subscriber() @@ -288,7 +252,6 @@ def __init__( middlewares: Sequence of middlewares graceful_timeout: Graceful timeout **kwargs: Keyword arguments - """ super().__init__( *args, @@ -320,7 +283,6 @@ async def connect(self, *args: Any, **kwargs: Any) -> ConnectionType: Returns: The connection object. - """ if self._connection is None: _kwargs = self._resolve_connection_kwargs(*args, **kwargs) @@ -350,7 +312,6 @@ async def __aexit__( Overrides: This method overrides the __aexit__ method of the base class. - """ await self.close(exc_type, exc_val, exec_tb) @@ -370,10 +331,6 @@ def _wrap_decode_message( Returns: The wrapped function. - - Raises: - AssertionError: If the code reaches an unreachable state. - """ params_ln = len(params) @@ -386,12 +343,6 @@ async def decode_wrapper(message: StreamMessage[MsgType]) -> T_HandlerReturn: Returns: The return value of the handler function - - Raises: - AssertionError: If the code reaches an unreachable state - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ if _raw is True: return await func(message) diff --git a/faststream/broker/fastapi/route.py b/faststream/broker/fastapi/route.py index 9b533fa4a1..f6564a41f4 100644 --- a/faststream/broker/fastapi/route.py +++ b/faststream/broker/fastapi/route.py @@ -162,7 +162,6 @@ def __init__( _headers: A dictionary to store the headers of the request. _body: A dictionary to store the body of the request. _query_params: A dictionary to store the query parameters of the request. - """ self._headers = headers self._body = body @@ -188,9 +187,6 @@ def get_session( Raises: AssertionError: If the dependant call is not defined. - - Note: - This function is used to create a session for handling requests. It takes a dependant object, which represents the session, and a dependency overrides provider, which allows for overriding dependencies. It returns a callable that takes a native message and returns an awaitable sendable message. The session is created based on the dependant object and the message passed to the callable. The session is then used to call the function obtained from the dependant object, and the result is returned. """ assert dependant.call # nosec B101 @@ -217,9 +213,6 @@ async def app(message: NativeMessage[Any]) -> SendableMessage: Raises: TypeError: If the body of the message is not a dictionary - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ body = message.decoded_body diff --git a/faststream/broker/fastapi/router.py b/faststream/broker/fastapi/router.py index e182746125..be883bac64 100644 --- a/faststream/broker/fastapi/router.py +++ b/faststream/broker/fastapi/router.py @@ -68,7 +68,6 @@ class StreamRouter(APIRouter, Generic[MsgType]): asyncapi_router : create an APIRouter for AsyncAPI documentation include_router : include another router in the StreamRouter _setup_log_context : setup log context for the broker - """ broker_class: Type[BrokerAsyncUsecase[MsgType, Any]] @@ -138,7 +137,6 @@ def __init__( asyncapi_tags: Optional sequence of asyncapi tags for the class schema schema_url: Optional URL for the class schema **connection_kwars: Additional keyword arguments for the connection - """ assert ( # nosec B101 self.broker_class @@ -211,7 +209,6 @@ def add_api_mq_route( Returns: The handler call wrapper for the route. - """ route: StreamRoute[MsgType, P_HandlerParams, T_HandlerReturn] = StreamRoute( path, @@ -245,10 +242,6 @@ def subscriber( Returns: A callable decorator that adds the decorated function as an endpoint for the specified path. - - Raises: - NotImplementedError: If silent animals are not supported. - """ current_dependencies = self.dependencies.copy() if dependencies: @@ -264,9 +257,6 @@ def decorator( Returns: The decorated function. - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ return self.add_api_mq_route( path, @@ -286,10 +276,6 @@ def wrap_lifespan(self, lifespan: Optional[Lifespan[Any]] = None) -> Lifespan[An Returns: The wrapped lifespan object. - - Raises: - NotImplementedError: If silent animals are not supported. - """ lifespan_context = lifespan if lifespan is not None else _DefaultLifespan(self) @@ -304,12 +290,6 @@ async def start_broker_lifespan( Yields: AsyncIterator[Mapping[str, Any]]: A mapping of context information during the lifespan of the broker. - - Raises: - NotImplementedError: If silent animals are not supported. - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ from faststream.asyncapi.generate import get_app_schema @@ -361,7 +341,6 @@ def after_startup( Returns: A decorated function that takes an `AppType` argument and returns a mapping of strings to any type. - """ ... @@ -380,7 +359,6 @@ def after_startup( Note: This function can be used as a decorator for other functions. - """ ... @@ -396,7 +374,6 @@ def after_startup( Returns: A decorated function that will be executed after startup. - """ ... @@ -412,7 +389,6 @@ def after_startup( Returns: A decorated function that takes an `AppType` argument and returns an awaitable `None`. - """ ... @@ -437,7 +413,6 @@ def after_startup( Returns: The registered function. - """ self._after_startup_hooks.append(to_async(func)) # type: ignore return func @@ -457,7 +432,6 @@ def publisher( Returns: An instance of `BasePublisher` that can be used to publish messages to the specified queue. - """ return self.broker.publisher( queue, @@ -479,7 +453,6 @@ def asyncapi_router(self, schema_url: Optional[str]) -> Optional[APIRouter]: Notes: This function defines three nested functions: download_app_json_schema, download_app_yaml_schema, and serve_asyncapi_schema. These functions are used to handle different routes for serving the AsyncAPI schema and documentation. - """ if not self.include_in_schema or not schema_url: return None @@ -533,9 +506,6 @@ def serve_asyncapi_schema( Raises: AssertionError: If the schema is not available. - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ assert ( # nosec B101 self.schema @@ -597,7 +567,6 @@ def include_router( deprecated (bool, optional): Whether the router is deprecated. Defaults to None. include_in_schema (bool, optional): Whether to include the router in the API schema. Defaults to True. generate_unique_id_function (Callable[[APIRoute], str], optional): The function to generate unique IDs for - """ if isinstance(router, StreamRouter): # pragma: no branch self._setup_log_context(self.broker, router.broker) @@ -637,6 +606,5 @@ def _setup_log_context( Raises: NotImplementedError: If the function is not implemented. - """ raise NotImplementedError() diff --git a/faststream/broker/handler.py b/faststream/broker/handler.py index c6f0623a3e..6e83642a3e 100644 --- a/faststream/broker/handler.py +++ b/faststream/broker/handler.py @@ -1,7 +1,10 @@ import asyncio from abc import abstractmethod from contextlib import AsyncExitStack, suppress +from dataclasses import dataclass +from functools import partial, wraps from inspect import unwrap +from logging import Logger from typing import ( TYPE_CHECKING, Any, @@ -10,6 +13,7 @@ Dict, Generic, List, + Mapping, Optional, Sequence, Tuple, @@ -18,15 +22,16 @@ ) import anyio -from fast_depends.core import CallModel +from fast_depends.core import CallModel, build_call_model +from fast_depends.dependencies import Depends from typing_extensions import Self, override from faststream._compat import IS_OPTIMIZED from faststream.asyncapi.base import AsyncAPIOperation from faststream.asyncapi.message import parse_handler_params from faststream.asyncapi.utils import to_camelcase -from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware +from faststream.broker.push_back_watcher import WatcherContext from faststream.broker.types import ( AsyncDecoder, AsyncParser, @@ -35,224 +40,291 @@ Filter, MsgType, P_HandlerParams, - SyncDecoder, - SyncParser, T_HandlerReturn, WrappedReturn, ) +from faststream.broker.utils import get_watcher, set_message_context from faststream.broker.wrapper import HandlerCallWrapper from faststream.exceptions import HandlerException, StopConsume from faststream.types import AnyDict, SendableMessage +from faststream.utils import apply_types from faststream.utils.context.repository import context -from faststream.utils.functions import to_async +from faststream.utils.functions import fake_context, to_async if TYPE_CHECKING: from contextvars import Token + from typing import Protocol, overload + + from faststream.broker.message import StreamMessage + + class WrapperProtocol(Generic[MsgType], Protocol): + @overload + def __call__( + self, + func: None = None, + *, + filter: Filter["StreamMessage[MsgType]"], + parser: CustomParser[MsgType, Any], + decoder: CustomDecoder["StreamMessage[MsgType]"], + middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), + dependencies: Sequence[Depends] = (), + ) -> Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ]: + ... + + @overload + def __call__( + self, + func: Callable[P_HandlerParams, T_HandlerReturn] = None, + *, + filter: Filter["StreamMessage[MsgType]"], + parser: CustomParser[MsgType, Any], + decoder: CustomDecoder["StreamMessage[MsgType]"], + middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), + dependencies: Sequence[Depends] = (), + ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: + ... + + def __call__( + self, + func: Optional[Callable[P_HandlerParams, T_HandlerReturn]] = None, + *, + filter: Filter["StreamMessage[MsgType]"], + parser: CustomParser[MsgType, Any], + decoder: CustomDecoder["StreamMessage[MsgType]"], + middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), + dependencies: Sequence[Depends] = (), + ) -> Union[ + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ], + ]: + ... + + +@dataclass(slots=True) +class HandlerItem(Generic[MsgType]): + """A class representing handler overloaded item.""" + + handler: HandlerCallWrapper[MsgType, Any, SendableMessage] + filter: Callable[["StreamMessage[MsgType]"], Awaitable[bool]] + parser: AsyncParser[MsgType, Any] + decoder: AsyncDecoder["StreamMessage[MsgType]"] + middlewares: Sequence[Callable[[Any], BaseMiddleware]] + dependant: CallModel[Any, SendableMessage] + @property + def call_name(self) -> str: + """Returns the name of the original call.""" + if self.handler is None: + return "" -class BaseHandler(AsyncAPIOperation, Generic[MsgType]): - """A base handler class for asynchronous API operations. + caller = unwrap(self.handler._original_call) + name = getattr(caller, "__name__", str(caller)) + return name - Attributes: - calls : List of tuples representing handler calls, filters, parsers, decoders, middlewares, and dependants. - global_middlewares : Sequence of global middlewares. + @property + def description(self) -> Optional[str]: + """Returns the description of original call.""" + if self.handler is None: + return None - Methods: - __init__ : Initializes the BaseHandler object. - name : Returns the name of the handler. - call_name : Returns the name of the handler call. - description : Returns the description of the handler. - consume : Abstract method to consume a message. + caller = unwrap(self.handler._original_call) + description = getattr(caller, "__doc__", None) + return description - Note: This class inherits from AsyncAPIOperation and is a generic class with type parameter MsgType. - """ +class BaseHandler(AsyncAPIOperation, Generic[MsgType]): + """A class representing an asynchronous handler. - calls: Union[ - List[ - Tuple[ - HandlerCallWrapper[MsgType, Any, SendableMessage], # handler - Callable[[StreamMessage[MsgType]], bool], # filter - SyncParser[MsgType, StreamMessage[MsgType]], # parser - SyncDecoder[StreamMessage[MsgType]], # decoder - Sequence[Callable[[Any], BaseMiddleware]], # middlewares - CallModel[Any, SendableMessage], # dependant - ] - ], - List[ - Tuple[ - HandlerCallWrapper[MsgType, Any, SendableMessage], # handler - Callable[[StreamMessage[MsgType]], Awaitable[bool]], # filter - AsyncParser[MsgType, StreamMessage[MsgType]], # parser - AsyncDecoder[StreamMessage[MsgType]], # decoder - Sequence[Callable[[Any], BaseMiddleware]], # middlewares - CallModel[Any, SendableMessage], # dependant - ] - ], - ] + Methods: + add_call : adds a new call to the list of calls + consume : consumes a message and returns a sendable message + start : starts the handler + close : closes the handler + """ - global_middlewares: Sequence[Callable[[Any], BaseMiddleware]] + calls: List[HandlerItem[MsgType]] def __init__( self, *, - log_context_builder: Callable[[StreamMessage[Any]], Dict[str, str]], - description: Optional[str] = None, - title: Optional[str] = None, - include_in_schema: bool = True, + log_context_builder: Callable[["StreamMessage[Any]"], Dict[str, str]], + middlewares: Sequence[Callable[[MsgType], BaseMiddleware]], + logger: Optional[Logger], + description: Optional[str], + title: Optional[str], + include_in_schema: bool, + graceful_timeout: Optional[float], ) -> None: - """Initialize a new instance of the class. - - Args: - log_context_builder: A callable that builds the log context. - description: Optional description of the instance. - title: Optional title of the instance. - include_in_schema: Whether to include the instance in the schema. - - """ - self.calls = [] # type: ignore[assignment] - self.global_middlewares = [] + """Initialize a new instance of the class.""" + self.calls = [] + self.middlewares = middlewares self.log_context_builder = log_context_builder + self.logger = logger self.running = False + self.lock = MultiLock() + self.graceful_timeout = graceful_timeout + # AsyncAPI information self._description = description self._title = title - self.include_in_schema = include_in_schema + super().__init__(include_in_schema=include_in_schema) - @property - def call_name(self) -> str: - """Returns the name of the handler call.""" - caller = unwrap(self.calls[0][0]._original_call) - name = getattr(caller, "__name__", str(caller)) - return to_camelcase(name) + def add_call( + self, + filter_: Filter["StreamMessage[MsgType]"], + parser_: CustomParser[MsgType, Any], + decoder_: CustomDecoder["StreamMessage[MsgType]"], + middlewares_: Sequence[Callable[[Any], BaseMiddleware]], + dependencies_: Sequence[Depends], + **wrap_kwargs: Any, + ) -> "WrapperProtocol[MsgType]": + def wrapper( + func: Optional[Callable[P_HandlerParams, T_HandlerReturn]] = None, + *, + filter: Filter["StreamMessage[MsgType]"] = filter_, + parser: CustomParser[MsgType, Any] = parser_, + decoder: CustomDecoder["StreamMessage[MsgType]"] = decoder_, + middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), + dependencies: Sequence[Depends] = (), + ) -> Union[ + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ], + ]: + def real_wrapper( + func: Callable[P_HandlerParams, T_HandlerReturn], + ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: + handler, dependant = self.wrap_handler( + func=func, + dependencies=(*dependencies_, *dependencies), + **wrap_kwargs, + ) + + self.calls.append( + HandlerItem( + handler=handler, + dependant=dependant, + filter=to_async(filter), + parser=to_async(parser), + decoder=to_async(decoder), + middlewares=(*middlewares_, *middlewares), + ) + ) - @property - def description(self) -> Optional[str]: - """Returns the description of the handler.""" - if not self.calls: # pragma: no cover - description = None + return handler - else: - caller = unwrap(self.calls[0][0]._original_call) - description = getattr(caller, "__doc__", None) + if func is None: + return real_wrapper - return self._description or description + else: + return real_wrapper(func) - @abstractmethod - def consume(self, msg: MsgType) -> SendableMessage: - """Consume a message. + return wrapper - Args: - msg: The message to be consumed. + def wrap_handler( + self, + *, + func: Callable[P_HandlerParams, T_HandlerReturn], + no_ack: bool, + is_validate: bool, + dependencies: Sequence[Depends], + raw: bool, + retry: int, + **process_kwargs: Any, + ) -> Tuple[ + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + CallModel[P_HandlerParams, T_HandlerReturn], + ]: + build_dep = partial( + build_call_model, + cast=is_validate, + extra_dependencies=dependencies, + ) - Returns: - The sendable message. + if isinstance(func, HandlerCallWrapper): + handler_call, func = func, func._original_call + if handler_call._wrapped_call is not None: + return handler_call, build_dep(func) - Raises: - NotImplementedError: If the method is not implemented. + else: + handler_call = HandlerCallWrapper(func) - """ - raise NotImplementedError() + f = to_async(func) + dependant = build_dep(f) - def get_payloads(self) -> List[Tuple[AnyDict, str]]: - """Get the payloads of the handler.""" - payloads: List[Tuple[AnyDict, str]] = [] + if not raw: + f = apply_types(None)(f, dependant) - for h, _, _, _, _, dep in self.calls: - body = parse_handler_params( - dep, prefix=f"{self._title or self.call_name}:Message" + f = self._wrap_decode_message( + func=f, + params_ln=len(dependant.flat_params), ) - payloads.append((body, to_camelcase(unwrap(h._original_call).__name__))) - - return payloads + f = self._process_message( + func=f, + watcher=( + partial(WatcherContext, watcher=get_watcher(self.logger, retry)) # type: ignore[arg-type] + if not no_ack + else fake_context + ), + **(process_kwargs or {}), + ) -class AsyncHandler(BaseHandler[MsgType]): - """A class representing an asynchronous handler. + f = set_message_context(f) + handler_call.set_wrapped(f) + return handler_call, dependant - Attributes: - calls : a list of tuples containing the following information: - - handler : the handler function - - filter : a callable that filters the stream message - - parser : an async parser for the message - - decoder : an async decoder for the message - - middlewares : a sequence of middlewares - - dependant : a call model for the handler + def _wrap_decode_message( + self, + func: Callable[..., Awaitable[T_HandlerReturn]], + params_ln: int, + ) -> Callable[ + ["StreamMessage[MsgType]"], + Awaitable[T_HandlerReturn], + ]: + """Wraps a function to decode a message and pass it as an argument to the wrapped function. - Methods: - add_call : adds a new call to the list of calls - consume : consumes a message and returns a sendable message - start : starts the handler - close : closes the handler + Args: + func: The function to be wrapped. + params_ln: The parameters number to be passed to the wrapped function. - """ + Returns: + The wrapped function. + """ - calls: List[ - Tuple[ - HandlerCallWrapper[MsgType, Any, SendableMessage], # handler - Callable[[StreamMessage[MsgType]], Awaitable[bool]], # filter - AsyncParser[MsgType, Any], # parser - AsyncDecoder[StreamMessage[MsgType]], # decoder - Sequence[Callable[[Any], BaseMiddleware]], # middlewares - CallModel[Any, SendableMessage], # dependant - ] - ] + @wraps(func) + async def decode_wrapper(message: "StreamMessage[MsgType]") -> T_HandlerReturn: + """A wrapper function to decode and handle a message. - def __init__( - self, - *, - log_context_builder: Callable[[StreamMessage[Any]], Dict[str, str]], - description: Optional[str] = None, - title: Optional[str] = None, - include_in_schema: bool = True, - graceful_timeout: Optional[float] = None, - ) -> None: - """Initialize a new instance of the class.""" - super().__init__( - log_context_builder=log_context_builder, - description=description, - title=title, - include_in_schema=include_in_schema, - ) - self.lock = MultiLock() - self.graceful_timeout = graceful_timeout + Args: + message : The message to be decoded and handled - def add_call( - self, - *, - handler: HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - parser: CustomParser[MsgType, Any], - decoder: CustomDecoder[Any], - dependant: CallModel[P_HandlerParams, T_HandlerReturn], - filter: Filter[StreamMessage[MsgType]], - middlewares: Optional[Sequence[Callable[[Any], BaseMiddleware]]], - ) -> None: - """Adds a call to the handler. + Returns: + The return value of the handler function + """ + msg = message.decoded_body - Args: - handler: The handler call wrapper. - parser: The custom parser. - decoder: The custom decoder. - dependant: The call model. - filter: The filter for stream messages. - middlewares: Optional sequence of middlewares. + if params_ln > 1: + if isinstance(msg, Mapping): + return await func(**msg) + elif isinstance(msg, Sequence): + return await func(*msg) + else: + return await func(msg) - Returns: - None + raise AssertionError("unreachable") - """ - self.calls.append( - ( - handler, - to_async(filter), - to_async(parser) if parser else None, # type: ignore[arg-type] - to_async(decoder) if decoder else None, # type: ignore[arg-type] - middlewares or (), - dependant, - ) - ) + return decode_wrapper @override async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[override] @@ -266,10 +338,6 @@ async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[overri Raises: StopConsume: If the consumption needs to be stopped. - - Raises: - Exception: If an error occurs during consumption. - """ result: Optional[WrappedReturn[SendableMessage]] = None result_msg: SendableMessage = None @@ -284,20 +352,20 @@ async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[overri stack.enter_context(context.scope("handler_", self)) gl_middlewares: List[BaseMiddleware] = [ - await stack.enter_async_context(m(msg)) for m in self.global_middlewares + await stack.enter_async_context(m(msg)) for m in self.middlewares ] logged = False processed = False - for handler, filter_, parser, decoder, middlewares, _ in self.calls: + for h in self.calls: local_middlewares: List[BaseMiddleware] = [ - await stack.enter_async_context(m(msg)) for m in middlewares + await stack.enter_async_context(m(msg)) for m in h.middlewares ] all_middlewares = gl_middlewares + local_middlewares # TODO: add parser & decoder caches - message = await parser(msg) + message = await h.parser(msg) if not logged: # pragma: no branch log_context_tag = context.set_local( @@ -305,10 +373,10 @@ async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[overri self.log_context_builder(message), ) - message.decoded_body = await decoder(message) + message.decoded_body = await h.decoder(message) message.processed = processed - if await filter_(message): + if await h.filter(message): assert ( # nosec B101 not processed ), "You can't process a message with multiple consumers" @@ -324,14 +392,14 @@ async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[overri result = await cast( Awaitable[Optional[WrappedReturn[SendableMessage]]], - handler.call_wrapped(message), + h.handler.call_wrapped(message), ) if result is not None: result_msg, pub_response = result # TODO: suppress all publishing errors and raise them after all publishers will be tried - for publisher in (pub_response, *handler._publishers): + for publisher in (pub_response, *h.handler._publishers): if publisher is not None: async with AsyncExitStack() as pub_stack: result_to_send = result_msg @@ -350,18 +418,18 @@ async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[overri except StopConsume: await self.close() - handler.trigger() + h.handler.trigger() except HandlerException as e: # pragma: no cover - handler.trigger() + h.handler.trigger() raise e except Exception as e: - handler.trigger(error=e) + h.handler.trigger(error=e) raise e else: - handler.trigger(result=result[0] if result else None) + h.handler.trigger(result=result[0] if result else None) message.processed = processed = True if IS_OPTIMIZED: # pragma: no cover break @@ -380,10 +448,48 @@ async def start(self) -> None: @abstractmethod async def close(self) -> None: - """Close the handler.""" + """Close the handler. + + Blocks loop up to graceful_timeout seconds. + """ self.running = False await self.lock.wait_release(self.graceful_timeout) + @property + def call_name(self) -> str: + """Returns the name of the handler call.""" + return to_camelcase(self.calls[0].call_name) + + @property + def description(self) -> Optional[str]: + """Returns the description of the handler.""" + if self._description: + return self._description + + if not self.calls: # pragma: no cover + return None + + else: + return self.calls[0].description + + def get_payloads(self) -> List[Tuple[AnyDict, str]]: + """Get the payloads of the handler.""" + payloads: List[Tuple[AnyDict, str]] = [] + + for h in self.calls: + body = parse_handler_params( + h.dependant, + prefix=f"{self._title or self.call_name}:Message", + ) + payloads.append( + ( + body, + to_camelcase(h.call_name), + ) + ) + + return payloads + class MultiLock: """A class representing a multi lock.""" @@ -414,7 +520,10 @@ def empty(self) -> bool: return self.queue.empty() async def wait_release(self, timeout: Optional[float] = None) -> None: - """Wait for the queue to be released.""" + """Wait for the queue to be released. + + Using for graceful shutdown. + """ if timeout: with anyio.move_on_after(timeout): await self.queue.join() diff --git a/faststream/broker/middlewares.py b/faststream/broker/middlewares.py index cb9bd3daa1..3497b54152 100644 --- a/faststream/broker/middlewares.py +++ b/faststream/broker/middlewares.py @@ -243,13 +243,11 @@ def __call__(self, msg: Any) -> Self: """ return self - async def on_consume(self, msg: DecodedMessage) -> DecodedMessage: + async def on_receive(self) -> DecodedMessage: if self.logger is not None: c = context.get_local("log_context") self.logger.log(self.log_level, "Received", extra=c) - return await super().on_consume(msg) - async def after_processed( self, exc_type: Optional[Type[BaseException]] = None, @@ -277,4 +275,5 @@ async def after_processed( ) self.logger.log(self.log_level, "Processed", extra=c) + return True diff --git a/faststream/broker/parsers.py b/faststream/broker/parsers.py index d2afca9f06..60f2feac4f 100644 --- a/faststream/broker/parsers.py +++ b/faststream/broker/parsers.py @@ -35,7 +35,6 @@ def decode_message(message: StreamMessage[Any]) -> DecodedMessage: Raises: JSONDecodeError: If the message body cannot be decoded as JSON. - """ body: Any = getattr(message, "body", message) m: DecodedMessage = body @@ -63,7 +62,6 @@ def encode_message( Returns: A tuple containing the encoded message as bytes and the content type of the message. - """ if msg is None: return b"", None @@ -93,7 +91,6 @@ def resolve_custom_func( Returns: A resolved function of type SyncDecoder - """ ... @@ -111,7 +108,6 @@ def resolve_custom_func( Returns: A resolved function of type SyncParser[MsgType]. - """ ... @@ -129,7 +125,6 @@ def resolve_custom_func( Returns: Resolved function. - """ ... @@ -147,7 +142,6 @@ def resolve_custom_func( Returns: Resolved function. - """ ... @@ -165,7 +159,6 @@ def resolve_custom_func( Returns: A decoder function. - """ ... @@ -183,17 +176,25 @@ def resolve_custom_func( Returns: Resolved function. - """ ... def resolve_custom_func( # type: ignore[misc] custom_func: Optional[ - Union[CustomDecoder[StreamMsg], CustomParser[MsgType, StreamMsg]] + Union[ + CustomDecoder[StreamMsg], + CustomParser[MsgType, StreamMsg], + ] ], - default_func: Union[Decoder[StreamMsg], Parser[MsgType, StreamMsg]], -) -> Union[Decoder[StreamMsg], Parser[MsgType, StreamMsg]]: + default_func: Union[ + Decoder[StreamMsg], + Parser[MsgType, StreamMsg], + ], +) -> Union[ + Decoder[StreamMsg], + Parser[MsgType, StreamMsg], +]: """Resolve a custom function. Args: @@ -202,7 +203,6 @@ def resolve_custom_func( # type: ignore[misc] Returns: The resolved function of type Decoder or Parser. - """ if custom_func is None: return default_func diff --git a/faststream/broker/publisher.py b/faststream/broker/publisher.py index 1913dd9f76..1003c377b7 100644 --- a/faststream/broker/publisher.py +++ b/faststream/broker/publisher.py @@ -76,7 +76,6 @@ def __call__( Raises: TypeError: If `func` is not callable. - """ handler_call: HandlerCallWrapper[ MsgType, P_HandlerParams, T_HandlerReturn @@ -89,6 +88,7 @@ def __call__( async def publish( self, message: SendableMessage, + *args: Any, correlation_id: Optional[str] = None, **kwargs: Any, ) -> Optional[SendableMessage]: @@ -104,7 +104,6 @@ async def publish( Raises: NotImplementedError: If the method is not implemented. - """ raise NotImplementedError() diff --git a/faststream/broker/router.py b/faststream/broker/router.py index fc3415a160..4f81e26485 100644 --- a/faststream/broker/router.py +++ b/faststream/broker/router.py @@ -247,7 +247,6 @@ def _wrap_subscriber( A callable object that wraps the decorated function This function is decorated with `@abstractmethod`, indicating that it is an abstract method and must be implemented by any subclass. - """ def router_subscriber_wrapper( @@ -260,9 +259,6 @@ def router_subscriber_wrapper( Returns: The wrapped function. - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ wrapped_func: HandlerCallWrapper[ MsgType, P_HandlerParams, T_HandlerReturn diff --git a/faststream/broker/test.py b/faststream/broker/test.py index ea227612e7..244c58c0fb 100644 --- a/faststream/broker/test.py +++ b/faststream/broker/test.py @@ -11,7 +11,7 @@ from faststream.app import FastStream from faststream.broker.core.abc import BrokerUsecase from faststream.broker.core.asynchronous import BrokerAsyncUsecase -from faststream.broker.handler import AsyncHandler +from faststream.broker.handler import BaseHandler from faststream.broker.middlewares import CriticalLogMiddleware from faststream.broker.wrapper import HandlerCallWrapper from faststream.types import SendableMessage, SettingField @@ -197,10 +197,10 @@ def _fake_start(cls, broker: Broker, *args: Any, **kwargs: Any) -> None: if handler is not None: mock = MagicMock() p.set_test(mock=mock, with_fake=False) - for f, _, _, _, _, _ in handler.calls: - f.set_test() - assert f.mock # nosec B101 - f.mock.side_effect = mock + for h in handler.calls: + h.handler.set_test() + assert h.handler.mock # nosec B101 + h.handler.mock.side_effect = mock else: f = cls.create_publisher_fake_subscriber(broker, p) @@ -233,8 +233,8 @@ def _fake_close( for h in broker.handlers.values(): h.running = False - for f, _, _, _, _, _ in h.calls: - f.reset_test() + for h in h.calls: + h.handler.reset_test() @staticmethod @abstractmethod @@ -278,12 +278,12 @@ def patch_broker_calls(broker: BrokerUsecase[Any, Any]) -> None: broker._abc_start() for handler in broker.handlers.values(): - for f, _, _, _, _, _ in handler.calls: - f.set_test() + for h in handler.calls: + h.handler.set_test() async def call_handler( - handler: AsyncHandler[Any], + handler: BaseHandler[Any], message: Any, rpc: bool = False, rpc_timeout: Optional[float] = 30.0, diff --git a/faststream/broker/types.py b/faststream/broker/types.py index 151c0b8b77..16b0895ce3 100644 --- a/faststream/broker/types.py +++ b/faststream/broker/types.py @@ -80,6 +80,7 @@ class AsyncPublisherProtocol(Protocol): async def publish( self, message: SendableMessage, + *args: Any, correlation_id: Optional[str] = None, **kwargs: Any, ) -> Optional[SendableMessage]: diff --git a/faststream/broker/utils.py b/faststream/broker/utils.py index 4b22a07281..87bc32a2f5 100644 --- a/faststream/broker/utils.py +++ b/faststream/broker/utils.py @@ -75,7 +75,6 @@ def set_message_context( Returns: The function with the message context set. - """ @wraps(func) @@ -89,7 +88,6 @@ async def set_message_wrapper( Returns: The wrapped return value of the handler function. - """ with context.scope("message", message): return await func(message) diff --git a/faststream/broker/wrapper.py b/faststream/broker/wrapper.py index e8b22b399e..7b55385994 100644 --- a/faststream/broker/wrapper.py +++ b/faststream/broker/wrapper.py @@ -24,7 +24,6 @@ class FakePublisher: Methods: publish : asynchronously publishes a message with optional correlation ID and additional keyword arguments - """ def __init__(self, method: Callable[..., Awaitable[SendableMessage]]) -> None: @@ -32,7 +31,6 @@ def __init__(self, method: Callable[..., Awaitable[SendableMessage]]) -> None: Args: method: A callable that takes any number of arguments and returns an awaitable sendable message. - """ self.method = method @@ -51,7 +49,6 @@ async def publish( Returns: The published message. - """ return await self.method(message, correlation_id=correlation_id, **kwargs) @@ -73,7 +70,6 @@ class HandlerCallWrapper(Generic[MsgType, P_HandlerParams, T_HandlerReturn]): set_wrapped : Set the wrapped handler call call_wrapped : Call the wrapped handler wait_call : Wait for the handler call to complete - """ mock: Optional[MagicMock] @@ -110,7 +106,6 @@ def __new__( Note: If the "call" argument is already an instance of the class, it is returned as is. Otherwise, a new instance of the class is created using the superclass's __new__ method. - """ if isinstance(call, cls): return call @@ -131,8 +126,6 @@ def __init__( _wrapped_call: The wrapped handler function. _publishers: A list of publishers. mock: A MagicMock object. - __name__: The name of the handler function. - """ if not isinstance(call, HandlerCallWrapper): self._original_call = call @@ -185,11 +178,6 @@ def call_wrapped( Returns: The result of the wrapped function call. - - Raises: - AssertionError: If `set_wrapped` has not been called before calling this function. - AssertionError: If the broker has not been started before calling this function. - """ assert self._wrapped_call, "You should use `set_wrapped` first" # nosec B101 if self.is_test: @@ -208,7 +196,6 @@ async def wait_call(self, timeout: Optional[float] = None) -> None: Returns: None - """ assert ( # nosec B101 self.future is not None diff --git a/faststream/kafka/broker.py b/faststream/kafka/broker.py index dd47b15a3f..0a39c16a82 100644 --- a/faststream/kafka/broker.py +++ b/faststream/kafka/broker.py @@ -229,12 +229,6 @@ async def process_wrapper( Returns: WrappedReturn[T_HandlerReturn]: The wrapped return value. - - Raises: - AssertionError: If the code reaches an unreachable point. - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ async with watcher(message): r = await func(message) @@ -433,12 +427,6 @@ def consumer_wrapper( Returns: The wrapped handler call. - - Raises: - NotImplementedError: If silent animals are not supported. - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ handler_call, dependant = self._wrap_handler( func=func, diff --git a/faststream/kafka/handler.py b/faststream/kafka/handler.py index 87878648e9..fca5fa1ea5 100644 --- a/faststream/kafka/handler.py +++ b/faststream/kafka/handler.py @@ -9,7 +9,7 @@ from typing_extensions import Unpack, override from faststream.__about__ import __version__ -from faststream.broker.handler import AsyncHandler +from faststream.broker.handler import BaseHandler from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.parsers import resolve_custom_func @@ -26,7 +26,7 @@ from faststream.kafka.shared.schemas import ConsumerConnectionParams -class LogicHandler(AsyncHandler[ConsumerRecord]): +class LogicHandler(BaseHandler[ConsumerRecord]): """A class to handle logic for consuming messages from Kafka. Attributes: diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index c2f50a5975..6a149fc686 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -1,11 +1,10 @@ import logging import warnings -from functools import partial, wraps +from functools import partial from types import TracebackType from typing import ( + TYPE_CHECKING, Any, - AsyncContextManager, - Awaitable, Callable, Dict, List, @@ -32,18 +31,12 @@ from typing_extensions import TypeAlias, override from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter -from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( - AsyncPublisherProtocol, CustomDecoder, CustomParser, Filter, - P_HandlerParams, - T_HandlerReturn, - WrappedReturn, ) -from faststream.broker.wrapper import FakePublisher, HandlerCallWrapper from faststream.exceptions import NOT_CONNECTED_YET from faststream.nats.asyncapi import Handler, Publisher from faststream.nats.helpers import stream_builder @@ -59,6 +52,9 @@ Subject: TypeAlias = str +if TYPE_CHECKING: + from faststream.broker.handler import WrapperProtocol + class NatsBroker( NatsLoggingMixin, @@ -227,34 +223,6 @@ async def start(self) -> None: self._log(f"`{handler.call_name}` waiting for messages", extra=c) await handler.start(self.stream if is_js else self._connection) - def _process_message( - self, - func: Callable[[StreamMessage[Msg]], Awaitable[T_HandlerReturn]], - watcher: Callable[..., AsyncContextManager[None]], - **kwargs: Any, - ) -> Callable[ - [StreamMessage[Msg]], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: - @wraps(func) - async def process_wrapper( - message: StreamMessage[Msg], - ) -> WrappedReturn[T_HandlerReturn]: - async with watcher(message): - r = await func(message) - - pub_response: Optional[AsyncPublisherProtocol] - if message.reply_to: - pub_response = FakePublisher( - partial(self.publish, subject=message.reply_to) - ) - else: - pub_response = None - - return r, pub_response - - return process_wrapper - def _log_connection_broken( self, error_cb: Optional[ErrorCallback] = None, @@ -309,12 +277,13 @@ def subscriber( # type: ignore[override] inbox_prefix: bytes = api.INBOX_PREFIX, # custom ack_first: bool = False, + retry: bool = False, stream: Union[str, JStream, None] = None, # broker arguments dependencies: Sequence[Depends] = (), parser: Optional[CustomParser[Msg, NatsMessage]] = None, decoder: Optional[CustomDecoder[NatsMessage]] = None, - middlewares: Optional[Sequence[Callable[[Msg], BaseMiddleware]]] = None, + middlewares: Sequence[Callable[[Msg], BaseMiddleware]] = (), filter: Filter[NatsMessage] = default_filter, no_ack: bool = False, max_workers: int = 1, @@ -323,10 +292,7 @@ def subscriber( # type: ignore[override] description: Optional[str] = None, include_in_schema: bool = True, **original_kwargs: Any, - ) -> Callable[ - [Callable[P_HandlerParams, T_HandlerReturn]], - HandlerCallWrapper[Msg, P_HandlerParams, T_HandlerReturn], - ]: + ) -> "WrapperProtocol[Msg]": stream = stream_builder.stream(stream) if pull_sub is not None and stream is None: @@ -399,6 +365,8 @@ def subscriber( # type: ignore[override] include_in_schema=include_in_schema, graceful_timeout=self.graceful_timeout, max_workers=max_workers, + middlewares=self.middlewares, + logger=self.logger, log_context_builder=partial( self._get_log_context, stream=stream.name if stream else "", @@ -411,32 +379,18 @@ def subscriber( # type: ignore[override] if stream: stream.subjects.append(handler.subject) - def consumer_wrapper( - func: Callable[P_HandlerParams, T_HandlerReturn], - ) -> HandlerCallWrapper[ - Msg, - P_HandlerParams, - T_HandlerReturn, - ]: - handler_call, dependant = self._wrap_handler( - func, - extra_dependencies=dependencies, - no_ack=no_ack, - **original_kwargs, - ) - - handler.add_call( - handler=handler_call, - filter=filter, - middlewares=middlewares, - parser=parser or self._global_parser, - decoder=decoder or self._global_decoder, - dependant=dependant, - ) - - return handler_call - - return consumer_wrapper + return handler.add_call( + filter=filter, + parser=parser or self._global_parser, + decoder=decoder or self._global_decoder, + dependencies=(*self.dependencies, *dependencies), + middlewares=middlewares, + no_ack=no_ack, + is_validate=self._is_validate, + raw=not self._is_apply_types, + retry=retry, + producer=self, + ) @override def publisher( # type: ignore[override] @@ -488,6 +442,7 @@ async def publish( # type: ignore[override] if stream is None: assert self._producer, NOT_CONNECTED_YET # nosec B101 return await self._producer.publish(*args, **kwargs) + else: assert self._js_producer, NOT_CONNECTED_YET # nosec B101 return await self._js_producer.publish( diff --git a/faststream/nats/broker.pyi b/faststream/nats/broker.pyi index 6c999d7012..aa2ee045cb 100644 --- a/faststream/nats/broker.pyi +++ b/faststream/nats/broker.pyi @@ -34,17 +34,16 @@ from typing_extensions import override from faststream.asyncapi import schema as asyncapi from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter +from faststream.broker.handler import WrapperProtocol from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( CustomDecoder, CustomParser, Filter, - P_HandlerParams, T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.log import access_logger from faststream.nats.asyncapi import Handler, Publisher from faststream.nats.js_stream import JStream @@ -251,10 +250,7 @@ class NatsBroker( description: str | None = None, include_in_schema: bool = True, **__service_kwargs: Any, - ) -> Callable[ - [Callable[P_HandlerParams, T_HandlerReturn]], - HandlerCallWrapper[Msg, P_HandlerParams, T_HandlerReturn], - ]: ... + ) -> WrapperProtocol[Msg]: ... @override def publisher( # type: ignore[override] self, diff --git a/faststream/nats/fastapi.pyi b/faststream/nats/fastapi.pyi index ebdf8db6e0..e7663e107d 100644 --- a/faststream/nats/fastapi.pyi +++ b/faststream/nats/fastapi.pyi @@ -197,6 +197,7 @@ class NatsRouter(StreamRouter[Msg]): filter: Filter[NatsMessage] = default_filter, retry: bool = False, no_ack: bool = False, + max_workers: int = 1, # AsyncAPI information title: str | None = None, description: str | None = None, diff --git a/faststream/nats/handler.py b/faststream/nats/handler.py index edac61874c..f8388f5eba 100644 --- a/faststream/nats/handler.py +++ b/faststream/nats/handler.py @@ -1,30 +1,44 @@ import asyncio from contextlib import suppress -from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, Union, cast +from functools import partial, wraps +from logging import Logger +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + Awaitable, + Callable, + Dict, + Optional, + Sequence, + Union, + cast, +) import anyio from anyio.abc import TaskGroup, TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from fast_depends.core import CallModel +from fast_depends.dependencies import Depends from nats.aio.client import Client from nats.aio.msg import Msg from nats.aio.subscription import Subscription from nats.errors import TimeoutError from nats.js import JetStreamContext -from typing_extensions import Annotated, Doc, override +from typing_extensions import Annotated, Doc -from faststream.broker.handler import AsyncHandler +from faststream.broker.handler import BaseHandler from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.parsers import resolve_custom_func from faststream.broker.types import ( + AsyncPublisherProtocol, CustomDecoder, CustomParser, Filter, - P_HandlerParams, T_HandlerReturn, + WrappedReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper +from faststream.broker.wrapper import FakePublisher from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage from faststream.nats.parser import JsParser, Parser @@ -32,8 +46,11 @@ from faststream.types import AnyDict, SendableMessage from faststream.utils.path import compile_path +if TYPE_CHECKING: + from faststream.broker.handler import WrapperProtocol + -class LogicNatsHandler(AsyncHandler[Msg]): +class LogicNatsHandler(BaseHandler[Msg]): """A class to represent a NATS handler.""" subscription: Union[ @@ -57,6 +74,9 @@ def __init__( Callable[[StreamMessage[Any]], Dict[str, str]], Doc("Function to create log extra data by message"), ], + logger: Annotated[ + Optional[Logger], Doc("Logger to use with process message Watcher") + ] = None, queue: Annotated[ str, Doc("NATS queue name"), @@ -84,6 +104,10 @@ def __init__( int, Doc("Process up to this parameter messages concurrently"), ] = 1, + middlewares: Annotated[ + Sequence[Callable[[Msg], BaseMiddleware]], + Doc("Global middleware to use `on_receive`, `after_processed`"), + ] = (), # AsyncAPI information description: Annotated[ Optional[str], @@ -118,7 +142,9 @@ def __init__( description=description, include_in_schema=include_in_schema, title=title, + middlewares=middlewares, graceful_timeout=graceful_timeout, + logger=logger, ) self.max_workers = max_workers @@ -133,25 +159,55 @@ def __init__( def add_call( self, *, - handler: HandlerCallWrapper[Msg, P_HandlerParams, T_HandlerReturn], - dependant: CallModel[P_HandlerParams, T_HandlerReturn], parser: Optional[CustomParser[Msg, NatsMessage]], decoder: Optional[CustomDecoder[NatsMessage]], filter: Filter[NatsMessage], - middlewares: Optional[Sequence[Callable[[Msg], BaseMiddleware]]], - ) -> None: + middlewares: Sequence[Callable[[Msg], BaseMiddleware]], + dependencies: Sequence[Depends], + **wrap_kwargs: Any, + ) -> "WrapperProtocol[Msg]": parser_ = Parser if self.stream is None else JsParser - super().add_call( - handler=handler, - parser=resolve_custom_func(parser, parser_.parse_message), - decoder=resolve_custom_func(decoder, parser_.decode_message), - filter=filter, # type: ignore[arg-type] - dependant=dependant, - middlewares=middlewares, + return super().add_call( + parser_=resolve_custom_func(parser, parser_.parse_message), + decoder_=resolve_custom_func(decoder, parser_.decode_message), + filter_=filter, + middlewares_=middlewares, + dependencies_=dependencies, + **wrap_kwargs, ) - @override - async def start( # type: ignore[override] + def _process_message( + self, + func: Callable[[NatsMessage], Awaitable[T_HandlerReturn]], + watcher: Callable[..., AsyncContextManager[None]], + producer: AsyncPublisherProtocol, + ) -> Callable[ + [NatsMessage], + Awaitable[WrappedReturn[T_HandlerReturn]], + ]: + @wraps(func) + async def process_wrapper( + message: NatsMessage, + ) -> WrappedReturn[T_HandlerReturn]: + async with watcher(message): + r = await func(message) + + pub_response: Optional[AsyncPublisherProtocol] + if message.reply_to: + pub_response = FakePublisher( + partial( + producer.publish, + subject=message.reply_to, + ) + ) + else: + pub_response = None + + return r, pub_response + + return process_wrapper + + async def start( self, connection: Annotated[ Union[Client, JetStreamContext], diff --git a/faststream/nats/parser.py b/faststream/nats/parser.py index f62653d4aa..643d71049d 100644 --- a/faststream/nats/parser.py +++ b/faststream/nats/parser.py @@ -26,18 +26,27 @@ def __init__(self, is_js: bool) -> None: @overload async def parse_message( - self, message: List[Msg], *, path: Optional[AnyDict] = None + self, + message: List[Msg], + *, + path: Optional[AnyDict] = None, ) -> StreamMessage[List[Msg]]: ... @overload async def parse_message( - self, message: Msg, *, path: Optional[AnyDict] = None + self, + message: Msg, + *, + path: Optional[AnyDict] = None, ) -> StreamMessage[Msg]: ... async def parse_message( - self, message: Union[Msg, List[Msg]], *, path: Optional[AnyDict] = None + self, + message: Union[Msg, List[Msg]], + *, + path: Optional[AnyDict] = None, ) -> Union[ StreamMessage[Msg], StreamMessage[List[Msg]], @@ -79,7 +88,10 @@ async def decode_message( StreamMessage[Msg], StreamMessage[List[Msg]], ], - ) -> Union[List[DecodedMessage], DecodedMessage]: + ) -> Union[ + DecodedMessage, + List[DecodedMessage], + ]: if isinstance(msg.raw_message, list): data: List[DecodedMessage] = [] diff --git a/faststream/nats/router.pyi b/faststream/nats/router.pyi index 5a6d8a8c6c..2cec360501 100644 --- a/faststream/nats/router.pyi +++ b/faststream/nats/router.pyi @@ -87,6 +87,7 @@ class NatsRouter(BaseRouter): filter: Filter[NatsMessage] = default_filter, retry: bool = False, no_ack: bool = False, + max_workers: int = 1, # AsyncAPI information title: str | None = None, description: str | None = None, diff --git a/faststream/nats/shared/router.pyi b/faststream/nats/shared/router.pyi index 48dba128bc..835c65252a 100644 --- a/faststream/nats/shared/router.pyi +++ b/faststream/nats/shared/router.pyi @@ -49,6 +49,8 @@ class NatsRoute: middlewares: Sequence[Callable[[Msg], BaseMiddleware]] | None = None, filter: Filter[NatsMessage] = default_filter, retry: bool = False, + no_ack: bool = False, + max_workers: int = 1, # AsyncAPI information title: str | None = None, description: str | None = None, diff --git a/faststream/rabbit/broker.py b/faststream/rabbit/broker.py index 101091bf7a..382c44f685 100644 --- a/faststream/rabbit/broker.py +++ b/faststream/rabbit/broker.py @@ -394,12 +394,6 @@ def consumer_wrapper( Returns: The wrapped consumer function. - - Raises: - NotImplementedError: If silent animals are not supported. - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ handler_call, dependant = self._wrap_handler( func, @@ -546,12 +540,6 @@ async def process_wrapper( Returns: A tuple containing the return value of the handler function and an optional AsyncPublisherProtocol. - - Raises: - AssertionError: If the code reaches an unreachable point. - !!! note - - The above docstring is autogenerated by docstring-gen library (https://docstring-gen.airt.ai) """ async with watcher(message): r = await func(message) diff --git a/faststream/rabbit/handler.py b/faststream/rabbit/handler.py index fde7de48d5..1fa278635b 100644 --- a/faststream/rabbit/handler.py +++ b/faststream/rabbit/handler.py @@ -4,7 +4,7 @@ from fast_depends.core import CallModel from typing_extensions import override -from faststream.broker.handler import AsyncHandler +from faststream.broker.handler import BaseHandler from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.parsers import resolve_custom_func @@ -27,7 +27,7 @@ from faststream.types import AnyDict -class LogicHandler(AsyncHandler[aio_pika.IncomingMessage], BaseRMQInformation): +class LogicHandler(BaseHandler[aio_pika.IncomingMessage], BaseRMQInformation): """A class to handle logic for RabbitMQ message consumption. Attributes: diff --git a/faststream/redis/handler.py b/faststream/redis/handler.py index 3d185ccc3b..c5e4fdafb4 100644 --- a/faststream/redis/handler.py +++ b/faststream/redis/handler.py @@ -23,7 +23,7 @@ from typing_extensions import override from faststream._compat import json_loads -from faststream.broker.handler import AsyncHandler +from faststream.broker.handler import BaseHandler from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.parsers import resolve_custom_func @@ -43,7 +43,7 @@ from faststream.redis.schemas import INCORRECT_SETUP_MSG, ListSub, PubSub, StreamSub -class LogicRedisHandler(AsyncHandler[AnyRedisDict]): +class LogicRedisHandler(BaseHandler[AnyRedisDict]): """A class to represent a Redis handler.""" subscription: Optional[RPubSub] diff --git a/faststream/utils/classes.py b/faststream/utils/classes.py index 40481c22a3..316687eb0a 100644 --- a/faststream/utils/classes.py +++ b/faststream/utils/classes.py @@ -1,4 +1,6 @@ -from typing import Any, ClassVar +from typing import Any, ClassVar, Optional, cast + +from typing_extensions import Self class Singleton: @@ -12,9 +14,9 @@ class Singleton: _drop : sets the instance to None, allowing a new instance to be created """ - _instance: ClassVar[Any] = None + _instance: ClassVar[Optional[Self]] = None - def __new__(cls, *args: Any, **kwargs: Any) -> Any: + def __new__(cls, *args: Any, **kwargs: Any) -> Self: """Create a singleton instance of a class. Args: @@ -26,7 +28,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: """ if cls._instance is None: cls._instance = super().__new__(cls) - return cls._instance + return cast(Self, cls._instance) @classmethod def _drop(cls) -> None: diff --git a/faststream/utils/context/repository.py b/faststream/utils/context/repository.py index 10989931bf..7c3ea2f171 100644 --- a/faststream/utils/context/repository.py +++ b/faststream/utils/context/repository.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from contextvars import ContextVar, Token from inspect import _empty -from typing import Any, Dict, Iterator, Mapping, cast +from typing import Any, Dict, Iterator, Mapping from faststream.types import AnyDict from faststream.utils.classes import Singleton @@ -168,4 +168,4 @@ def scope(self, key: str, value: Any) -> Iterator[None]: self.reset_local(key, token) -context: ContextRepo = cast(ContextRepo, ContextRepo()) # type: ignore[redundant-cast] +context = ContextRepo() diff --git a/pyproject.toml b/pyproject.toml index b9e6d7ad6b..d723084b5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dynamic = ["version"] dependencies = [ "anyio>=3.7.1,<5", - "fast-depends>=2.2.6,<3", + "fast-depends>=2.3.1,<2.4.0", "typer>=0.9,<1", "typing-extensions>=4.8.0", ] diff --git a/tests/asyncapi/base/arguments.py b/tests/asyncapi/base/arguments.py index 94e3dd706e..677273ac93 100644 --- a/tests/asyncapi/base/arguments.py +++ b/tests/asyncapi/base/arguments.py @@ -333,7 +333,7 @@ class User(pydantic.BaseModel): else: class Config: - schema_extra = {"examples": [{"name": "john", "id": 1}]} # noqa: RUF012 + schema_extra = {"examples": [{"name": "john", "id": 1}]} broker = self.broker_class() diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index efdf4982ff..f8f2127acf 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -12,8 +12,8 @@ @pytest.mark.asyncio() class BrokerConsumeTestcase: # noqa: D101 @pytest.fixture() - def consume_broker(self, broker: BrokerUsecase): - return broker + def consume_broker(self, full_broker: BrokerUsecase): + return full_broker async def test_consume( self, diff --git a/tests/brokers/base/rpc.py b/tests/brokers/base/rpc.py index cf9ac49175..5e075ff19a 100644 --- a/tests/brokers/base/rpc.py +++ b/tests/brokers/base/rpc.py @@ -10,8 +10,8 @@ class BrokerRPCTestcase: # noqa: D101 @pytest.fixture() - def rpc_broker(self, broker): - return broker + def rpc_broker(self, full_broker): + return full_broker @pytest.mark.asyncio() async def test_rpc(self, queue: str, rpc_broker: BrokerUsecase): diff --git a/tests/cli/test_app.py b/tests/cli/test_app.py index 8baf72b354..4f08481ae4 100644 --- a/tests/cli/test_app.py +++ b/tests/cli/test_app.py @@ -144,7 +144,7 @@ async def raises(): app.broker, "close", async_mock.broker_stopped ), patch.object( anyio, "open_signal_receiver", fake_open_signal_receiver - ), pytest.raises(ValueError): # noqa: PT011 + ), pytest.raises(ValueError): await app.run() diff --git a/tests/docs/getting_started/routers/test_delay.py b/tests/docs/getting_started/routers/test_delay.py index 5888e0103a..85bca26a2b 100644 --- a/tests/docs/getting_started/routers/test_delay.py +++ b/tests/docs/getting_started/routers/test_delay.py @@ -15,7 +15,7 @@ async def test_delay_router_kafka(): ) async with TestKafkaBroker(broker) as br, TestApp(app): - next(iter(br.handlers.values())).calls[0][0].mock.assert_called_once_with( + next(iter(br.handlers.values())).calls[0].handler.mock.assert_called_once_with( {"name": "John", "user_id": 1} ) @@ -28,7 +28,7 @@ async def test_delay_router_rabbit(): ) async with TestRabbitBroker(broker) as br, TestApp(app): - next(iter(br.handlers.values())).calls[0][0].mock.assert_called_once_with( + next(iter(br.handlers.values())).calls[0].handler.mock.assert_called_once_with( {"name": "John", "user_id": 1} ) @@ -41,7 +41,7 @@ async def test_delay_router_nats(): ) async with TestNatsBroker(broker) as br, TestApp(app): - next(iter(br.handlers.values())).calls[0][0].mock.assert_called_once_with( + next(iter(br.handlers.values())).calls[0].handler.mock.assert_called_once_with( {"name": "John", "user_id": 1} ) @@ -54,6 +54,6 @@ async def test_delay_router_redis(): ) async with TestRedisBroker(broker) as br, TestApp(app): - next(iter(br.handlers.values())).calls[0][0].mock.assert_called_once_with( + next(iter(br.handlers.values())).calls[0].handler.mock.assert_called_once_with( {"name": "John", "user_id": 1} ) diff --git a/tests/examples/router/test_delay_registration.py b/tests/examples/router/test_delay_registration.py index f8ca1ef0d3..5859804e49 100644 --- a/tests/examples/router/test_delay_registration.py +++ b/tests/examples/router/test_delay_registration.py @@ -6,7 +6,7 @@ @pytest.mark.asyncio() async def test_example(): - handle = broker.handlers["prefix_in"].calls[0][0] + handle = broker.handlers["prefix_in"].calls[0].handler async with TestKafkaBroker(broker), TestApp(app): await handle.wait_call(3) diff --git a/tests/marks.py b/tests/marks.py index b1e7f0e4c7..911467508d 100644 --- a/tests/marks.py +++ b/tests/marks.py @@ -12,4 +12,4 @@ pydanticV1 = pytest.mark.skipif(PYDANTIC_V2, reason="requires PydanticV2") # noqa: N816 -pydanticV2 = pytest.mark.skipif(not PYDANTIC_V2, reason="requires PydanticV1") # noqa: N816 +pydanticV2 = pytest.mark.skipif(not PYDANTIC_V2, reason="requires PydanticV1") From 13fcec00b388aee33c49e1b6e8c1741e91b9221a Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 10 Jan 2024 21:12:14 +0300 Subject: [PATCH 02/87] tests: fix raw tests --- faststream/broker/fastapi/route.py | 2 +- faststream/broker/handler.py | 15 ++++++++++----- faststream/broker/types.py | 4 ++-- faststream/broker/wrapper.py | 6 ++++-- faststream/nats/broker.py | 6 +++++- faststream/nats/handler.py | 3 ++- tests/brokers/base/consume.py | 4 ++-- tests/brokers/base/router.py | 1 - 8 files changed, 26 insertions(+), 15 deletions(-) diff --git a/faststream/broker/fastapi/route.py b/faststream/broker/fastapi/route.py index f6564a41f4..ff2b7d8217 100644 --- a/faststream/broker/fastapi/route.py +++ b/faststream/broker/fastapi/route.py @@ -113,7 +113,7 @@ def __init__( self.handler = broker.subscriber( path, *extra, - _raw=True, + raw=True, _get_dependant=lambda call: dependant, **handle_kwargs, )( diff --git a/faststream/broker/handler.py b/faststream/broker/handler.py index 6e83642a3e..42936651e8 100644 --- a/faststream/broker/handler.py +++ b/faststream/broker/handler.py @@ -22,6 +22,7 @@ ) import anyio +from fast_depends import inject from fast_depends.core import CallModel, build_call_model from fast_depends.dependencies import Depends from typing_extensions import Self, override @@ -47,7 +48,6 @@ from faststream.broker.wrapper import HandlerCallWrapper from faststream.exceptions import HandlerException, StopConsume from faststream.types import AnyDict, SendableMessage -from faststream.utils import apply_types from faststream.utils.context.repository import context from faststream.utils.functions import fake_context, to_async @@ -201,12 +201,15 @@ def wrapper( HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], ], ]: + total_deps = (*dependencies_, *dependencies) + total_middlewares = (*middlewares_, *middlewares) + def real_wrapper( func: Callable[P_HandlerParams, T_HandlerReturn], ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: handler, dependant = self.wrap_handler( func=func, - dependencies=(*dependencies_, *dependencies), + dependencies=total_deps, **wrap_kwargs, ) @@ -217,7 +220,7 @@ def real_wrapper( filter=to_async(filter), parser=to_async(parser), decoder=to_async(decoder), - middlewares=(*middlewares_, *middlewares), + middlewares=total_middlewares, ) ) @@ -236,6 +239,7 @@ def wrap_handler( *, func: Callable[P_HandlerParams, T_HandlerReturn], no_ack: bool, + apply_types: bool, is_validate: bool, dependencies: Sequence[Depends], raw: bool, @@ -262,9 +266,10 @@ def wrap_handler( f = to_async(func) dependant = build_dep(f) - if not raw: - f = apply_types(None)(f, dependant) + if apply_types and not raw: + f = inject(None)(f, dependant) + if not raw: f = self._wrap_decode_message( func=f, params_ln=len(dependant.flat_params), diff --git a/faststream/broker/types.py b/faststream/broker/types.py index 16b0895ce3..0167e9f762 100644 --- a/faststream/broker/types.py +++ b/faststream/broker/types.py @@ -83,17 +83,17 @@ async def publish( *args: Any, correlation_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[SendableMessage]: + ) -> Any: """Publishes a message asynchronously. Args: message: The message to be published. + *args: Additional positional arguments. correlation_id: Optional correlation ID for the message. **kwargs: Additional keyword arguments. Returns: The published message, or None if the message was not published. - """ ... diff --git a/faststream/broker/wrapper.py b/faststream/broker/wrapper.py index 7b55385994..db076ee913 100644 --- a/faststream/broker/wrapper.py +++ b/faststream/broker/wrapper.py @@ -37,20 +37,22 @@ def __init__(self, method: Callable[..., Awaitable[SendableMessage]]) -> None: async def publish( self, message: SendableMessage, + *args: Any, correlation_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[SendableMessage]: + ) -> Any: """Publish a message. Args: message: The message to be published. + *args: Additinal positional arguments. correlation_id: Optional correlation ID for the message. **kwargs: Additional keyword arguments. Returns: The published message. """ - return await self.method(message, correlation_id=correlation_id, **kwargs) + return await self.method(message, *args, correlation_id=correlation_id, **kwargs) class HandlerCallWrapper(Generic[MsgType, P_HandlerParams, T_HandlerReturn]): diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index 6a149fc686..7bd91db9a7 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -13,6 +13,7 @@ Type, Union, ) +from fastapi import FastAPI import nats from fast_depends.dependencies import Depends @@ -278,6 +279,7 @@ def subscriber( # type: ignore[override] # custom ack_first: bool = False, retry: bool = False, + raw: bool = False, stream: Union[str, JStream, None] = None, # broker arguments dependencies: Sequence[Depends] = (), @@ -385,9 +387,11 @@ def subscriber( # type: ignore[override] decoder=decoder or self._global_decoder, dependencies=(*self.dependencies, *dependencies), middlewares=middlewares, + # wrapper kwargs no_ack=no_ack, is_validate=self._is_validate, - raw=not self._is_apply_types, + apply_types=self._is_apply_types, + raw=raw, retry=retry, producer=self, ) diff --git a/faststream/nats/handler.py b/faststream/nats/handler.py index f8388f5eba..1a6f3d5b94 100644 --- a/faststream/nats/handler.py +++ b/faststream/nats/handler.py @@ -43,6 +43,7 @@ from faststream.nats.message import NatsMessage from faststream.nats.parser import JsParser, Parser from faststream.nats.pull_sub import PullSub +from faststream.nats.producer import NatsFastProducer from faststream.types import AnyDict, SendableMessage from faststream.utils.path import compile_path @@ -180,7 +181,7 @@ def _process_message( self, func: Callable[[NatsMessage], Awaitable[T_HandlerReturn]], watcher: Callable[..., AsyncContextManager[None]], - producer: AsyncPublisherProtocol, + producer: NatsFastProducer, ) -> Callable[ [NatsMessage], Awaitable[WrappedReturn[T_HandlerReturn]], diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index f8f2127acf..efdf4982ff 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -12,8 +12,8 @@ @pytest.mark.asyncio() class BrokerConsumeTestcase: # noqa: D101 @pytest.fixture() - def consume_broker(self, full_broker: BrokerUsecase): - return full_broker + def consume_broker(self, broker: BrokerUsecase): + return broker async def test_consume( self, diff --git a/tests/brokers/base/router.py b/tests/brokers/base/router.py index 57f95b9912..78dcbd283d 100644 --- a/tests/brokers/base/router.py +++ b/tests/brokers/base/router.py @@ -297,7 +297,6 @@ async def dep1(s): mock.dep1() async def dep2(s): - mock.dep1.assert_called_once() mock.dep2() router = type(router)(dependencies=(Depends(dep1),)) From d82f518a6b2a8d7e41c1ace2801a1c6392774de2 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 10 Jan 2024 21:57:26 +0300 Subject: [PATCH 03/87] refactor: nats completed --- faststream/broker/core/abc.py | 7 ++- faststream/broker/fastapi/route.py | 5 +- faststream/broker/handler.py | 85 ++++++++++++++++++++++++------ faststream/broker/publisher.py | 1 + faststream/broker/test.py | 17 +----- faststream/broker/wrapper.py | 4 +- faststream/kafka/test.py | 2 +- faststream/nats/broker.py | 10 +--- faststream/nats/handler.py | 2 +- faststream/nats/test.py | 2 +- faststream/rabbit/test.py | 2 +- faststream/redis/test.py | 2 +- serve.py | 17 ++++++ 13 files changed, 105 insertions(+), 51 deletions(-) create mode 100644 serve.py diff --git a/faststream/broker/core/abc.py b/faststream/broker/core/abc.py index 97439d586a..a9b3c9756b 100644 --- a/faststream/broker/core/abc.py +++ b/faststream/broker/core/abc.py @@ -170,7 +170,12 @@ def __init__( midd_args: Sequence[Callable[[MsgType], BaseMiddleware]] = ( middlewares or empty_middleware ) - self.middlewares = (CriticalLogMiddleware(logger, log_level), *midd_args) + + if not is_test_env(): + self.middlewares = (CriticalLogMiddleware(logger, log_level), *midd_args) + else: + self.middlewares = midd_args + self.dependencies = dependencies self._connection_args = (url, *args) diff --git a/faststream/broker/fastapi/route.py b/faststream/broker/fastapi/route.py index ff2b7d8217..3218b9ddbe 100644 --- a/faststream/broker/fastapi/route.py +++ b/faststream/broker/fastapi/route.py @@ -114,7 +114,7 @@ def __init__( path, *extra, raw=True, - _get_dependant=lambda call: dependant, + get_dependant=lambda call: dependant, **handle_kwargs, )( handler # type: ignore[arg-type] @@ -210,9 +210,6 @@ async def app(message: NativeMessage[Any]) -> SendableMessage: Returns: The sendable message - - Raises: - TypeError: If the body of the message is not a dictionary """ body = message.decoded_body diff --git a/faststream/broker/handler.py b/faststream/broker/handler.py index 42936651e8..a03804a7de 100644 --- a/faststream/broker/handler.py +++ b/faststream/broker/handler.py @@ -25,9 +25,10 @@ from fast_depends import inject from fast_depends.core import CallModel, build_call_model from fast_depends.dependencies import Depends +from pydantic import create_model from typing_extensions import Self, override -from faststream._compat import IS_OPTIMIZED +from faststream._compat import IS_OPTIMIZED, PYDANTIC_V2 from faststream.asyncapi.base import AsyncAPIOperation from faststream.asyncapi.message import parse_handler_params from faststream.asyncapi.utils import to_camelcase @@ -47,7 +48,7 @@ from faststream.broker.utils import get_watcher, set_message_context from faststream.broker.wrapper import HandlerCallWrapper from faststream.exceptions import HandlerException, StopConsume -from faststream.types import AnyDict, SendableMessage +from faststream.types import AnyDict, F_Return, F_Spec, SendableMessage from faststream.utils.context.repository import context from faststream.utils.functions import fake_context, to_async @@ -58,6 +59,8 @@ from faststream.broker.message import StreamMessage class WrapperProtocol(Generic[MsgType], Protocol): + """Annotation class to represent @subsriber return type.""" + @overload def __call__( self, @@ -77,7 +80,7 @@ def __call__( @overload def __call__( self, - func: Callable[P_HandlerParams, T_HandlerReturn] = None, + func: Callable[P_HandlerParams, T_HandlerReturn], *, filter: Filter["StreamMessage[MsgType]"], parser: CustomParser[MsgType, Any], @@ -238,21 +241,29 @@ def wrap_handler( self, *, func: Callable[P_HandlerParams, T_HandlerReturn], - no_ack: bool, apply_types: bool, is_validate: bool, dependencies: Sequence[Depends], - raw: bool, - retry: int, + no_ack: bool = False, + raw: bool = False, + retry: Union[bool, int] = False, + get_dependant: Optional[Any] = None, **process_kwargs: Any, ) -> Tuple[ HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], CallModel[P_HandlerParams, T_HandlerReturn], ]: - build_dep = partial( - build_call_model, - cast=is_validate, - extra_dependencies=dependencies, + build_dep = build_dep = cast( + Callable[ + [Callable[F_Spec, F_Return]], + CallModel[F_Spec, F_Return], + ], + get_dependant + or partial( + build_call_model, + cast=is_validate, + extra_dependencies=dependencies, + ), ) if isinstance(func, HandlerCallWrapper): @@ -266,14 +277,20 @@ def wrap_handler( f = to_async(func) dependant = build_dep(f) - if apply_types and not raw: - f = inject(None)(f, dependant) + if getattr(dependant, "flat_params", None) is None: # FastAPI case + extra = [build_dep(d.dependency) for d in dependencies] + dependant.dependencies.extend(extra) + dependant = _patch_fastapi_dependant(dependant) - if not raw: - f = self._wrap_decode_message( - func=f, - params_ln=len(dependant.flat_params), - ) + else: + if apply_types and not raw: + f = inject(None)(f, dependant) + + if not raw: + f = self._wrap_decode_message( + func=f, + params_ln=len(dependant.flat_params), + ) f = self._process_message( func=f, @@ -532,3 +549,37 @@ async def wait_release(self, timeout: Optional[float] = None) -> None: if timeout: with anyio.move_on_after(timeout): await self.queue.join() + + +def _patch_fastapi_dependant( + dependant: CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]], +) -> CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]]: + """Patch FastAPI dependant. + + Args: + dependant: The dependant to be patched. + + Returns: + The patched dependant. + """ + params = dependant.query_params + dependant.body_params # type: ignore[attr-defined] + + for d in dependant.dependencies: + params.extend(d.query_params + d.body_params) # type: ignore[attr-defined] + + params_unique = {} + params_names = set() + for p in params: + if p.name not in params_names: + params_names.add(p.name) + info = p.field_info if PYDANTIC_V2 else p + params_unique[p.name] = (info.annotation, info.default) + + dependant.model = create_model( # type: ignore[call-overload] + getattr(dependant.call.__name__, "__name__", type(dependant.call).__name__), + **params_unique, + ) + dependant.custom_fields = {} + dependant.flat_params = params_unique # type: ignore[assignment,misc] + + return dependant diff --git a/faststream/broker/publisher.py b/faststream/broker/publisher.py index 1003c377b7..a3db1791f7 100644 --- a/faststream/broker/publisher.py +++ b/faststream/broker/publisher.py @@ -96,6 +96,7 @@ async def publish( Args: message: The message to be published. + *args: Additional positional arguments. correlation_id: Optional correlation ID for the message. **kwargs: Additional keyword arguments. diff --git a/faststream/broker/test.py b/faststream/broker/test.py index 244c58c0fb..6d917d451c 100644 --- a/faststream/broker/test.py +++ b/faststream/broker/test.py @@ -12,7 +12,6 @@ from faststream.broker.core.abc import BrokerUsecase from faststream.broker.core.asynchronous import BrokerAsyncUsecase from faststream.broker.handler import BaseHandler -from faststream.broker.middlewares import CriticalLogMiddleware from faststream.broker.wrapper import HandlerCallWrapper from faststream.types import SendableMessage, SettingField from faststream.utils.ast import is_contains_context_name @@ -221,11 +220,6 @@ def _fake_close( exc_val: Optional[BaseException] = None, exec_tb: Optional[TracebackType] = None, ) -> None: - broker.middlewares = [ - CriticalLogMiddleware(broker.logger, broker.log_level), - *broker.middlewares, - ] - for p in broker._publishers.values(): if p._fake_handler: p.reset_test() @@ -233,8 +227,8 @@ def _fake_close( for h in broker.handlers.values(): h.running = False - for h in h.calls: - h.handler.reset_test() + for call in h.calls: + call.handler.reset_test() @staticmethod @abstractmethod @@ -267,14 +261,7 @@ def patch_broker_calls(broker: BrokerUsecase[Any, Any]) -> None: Returns: None. - """ - broker.middlewares = tuple( - filter( # type: ignore[assignment] - lambda x: not isinstance(x, CriticalLogMiddleware), - broker.middlewares, - ) - ) broker._abc_start() for handler in broker.handlers.values(): diff --git a/faststream/broker/wrapper.py b/faststream/broker/wrapper.py index db076ee913..5235e503a7 100644 --- a/faststream/broker/wrapper.py +++ b/faststream/broker/wrapper.py @@ -52,7 +52,9 @@ async def publish( Returns: The published message. """ - return await self.method(message, *args, correlation_id=correlation_id, **kwargs) + return await self.method( + message, *args, correlation_id=correlation_id, **kwargs + ) class HandlerCallWrapper(Generic[MsgType, P_HandlerParams, T_HandlerReturn]): diff --git a/faststream/kafka/test.py b/faststream/kafka/test.py index 314f9100b0..8a1d0df548 100644 --- a/faststream/kafka/test.py +++ b/faststream/kafka/test.py @@ -35,7 +35,7 @@ def create_publisher_fake_subscriber( @broker.subscriber( # type: ignore[call-overload,misc] publisher.topic, batch=publisher.batch, - _raw=True, + raw=True, ) def f(msg: Any) -> None: pass diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index 7bd91db9a7..305566fe92 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -13,7 +13,6 @@ Type, Union, ) -from fastapi import FastAPI import nats from fast_depends.dependencies import Depends @@ -278,8 +277,6 @@ def subscriber( # type: ignore[override] inbox_prefix: bytes = api.INBOX_PREFIX, # custom ack_first: bool = False, - retry: bool = False, - raw: bool = False, stream: Union[str, JStream, None] = None, # broker arguments dependencies: Sequence[Depends] = (), @@ -287,13 +284,12 @@ def subscriber( # type: ignore[override] decoder: Optional[CustomDecoder[NatsMessage]] = None, middlewares: Sequence[Callable[[Msg], BaseMiddleware]] = (), filter: Filter[NatsMessage] = default_filter, - no_ack: bool = False, max_workers: int = 1, # AsyncAPI information title: Optional[str] = None, description: Optional[str] = None, include_in_schema: bool = True, - **original_kwargs: Any, + **wrapper_kwargs: Any, ) -> "WrapperProtocol[Msg]": stream = stream_builder.stream(stream) @@ -388,12 +384,10 @@ def subscriber( # type: ignore[override] dependencies=(*self.dependencies, *dependencies), middlewares=middlewares, # wrapper kwargs - no_ack=no_ack, is_validate=self._is_validate, apply_types=self._is_apply_types, - raw=raw, - retry=retry, producer=self, + **wrapper_kwargs, ) @override diff --git a/faststream/nats/handler.py b/faststream/nats/handler.py index 1a6f3d5b94..e201ce3e78 100644 --- a/faststream/nats/handler.py +++ b/faststream/nats/handler.py @@ -42,8 +42,8 @@ from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage from faststream.nats.parser import JsParser, Parser -from faststream.nats.pull_sub import PullSub from faststream.nats.producer import NatsFastProducer +from faststream.nats.pull_sub import PullSub from faststream.types import AnyDict, SendableMessage from faststream.utils.path import compile_path diff --git a/faststream/nats/test.py b/faststream/nats/test.py index 50acf52d61..262186cba4 100644 --- a/faststream/nats/test.py +++ b/faststream/nats/test.py @@ -28,7 +28,7 @@ def create_publisher_fake_subscriber( broker: NatsBroker, publisher: Publisher, ) -> HandlerCallWrapper[Any, Any, Any]: - @broker.subscriber(publisher.subject, _raw=True) + @broker.subscriber(publisher.subject, raw=True) def f(msg: Any) -> None: pass diff --git a/faststream/rabbit/test.py b/faststream/rabbit/test.py index b8bd229d3b..260fa78a6b 100644 --- a/faststream/rabbit/test.py +++ b/faststream/rabbit/test.py @@ -51,7 +51,7 @@ def create_publisher_fake_subscriber( @broker.subscriber( queue=publisher.queue, exchange=publisher.exchange, - _raw=True, + raw=True, ) def f(msg: Any) -> None: pass diff --git a/faststream/redis/test.py b/faststream/redis/test.py index 7e5045ed6e..59515846dc 100644 --- a/faststream/redis/test.py +++ b/faststream/redis/test.py @@ -33,7 +33,7 @@ def create_publisher_fake_subscriber( channel=publisher.channel, list=publisher.list, stream=publisher.stream, - _raw=True, + raw=True, ) def f(msg: Any) -> None: pass diff --git a/serve.py b/serve.py new file mode 100644 index 0000000000..0b37927ce5 --- /dev/null +++ b/serve.py @@ -0,0 +1,17 @@ +from fastapi import FastAPI + +from faststream.nats.fastapi import Logger, NatsRouter + +router = NatsRouter() +app = FastAPI(lifespan=router.lifespan_context) +app.include_router(router) + + +@router.subscriber("test") +async def handler(msg, logger: Logger): + logger.info(msg) + + +@router.after_startup +async def t(app): + await router.broker.publish("test", "test") From 8a33552e315f95e85b3a1fc745a691bb0823aba0 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 10 Jan 2024 21:57:46 +0300 Subject: [PATCH 04/87] chore: remove useless --- serve.py | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 serve.py diff --git a/serve.py b/serve.py deleted file mode 100644 index 0b37927ce5..0000000000 --- a/serve.py +++ /dev/null @@ -1,17 +0,0 @@ -from fastapi import FastAPI - -from faststream.nats.fastapi import Logger, NatsRouter - -router = NatsRouter() -app = FastAPI(lifespan=router.lifespan_context) -app.include_router(router) - - -@router.subscriber("test") -async def handler(msg, logger: Logger): - logger.info(msg) - - -@router.after_startup -async def t(app): - await router.broker.publish("test", "test") From 069537921fe77a9875d92a3c497a9b9e2938f680 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 10 Jan 2024 22:02:01 +0300 Subject: [PATCH 05/87] lint: fix ruff warnings --- tests/asyncapi/base/arguments.py | 2 +- tests/cli/test_app.py | 2 +- tests/docs/getting_started/cli/test_kafka_context.py | 4 ++-- tests/docs/getting_started/cli/test_nats_context.py | 4 ++-- tests/docs/getting_started/cli/test_rabbit_context.py | 4 ++-- tests/docs/getting_started/cli/test_redis_context.py | 4 ++-- .../getting_started/config/test_settings_base_1.py | 4 ++-- .../getting_started/config/test_settings_base_2.py | 4 ++-- tests/docs/getting_started/config/test_settings_env.py | 4 ++-- tests/docs/getting_started/config/test_usage.py | 4 ++-- tests/docs/getting_started/lifespan/test_basic.py | 10 +++++----- tests/marks.py | 4 ++-- 12 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/asyncapi/base/arguments.py b/tests/asyncapi/base/arguments.py index 677273ac93..94e3dd706e 100644 --- a/tests/asyncapi/base/arguments.py +++ b/tests/asyncapi/base/arguments.py @@ -333,7 +333,7 @@ class User(pydantic.BaseModel): else: class Config: - schema_extra = {"examples": [{"name": "john", "id": 1}]} + schema_extra = {"examples": [{"name": "john", "id": 1}]} # noqa: RUF012 broker = self.broker_class() diff --git a/tests/cli/test_app.py b/tests/cli/test_app.py index 4f08481ae4..fc02f3dad7 100644 --- a/tests/cli/test_app.py +++ b/tests/cli/test_app.py @@ -144,7 +144,7 @@ async def raises(): app.broker, "close", async_mock.broker_stopped ), patch.object( anyio, "open_signal_receiver", fake_open_signal_receiver - ), pytest.raises(ValueError): + ), pytest.raises(ValueError, match="Ooops!"): await app.run() diff --git a/tests/docs/getting_started/cli/test_kafka_context.py b/tests/docs/getting_started/cli/test_kafka_context.py index 467dd54b67..a46bf0d43f 100644 --- a/tests/docs/getting_started/cli/test_kafka_context.py +++ b/tests/docs/getting_started/cli/test_kafka_context.py @@ -2,11 +2,11 @@ from faststream import TestApp, context from faststream.kafka import TestKafkaBroker -from tests.marks import pydanticV2 +from tests.marks import pydantic_v2 from tests.mocks import mock_pydantic_settings_env -@pydanticV2 +@pydantic_v2 @pytest.mark.asyncio() async def test(): with mock_pydantic_settings_env({"host": "localhost"}): diff --git a/tests/docs/getting_started/cli/test_nats_context.py b/tests/docs/getting_started/cli/test_nats_context.py index c0d8a422e0..a55f760894 100644 --- a/tests/docs/getting_started/cli/test_nats_context.py +++ b/tests/docs/getting_started/cli/test_nats_context.py @@ -2,11 +2,11 @@ from faststream import TestApp, context from faststream.nats import TestNatsBroker -from tests.marks import pydanticV2 +from tests.marks import pydantic_v2 from tests.mocks import mock_pydantic_settings_env -@pydanticV2 +@pydantic_v2 @pytest.mark.asyncio() async def test(): with mock_pydantic_settings_env({"host": "localhost"}): diff --git a/tests/docs/getting_started/cli/test_rabbit_context.py b/tests/docs/getting_started/cli/test_rabbit_context.py index 4f57569d2c..c7b0f32bb4 100644 --- a/tests/docs/getting_started/cli/test_rabbit_context.py +++ b/tests/docs/getting_started/cli/test_rabbit_context.py @@ -2,11 +2,11 @@ from faststream import TestApp, context from faststream.rabbit import TestRabbitBroker -from tests.marks import pydanticV2 +from tests.marks import pydantic_v2 from tests.mocks import mock_pydantic_settings_env -@pydanticV2 +@pydantic_v2 @pytest.mark.asyncio() async def test(): with mock_pydantic_settings_env( diff --git a/tests/docs/getting_started/cli/test_redis_context.py b/tests/docs/getting_started/cli/test_redis_context.py index 33e46621c9..d9e35a459b 100644 --- a/tests/docs/getting_started/cli/test_redis_context.py +++ b/tests/docs/getting_started/cli/test_redis_context.py @@ -2,11 +2,11 @@ from faststream import TestApp, context from faststream.redis import TestRedisBroker -from tests.marks import pydanticV2 +from tests.marks import pydantic_v2 from tests.mocks import mock_pydantic_settings_env -@pydanticV2 +@pydantic_v2 @pytest.mark.asyncio() async def test(): with mock_pydantic_settings_env({"host": "redis://localhost:6380"}): diff --git a/tests/docs/getting_started/config/test_settings_base_1.py b/tests/docs/getting_started/config/test_settings_base_1.py index d186a40158..fd42ba6533 100644 --- a/tests/docs/getting_started/config/test_settings_base_1.py +++ b/tests/docs/getting_started/config/test_settings_base_1.py @@ -1,7 +1,7 @@ -from tests.marks import pydanticV1 +from tests.marks import pydantic_v1 -@pydanticV1 +@pydantic_v1 def test_exists_and_valid(): from docs.docs_src.getting_started.config.settings_base_1 import settings diff --git a/tests/docs/getting_started/config/test_settings_base_2.py b/tests/docs/getting_started/config/test_settings_base_2.py index 2704c0d186..780f73f278 100644 --- a/tests/docs/getting_started/config/test_settings_base_2.py +++ b/tests/docs/getting_started/config/test_settings_base_2.py @@ -1,8 +1,8 @@ -from tests.marks import pydanticV2 +from tests.marks import pydantic_v2 from tests.mocks import mock_pydantic_settings_env -@pydanticV2 +@pydantic_v2 def test_exists_and_valid(): with mock_pydantic_settings_env({"url": "localhost:9092"}): from docs.docs_src.getting_started.config.settings_base_2 import settings diff --git a/tests/docs/getting_started/config/test_settings_env.py b/tests/docs/getting_started/config/test_settings_env.py index 67c5fdf6c7..960485ed4c 100644 --- a/tests/docs/getting_started/config/test_settings_env.py +++ b/tests/docs/getting_started/config/test_settings_env.py @@ -1,8 +1,8 @@ -from tests.marks import pydanticV2 +from tests.marks import pydantic_v2 from tests.mocks import mock_pydantic_settings_env -@pydanticV2 +@pydantic_v2 def test_exists_and_valid(): with mock_pydantic_settings_env({"url": "localhost:9092"}): from docs.docs_src.getting_started.config.settings_env import settings diff --git a/tests/docs/getting_started/config/test_usage.py b/tests/docs/getting_started/config/test_usage.py index 4a9667c1d8..cdceaf9e8d 100644 --- a/tests/docs/getting_started/config/test_usage.py +++ b/tests/docs/getting_started/config/test_usage.py @@ -1,8 +1,8 @@ -from tests.marks import pydanticV2 +from tests.marks import pydantic_v2 from tests.mocks import mock_pydantic_settings_env -@pydanticV2 +@pydantic_v2 def test_exists_and_valid(): with mock_pydantic_settings_env({"url": "localhost:9092"}): from docs.docs_src.getting_started.config.usage import settings diff --git a/tests/docs/getting_started/lifespan/test_basic.py b/tests/docs/getting_started/lifespan/test_basic.py index fafdd17114..3d2dd82e1a 100644 --- a/tests/docs/getting_started/lifespan/test_basic.py +++ b/tests/docs/getting_started/lifespan/test_basic.py @@ -5,11 +5,11 @@ from faststream.nats import TestNatsBroker from faststream.rabbit import TestRabbitBroker from faststream.redis import TestRedisBroker -from tests.marks import pydanticV2 +from tests.marks import pydantic_v2 from tests.mocks import mock_pydantic_settings_env -@pydanticV2 +@pydantic_v2 @pytest.mark.asyncio() async def test_rabbit_basic_lifespan(): with mock_pydantic_settings_env({"host": "localhost"}): @@ -19,7 +19,7 @@ async def test_rabbit_basic_lifespan(): assert context.get("settings").host == "localhost" -@pydanticV2 +@pydantic_v2 @pytest.mark.asyncio() async def test_kafka_basic_lifespan(): with mock_pydantic_settings_env({"host": "localhost"}): @@ -29,7 +29,7 @@ async def test_kafka_basic_lifespan(): assert context.get("settings").host == "localhost" -@pydanticV2 +@pydantic_v2 @pytest.mark.asyncio() async def test_nats_basic_lifespan(): with mock_pydantic_settings_env({"host": "localhost"}): @@ -39,7 +39,7 @@ async def test_nats_basic_lifespan(): assert context.get("settings").host == "localhost" -@pydanticV2 +@pydantic_v2 @pytest.mark.asyncio() async def test_redis_basic_lifespan(): with mock_pydantic_settings_env({"host": "localhost"}): diff --git a/tests/marks.py b/tests/marks.py index 911467508d..4a41446988 100644 --- a/tests/marks.py +++ b/tests/marks.py @@ -10,6 +10,6 @@ sys.version_info < (3, 10), reason="requires python3.10+" ) -pydanticV1 = pytest.mark.skipif(PYDANTIC_V2, reason="requires PydanticV2") # noqa: N816 +pydantic_v1 = pytest.mark.skipif(PYDANTIC_V2, reason="requires PydanticV2") -pydanticV2 = pytest.mark.skipif(not PYDANTIC_V2, reason="requires PydanticV1") +pydantic_v2 = pytest.mark.skipif(not PYDANTIC_V2, reason="requires PydanticV1") From 609d24388a3eea9e2afb199fe739716de6ada746 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 10 Jan 2024 22:08:09 +0300 Subject: [PATCH 06/87] refactor: remove useless broker methods --- faststream/broker/core/abc.py | 200 +--------------------------------- 1 file changed, 2 insertions(+), 198 deletions(-) diff --git a/faststream/broker/core/abc.py b/faststream/broker/core/abc.py index a9b3c9756b..e4fa5a6c85 100644 --- a/faststream/broker/core/abc.py +++ b/faststream/broker/core/abc.py @@ -1,29 +1,20 @@ import logging import warnings from abc import ABC, abstractmethod -from functools import partial -from itertools import chain from types import TracebackType from typing import ( Any, - Awaitable, Callable, Generic, List, Mapping, Optional, Sequence, - Sized, - Tuple, Type, Union, - cast, ) -from fast_depends._compat import PYDANTIC_V2 -from fast_depends.core import CallModel, build_call_model from fast_depends.dependencies import Depends -from pydantic import create_model from faststream._compat import is_test_env from faststream.asyncapi import schema as asyncapi @@ -32,7 +23,6 @@ from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware, CriticalLogMiddleware from faststream.broker.publisher import BasePublisher -from faststream.broker.push_back_watcher import WatcherContext from faststream.broker.router import BrokerRouter from faststream.broker.types import ( ConnectionType, @@ -45,18 +35,14 @@ ) from faststream.broker.utils import ( change_logger_handlers, - get_watcher, - set_message_context, ) from faststream.broker.wrapper import HandlerCallWrapper from faststream.log import access_logger from faststream.security import BaseSecurity -from faststream.types import AnyDict, F_Return, F_Spec -from faststream.utils import apply_types, context +from faststream.types import AnyDict +from faststream.utils import context from faststream.utils.functions import ( - fake_context, get_function_positional_arguments, - to_async, ) @@ -93,7 +79,6 @@ class BrokerUsecase( subscriber : decorator to register a subscriber publisher : register a publisher _wrap_decode_message : wrap a message decoding function - """ logger: Optional[logging.Logger] @@ -205,7 +190,6 @@ def include_router(self, router: BrokerRouter[Any, MsgType]) -> None: Returns: None - """ for r in router._handlers: self.subscriber(*r.args, **r.kwargs)(r.call) @@ -220,7 +204,6 @@ def include_routers(self, *routers: BrokerRouter[Any, MsgType]) -> None: Returns: None - """ for r in routers: self.include_router(r) @@ -234,7 +217,6 @@ def _resolve_connection_kwargs(self, *args: Any, **kwargs: Any) -> AnyDict: Returns: A dictionary containing the resolved connection keyword arguments. - """ arguments = get_function_positional_arguments(self.__init__) # type: ignore init_kwargs = { @@ -248,106 +230,6 @@ def _resolve_connection_kwargs(self, *args: Any, **kwargs: Any) -> AnyDict: } return {**init_kwargs, **connect_kwargs} - def _wrap_handler( - self, - func: Union[ - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - Callable[P_HandlerParams, T_HandlerReturn], - ], - *, - retry: Union[bool, int] = False, - extra_dependencies: Sequence[Depends] = (), - no_ack: bool = False, - _raw: bool = False, - _get_dependant: Optional[Any] = None, - _process_kwargs: Optional[AnyDict] = None, - ) -> Tuple[ - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - CallModel[Any, Any], - ]: - """Wrap a handler function. - - Args: - func: The handler function to wrap. - retry: Whether to retry the handler function if it fails. Can be a boolean or an integer specifying the number of retries. - extra_dependencies: Additional dependencies for the handler function. - no_ack: Whether not to ack/nack/reject messages. - _raw: Whether to use the raw handler function. - _get_dependant: The dependant function to use. - **broker_log_context_kwargs: Additional keyword arguments for the broker log context. - - Returns: - A tuple containing the wrapped handler function and the call model. - - Raises: - NotImplementedError: If silent animals are not supported. - - """ - final_extra_deps = tuple(chain(extra_dependencies, self.dependencies)) - - build_dep = cast( - Callable[ - [Callable[F_Spec, F_Return]], - CallModel[F_Spec, F_Return], - ], - _get_dependant - or partial( - build_call_model, - cast=self._is_validate, - ), - ) - - if isinstance(func, HandlerCallWrapper): - handler_call, func = func, func._original_call - if handler_call._wrapped_call is not None: - return handler_call, build_dep(func) - - else: - handler_call = HandlerCallWrapper(func) - - f = to_async(func) - - dependant = build_dep(f) - - extra = [build_dep(d.dependency) for d in final_extra_deps] - extend_dependencies(extra, dependant) - - if getattr(dependant, "flat_params", None) is None: # handle FastAPI Dependant - dependant = _patch_fastapi_dependant(dependant) - params = () - - else: - params = set( - chain( - dependant.flat_params.keys(), - *(d.flat_params.keys() for d in extra), - ) - ) - - if self._is_apply_types and not _raw: - f = apply_types(None)(f, dependant) # type: ignore[arg-type,assignment] - - decode_f = self._wrap_decode_message( - func=f, - _raw=_raw, - params=params, - ) - - process_f = self._process_message( - func=decode_f, - watcher=( - partial(WatcherContext, watcher=get_watcher(self.logger, retry)) # type: ignore[arg-type] - if not no_ack - else fake_context - ), - **(_process_kwargs or {}), - ) - - process_f = set_message_context(process_f) - - handler_call.set_wrapped(process_f) - return handler_call, dependant - def _abc_start(self) -> None: if not self.started: self.started = True @@ -465,84 +347,6 @@ def publisher( Returns: The published publisher. - - Raises: - NotImplementedError: If the method is not implemented. """ self._publishers = {**self._publishers, key: publisher} return publisher - - @abstractmethod - def _wrap_decode_message( - self, - func: Callable[..., Awaitable[T_HandlerReturn]], - params: Sized = (), - _raw: bool = False, - ) -> Callable[[StreamMessage[MsgType]], Awaitable[T_HandlerReturn]]: - """Wrap a decoding message function. - - Args: - func: The function to wrap. - params: The parameters to pass to the function. - _raw: Whether to return the raw message or not. - - Returns: - The wrapped function. - - Raises: - NotImplementedError: If the method is not implemented. - """ - raise NotImplementedError() - - -def extend_dependencies( - extra: Sequence[CallModel[Any, Any]], dependant: CallModel[Any, Any] -) -> CallModel[Any, Any]: - """Extends the dependencies of a function or FastAPI dependency. - - Args: - extra: Additional dependencies to be added. - dependant: The function or FastAPI dependency whose dependencies will be extended. - - Returns: - The updated function or FastAPI dependency. - """ - if isinstance(dependant, CallModel): - dependant.extra_dependencies = (*dependant.extra_dependencies, *extra) - else: # FastAPI dependencies - dependant.dependencies.extend(extra) - return dependant - - -def _patch_fastapi_dependant( - dependant: CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]], -) -> CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]]: - """Patch FastAPI dependant. - - Args: - dependant: The dependant to be patched. - - Returns: - The patched dependant. - """ - params = dependant.query_params + dependant.body_params # type: ignore[attr-defined] - - for d in dependant.dependencies: - params.extend(d.query_params + d.body_params) # type: ignore[attr-defined] - - params_unique = {} - params_names = set() - for p in params: - if p.name not in params_names: - params_names.add(p.name) - info = p.field_info if PYDANTIC_V2 else p - params_unique[p.name] = (info.annotation, info.default) - - dependant.model = create_model( # type: ignore[call-overload] - getattr(dependant.call.__name__, "__name__", type(dependant.call).__name__), - **params_unique, - ) - dependant.custom_fields = {} - dependant.flat_params = params_unique # type: ignore[assignment,misc] - - return dependant From c223cc78be1198c1a5a844a2e8d431af4456e98d Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 10 Jan 2024 22:20:24 +0300 Subject: [PATCH 07/87] refactor: move wrap handler logic to mixin class --- faststream/broker/core/handler_wrapper.py | 263 ++++++++++++++++++++++ faststream/broker/handler.py | 215 +----------------- 2 files changed, 271 insertions(+), 207 deletions(-) create mode 100644 faststream/broker/core/handler_wrapper.py diff --git a/faststream/broker/core/handler_wrapper.py b/faststream/broker/core/handler_wrapper.py new file mode 100644 index 0000000000..ddcaf70675 --- /dev/null +++ b/faststream/broker/core/handler_wrapper.py @@ -0,0 +1,263 @@ +from functools import partial, wraps +from logging import Logger +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + Awaitable, + Callable, + Generic, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from fast_depends import inject +from fast_depends.core import CallModel, build_call_model +from fast_depends.dependencies import Depends +from pydantic import create_model + +from faststream._compat import PYDANTIC_V2 +from faststream.broker.middlewares import BaseMiddleware +from faststream.broker.push_back_watcher import WatcherContext +from faststream.broker.types import ( + CustomDecoder, + CustomParser, + Filter, + MsgType, + P_HandlerParams, + T_HandlerReturn, + WrappedReturn, +) +from faststream.broker.utils import get_watcher, set_message_context +from faststream.broker.wrapper import HandlerCallWrapper +from faststream.types import F_Return, F_Spec +from faststream.utils.functions import fake_context, to_async + +if TYPE_CHECKING: + from typing import Protocol, overload + + from faststream.broker.message import StreamMessage + + class WrapperProtocol(Generic[MsgType], Protocol): + """Annotation class to represent @subsriber return type.""" + + @overload + def __call__( + self, + func: None = None, + *, + filter: Filter["StreamMessage[MsgType]"], + parser: CustomParser[MsgType, Any], + decoder: CustomDecoder["StreamMessage[MsgType]"], + middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), + dependencies: Sequence[Depends] = (), + ) -> Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ]: + ... + + @overload + def __call__( + self, + func: Callable[P_HandlerParams, T_HandlerReturn], + *, + filter: Filter["StreamMessage[MsgType]"], + parser: CustomParser[MsgType, Any], + decoder: CustomDecoder["StreamMessage[MsgType]"], + middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), + dependencies: Sequence[Depends] = (), + ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: + ... + + def __call__( + self, + func: Optional[Callable[P_HandlerParams, T_HandlerReturn]] = None, + *, + filter: Filter["StreamMessage[MsgType]"], + parser: CustomParser[MsgType, Any], + decoder: CustomDecoder["StreamMessage[MsgType]"], + middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), + dependencies: Sequence[Depends] = (), + ) -> Union[ + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + ], + ]: + ... + + +class WrapHandlerMixin(Generic[MsgType]): + """A class to patch original handle function.""" + + def __init__( + self, + *, + middlewares: Sequence[Callable[[MsgType], BaseMiddleware]], + ) -> None: + """Initialize a new instance of the class.""" + self.calls = [] + self.middlewares = middlewares + + def wrap_handler( + self, + *, + func: Callable[P_HandlerParams, T_HandlerReturn], + dependencies: Sequence[Depends], + logger: Optional[Logger], + apply_types: bool, + is_validate: bool, + no_ack: bool = False, + raw: bool = False, + retry: Union[bool, int] = False, + get_dependant: Optional[Any] = None, + **process_kwargs: Any, + ) -> Tuple[ + HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], + CallModel[P_HandlerParams, T_HandlerReturn], + ]: + build_dep = build_dep = cast( + Callable[ + [Callable[F_Spec, F_Return]], + CallModel[F_Spec, F_Return], + ], + get_dependant + or partial( + build_call_model, + cast=is_validate, + extra_dependencies=dependencies, + ), + ) + + if isinstance(func, HandlerCallWrapper): + handler_call, func = func, func._original_call + if handler_call._wrapped_call is not None: + return handler_call, build_dep(func) + + else: + handler_call = HandlerCallWrapper(func) + + f = to_async(func) + dependant = build_dep(f) + + if getattr(dependant, "flat_params", None) is None: # FastAPI case + extra = [build_dep(d.dependency) for d in dependencies] + dependant.dependencies.extend(extra) + dependant = _patch_fastapi_dependant(dependant) + + else: + if apply_types and not raw: + f = inject(None)(f, dependant) + + if not raw: + f = self._wrap_decode_message( + func=f, + params_ln=len(dependant.flat_params), + ) + + f = self._process_message( + func=f, + watcher=( + partial(WatcherContext, watcher=get_watcher(logger, retry)) # type: ignore[arg-type] + if not no_ack + else fake_context + ), + **(process_kwargs or {}), + ) + + f = set_message_context(f) + handler_call.set_wrapped(f) + return handler_call, dependant + + def _wrap_decode_message( + self, + func: Callable[..., Awaitable[T_HandlerReturn]], + params_ln: int, + ) -> Callable[ + ["StreamMessage[MsgType]"], + Awaitable[T_HandlerReturn], + ]: + """Wraps a function to decode a message and pass it as an argument to the wrapped function. + + Args: + func: The function to be wrapped. + params_ln: The parameters number to be passed to the wrapped function. + + Returns: + The wrapped function. + """ + + @wraps(func) + async def decode_wrapper(message: "StreamMessage[MsgType]") -> T_HandlerReturn: + """A wrapper function to decode and handle a message. + + Args: + message : The message to be decoded and handled + + Returns: + The return value of the handler function + """ + msg = message.decoded_body + + if params_ln > 1: + if isinstance(msg, Mapping): + return await func(**msg) + elif isinstance(msg, Sequence): + return await func(*msg) + else: + return await func(msg) + + raise AssertionError("unreachable") + + return decode_wrapper + + def _process_message( + self, + func: Callable[[MsgType], Awaitable[T_HandlerReturn]], + watcher: Callable[..., AsyncContextManager[None]], + **kwargs: Any, + ) -> Callable[ + ["StreamMessage[MsgType]"], + Awaitable[WrappedReturn[T_HandlerReturn]], + ]: + raise NotImplementedError() + + +def _patch_fastapi_dependant( + dependant: CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]], +) -> CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]]: + """Patch FastAPI dependant. + + Args: + dependant: The dependant to be patched. + + Returns: + The patched dependant. + """ + params = dependant.query_params + dependant.body_params # type: ignore[attr-defined] + + for d in dependant.dependencies: + params.extend(d.query_params + d.body_params) # type: ignore[attr-defined] + + params_unique = {} + params_names = set() + for p in params: + if p.name not in params_names: + params_names.add(p.name) + info = p.field_info if PYDANTIC_V2 else p + params_unique[p.name] = (info.annotation, info.default) + + dependant.model = create_model( # type: ignore[call-overload] + getattr(dependant.call.__name__, "__name__", type(dependant.call).__name__), + **params_unique, + ) + dependant.custom_fields = {} + dependant.flat_params = params_unique # type: ignore[assignment,misc] + + return dependant diff --git a/faststream/broker/handler.py b/faststream/broker/handler.py index a03804a7de..0e539aa967 100644 --- a/faststream/broker/handler.py +++ b/faststream/broker/handler.py @@ -2,7 +2,6 @@ from abc import abstractmethod from contextlib import AsyncExitStack, suppress from dataclasses import dataclass -from functools import partial, wraps from inspect import unwrap from logging import Logger from typing import ( @@ -13,7 +12,6 @@ Dict, Generic, List, - Mapping, Optional, Sequence, Tuple, @@ -22,18 +20,16 @@ ) import anyio -from fast_depends import inject -from fast_depends.core import CallModel, build_call_model +from fast_depends.core import CallModel from fast_depends.dependencies import Depends -from pydantic import create_model from typing_extensions import Self, override -from faststream._compat import IS_OPTIMIZED, PYDANTIC_V2 +from faststream._compat import IS_OPTIMIZED from faststream.asyncapi.base import AsyncAPIOperation from faststream.asyncapi.message import parse_handler_params from faststream.asyncapi.utils import to_camelcase +from faststream.broker.core.handler_wrapper import WrapHandlerMixin from faststream.broker.middlewares import BaseMiddleware -from faststream.broker.push_back_watcher import WatcherContext from faststream.broker.types import ( AsyncDecoder, AsyncParser, @@ -45,69 +41,18 @@ T_HandlerReturn, WrappedReturn, ) -from faststream.broker.utils import get_watcher, set_message_context from faststream.broker.wrapper import HandlerCallWrapper from faststream.exceptions import HandlerException, StopConsume -from faststream.types import AnyDict, F_Return, F_Spec, SendableMessage +from faststream.types import AnyDict, SendableMessage from faststream.utils.context.repository import context -from faststream.utils.functions import fake_context, to_async +from faststream.utils.functions import to_async if TYPE_CHECKING: from contextvars import Token - from typing import Protocol, overload + from faststream.broker.core.handler_wrapper import WrapperProtocol from faststream.broker.message import StreamMessage - class WrapperProtocol(Generic[MsgType], Protocol): - """Annotation class to represent @subsriber return type.""" - - @overload - def __call__( - self, - func: None = None, - *, - filter: Filter["StreamMessage[MsgType]"], - parser: CustomParser[MsgType, Any], - decoder: CustomDecoder["StreamMessage[MsgType]"], - middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), - dependencies: Sequence[Depends] = (), - ) -> Callable[ - [Callable[P_HandlerParams, T_HandlerReturn]], - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - ]: - ... - - @overload - def __call__( - self, - func: Callable[P_HandlerParams, T_HandlerReturn], - *, - filter: Filter["StreamMessage[MsgType]"], - parser: CustomParser[MsgType, Any], - decoder: CustomDecoder["StreamMessage[MsgType]"], - middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), - dependencies: Sequence[Depends] = (), - ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: - ... - - def __call__( - self, - func: Optional[Callable[P_HandlerParams, T_HandlerReturn]] = None, - *, - filter: Filter["StreamMessage[MsgType]"], - parser: CustomParser[MsgType, Any], - decoder: CustomDecoder["StreamMessage[MsgType]"], - middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), - dependencies: Sequence[Depends] = (), - ) -> Union[ - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - Callable[ - [Callable[P_HandlerParams, T_HandlerReturn]], - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - ], - ]: - ... - @dataclass(slots=True) class HandlerItem(Generic[MsgType]): @@ -141,7 +86,7 @@ def description(self) -> Optional[str]: return description -class BaseHandler(AsyncAPIOperation, Generic[MsgType]): +class BaseHandler(AsyncAPIOperation, WrapHandlerMixin[MsgType]): """A class representing an asynchronous handler. Methods: @@ -213,6 +158,7 @@ def real_wrapper( handler, dependant = self.wrap_handler( func=func, dependencies=total_deps, + logger=self.logger, **wrap_kwargs, ) @@ -237,117 +183,6 @@ def real_wrapper( return wrapper - def wrap_handler( - self, - *, - func: Callable[P_HandlerParams, T_HandlerReturn], - apply_types: bool, - is_validate: bool, - dependencies: Sequence[Depends], - no_ack: bool = False, - raw: bool = False, - retry: Union[bool, int] = False, - get_dependant: Optional[Any] = None, - **process_kwargs: Any, - ) -> Tuple[ - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - CallModel[P_HandlerParams, T_HandlerReturn], - ]: - build_dep = build_dep = cast( - Callable[ - [Callable[F_Spec, F_Return]], - CallModel[F_Spec, F_Return], - ], - get_dependant - or partial( - build_call_model, - cast=is_validate, - extra_dependencies=dependencies, - ), - ) - - if isinstance(func, HandlerCallWrapper): - handler_call, func = func, func._original_call - if handler_call._wrapped_call is not None: - return handler_call, build_dep(func) - - else: - handler_call = HandlerCallWrapper(func) - - f = to_async(func) - dependant = build_dep(f) - - if getattr(dependant, "flat_params", None) is None: # FastAPI case - extra = [build_dep(d.dependency) for d in dependencies] - dependant.dependencies.extend(extra) - dependant = _patch_fastapi_dependant(dependant) - - else: - if apply_types and not raw: - f = inject(None)(f, dependant) - - if not raw: - f = self._wrap_decode_message( - func=f, - params_ln=len(dependant.flat_params), - ) - - f = self._process_message( - func=f, - watcher=( - partial(WatcherContext, watcher=get_watcher(self.logger, retry)) # type: ignore[arg-type] - if not no_ack - else fake_context - ), - **(process_kwargs or {}), - ) - - f = set_message_context(f) - handler_call.set_wrapped(f) - return handler_call, dependant - - def _wrap_decode_message( - self, - func: Callable[..., Awaitable[T_HandlerReturn]], - params_ln: int, - ) -> Callable[ - ["StreamMessage[MsgType]"], - Awaitable[T_HandlerReturn], - ]: - """Wraps a function to decode a message and pass it as an argument to the wrapped function. - - Args: - func: The function to be wrapped. - params_ln: The parameters number to be passed to the wrapped function. - - Returns: - The wrapped function. - """ - - @wraps(func) - async def decode_wrapper(message: "StreamMessage[MsgType]") -> T_HandlerReturn: - """A wrapper function to decode and handle a message. - - Args: - message : The message to be decoded and handled - - Returns: - The return value of the handler function - """ - msg = message.decoded_body - - if params_ln > 1: - if isinstance(msg, Mapping): - return await func(**msg) - elif isinstance(msg, Sequence): - return await func(*msg) - else: - return await func(msg) - - raise AssertionError("unreachable") - - return decode_wrapper - @override async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[override] """Consume a message asynchronously. @@ -549,37 +384,3 @@ async def wait_release(self, timeout: Optional[float] = None) -> None: if timeout: with anyio.move_on_after(timeout): await self.queue.join() - - -def _patch_fastapi_dependant( - dependant: CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]], -) -> CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]]: - """Patch FastAPI dependant. - - Args: - dependant: The dependant to be patched. - - Returns: - The patched dependant. - """ - params = dependant.query_params + dependant.body_params # type: ignore[attr-defined] - - for d in dependant.dependencies: - params.extend(d.query_params + d.body_params) # type: ignore[attr-defined] - - params_unique = {} - params_names = set() - for p in params: - if p.name not in params_names: - params_names.add(p.name) - info = p.field_info if PYDANTIC_V2 else p - params_unique[p.name] = (info.annotation, info.default) - - dependant.model = create_model( # type: ignore[call-overload] - getattr(dependant.call.__name__, "__name__", type(dependant.call).__name__), - **params_unique, - ) - dependant.custom_fields = {} - dependant.flat_params = params_unique # type: ignore[assignment,misc] - - return dependant From d28e60b2b087073cbac47b5ac4baa254249e6b17 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 10 Jan 2024 22:37:19 +0300 Subject: [PATCH 08/87] refactor: remove useless broker methods --- faststream/broker/core/abc.py | 186 +--------------- faststream/broker/core/asynchronous.py | 288 +++++++++++-------------- 2 files changed, 132 insertions(+), 342 deletions(-) diff --git a/faststream/broker/core/abc.py b/faststream/broker/core/abc.py index e4fa5a6c85..b3e17523e2 100644 --- a/faststream/broker/core/abc.py +++ b/faststream/broker/core/abc.py @@ -1,7 +1,5 @@ import logging -import warnings -from abc import ABC, abstractmethod -from types import TracebackType +from abc import ABC from typing import ( Any, Callable, @@ -10,7 +8,6 @@ Mapping, Optional, Sequence, - Type, Union, ) @@ -23,27 +20,15 @@ from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware, CriticalLogMiddleware from faststream.broker.publisher import BasePublisher -from faststream.broker.router import BrokerRouter from faststream.broker.types import ( ConnectionType, CustomDecoder, CustomParser, - Filter, MsgType, - P_HandlerParams, - T_HandlerReturn, ) -from faststream.broker.utils import ( - change_logger_handlers, -) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.log import access_logger from faststream.security import BaseSecurity -from faststream.types import AnyDict from faststream.utils import context -from faststream.utils.functions import ( - get_function_positional_arguments, -) class BrokerUsecase( @@ -181,172 +166,3 @@ def __init__( self.description = description self.tags = tags self.security = security - - def include_router(self, router: BrokerRouter[Any, MsgType]) -> None: - """Includes a router in the current object. - - Args: - router: The router to be included. - - Returns: - None - """ - for r in router._handlers: - self.subscriber(*r.args, **r.kwargs)(r.call) - - self._publishers = {**self._publishers, **router._publishers} - - def include_routers(self, *routers: BrokerRouter[Any, MsgType]) -> None: - """Includes routers in the current object. - - Args: - *routers: Variable length argument list of routers to include. - - Returns: - None - """ - for r in routers: - self.include_router(r) - - def _resolve_connection_kwargs(self, *args: Any, **kwargs: Any) -> AnyDict: - """Resolve connection keyword arguments. - - Args: - *args: Positional arguments passed to the function. - **kwargs: Keyword arguments passed to the function. - - Returns: - A dictionary containing the resolved connection keyword arguments. - """ - arguments = get_function_positional_arguments(self.__init__) # type: ignore - init_kwargs = { - **self._connection_kwargs, - **dict(zip(arguments, self._connection_args)), - } - - connect_kwargs = { - **kwargs, - **dict(zip(arguments, args)), - } - return {**init_kwargs, **connect_kwargs} - - def _abc_start(self) -> None: - if not self.started: - self.started = True - - if self.logger is not None: - change_logger_handlers(self.logger, self.fmt) - - def _abc_close( - self, - exc_type: Optional[Type[BaseException]] = None, - exc_val: Optional[BaseException] = None, - exec_tb: Optional[TracebackType] = None, - ) -> None: - """Closes the ABC. - - Args: - exc_type: The exception type - exc_val: The exception value - exec_tb: The traceback - - Returns: - None - """ - self.started = False - - def _abc__close( - self, - exc_type: Optional[Type[BaseException]] = None, - exc_val: Optional[BaseException] = None, - exec_tb: Optional[TracebackType] = None, - ) -> None: - """Closes the connection. - - Args: - exc_type: The type of the exception being handled (optional) - exc_val: The exception instance being handled (optional) - exec_tb: The traceback for the exception being handled (optional) - - Returns: - None - - Note: - This is an abstract method and must be implemented by subclasses. - """ - self._connection = None - - @abstractmethod - def subscriber( # type: ignore[return] - self, - *broker_args: Any, - retry: Union[bool, int] = False, - dependencies: Sequence[Depends] = (), - decoder: Optional[CustomDecoder[StreamMessage[MsgType]]] = None, - parser: Optional[CustomParser[MsgType, StreamMessage[MsgType]]] = None, - middlewares: Optional[ - Sequence[ - Callable[ - [StreamMessage[MsgType]], - BaseMiddleware, - ] - ] - ] = None, - filter: Filter[StreamMessage[MsgType]] = lambda m: not m.processed, - _raw: bool = False, - _get_dependant: Optional[Any] = None, - **broker_kwargs: Any, - ) -> Callable[ - [ - Union[ - Callable[P_HandlerParams, T_HandlerReturn], - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - ] - ], - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - ]: - """This is a function decorator for subscribing to a message broker. - - Args: - *broker_args: Positional arguments to be passed to the broker. - retry: Whether to retry the subscription if it fails. Can be a boolean or an integer specifying the number of retries. - dependencies: Sequence of dependencies to be injected into the handler function. - decoder: Custom decoder function to decode the message. - parser: Custom parser function to parse the decoded message. - middlewares: Sequence of middleware functions to be applied to the message. - filter: Filter function to filter the messages to be processed. - _raw: Whether to return the raw message instead of the processed message. - _get_dependant: Optional parameter to get the dependant object. - **broker_kwargs: Keyword arguments to be passed to the broker. - - Returns: - A callable object that can be used as a decorator for a handler function. - - Raises: - RuntimeWarning: If the broker is already running. - """ - if self.started and not is_test_env(): # pragma: no cover - warnings.warn( - "You are trying to register `handler` with already running broker\n" - "It has no effect until broker restarting.", - category=RuntimeWarning, - stacklevel=1, - ) - - @abstractmethod - def publisher( - self, - key: Any, - publisher: BasePublisher[MsgType], - ) -> BasePublisher[MsgType]: - """Publishes a publisher. - - Args: - key: The key associated with the publisher. - publisher: The publisher to be published. - - Returns: - The published publisher. - """ - self._publishers = {**self._publishers, key: publisher} - return publisher diff --git a/faststream/broker/core/asynchronous.py b/faststream/broker/core/asynchronous.py index 708e93ccdc..17da00dd99 100644 --- a/faststream/broker/core/asynchronous.py +++ b/faststream/broker/core/asynchronous.py @@ -1,29 +1,28 @@ import logging +import warnings from abc import abstractmethod -from functools import wraps from types import TracebackType from typing import ( Any, - Awaitable, Callable, Mapping, Optional, Sequence, - Sized, - Tuple, Type, Union, cast, ) -from fast_depends.core import CallModel from fast_depends.dependencies import Depends -from typing_extensions import Self, override +from typing_extensions import Self +from faststream._compat import is_test_env from faststream.broker.core.abc import BrokerUsecase from faststream.broker.handler import BaseHandler from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware +from faststream.broker.publisher import BasePublisher +from faststream.broker.router import BrokerRouter from faststream.broker.types import ( AsyncCustomDecoder, AsyncCustomParser, @@ -35,10 +34,11 @@ P_HandlerParams, T_HandlerReturn, ) +from faststream.broker.utils import change_logger_handlers from faststream.broker.wrapper import HandlerCallWrapper from faststream.log import access_logger from faststream.types import AnyDict, SendableMessage -from faststream.utils.functions import to_async +from faststream.utils.functions import get_function_positional_arguments, to_async async def default_filter(msg: StreamMessage[Any]) -> bool: @@ -77,48 +77,125 @@ class BrokerAsyncUsecase(BrokerUsecase[MsgType, ConnectionType]): _global_parser: Optional[AsyncCustomParser[MsgType, StreamMessage[MsgType]]] _global_decoder: Optional[AsyncCustomDecoder[StreamMessage[MsgType]]] + def include_router(self, router: BrokerRouter[Any, MsgType]) -> None: + """Includes a router in the current object. + + Args: + router: The router to be included. + + Returns: + None + """ + for r in router._handlers: + self.subscriber(*r.args, **r.kwargs)(r.call) + + self._publishers = {**self._publishers, **router._publishers} + + def include_routers(self, *routers: BrokerRouter[Any, MsgType]) -> None: + """Includes routers in the current object. + + Args: + *routers: Variable length argument list of routers to include. + + Returns: + None + """ + for r in routers: + self.include_router(r) + + async def __aenter__(self) -> Self: + """Enter the context manager.""" + await self.connect() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exec_tb: Optional[TracebackType], + ) -> None: + """Exit the context manager. + + Args: + exc_type: The type of the exception raised, if any. + exc_val: The exception raised, if any. + exec_tb: The traceback of the exception raised, if any. + + Returns: + None + + Overrides: + This method overrides the __aexit__ method of the base class. + """ + await self.close(exc_type, exc_val, exec_tb) + @abstractmethod async def start(self) -> None: """Start the broker async use case.""" - super()._abc_start() + self._abc_start() for h in self.handlers.values(): for f in h.calls: f.handler.refresh(with_mock=False) await self.connect() - @abstractmethod - async def _connect(self, **kwargs: Any) -> ConnectionType: - """Connect to a resource. + def _abc_start(self) -> None: + if not self.started: + self.started = True + + if self.logger is not None: + change_logger_handlers(self.logger, self.fmt) + + async def connect(self, *args: Any, **kwargs: Any) -> ConnectionType: + """Connect to a remote server. Args: - **kwargs: Additional keyword arguments for the connection. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. Returns: The connection object. + """ + if self._connection is None: + _kwargs = self._resolve_connection_kwargs(*args, **kwargs) + self._connection = await self._connect(**_kwargs) + return self._connection - Raises: - NotImplementedError: If the method is not implemented. + def _resolve_connection_kwargs(self, *args: Any, **kwargs: Any) -> AnyDict: + """Resolve connection keyword arguments. + + Args: + *args: Positional arguments passed to the function. + **kwargs: Keyword arguments passed to the function. + + Returns: + A dictionary containing the resolved connection keyword arguments. """ - raise NotImplementedError() + arguments = get_function_positional_arguments(self.__init__) # type: ignore + init_kwargs = { + **self._connection_kwargs, + **dict(zip(arguments, self._connection_args)), + } + + connect_kwargs = { + **kwargs, + **dict(zip(arguments, args)), + } + return {**init_kwargs, **connect_kwargs} @abstractmethod - async def _close( - self, - exc_type: Optional[Type[BaseException]] = None, - exc_val: Optional[BaseException] = None, - exec_tb: Optional[TracebackType] = None, - ) -> None: - """Close the object. + async def _connect(self, **kwargs: Any) -> ConnectionType: + """Connect to a resource. Args: - exc_type: Optional. The type of the exception. - exc_val: Optional. The exception value. - exec_tb: Optional. The traceback of the exception. + **kwargs: Additional keyword arguments for the connection. Returns: - None + The connection object. + + Raises: + NotImplementedError: If the method is not implemented. """ - super()._abc__close(exc_type, exc_val, exec_tb) + raise NotImplementedError() async def close( self, @@ -139,13 +216,13 @@ async def close( Raises: NotImplementedError: If the method is not implemented. """ - super()._abc_close(exc_type, exc_val, exec_tb) + self.started = False for h in self.handlers.values(): await h.close() if self._connection is not None: - await self._close(exc_type, exc_val, exec_tb) + self._connection = None @abstractmethod async def publish( @@ -177,7 +254,6 @@ async def publish( """ raise NotImplementedError() - @override @abstractmethod def subscriber( # type: ignore[override,return] self, @@ -220,7 +296,31 @@ def subscriber( # type: ignore[override,return] Raises: NotImplementedError: If silent animals are not supported. """ - super().subscriber() + if self.started and not is_test_env(): # pragma: no cover + warnings.warn( + "You are trying to register `handler` with already running broker\n" + "It has no effect until broker restarting.", + category=RuntimeWarning, + stacklevel=1, + ) + + @abstractmethod + def publisher( + self, + key: Any, + publisher: BasePublisher[MsgType], + ) -> BasePublisher[MsgType]: + """Publishes a publisher. + + Args: + key: The key associated with the publisher. + publisher: The publisher to be published. + + Returns: + The published publisher. + """ + self._publishers = {**self._publishers, key: publisher} + return publisher def __init__( self, @@ -273,129 +373,3 @@ def __init__( **kwargs, ) self.graceful_timeout = graceful_timeout - - async def connect(self, *args: Any, **kwargs: Any) -> ConnectionType: - """Connect to a remote server. - - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - The connection object. - """ - if self._connection is None: - _kwargs = self._resolve_connection_kwargs(*args, **kwargs) - self._connection = await self._connect(**_kwargs) - return self._connection - - async def __aenter__(self) -> Self: - """Enter the context manager.""" - await self.connect() - return self - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exec_tb: Optional[TracebackType], - ) -> None: - """Exit the context manager. - - Args: - exc_type: The type of the exception raised, if any. - exc_val: The exception raised, if any. - exec_tb: The traceback of the exception raised, if any. - - Returns: - None - - Overrides: - This method overrides the __aexit__ method of the base class. - """ - await self.close(exc_type, exc_val, exec_tb) - - @override - def _wrap_decode_message( - self, - func: Callable[..., Awaitable[T_HandlerReturn]], - params: Sized = (), - _raw: bool = False, - ) -> Callable[[StreamMessage[MsgType]], Awaitable[T_HandlerReturn]]: - """Wraps a function to decode a message and pass it as an argument to the wrapped function. - - Args: - func: The function to be wrapped. - params: The parameters to be passed to the wrapped function. - _raw: Whether to return the raw message or not. - - Returns: - The wrapped function. - """ - params_ln = len(params) - - @wraps(func) - async def decode_wrapper(message: StreamMessage[MsgType]) -> T_HandlerReturn: - """A wrapper function to decode and handle a message. - - Args: - message : The message to be decoded and handled - - Returns: - The return value of the handler function - """ - if _raw is True: - return await func(message) - - msg = message.decoded_body - if params_ln > 1: - if isinstance(msg, Mapping): - return await func(**msg) - elif isinstance(msg, Sequence): - return await func(*msg) - else: - return await func(msg) - - raise AssertionError("unreachable") - - return decode_wrapper - - @override - def _wrap_handler( - self, - func: Callable[P_HandlerParams, T_HandlerReturn], - *, - retry: Union[bool, int] = False, - extra_dependencies: Sequence[Depends] = (), - no_ack: bool = False, - _raw: bool = False, - _get_dependant: Optional[Any] = None, - _process_kwargs: Optional[AnyDict] = None, - ) -> Tuple[ - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - CallModel[P_HandlerParams, T_HandlerReturn], - ]: - """Wrap a handler function. - - Args: - func: The handler function to wrap. - retry: Whether to retry the handler function if it fails. Can be a boolean or an integer specifying the number of retries. - extra_dependencies: Additional dependencies to inject into the handler function. - no_ack: Whether not to ack/nack/reject messages. - _raw: Whether to return the raw response from the handler function. - _get_dependant: An optional object to use as the dependant for the handler function. - **broker_log_context_kwargs: Additional keyword arguments to pass to the broker log context. - - Returns: - A tuple containing the wrapped handler function and the call model. - - """ - return super()._wrap_handler( # type: ignore[return-value] - func, - retry=retry, - extra_dependencies=extra_dependencies, - no_ack=no_ack, - _raw=_raw, - _get_dependant=_get_dependant, - _process_kwargs=_process_kwargs, - ) From 3aa5d1d5e6100727b4530bde9fc16fd1f9beef03 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Thu, 11 Jan 2024 20:05:52 +0300 Subject: [PATCH 09/87] refactor: new handler consume logic --- docs/docs/SUMMARY.md | 2 +- .../core/asynchronous/BrokerAsyncUsecase.md | 2 +- faststream/app.py | 56 +-- faststream/broker/core/abc.py | 168 --------- .../core/{asynchronous.py => broker.py} | 305 +++++++++------- .../{wrapper.py => core/call_wrapper.py} | 53 +-- faststream/broker/{ => core}/handler.py | 340 ++++++++---------- ...er_wrapper.py => handler_wrapper_mixin.py} | 7 +- .../core/{mixins.py => logging_mixin.py} | 4 - faststream/broker/{ => core}/publisher.py | 48 ++- faststream/broker/fastapi/route.py | 8 +- faststream/broker/fastapi/router.py | 14 +- faststream/broker/message.py | 26 +- faststream/broker/middlewares.py | 4 +- faststream/broker/push_back_watcher.py | 40 +-- faststream/broker/router.py | 28 +- faststream/broker/schemas.py | 13 +- faststream/broker/test.py | 13 +- faststream/broker/types.py | 4 +- faststream/broker/utils.py | 81 +++-- faststream/kafka/broker.py | 13 +- faststream/kafka/broker.pyi | 6 +- faststream/kafka/fastapi.pyi | 4 +- faststream/kafka/handler.py | 4 +- faststream/kafka/router.pyi | 4 +- faststream/kafka/shared/logging.py | 2 +- faststream/kafka/shared/publisher.py | 2 +- faststream/kafka/shared/router.py | 2 +- faststream/kafka/shared/router.pyi | 2 +- faststream/kafka/test.py | 2 +- faststream/nats/broker.py | 8 +- faststream/nats/broker.pyi | 6 +- faststream/nats/fastapi.pyi | 4 +- faststream/nats/handler.py | 36 +- faststream/nats/publisher.py | 2 +- faststream/nats/router.pyi | 4 +- faststream/nats/shared/logging.py | 2 +- faststream/nats/shared/router.py | 2 +- faststream/nats/shared/router.pyi | 4 +- faststream/nats/test.py | 2 +- faststream/rabbit/broker.py | 15 +- faststream/rabbit/broker.pyi | 6 +- faststream/rabbit/fastapi.pyi | 4 +- faststream/rabbit/handler.py | 4 +- faststream/rabbit/router.pyi | 4 +- faststream/rabbit/shared/logging.py | 2 +- faststream/rabbit/shared/publisher.py | 2 +- faststream/rabbit/shared/router.py | 2 +- faststream/rabbit/shared/router.pyi | 2 +- faststream/rabbit/test.py | 2 +- faststream/redis/broker.py | 11 +- faststream/redis/broker.pyi | 6 +- faststream/redis/fastapi.py | 2 +- faststream/redis/fastapi.pyi | 4 +- faststream/redis/handler.py | 4 +- faststream/redis/publisher.py | 2 +- faststream/redis/router.pyi | 4 +- faststream/redis/shared/logging.py | 2 +- faststream/redis/shared/router.py | 2 +- faststream/redis/shared/router.pyi | 2 +- faststream/redis/test.py | 2 +- tests/asyncapi/base/arguments.py | 2 +- tests/asyncapi/base/fastapi.py | 2 +- tests/asyncapi/base/naming.py | 2 +- tests/asyncapi/base/publisher.py | 2 +- tests/asyncapi/base/router.py | 53 ++- tests/brokers/base/connection.py | 2 +- tests/brokers/base/consume.py | 2 +- tests/brokers/base/fastapi.py | 8 +- tests/brokers/base/middlewares.py | 2 +- tests/brokers/base/parser.py | 8 +- tests/brokers/base/publish.py | 2 +- tests/brokers/base/router.py | 36 +- tests/brokers/base/rpc.py | 2 +- tests/brokers/base/testclient.py | 2 +- tests/utils/test_handler_lock.py | 2 +- 76 files changed, 703 insertions(+), 835 deletions(-) delete mode 100644 faststream/broker/core/abc.py rename faststream/broker/core/{asynchronous.py => broker.py} (55%) rename faststream/broker/{wrapper.py => core/call_wrapper.py} (78%) rename faststream/broker/{ => core}/handler.py (55%) rename faststream/broker/core/{handler_wrapper.py => handler_wrapper_mixin.py} (98%) rename faststream/broker/core/{mixins.py => logging_mixin.py} (99%) rename faststream/broker/{ => core}/publisher.py (76%) diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index 2c331b1bce..db8a00420a 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -216,7 +216,7 @@ search: - [BrokerUsecase](api/faststream/broker/core/abc/BrokerUsecase.md) - [extend_dependencies](api/faststream/broker/core/abc/extend_dependencies.md) - asynchronous - - [BrokerAsyncUsecase](api/faststream/broker/core/asynchronous/BrokerAsyncUsecase.md) + - [BrokerUsecase](api/faststream/broker/core/asynchronous/BrokerUsecase.md) - [default_filter](api/faststream/broker/core/asynchronous/default_filter.md) - mixins - [LoggingMixin](api/faststream/broker/core/mixins/LoggingMixin.md) diff --git a/docs/docs/en/api/faststream/broker/core/asynchronous/BrokerAsyncUsecase.md b/docs/docs/en/api/faststream/broker/core/asynchronous/BrokerAsyncUsecase.md index f90acf5470..452a4442db 100644 --- a/docs/docs/en/api/faststream/broker/core/asynchronous/BrokerAsyncUsecase.md +++ b/docs/docs/en/api/faststream/broker/core/asynchronous/BrokerAsyncUsecase.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.broker.core.asynchronous.BrokerAsyncUsecase +::: faststream.broker.core.asynchronous.BrokerUsecase diff --git a/faststream/app.py b/faststream/app.py index 950379acdc..e778f43080 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -1,22 +1,20 @@ import logging -from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + TypeVar, + Union, +) import anyio -from pydantic import AnyHttpUrl from typing_extensions import ParamSpec from faststream._compat import ExceptionGroup -from faststream.asyncapi.schema import ( - Contact, - ContactDict, - ExternalDocs, - ExternalDocsDict, - License, - LicenseDict, - Tag, - TagDict, -) -from faststream.broker.core.asynchronous import BrokerAsyncUsecase from faststream.cli.supervisors.utils import HANDLED_SIGNALS from faststream.log import logger from faststream.types import AnyDict, AsyncFunc, Lifespan, SettingField @@ -27,6 +25,22 @@ T_HookReturn = TypeVar("T_HookReturn") +if TYPE_CHECKING: + from pydantic import AnyHttpUrl + + from faststream.asyncapi.schema import ( + Contact, + ContactDict, + ExternalDocs, + ExternalDocsDict, + License, + LicenseDict, + Tag, + TagDict, + ) + from faststream.broker.core.broker import BrokerUsecase + + class FastStream: """A class representing a FastStream application. @@ -57,19 +71,21 @@ class FastStream: def __init__( self, - broker: Optional[BrokerAsyncUsecase[Any, Any]] = None, + broker: Optional["BrokerUsecase[Any, Any]"] = None, logger: Optional[logging.Logger] = logger, lifespan: Optional[Lifespan] = None, # AsyncAPI args, title: str = "FastStream", version: str = "0.1.0", description: str = "", - terms_of_service: Optional[AnyHttpUrl] = None, - license: Optional[Union[License, LicenseDict, AnyDict]] = None, - contact: Optional[Union[Contact, ContactDict, AnyDict]] = None, + terms_of_service: Optional["AnyHttpUrl"] = None, + license: Optional[Union["License", "LicenseDict", AnyDict]] = None, + contact: Optional[Union["Contact", "ContactDict", AnyDict]] = None, identifier: Optional[str] = None, - tags: Optional[Sequence[Union[Tag, TagDict, AnyDict]]] = None, - external_docs: Optional[Union[ExternalDocs, ExternalDocsDict, AnyDict]] = None, + tags: Optional[Sequence[Union["Tag", "TagDict", AnyDict]]] = None, + external_docs: Optional[ + Union["ExternalDocs", "ExternalDocsDict", AnyDict] + ] = None, ) -> None: """Asynchronous FastStream Application class. @@ -119,7 +135,7 @@ def __init__( self.asyncapi_tags = tags self.external_docs = external_docs - def set_broker(self, broker: BrokerAsyncUsecase[Any, Any]) -> None: + def set_broker(self, broker: "BrokerUsecase[Any, Any]") -> None: """Set already existed App object broker. Useful then you create/init broker in `on_startup` hook. diff --git a/faststream/broker/core/abc.py b/faststream/broker/core/abc.py deleted file mode 100644 index b3e17523e2..0000000000 --- a/faststream/broker/core/abc.py +++ /dev/null @@ -1,168 +0,0 @@ -import logging -from abc import ABC -from typing import ( - Any, - Callable, - Generic, - List, - Mapping, - Optional, - Sequence, - Union, -) - -from fast_depends.dependencies import Depends - -from faststream._compat import is_test_env -from faststream.asyncapi import schema as asyncapi -from faststream.broker.core.mixins import LoggingMixin -from faststream.broker.handler import BaseHandler -from faststream.broker.message import StreamMessage -from faststream.broker.middlewares import BaseMiddleware, CriticalLogMiddleware -from faststream.broker.publisher import BasePublisher -from faststream.broker.types import ( - ConnectionType, - CustomDecoder, - CustomParser, - MsgType, -) -from faststream.log import access_logger -from faststream.security import BaseSecurity -from faststream.utils import context - - -class BrokerUsecase( - ABC, - Generic[MsgType, ConnectionType], - LoggingMixin, -): - """A class representing a broker use case. - - Attributes: - logger : optional logger object - log_level : log level - handlers : dictionary of handlers - _publishers : dictionary of publishers - dependencies : sequence of dependencies - started : boolean indicating if the broker has started - middlewares : sequence of middleware functions - _global_parser : optional custom parser object - _global_decoder : optional custom decoder object - _connection : optional connection object - _fmt : optional string format - - Methods: - __init__ : constructor method - include_router : include a router in the broker - include_routers : include multiple routers in the broker - _resolve_connection_kwargs : resolve connection kwargs - _wrap_handler : wrap a handler function - _abc_start : start the broker - _abc_close : close the broker - _abc__close : close the broker connection - _process_message : process a message - subscriber : decorator to register a subscriber - publisher : register a publisher - _wrap_decode_message : wrap a message decoding function - """ - - logger: Optional[logging.Logger] - log_level: int - handlers: Mapping[Any, BaseHandler[MsgType]] - _publishers: Mapping[Any, BasePublisher[MsgType]] - - dependencies: Sequence[Depends] - started: bool - middlewares: Sequence[Callable[[Any], BaseMiddleware]] - _global_parser: Optional[CustomParser[MsgType, StreamMessage[MsgType]]] - _global_decoder: Optional[CustomDecoder[StreamMessage[MsgType]]] - _connection: Optional[ConnectionType] - _fmt: Optional[str] - - def __init__( - self, - url: Union[str, List[str]], - *args: Any, - # AsyncAPI kwargs - protocol: Optional[str] = None, - protocol_version: Optional[str] = None, - description: Optional[str] = None, - tags: Optional[Sequence[Union[asyncapi.Tag, asyncapi.TagDict]]] = None, - asyncapi_url: Union[str, List[str], None] = None, - # broker kwargs - apply_types: bool = True, - validate: bool = True, - logger: Optional[logging.Logger] = access_logger, - log_level: int = logging.INFO, - log_fmt: Optional[str] = "%(asctime)s %(levelname)s - %(message)s", - dependencies: Sequence[Depends] = (), - middlewares: Optional[Sequence[Callable[[MsgType], BaseMiddleware]]] = None, - decoder: Optional[CustomDecoder[StreamMessage[MsgType]]] = None, - parser: Optional[CustomParser[MsgType, StreamMessage[MsgType]]] = None, - security: Optional[BaseSecurity] = None, - **kwargs: Any, - ) -> None: - """Initialize a broker. - - Args: - url: The URL or list of URLs to connect to. - *args: Additional arguments. - protocol: The protocol to use for the connection. - protocol_version: The version of the protocol. - description: A description of the broker. - tags: Tags associated with the broker. - asyncapi_url: The URL or list of URLs to the AsyncAPI schema. - apply_types: Whether to apply types to messages. - validate: Whether to cast types using Pydantic validation. - logger: The logger to use. - log_level: The log level to use. - log_fmt: The log format to use. - dependencies: Dependencies of the broker. - middlewares: Middlewares to use. - decoder: Custom decoder for messages. - parser: Custom parser for messages. - security: Security scheme to use. - **kwargs: Additional keyword arguments. - - """ - super().__init__( - logger=logger, - log_level=log_level, - log_fmt=log_fmt, - ) - - self._connection = None - self._is_apply_types = apply_types - self._is_validate = validate - self.handlers = {} - self._publishers = {} - empty_middleware: Sequence[Callable[[MsgType], BaseMiddleware]] = () - midd_args: Sequence[Callable[[MsgType], BaseMiddleware]] = ( - middlewares or empty_middleware - ) - - if not is_test_env(): - self.middlewares = (CriticalLogMiddleware(logger, log_level), *midd_args) - else: - self.middlewares = midd_args - - self.dependencies = dependencies - - self._connection_args = (url, *args) - self._connection_kwargs = kwargs - - self._global_parser = parser - self._global_decoder = decoder - - context.set_global("logger", logger) - context.set_global("broker", self) - - self.started = False - - # AsyncAPI information - self.url = asyncapi_url or url - self.protocol = protocol - self.protocol_version = protocol_version - self.description = description - self.tags = tags - self.security = security diff --git a/faststream/broker/core/asynchronous.py b/faststream/broker/core/broker.py similarity index 55% rename from faststream/broker/core/asynchronous.py rename to faststream/broker/core/broker.py index 17da00dd99..22a0807430 100644 --- a/faststream/broker/core/asynchronous.py +++ b/faststream/broker/core/broker.py @@ -1,10 +1,12 @@ import logging import warnings -from abc import abstractmethod -from types import TracebackType +from abc import ABC, abstractmethod from typing import ( + TYPE_CHECKING, Any, Callable, + Generic, + List, Mapping, Optional, Sequence, @@ -13,16 +15,11 @@ cast, ) -from fast_depends.dependencies import Depends -from typing_extensions import Self +from typing_extensions import Annotated, Doc from faststream._compat import is_test_env -from faststream.broker.core.abc import BrokerUsecase -from faststream.broker.handler import BaseHandler -from faststream.broker.message import StreamMessage -from faststream.broker.middlewares import BaseMiddleware -from faststream.broker.publisher import BasePublisher -from faststream.broker.router import BrokerRouter +from faststream.broker.core.logging_mixin import LoggingMixin +from faststream.broker.middlewares import CriticalLogMiddleware from faststream.broker.types import ( AsyncCustomDecoder, AsyncCustomParser, @@ -31,17 +28,30 @@ CustomParser, Filter, MsgType, - P_HandlerParams, - T_HandlerReturn, ) from faststream.broker.utils import change_logger_handlers -from faststream.broker.wrapper import HandlerCallWrapper from faststream.log import access_logger -from faststream.types import AnyDict, SendableMessage +from faststream.utils.context.repository import context from faststream.utils.functions import get_function_positional_arguments, to_async +if TYPE_CHECKING: + from types import TracebackType -async def default_filter(msg: StreamMessage[Any]) -> bool: + from fast_depends.dependencies import Depends + from typing_extensions import Self + + from faststream.asyncapi.schema import Tag, TagDict + from faststream.broker.core.handler import BaseHandler + from faststream.broker.core.handler_wrapper_mixin import WrapperProtocol + from faststream.broker.core.publisher import BasePublisher + from faststream.broker.message import StreamMessage + from faststream.broker.middlewares import BaseMiddleware + from faststream.broker.router import BrokerRouter + from faststream.security import BaseSecurity + from faststream.types import AnyDict, SendableMessage + + +async def default_filter(msg: "StreamMessage[Any]") -> bool: """A function to filter stream messages. Args: @@ -49,12 +59,15 @@ async def default_filter(msg: StreamMessage[Any]) -> bool: Returns: True if the message has not been processed, False otherwise - """ return not msg.processed -class BrokerAsyncUsecase(BrokerUsecase[MsgType, ConnectionType]): +class BrokerUsecase( + ABC, + Generic[MsgType, ConnectionType], + LoggingMixin, +): """A class representing a broker async use case. Attributes: @@ -62,22 +75,134 @@ class BrokerAsyncUsecase(BrokerUsecase[MsgType, ConnectionType]): middlewares : A sequence of middleware functions. _global_parser : An optional global parser for messages. _global_decoder : An optional global decoder for messages. - - Methods: - start() : Abstract method to start the broker async use case. - _connect(**kwargs: Any) : Abstract method to connect to the broker. - _close(exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exec_tb: Optional[TracebackType] = None) : Abstract method to close the connection to the broker. - close(exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exec_tb: Optional[TracebackType] = None) : Close the connection to the broker. - _process_message(func: Callable[[StreamMessage[MsgType]], Awaitable[T_HandlerReturn]], watcher: BaseWatcher) : Abstract method to process a message. - publish(message: SendableMessage, *args: Any, reply_to: str = "", rpc: bool = False, rpc_timeout: Optional[float] """ - handlers: Mapping[Any, BaseHandler[MsgType]] - middlewares: Sequence[Callable[[MsgType], BaseMiddleware]] - _global_parser: Optional[AsyncCustomParser[MsgType, StreamMessage[MsgType]]] - _global_decoder: Optional[AsyncCustomDecoder[StreamMessage[MsgType]]] + handlers: Mapping[Any, "BaseHandler[MsgType]"] + _publishers: Mapping[Any, "BasePublisher[MsgType]"] + + def __init__( + self, + url: Annotated[ + Union[str, List[str], None], + Doc("Broker address to connect"), + ], + *args: Any, + apply_types: Annotated[ + bool, + Doc("Whether to use FastDepends or not"), + ] = True, + validate: Annotated[ + bool, + Doc("Whether to cast types using Pydantic validation"), + ] = True, + logger: Annotated[ + Optional[logging.Logger], + Doc("Logger object for logging"), + ] = access_logger, + log_level: Annotated[ + int, + Doc("Log level for logging"), + ] = logging.INFO, + log_fmt: Annotated[ + Optional[str], + Doc("Log format for logging"), + ] = "%(asctime)s %(levelname)s - %(message)s", + decoder: Annotated[ + Optional[CustomDecoder["StreamMessage[MsgType]"]], + Doc("Custom decoder object"), + ] = None, + parser: Annotated[ + Optional[CustomParser[MsgType, "StreamMessage[MsgType]"]], + Doc("Custom parser object"), + ] = None, + dependencies: Annotated[ + Sequence["Depends"], + Doc("Dependencies to apply to all broker subscribers"), + ] = (), + middlewares: Annotated[ + Sequence[Callable[[MsgType], "BaseMiddleware"]], + Doc("Middlewares to apply to all broker publishers/subscribers"), + ] = (), + graceful_timeout: Annotated[ + Optional[float], + Doc("Graceful shutdown timeout"), + ] = None, + # AsyncAPI kwargs + protocol: Annotated[ + Optional[str], + Doc("AsyncAPI server protocol"), + ] = None, + protocol_version: Annotated[ + Optional[str], + Doc("AsyncAPI server protocol version"), + ] = None, + description: Annotated[ + Optional[str], + Doc("AsyncAPI server description"), + ] = None, + tags: Annotated[ + Optional[Sequence[Union["Tag", "TagDict"]]], + Doc("AsyncAPI server tags"), + ] = None, + asyncapi_url: Annotated[ + Union[str, List[str], None], + Doc("AsyncAPI hardcoded server addresses"), + ] = None, + security: Annotated[ + Optional["BaseSecurity"], + Doc( + "Security options to connect broker and generate AsyncAPI server security" + ), + ] = None, + **kwargs: Any, + ) -> None: + """Initialize the class.""" + super().__init__( + logger=logger, + log_level=log_level, + log_fmt=log_fmt, + ) + + context.set_global("logger", logger) + context.set_global("broker", self) + + self._connection_args = (url, *args) + self._connection_kwargs = kwargs + + self.running = False + self.graceful_timeout = graceful_timeout + self._connection = None + self._is_apply_types = apply_types + self._is_validate = validate + + self.handlers = {} + self._publishers = {} + + if not is_test_env(): + self.middlewares = (CriticalLogMiddleware(logger, log_level), *middlewares) + else: + self.middlewares = middlewares + + self.dependencies = dependencies + + self._global_parser = cast( + Optional[AsyncCustomParser[MsgType, "StreamMessage[MsgType]"]], + to_async(parser) if parser else None, + ) + self._global_decoder = cast( + Optional[AsyncCustomDecoder["StreamMessage[MsgType]"]], + to_async(decoder) if decoder else None, + ) + + # AsyncAPI information + self.url = asyncapi_url or url + self.protocol = protocol + self.protocol_version = protocol_version + self.description = description + self.tags = tags + self.security = security - def include_router(self, router: BrokerRouter[Any, MsgType]) -> None: + def include_router(self, router: "BrokerRouter[Any, MsgType]") -> None: """Includes a router in the current object. Args: @@ -91,7 +216,7 @@ def include_router(self, router: BrokerRouter[Any, MsgType]) -> None: self._publishers = {**self._publishers, **router._publishers} - def include_routers(self, *routers: BrokerRouter[Any, MsgType]) -> None: + def include_routers(self, *routers: "BrokerRouter[Any, MsgType]") -> None: """Includes routers in the current object. Args: @@ -103,7 +228,7 @@ def include_routers(self, *routers: BrokerRouter[Any, MsgType]) -> None: for r in routers: self.include_router(r) - async def __aenter__(self) -> Self: + async def __aenter__(self) -> "Self": """Enter the context manager.""" await self.connect() return self @@ -112,7 +237,7 @@ async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], - exec_tb: Optional[TracebackType], + exec_tb: Optional["TracebackType"], ) -> None: """Exit the context manager. @@ -139,8 +264,8 @@ async def start(self) -> None: await self.connect() def _abc_start(self) -> None: - if not self.started: - self.started = True + if not self.running: + self.running = True if self.logger is not None: change_logger_handlers(self.logger, self.fmt) @@ -160,7 +285,7 @@ async def connect(self, *args: Any, **kwargs: Any) -> ConnectionType: self._connection = await self._connect(**_kwargs) return self._connection - def _resolve_connection_kwargs(self, *args: Any, **kwargs: Any) -> AnyDict: + def _resolve_connection_kwargs(self, *args: Any, **kwargs: Any) -> "AnyDict": """Resolve connection keyword arguments. Args: @@ -170,7 +295,7 @@ def _resolve_connection_kwargs(self, *args: Any, **kwargs: Any) -> AnyDict: Returns: A dictionary containing the resolved connection keyword arguments. """ - arguments = get_function_positional_arguments(self.__init__) # type: ignore + arguments = get_function_positional_arguments(self.__init__) init_kwargs = { **self._connection_kwargs, **dict(zip(arguments, self._connection_args)), @@ -191,9 +316,6 @@ async def _connect(self, **kwargs: Any) -> ConnectionType: Returns: The connection object. - - Raises: - NotImplementedError: If the method is not implemented. """ raise NotImplementedError() @@ -201,7 +323,7 @@ async def close( self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, - exec_tb: Optional[TracebackType] = None, + exec_tb: Optional["TracebackType"] = None, ) -> None: """Closes the object. @@ -212,11 +334,8 @@ async def close( Returns: None - - Raises: - NotImplementedError: If the method is not implemented. """ - self.started = False + self.running = False for h in self.handlers.values(): await h.close() @@ -227,14 +346,14 @@ async def close( @abstractmethod async def publish( self, - message: SendableMessage, + message: "SendableMessage", *args: Any, reply_to: str = "", rpc: bool = False, rpc_timeout: Optional[float] = None, raise_timeout: bool = False, **kwargs: Any, - ) -> Optional[SendableMessage]: + ) -> Any: """Publish a message. Args: @@ -255,27 +374,19 @@ async def publish( raise NotImplementedError() @abstractmethod - def subscriber( # type: ignore[override,return] + def subscriber( self, *broker_args: Any, + filter: Filter["StreamMessage[MsgType]"] = default_filter, + decoder: Optional[CustomDecoder["StreamMessage[MsgType]"]] = None, + parser: Optional[CustomParser[MsgType, "StreamMessage[MsgType]"]] = None, + dependencies: Sequence["Depends"] = (), + middlewares: Sequence[Callable[[MsgType], "BaseMiddleware"]] = (), + raw: bool = False, + no_ack: bool = False, retry: Union[bool, int] = False, - dependencies: Sequence[Depends] = (), - decoder: Optional[CustomDecoder[StreamMessage[MsgType]]] = None, - parser: Optional[CustomParser[MsgType, StreamMessage[MsgType]]] = None, - middlewares: Optional[Sequence[Callable[[MsgType], BaseMiddleware]]] = None, - filter: Filter[StreamMessage[MsgType]] = default_filter, - _raw: bool = False, - _get_dependant: Optional[Any] = None, **broker_kwargs: Any, - ) -> Callable[ - [ - Union[ - Callable[P_HandlerParams, T_HandlerReturn], - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - ] - ], - HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - ]: + ) -> "WrapperProtocol[MsgType]": """A function decorator for subscribing to a message broker. Args: @@ -286,17 +397,15 @@ def subscriber( # type: ignore[override,return] parser: Custom parser function for parsing the decoded message. middlewares: Sequence of middleware functions to be applied to the message. filter: Filter function for filtering the messages to be processed. - _raw: Whether to return the raw message instead of the processed result. - _get_dependant: Optional argument to get the dependant object. + no_ack: Disable FastStream acknowledgement behavior. + raw: Whether to return the raw message instead of the processed result. + get_dependant: Optional argument to get the dependant object. **broker_kwargs: Keyword arguments to be passed to the message broker. Returns: A callable decorator that wraps the decorated function and handles the subscription. - - Raises: - NotImplementedError: If silent animals are not supported. """ - if self.started and not is_test_env(): # pragma: no cover + if self.running and not is_test_env(): # pragma: no cover warnings.warn( "You are trying to register `handler` with already running broker\n" "It has no effect until broker restarting.", @@ -308,8 +417,8 @@ def subscriber( # type: ignore[override,return] def publisher( self, key: Any, - publisher: BasePublisher[MsgType], - ) -> BasePublisher[MsgType]: + publisher: "BasePublisher[MsgType]", + ) -> "BasePublisher[MsgType]": """Publishes a publisher. Args: @@ -321,55 +430,3 @@ def publisher( """ self._publishers = {**self._publishers, key: publisher} return publisher - - def __init__( - self, - *args: Any, - apply_types: bool = True, - validate: bool = True, - logger: Optional[logging.Logger] = access_logger, - log_level: int = logging.INFO, - log_fmt: Optional[str] = "%(asctime)s %(levelname)s - %(message)s", - dependencies: Sequence[Depends] = (), - decoder: Optional[CustomDecoder[StreamMessage[MsgType]]] = None, - parser: Optional[CustomParser[MsgType, StreamMessage[MsgType]]] = None, - middlewares: Optional[Sequence[Callable[[MsgType], BaseMiddleware]]] = None, - graceful_timeout: Optional[float] = None, - **kwargs: Any, - ) -> None: - """Initialize the class. - - Args: - *args: Variable length arguments - apply_types: Whether to apply types or not - validate: Whether to cast types using Pydantic validation. - logger: Logger object for logging - log_level: Log level for logging - log_fmt: Log format for logging - dependencies: Sequence of dependencies - decoder: Custom decoder object - parser: Custom parser object - middlewares: Sequence of middlewares - graceful_timeout: Graceful timeout - **kwargs: Keyword arguments - """ - super().__init__( - *args, - apply_types=apply_types, - validate=validate, - logger=logger, - log_level=log_level, - log_fmt=log_fmt, - dependencies=dependencies, - decoder=cast( - Optional[AsyncCustomDecoder[StreamMessage[MsgType]]], - to_async(decoder) if decoder else None, - ), - parser=cast( - Optional[AsyncCustomParser[MsgType, StreamMessage[MsgType]]], - to_async(parser) if parser else None, - ), - middlewares=middlewares, - **kwargs, - ) - self.graceful_timeout = graceful_timeout diff --git a/faststream/broker/wrapper.py b/faststream/broker/core/call_wrapper.py similarity index 78% rename from faststream/broker/wrapper.py rename to faststream/broker/core/call_wrapper.py index 5235e503a7..a1aa78989f 100644 --- a/faststream/broker/wrapper.py +++ b/faststream/broker/core/call_wrapper.py @@ -6,55 +6,13 @@ from faststream.broker.message import StreamMessage from faststream.broker.types import ( - AsyncPublisherProtocol, MsgType, P_HandlerParams, + PublisherProtocol, T_HandlerReturn, WrappedHandlerCall, WrappedReturn, ) -from faststream.types import SendableMessage - - -class FakePublisher: - """A class to represent a fake publisher. - - Attributes: - method : a callable method that takes arguments and returns an awaitable sendable message - - Methods: - publish : asynchronously publishes a message with optional correlation ID and additional keyword arguments - """ - - def __init__(self, method: Callable[..., Awaitable[SendableMessage]]) -> None: - """Initialize an object. - - Args: - method: A callable that takes any number of arguments and returns an awaitable sendable message. - """ - self.method = method - - async def publish( - self, - message: SendableMessage, - *args: Any, - correlation_id: Optional[str] = None, - **kwargs: Any, - ) -> Any: - """Publish a message. - - Args: - message: The message to be published. - *args: Additinal positional arguments. - correlation_id: Optional correlation ID for the message. - **kwargs: Additional keyword arguments. - - Returns: - The published message. - """ - return await self.method( - message, *args, correlation_id=correlation_id, **kwargs - ) class HandlerCallWrapper(Generic[MsgType, P_HandlerParams, T_HandlerReturn]): @@ -65,7 +23,7 @@ class HandlerCallWrapper(Generic[MsgType, P_HandlerParams, T_HandlerReturn]): _wrapped_call : WrappedHandlerCall object representing the wrapped handler call _original_call : original handler call - _publishers : list of AsyncPublisherProtocol objects + _publishers : list of PublisherProtocol objects Methods: __new__ : Create a new instance of the class @@ -82,7 +40,7 @@ class HandlerCallWrapper(Generic[MsgType, P_HandlerParams, T_HandlerReturn]): _wrapped_call: Optional[WrappedHandlerCall[MsgType, T_HandlerReturn]] _original_call: Callable[P_HandlerParams, T_HandlerReturn] - _publishers: List[AsyncPublisherProtocol] + _publishers: List[PublisherProtocol] __slots__ = ( "mock", @@ -226,9 +184,8 @@ def trigger( if not self.is_test: return - assert ( # nosec B101 - self.future is not None - ), "You can use this method only with TestClient" + if self.future is None: + raise ValueError("You can use this method only with TestClient") if self.future.done(): self.future = asyncio.Future() diff --git a/faststream/broker/handler.py b/faststream/broker/core/handler.py similarity index 55% rename from faststream/broker/handler.py rename to faststream/broker/core/handler.py index 0e539aa967..d673a18888 100644 --- a/faststream/broker/handler.py +++ b/faststream/broker/core/handler.py @@ -1,12 +1,12 @@ -import asyncio from abc import abstractmethod -from contextlib import AsyncExitStack, suppress +from contextlib import AsyncExitStack from dataclasses import dataclass from inspect import unwrap from logging import Logger from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, Awaitable, Callable, Dict, @@ -16,20 +16,13 @@ Sequence, Tuple, Union, - cast, ) -import anyio -from fast_depends.core import CallModel -from fast_depends.dependencies import Depends -from typing_extensions import Self, override - -from faststream._compat import IS_OPTIMIZED from faststream.asyncapi.base import AsyncAPIOperation from faststream.asyncapi.message import parse_handler_params from faststream.asyncapi.utils import to_camelcase -from faststream.broker.core.handler_wrapper import WrapHandlerMixin -from faststream.broker.middlewares import BaseMiddleware +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.handler_wrapper_mixin import WrapHandlerMixin from faststream.broker.types import ( AsyncDecoder, AsyncParser, @@ -38,20 +31,23 @@ Filter, MsgType, P_HandlerParams, + PublisherProtocol, T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper +from faststream.broker.utils import MultiLock from faststream.exceptions import HandlerException, StopConsume from faststream.types import AnyDict, SendableMessage from faststream.utils.context.repository import context from faststream.utils.functions import to_async if TYPE_CHECKING: - from contextvars import Token + from fast_depends.core import CallModel + from fast_depends.dependencies import Depends - from faststream.broker.core.handler_wrapper import WrapperProtocol + from faststream.broker.core.handler_wrapper_mixin import WrapperProtocol from faststream.broker.message import StreamMessage + from faststream.broker.middlewares import BaseMiddleware @dataclass(slots=True) @@ -62,8 +58,8 @@ class HandlerItem(Generic[MsgType]): filter: Callable[["StreamMessage[MsgType]"], Awaitable[bool]] parser: AsyncParser[MsgType, Any] decoder: AsyncDecoder["StreamMessage[MsgType]"] - middlewares: Sequence[Callable[[Any], BaseMiddleware]] - dependant: CallModel[Any, SendableMessage] + middlewares: Sequence[Callable[[Any], "BaseMiddleware"]] + dependant: "CallModel[Any, SendableMessage]" @property def call_name(self) -> str: @@ -85,6 +81,56 @@ def description(self) -> Optional[str]: description = getattr(caller, "__doc__", None) return description + async def call( + self, msg: MsgType, cache: Dict[Any, Any] + ) -> AsyncGenerator[ + Dict[str, str], + Optional["StreamMessage[MsgType]"], + ]: + message = cache[self.parser] = cache.get( + self.parser, + await self.parser(msg), + ) + message.decoded_body = cache[self.decoder] = cache.get( + self.decoder, + await self.decoder(message), + ) + + if await self.filter(message): + log_context = yield message + + result = None + async with AsyncExitStack() as consume_stack: + consume_stack.enter_context(context.scope("message", message)) + consume_stack.enter_context(context.scope("log_context", log_context)) + + for middleware in self.middlewares: + message.decoded_body = await consume_stack.enter_async_context( + middleware.consume_scope(message.decoded_body) + ) + + try: + result = await self.handler.call_wrapped(message) + + except StopConsume: + self.handler.trigger() + raise + + except HandlerException: + self.handler.trigger() + raise + + except Exception as e: + self.handler.trigger(error=e) + raise e + + else: + self.handler.trigger(result=result[0] if result else None) + yield result + + else: + yield None + class BaseHandler(AsyncAPIOperation, WrapHandlerMixin[MsgType]): """A class representing an asynchronous handler. @@ -102,7 +148,7 @@ def __init__( self, *, log_context_builder: Callable[["StreamMessage[Any]"], Dict[str, str]], - middlewares: Sequence[Callable[[MsgType], BaseMiddleware]], + middlewares: Sequence[Callable[[MsgType], "BaseMiddleware"]], logger: Optional[Logger], description: Optional[str], title: Optional[str], @@ -125,13 +171,62 @@ def __init__( self._title = title super().__init__(include_in_schema=include_in_schema) + @abstractmethod + async def start(self) -> None: + """Start the handler.""" + self.running = True + + @abstractmethod + async def close(self) -> None: + """Close the handler. + + Blocks loop up to graceful_timeout seconds. + """ + self.running = False + await self.lock.wait_release(self.graceful_timeout) + + @property + def call_name(self) -> str: + """Returns the name of the handler call.""" + return to_camelcase(self.calls[0].call_name) + + @property + def description(self) -> Optional[str]: + """Returns the description of the handler.""" + if self._description: + return self._description + + if not self.calls: # pragma: no cover + return None + + else: + return self.calls[0].description + + def get_payloads(self) -> List[Tuple[AnyDict, str]]: + """Get the payloads of the handler.""" + payloads: List[Tuple[AnyDict, str]] = [] + + for h in self.calls: + body = parse_handler_params( + h.dependant, + prefix=f"{self._title or self.call_name}:Message", + ) + payloads.append( + ( + body, + to_camelcase(h.call_name), + ) + ) + + return payloads + def add_call( self, filter_: Filter["StreamMessage[MsgType]"], parser_: CustomParser[MsgType, Any], decoder_: CustomDecoder["StreamMessage[MsgType]"], - middlewares_: Sequence[Callable[[Any], BaseMiddleware]], - dependencies_: Sequence[Depends], + middlewares_: Sequence[Callable[[Any], "BaseMiddleware"]], + dependencies_: Sequence["Depends"], **wrap_kwargs: Any, ) -> "WrapperProtocol[MsgType]": def wrapper( @@ -140,8 +235,8 @@ def wrapper( filter: Filter["StreamMessage[MsgType]"] = filter_, parser: CustomParser[MsgType, Any] = parser_, decoder: CustomDecoder["StreamMessage[MsgType]"] = decoder_, - middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), - dependencies: Sequence[Depends] = (), + middlewares: Sequence[Callable[[Any], "BaseMiddleware"]] = (), + dependencies: Sequence["Depends"] = (), ) -> Union[ HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], Callable[ @@ -150,7 +245,7 @@ def wrapper( ], ]: total_deps = (*dependencies_, *dependencies) - total_middlewares = (*middlewares_, *middlewares) + total_middlewares = (*self.middlewares, *middlewares_, *middlewares) def real_wrapper( func: Callable[P_HandlerParams, T_HandlerReturn], @@ -183,8 +278,7 @@ def real_wrapper( return wrapper - @override - async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[override] + async def consume(self, msg: MsgType) -> SendableMessage: """Consume a message asynchronously. Args: @@ -192,9 +286,6 @@ async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[overri Returns: The sendable message. - - Raises: - StopConsume: If the consumption needs to be stopped. """ result: Optional[WrappedReturn[SendableMessage]] = None result_msg: SendableMessage = None @@ -202,185 +293,54 @@ async def consume(self, msg: MsgType) -> SendableMessage: # type: ignore[overri if not self.running: return result_msg - log_context_tag: Optional["Token[Any]"] = None async with AsyncExitStack() as stack: stack.enter_context(self.lock) stack.enter_context(context.scope("handler_", self)) - gl_middlewares: List[BaseMiddleware] = [ - await stack.enter_async_context(m(msg)) for m in self.middlewares - ] + for m in self.middlewares: + await stack.enter_async_context(m(msg)) - logged = False + cache = {} processed = False for h in self.calls: - local_middlewares: List[BaseMiddleware] = [ - await stack.enter_async_context(m(msg)) for m in h.middlewares - ] + if processed: + break - all_middlewares = gl_middlewares + local_middlewares - - # TODO: add parser & decoder caches - message = await h.parser(msg) - - if not logged: # pragma: no branch - log_context_tag = context.set_local( - "log_context", - self.log_context_builder(message), - ) - - message.decoded_body = await h.decoder(message) - message.processed = processed - - if await h.filter(message): - assert ( # nosec B101 - not processed - ), "You can't process a message with multiple consumers" + caller = h.call(msg, cache) + if (message := await caller.__anext__()) is not None: + processed = True try: - async with AsyncExitStack() as consume_stack: - for m_consume in all_middlewares: - message.decoded_body = ( - await consume_stack.enter_async_context( - m_consume.consume_scope(message.decoded_body) - ) - ) - - result = await cast( - Awaitable[Optional[WrappedReturn[SendableMessage]]], - h.handler.call_wrapped(message), - ) - - if result is not None: - result_msg, pub_response = result - - # TODO: suppress all publishing errors and raise them after all publishers will be tried - for publisher in (pub_response, *h.handler._publishers): - if publisher is not None: - async with AsyncExitStack() as pub_stack: - result_to_send = result_msg - - for m_pub in all_middlewares: - result_to_send = ( - await pub_stack.enter_async_context( - m_pub.publish_scope(result_to_send) - ) - ) - - await publisher.publish( - message=result_to_send, - correlation_id=message.correlation_id, - ) - + await caller.asend(self.log_context_builder(message)) + except StopAsyncIteration as e: + result = e.value except StopConsume: await self.close() - h.handler.trigger() - - except HandlerException as e: # pragma: no cover - h.handler.trigger() - raise e - - except Exception as e: - h.handler.trigger(error=e) - raise e + return + + # TODO: suppress all publishing errors and raise them after all publishers will be tried + for publisher in ( + *self.make_response_publisher(message), + *h.handler._publishers, + ): + if publisher is not None: + async with AsyncExitStack() as pub_stack: + for m_pub in h.middlewares: + result = await pub_stack.enter_async_context( + m_pub.publish_scope(result) + ) - else: - h.handler.trigger(result=result[0] if result else None) - message.processed = processed = True - if IS_OPTIMIZED: # pragma: no cover - break + await publisher.publish( + message=result, + correlation_id=message.correlation_id, + ) assert not self.running or processed, "You have to consume message" # nosec B101 - if log_context_tag is not None: - context.reset_local("log_context", log_context_tag) - return result_msg - @abstractmethod - async def start(self) -> None: - """Start the handler.""" - self.running = True - - @abstractmethod - async def close(self) -> None: - """Close the handler. - - Blocks loop up to graceful_timeout seconds. - """ - self.running = False - await self.lock.wait_release(self.graceful_timeout) - - @property - def call_name(self) -> str: - """Returns the name of the handler call.""" - return to_camelcase(self.calls[0].call_name) - - @property - def description(self) -> Optional[str]: - """Returns the description of the handler.""" - if self._description: - return self._description - - if not self.calls: # pragma: no cover - return None - - else: - return self.calls[0].description - - def get_payloads(self) -> List[Tuple[AnyDict, str]]: - """Get the payloads of the handler.""" - payloads: List[Tuple[AnyDict, str]] = [] - - for h in self.calls: - body = parse_handler_params( - h.dependant, - prefix=f"{self._title or self.call_name}:Message", - ) - payloads.append( - ( - body, - to_camelcase(h.call_name), - ) - ) - - return payloads - - -class MultiLock: - """A class representing a multi lock.""" - - def __init__(self) -> None: - """Initialize a new instance of the class.""" - self.queue: "asyncio.Queue[None]" = asyncio.Queue() - - def __enter__(self) -> Self: - """Enter the context.""" - self.queue.put_nowait(None) - return self - - def __exit__(self, *args: Any, **kwargs: Any) -> None: - """Exit the context.""" - with suppress(asyncio.QueueEmpty, ValueError): - self.queue.get_nowait() - self.queue.task_done() - - @property - def qsize(self) -> int: - """Return the size of the queue.""" - return self.queue.qsize() - - @property - def empty(self) -> bool: - """Return whether the queue is empty.""" - return self.queue.empty() - - async def wait_release(self, timeout: Optional[float] = None) -> None: - """Wait for the queue to be released. - - Using for graceful shutdown. - """ - if timeout: - with anyio.move_on_after(timeout): - await self.queue.join() + def make_response_publisher( + self, message: "StreamMessage[MsgType]" + ) -> Sequence[PublisherProtocol]: + raise NotImplementedError() diff --git a/faststream/broker/core/handler_wrapper.py b/faststream/broker/core/handler_wrapper_mixin.py similarity index 98% rename from faststream/broker/core/handler_wrapper.py rename to faststream/broker/core/handler_wrapper_mixin.py index ddcaf70675..fd901991f6 100644 --- a/faststream/broker/core/handler_wrapper.py +++ b/faststream/broker/core/handler_wrapper_mixin.py @@ -21,6 +21,7 @@ from pydantic import create_model from faststream._compat import PYDANTIC_V2 +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.middlewares import BaseMiddleware from faststream.broker.push_back_watcher import WatcherContext from faststream.broker.types import ( @@ -32,8 +33,7 @@ T_HandlerReturn, WrappedReturn, ) -from faststream.broker.utils import get_watcher, set_message_context -from faststream.broker.wrapper import HandlerCallWrapper +from faststream.broker.utils import get_watcher from faststream.types import F_Return, F_Spec from faststream.utils.functions import fake_context, to_async @@ -113,8 +113,8 @@ def wrap_handler( logger: Optional[Logger], apply_types: bool, is_validate: bool, - no_ack: bool = False, raw: bool = False, + no_ack: bool = False, retry: Union[bool, int] = False, get_dependant: Optional[Any] = None, **process_kwargs: Any, @@ -171,7 +171,6 @@ def wrap_handler( **(process_kwargs or {}), ) - f = set_message_context(f) handler_call.set_wrapped(f) return handler_call, dependant diff --git a/faststream/broker/core/mixins.py b/faststream/broker/core/logging_mixin.py similarity index 99% rename from faststream/broker/core/mixins.py rename to faststream/broker/core/logging_mixin.py index d382eb6ad4..7a71caa936 100644 --- a/faststream/broker/core/mixins.py +++ b/faststream/broker/core/logging_mixin.py @@ -19,7 +19,6 @@ class LoggingMixin: fmt : getter method for _fmt attribute _get_log_context : returns a dictionary with log context information _log : logs a message with optional log level, extra data, and exception info - """ def __init__( @@ -41,7 +40,6 @@ def __init__( Returns: None - """ self.logger = logger self.log_level = log_level @@ -67,7 +65,6 @@ def _get_log_context( Returns: A dictionary containing the log context with the following keys: - message_id: The first 10 characters of the message_id if message is not None, otherwise an empty string - """ return { "message_id": message.message_id[: self._message_id_ln] if message else "", @@ -90,7 +87,6 @@ def _log( Returns: None - """ if self.logger is not None: self.logger.log( diff --git a/faststream/broker/publisher.py b/faststream/broker/core/publisher.py similarity index 76% rename from faststream/broker/publisher.py rename to faststream/broker/core/publisher.py index a3db1791f7..e59ce8358b 100644 --- a/faststream/broker/publisher.py +++ b/faststream/broker/core/publisher.py @@ -1,7 +1,7 @@ from abc import abstractmethod from dataclasses import dataclass, field from inspect import unwrap -from typing import Any, Callable, Generic, List, Optional, Tuple +from typing import Any, Awaitable, Callable, Generic, List, Optional, Tuple from unittest.mock import MagicMock from fast_depends._compat import create_model, get_config_base @@ -10,11 +10,55 @@ from faststream.asyncapi.base import AsyncAPIOperation from faststream.asyncapi.message import get_response_schema from faststream.asyncapi.utils import to_camelcase +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.types import MsgType, P_HandlerParams, T_HandlerReturn -from faststream.broker.wrapper import HandlerCallWrapper from faststream.types import AnyDict, SendableMessage +class FakePublisher: + """A class to represent a fake publisher. + + Attributes: + method : a callable method that takes arguments and returns an awaitable sendable message + + Methods: + publish : asynchronously publishes a message with optional correlation ID and additional keyword arguments + """ + + def __init__(self, method: Callable[..., Awaitable[SendableMessage]]) -> None: + """Initialize an object. + + Args: + method: A callable that takes any number of arguments and returns an awaitable sendable message. + """ + self.method = method + + async def publish( + self, + message: SendableMessage, + *args: Any, + correlation_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Publish a message. + + Args: + message: The message to be published. + *args: Additinal positional arguments. + correlation_id: Optional correlation ID for the message. + **kwargs: Additional keyword arguments. + + Returns: + The published message. + """ + return await self.method( + message, + *args, + correlation_id=correlation_id, + **kwargs, + ) + + @dataclass class BasePublisher(AsyncAPIOperation, Generic[MsgType]): """A base class for publishers in an asynchronous API. diff --git a/faststream/broker/fastapi/route.py b/faststream/broker/fastapi/route.py index 3218b9ddbe..141da5de7c 100644 --- a/faststream/broker/fastapi/route.py +++ b/faststream/broker/fastapi/route.py @@ -28,11 +28,11 @@ from starlette.routing import BaseRoute from faststream._compat import FASTAPI_V106, raise_fastapi_validation_error -from faststream.broker.core.asynchronous import BrokerAsyncUsecase +from faststream.broker.core.broker import BrokerUsecase +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.message import StreamMessage as NativeMessage from faststream.broker.schemas import NameRequired from faststream.broker.types import MsgType, P_HandlerParams, T_HandlerReturn -from faststream.broker.wrapper import HandlerCallWrapper from faststream.types import AnyDict, SendableMessage @@ -42,7 +42,7 @@ class StreamRoute(BaseRoute, Generic[MsgType, P_HandlerParams, T_HandlerReturn]) Attributes: handler : HandlerCallWrapper object representing the handler for the route path : path of the route - broker : BrokerAsyncUsecase object representing the broker for the route + broker : BrokerUsecase object representing the broker for the route dependant : Dependable object representing the dependencies for the route """ @@ -56,7 +56,7 @@ def __init__( Callable[P_HandlerParams, T_HandlerReturn], HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], ], - broker: BrokerAsyncUsecase[MsgType, Any], + broker: BrokerUsecase[MsgType, Any], dependencies: Sequence[params.Depends] = (), dependency_overrides_provider: Optional[Any] = None, **handle_kwargs: Any, diff --git a/faststream/broker/fastapi/router.py b/faststream/broker/fastapi/router.py index be883bac64..5d606dbc50 100644 --- a/faststream/broker/fastapi/router.py +++ b/faststream/broker/fastapi/router.py @@ -32,12 +32,12 @@ from faststream.asyncapi import schema as asyncapi from faststream.asyncapi.schema import Schema from faststream.asyncapi.site import get_asyncapi_html -from faststream.broker.core.asynchronous import BrokerAsyncUsecase +from faststream.broker.core.broker import BrokerUsecase +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.publisher import BasePublisher from faststream.broker.fastapi.route import StreamRoute -from faststream.broker.publisher import BasePublisher from faststream.broker.schemas import NameRequired from faststream.broker.types import MsgType, P_HandlerParams, T_HandlerReturn -from faststream.broker.wrapper import HandlerCallWrapper from faststream.types import AnyDict from faststream.utils.functions import to_async @@ -70,8 +70,8 @@ class StreamRouter(APIRouter, Generic[MsgType]): _setup_log_context : setup log context for the broker """ - broker_class: Type[BrokerAsyncUsecase[MsgType, Any]] - broker: BrokerAsyncUsecase[MsgType, Any] + broker_class: Type[BrokerUsecase[MsgType, Any]] + broker: BrokerUsecase[MsgType, Any] docs_router: Optional[APIRouter] _after_startup_hooks: List[ Callable[[AppType], Awaitable[Optional[Mapping[str, Any]]]] @@ -592,8 +592,8 @@ def include_router( @staticmethod @abstractmethod def _setup_log_context( - main_broker: BrokerAsyncUsecase[MsgType, Any], - including_broker: BrokerAsyncUsecase[MsgType, Any], + main_broker: BrokerUsecase[MsgType, Any], + including_broker: BrokerUsecase[MsgType, Any], ) -> None: """Set up log context. diff --git a/faststream/broker/message.py b/faststream/broker/message.py index 73de348f0a..e856a5c5f6 100644 --- a/faststream/broker/message.py +++ b/faststream/broker/message.py @@ -4,12 +4,12 @@ from faststream.types import AnyDict, DecodedMessage -Msg = TypeVar("Msg") +MsgType = TypeVar("MsgType") @dataclass -class ABCStreamMessage(Generic[Msg]): - """A generic class to represent a stream message. +class StreamMessage(Generic[MsgType]): + """Generic class to represent a stream message. Attributes: raw_message : the raw message @@ -21,10 +21,9 @@ class ABCStreamMessage(Generic[Msg]): message_id : the unique identifier of the message correlation_id : the correlation identifier of the message processed : a flag indicating whether the message has been processed or not - """ - raw_message: Msg + raw_message: "MsgType" body: Union[bytes, Any] decoded_body: Optional[DecodedMessage] = None @@ -41,23 +40,6 @@ class ABCStreamMessage(Generic[Msg]): processed: bool = field(default=False, init=False) committed: bool = field(default=False, init=False) - -class SyncStreamMessage(ABCStreamMessage[Msg]): - """A generic class to represent a stream message.""" - - def ack(self, **kwargs: Any) -> None: - self.committed = True - - def nack(self, **kwargs: Any) -> None: - self.committed = True - - def reject(self, **kwargs: Any) -> None: - self.committed = True - - -class StreamMessage(ABCStreamMessage[Msg]): - """A generic class to represent a stream message.""" - async def ack(self, **kwargs: Any) -> None: self.committed = True diff --git a/faststream/broker/middlewares.py b/faststream/broker/middlewares.py index 3497b54152..ca8f634720 100644 --- a/faststream/broker/middlewares.py +++ b/faststream/broker/middlewares.py @@ -243,11 +243,13 @@ def __call__(self, msg: Any) -> Self: """ return self - async def on_receive(self) -> DecodedMessage: + async def on_consume(self, msg: DecodedMessage) -> DecodedMessage: if self.logger is not None: c = context.get_local("log_context") self.logger.log(self.log_level, "Received", extra=c) + return await super().on_consume(msg) + async def after_processed( self, exc_type: Optional[Type[BaseException]] = None, diff --git a/faststream/broker/push_back_watcher.py b/faststream/broker/push_back_watcher.py index 5eaed0b4b3..548348c804 100644 --- a/faststream/broker/push_back_watcher.py +++ b/faststream/broker/push_back_watcher.py @@ -2,11 +2,9 @@ from collections import Counter from logging import Logger from types import TracebackType -from typing import Any, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Type from typing import Counter as CounterType -from faststream.broker.message import StreamMessage, SyncStreamMessage -from faststream.broker.types import MsgType from faststream.exceptions import ( AckMessage, HandlerException, @@ -16,6 +14,10 @@ ) from faststream.utils.functions import call_or_await +if TYPE_CHECKING: + from faststream.broker.message import StreamMessage + from faststream.broker.types import MsgType + class BaseWatcher(ABC): """A base class for a watcher. @@ -31,7 +33,6 @@ class BaseWatcher(ABC): add : add a message to the watcher is_max : check if the maximum number of tries has been reached for a message remove : remove a message from the watcher - """ max_tries: int @@ -46,10 +47,6 @@ def __init__( Args: max_tries: Maximum number of tries allowed logger: Optional logger object - - Raises: - NotImplementedError: If the method is not implemented in the subclass. - """ self.logger = logger self.max_tries = max_tries @@ -66,7 +63,6 @@ def add(self, message_id: str) -> None: Raises: NotImplementedError: If the method is not implemented - """ raise NotImplementedError() @@ -82,7 +78,6 @@ def is_max(self, message_id: str) -> bool: Raises: NotImplementedError: This method is meant to be overridden by subclasses. - """ raise NotImplementedError() @@ -98,7 +93,6 @@ def remove(self, message_id: str) -> None: Raises: NotImplementedError: If the method is not implemented - """ raise NotImplementedError() @@ -114,7 +108,6 @@ def add(self, message_id: str) -> None: Returns: None - """ pass @@ -125,8 +118,7 @@ def is_max(self, message_id: str) -> bool: message_id: ID of the message to check Returns: - True if the message is the maximum, False otherwise - + Always False """ return False @@ -138,7 +130,6 @@ def remove(self, message_id: str) -> None: Returns: None - """ pass @@ -154,7 +145,6 @@ def add(self, message_id: str) -> None: Returns: None - """ pass @@ -165,8 +155,7 @@ def is_max(self, message_id: str) -> bool: message_id: The ID of the message to check. Returns: - True if the given message ID is the maximum, False otherwise. - + Always True """ return True @@ -178,7 +167,6 @@ def remove(self, message_id: str) -> None: Returns: None - """ pass @@ -198,7 +186,6 @@ class CounterWatcher(BaseWatcher): add(self, message_id: str) -> None - adds a message to the counter is_max(self, message_id: str) -> bool - checks if the count of a message has reached the maximum tries remove(self, message: str) -> None - removes a message from the counter - """ memory: CounterType[str] @@ -213,7 +200,6 @@ def __init__( Args: max_tries (int): maximum number of tries logger (Optional[Logger]): logger object (default: None) - """ super().__init__(logger=logger, max_tries=max_tries) self.memory = Counter() @@ -226,7 +212,6 @@ def add(self, message_id: str) -> None: Returns: None - """ self.memory[message_id] += 1 @@ -238,7 +223,6 @@ def is_max(self, message_id: str) -> bool: Returns: True if the number of tries has exceeded the maximum allowed tries, False otherwise - """ is_max = self.memory[message_id] > self.max_tries if self.logger is not None: @@ -256,7 +240,6 @@ def remove(self, message: str) -> None: Returns: None - """ self.memory[message] = 0 self.memory += Counter() @@ -276,12 +259,11 @@ class WatcherContext: __ack : acknowledges the message __nack : negatively acknowledges the message __reject : rejects the message - """ def __init__( self, - message: Union[SyncStreamMessage[MsgType], StreamMessage[MsgType]], + message: "StreamMessage[MsgType]", watcher: BaseWatcher, **extra_ack_args: Any, ) -> None: @@ -289,14 +271,13 @@ def __init__( Args: watcher: An instance of BaseWatcher. - message: An instance of SyncStreamMessage or StreamMessage. + message: An instance of StreamMessage. **extra_ack_args: Additional arguments for acknowledgement. Attributes: watcher: An instance of BaseWatcher. - message: An instance of SyncStreamMessage or StreamMessage. + message: An instance of StreamMessage. extra_ack_args: Additional arguments for acknowledgement. - """ self.watcher = watcher self.message = message @@ -320,7 +301,6 @@ async def __aexit__( Returns: A boolean indicating whether the exit was successful or not. - """ if not exc_type: await self.__ack() diff --git a/faststream/broker/router.py b/faststream/broker/router.py index 4f81e26485..6673a9e162 100644 --- a/faststream/broker/router.py +++ b/faststream/broker/router.py @@ -14,8 +14,9 @@ from fast_depends.dependencies import Depends +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.publisher import BasePublisher from faststream.broker.message import StreamMessage -from faststream.broker.publisher import BasePublisher from faststream.broker.types import ( CustomDecoder, CustomParser, @@ -23,7 +24,6 @@ P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.types import AnyDict, SendableMessage PublisherKeyType = TypeVar("PublisherKeyType") @@ -41,7 +41,6 @@ class BrokerRoute(Generic[MsgType, T_HandlerReturn]): call : callable object representing the route *args : variable length arguments for the route **kwargs : variable length keyword arguments for the route - """ call: Callable[..., T_HandlerReturn] @@ -60,7 +59,6 @@ def __init__( call: A callable object. *args: Positional arguments to be passed to the callable object. **kwargs: Keyword arguments to be passed to the callable object. - """ self.call = call self.args = args @@ -84,7 +82,6 @@ class BrokerRouter(Generic[PublisherKeyType, MsgType]): publisher : abstract method to define a publisher include_router : method to include a router include_routers : method to include multiple routers - """ prefix: str @@ -106,7 +103,6 @@ def _get_publisher_key(publisher: BasePublisher[MsgType]) -> PublisherKeyType: Raises: NotImplementedError: This function is not implemented. - """ raise NotImplementedError() @@ -158,7 +154,6 @@ def __init__( parser (Optional[CustomParser[MsgType]]): Parser for the object. decoder (Optional[CustomDecoder[StreamMessage[MsgType]]]): Decoder for the object. include_in_schema (Optional[bool]): Whether to include the object in the schema. - """ self.prefix = prefix self.include_in_schema = include_in_schema @@ -208,7 +203,6 @@ def subscriber( Raises: NotImplementedError: If the function is not implemented - """ raise NotImplementedError() @@ -245,8 +239,6 @@ def _wrap_subscriber( Returns: A callable object that wraps the decorated function - - This function is decorated with `@abstractmethod`, indicating that it is an abstract method and must be implemented by any subclass. """ def router_subscriber_wrapper( @@ -270,11 +262,7 @@ def router_subscriber_wrapper( middlewares=(*(self._middlewares or ()), *(middlewares or ())), parser=parser or self._parser, decoder=decoder or self._decoder, - include_in_schema=( - include_in_schema - if self.include_in_schema is None - else self.include_in_schema - ), + include_in_schema=self.solve_include_in_schema(include_in_schema), **kwargs, ) self._handlers.append(route) @@ -282,6 +270,12 @@ def router_subscriber_wrapper( return router_subscriber_wrapper + def solve_include_in_schema(self, include_in_schema: bool) -> bool: + if self.include_in_schema is None: + return include_in_schema + else: + return self.include_in_schema + @abstractmethod def publisher( self, @@ -301,7 +295,6 @@ def publisher( Raises: NotImplementedError: If the method is not implemented - """ raise NotImplementedError() @@ -313,7 +306,6 @@ def include_router(self, router: "BrokerRouter[PublisherKeyType, MsgType]") -> N Returns: None - """ for h in router._handlers: self.subscriber(*h.args, **h.kwargs)(h.call) @@ -321,6 +313,7 @@ def include_router(self, router: "BrokerRouter[PublisherKeyType, MsgType]") -> N for p in router._publishers.values(): p = self._update_publisher_prefix(self.prefix, p) key = self._get_publisher_key(p) + p.include_in_schema = self.solve_include_in_schema(p.include_in_schema) self._publishers[key] = self._publishers.get(key, p) def include_routers( @@ -333,7 +326,6 @@ def include_routers( Returns: None - """ for r in routers: self.include_router(r) diff --git a/faststream/broker/schemas.py b/faststream/broker/schemas.py index b575179388..99911d645e 100644 --- a/faststream/broker/schemas.py +++ b/faststream/broker/schemas.py @@ -1,6 +1,6 @@ from typing import Any, Optional, Type, TypeVar, Union, overload -from pydantic import BaseModel, Field, Json +from pydantic import BaseModel, Field Cls = TypeVar("Cls") NameRequiredCls = TypeVar("NameRequiredCls", bound="NameRequired") @@ -114,14 +114,3 @@ def validate( if value is not None and isinstance(value, str): value = cls(value, **kwargs) return value - - -class RawDecoced(BaseModel): - """A class to represent a raw decoded message. - - Attributes: - message : the decoded message, which can be either a JSON object or a string - - """ - - message: Union[Json[Any], str] diff --git a/faststream/broker/test.py b/faststream/broker/test.py index 6d917d451c..114fab6384 100644 --- a/faststream/broker/test.py +++ b/faststream/broker/test.py @@ -9,15 +9,14 @@ from anyio.from_thread import start_blocking_portal from faststream.app import FastStream -from faststream.broker.core.abc import BrokerUsecase -from faststream.broker.core.asynchronous import BrokerAsyncUsecase -from faststream.broker.handler import BaseHandler -from faststream.broker.wrapper import HandlerCallWrapper +from faststream.broker.core.broker import BrokerUsecase +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.handler import BaseHandler from faststream.types import SendableMessage, SettingField from faststream.utils.ast import is_contains_context_name from faststream.utils.functions import timeout_scope -Broker = TypeVar("Broker", bound=BrokerAsyncUsecase[Any, Any]) +Broker = TypeVar("Broker", bound=BrokerUsecase[Any, Any]) class TestApp: @@ -33,7 +32,6 @@ class TestApp: __init__ : initializes the TestApp object __aenter__ : enters the asynchronous context and starts the FastStream application __aexit__ : exits the asynchronous context and stops the FastStream application - """ __test__ = False @@ -54,7 +52,6 @@ def __init__( Returns: None - """ self.app = app self._extra_options = run_extra_options or {} @@ -127,7 +124,6 @@ def __init__( broker: An instance of the Broker class. with_real: Whether to use a real broker. connect_only: Whether to only connect to the broker. - """ self.with_real = with_real self.broker = broker @@ -290,7 +286,6 @@ async def call_handler( Raises: TimeoutError: If the RPC times out and `raise_timeout` is True. - """ with timeout_scope(rpc_timeout, raise_timeout): result = await handler.consume(message) diff --git a/faststream/broker/types.py b/faststream/broker/types.py index 0167e9f762..9c2802ddf7 100644 --- a/faststream/broker/types.py +++ b/faststream/broker/types.py @@ -74,7 +74,7 @@ ) -class AsyncPublisherProtocol(Protocol): +class PublisherProtocol(Protocol): """A protocol for an asynchronous publisher.""" async def publish( @@ -98,7 +98,7 @@ async def publish( ... -WrappedReturn: TypeAlias = Tuple[T_HandlerReturn, Optional[AsyncPublisherProtocol]] +WrappedReturn: TypeAlias = Tuple[T_HandlerReturn, Optional[PublisherProtocol]] AsyncWrappedHandlerCall: TypeAlias = Callable[ [StreamMessage[MsgType]], diff --git a/faststream/broker/utils.py b/faststream/broker/utils.py index 87bc32a2f5..6c830f0b29 100644 --- a/faststream/broker/utils.py +++ b/faststream/broker/utils.py @@ -1,19 +1,24 @@ -import logging -from functools import wraps -from typing import Awaitable, Callable, Optional, Union +import asyncio +from contextlib import suppress +from typing import TYPE_CHECKING, Optional, Type, Union + +import anyio -from faststream.broker.message import StreamMessage from faststream.broker.push_back_watcher import ( BaseWatcher, CounterWatcher, EndlessWatcher, OneTryWatcher, ) -from faststream.broker.types import MsgType, T_HandlerReturn, WrappedReturn -from faststream.utils import context + +if TYPE_CHECKING: + from logging import Logger + from types import TracebackType + + from typing_extensions import Self -def change_logger_handlers(logger: logging.Logger, fmt: str) -> None: +def change_logger_handlers(logger: "Logger", fmt: str) -> None: """Change the formatter of the logger handlers. Args: @@ -36,7 +41,7 @@ def change_logger_handlers(logger: logging.Logger, fmt: str) -> None: def get_watcher( - logger: Optional[logging.Logger], + logger: Optional["Logger"], try_number: Union[bool, int] = True, ) -> BaseWatcher: """Get a watcher object based on the provided parameters. @@ -62,34 +67,44 @@ def get_watcher( return watcher -def set_message_context( - func: Callable[ - [StreamMessage[MsgType]], - Awaitable[WrappedReturn[T_HandlerReturn]], - ], -) -> Callable[[StreamMessage[MsgType]], Awaitable[WrappedReturn[T_HandlerReturn]]]: - """Sets the message context for a function. +class MultiLock: + """A class representing a multi lock.""" - Args: - func: The function to set the message context for. + def __init__(self) -> None: + """Initialize a new instance of the class.""" + self.queue: "asyncio.Queue[None]" = asyncio.Queue() - Returns: - The function with the message context set. - """ + def __enter__(self) -> "Self": + """Enter the context.""" + self.queue.put_nowait(None) + return self - @wraps(func) - async def set_message_wrapper( - message: StreamMessage[MsgType], - ) -> WrappedReturn[T_HandlerReturn]: - """Wraps a function that handles a stream message. + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exec_tb: Optional["TracebackType"], + ) -> None: + """Exit the context.""" + with suppress(asyncio.QueueEmpty, ValueError): + self.queue.get_nowait() + self.queue.task_done() - Args: - message: The stream message to be handled. + @property + def qsize(self) -> int: + """Return the size of the queue.""" + return self.queue.qsize() - Returns: - The wrapped return value of the handler function. - """ - with context.scope("message", message): - return await func(message) + @property + def empty(self) -> bool: + """Return whether the queue is empty.""" + return self.queue.empty() - return set_message_wrapper + async def wait_release(self, timeout: Optional[float] = None) -> None: + """Wait for the queue to be released. + + Using for graceful shutdown. + """ + if timeout: + with anyio.move_on_after(timeout): + await self.queue.join() diff --git a/faststream/kafka/broker.py b/faststream/kafka/broker.py index 0a39c16a82..6ce444ea51 100644 --- a/faststream/kafka/broker.py +++ b/faststream/kafka/broker.py @@ -23,19 +23,20 @@ from typing_extensions import override from faststream.__about__ import __version__ -from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter +from faststream.broker.core.broker import BrokerUsecase, default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.publisher import FakePublisher from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( - AsyncPublisherProtocol, CustomDecoder, CustomParser, Filter, P_HandlerParams, + PublisherProtocol, T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import FakePublisher, HandlerCallWrapper from faststream.exceptions import NOT_CONNECTED_YET from faststream.kafka.asyncapi import Handler, Publisher from faststream.kafka.message import KafkaMessage @@ -50,11 +51,11 @@ class KafkaBroker( KafkaLoggingMixin, - BrokerAsyncUsecase[aiokafka.ConsumerRecord, ConsumerConnectionParams], + BrokerUsecase[aiokafka.ConsumerRecord, ConsumerConnectionParams], ): """KafkaBroker is a class for managing Kafka message consumption and publishing. - It extends BrokerAsyncUsecase to handle asynchronous operations. + It extends BrokerUsecase to handle asynchronous operations. Args: bootstrap_servers (Union[str, Iterable[str]]): Kafka bootstrap server(s). @@ -233,7 +234,7 @@ async def process_wrapper( async with watcher(message): r = await func(message) - pub_response: Optional[AsyncPublisherProtocol] + pub_response: Optional[PublisherProtocol] if message.reply_to: pub_response = FakePublisher( partial(self.publish, topic=message.reply_to) diff --git a/faststream/kafka/broker.pyi b/faststream/kafka/broker.pyi index 625276cf94..b3584a6198 100644 --- a/faststream/kafka/broker.pyi +++ b/faststream/kafka/broker.pyi @@ -23,7 +23,8 @@ from typing_extensions import override from faststream.__about__ import __version__ from faststream.asyncapi import schema as asyncapi -from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter +from faststream.broker.core.broker import BrokerUsecase, default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.security import BaseSecurity @@ -35,7 +36,6 @@ from faststream.broker.types import ( T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.kafka.asyncapi import Handler, Publisher from faststream.kafka.message import KafkaMessage from faststream.kafka.producer import AioKafkaFastProducer @@ -48,7 +48,7 @@ Partition = TypeVar("Partition") class KafkaBroker( KafkaLoggingMixin, - BrokerAsyncUsecase[aiokafka.ConsumerRecord, ConsumerConnectionParams], + BrokerUsecase[aiokafka.ConsumerRecord, ConsumerConnectionParams], ): handlers: dict[str, Handler] _publishers: dict[str, Publisher] diff --git a/faststream/kafka/fastapi.pyi b/faststream/kafka/fastapi.pyi index 8ffc35de03..c4bfce48b1 100644 --- a/faststream/kafka/fastapi.pyi +++ b/faststream/kafka/fastapi.pyi @@ -31,7 +31,8 @@ from typing_extensions import override from faststream.__about__ import __version__ from faststream.asyncapi import schema as asyncapi -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.fastapi.router import StreamRouter from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware @@ -43,7 +44,6 @@ from faststream.broker.types import ( P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.kafka.asyncapi import Publisher from faststream.kafka.broker import KafkaBroker from faststream.kafka.message import KafkaMessage diff --git a/faststream/kafka/handler.py b/faststream/kafka/handler.py index fca5fa1ea5..f497cff399 100644 --- a/faststream/kafka/handler.py +++ b/faststream/kafka/handler.py @@ -9,7 +9,8 @@ from typing_extensions import Unpack, override from faststream.__about__ import __version__ -from faststream.broker.handler import BaseHandler +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.handler import BaseHandler from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.parsers import resolve_custom_func @@ -20,7 +21,6 @@ P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.kafka.message import KafkaMessage from faststream.kafka.parser import AioKafkaParser from faststream.kafka.shared.schemas import ConsumerConnectionParams diff --git a/faststream/kafka/router.pyi b/faststream/kafka/router.pyi index b50eb4c4b6..992f030bf6 100644 --- a/faststream/kafka/router.pyi +++ b/faststream/kafka/router.pyi @@ -6,7 +6,8 @@ from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from typing_extensions import override -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.middlewares import BaseMiddleware from faststream.broker.router import BrokerRouter from faststream.broker.types import ( @@ -16,7 +17,6 @@ from faststream.broker.types import ( P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.kafka.asyncapi import Publisher from faststream.kafka.message import KafkaMessage from faststream.kafka.shared.router import KafkaRoute diff --git a/faststream/kafka/shared/logging.py b/faststream/kafka/shared/logging.py index 0ff05c044c..f1192f865e 100644 --- a/faststream/kafka/shared/logging.py +++ b/faststream/kafka/shared/logging.py @@ -4,7 +4,7 @@ from aiokafka import ConsumerRecord from typing_extensions import override -from faststream.broker.core.mixins import LoggingMixin +from faststream.broker.core.logging_mixin import LoggingMixin from faststream.broker.message import StreamMessage from faststream.log import access_logger from faststream.types import AnyDict diff --git a/faststream/kafka/shared/publisher.py b/faststream/kafka/shared/publisher.py index 3f5f52fe6e..ee072bd03d 100644 --- a/faststream/kafka/shared/publisher.py +++ b/faststream/kafka/shared/publisher.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Dict, Optional -from faststream.broker.publisher import BasePublisher +from faststream.broker.core.publisher import BasePublisher from faststream.broker.types import MsgType diff --git a/faststream/kafka/shared/router.py b/faststream/kafka/shared/router.py index 7bde2f4bb9..2c71449879 100644 --- a/faststream/kafka/shared/router.py +++ b/faststream/kafka/shared/router.py @@ -2,10 +2,10 @@ from aiokafka import ConsumerRecord +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.router import BrokerRoute as KafkaRoute from faststream.broker.router import BrokerRouter from faststream.broker.types import P_HandlerParams, T_HandlerReturn -from faststream.broker.wrapper import HandlerCallWrapper from faststream.types import SendableMessage __all__ = ( diff --git a/faststream/kafka/shared/router.pyi b/faststream/kafka/shared/router.pyi index 42b6100e81..1698cafd21 100644 --- a/faststream/kafka/shared/router.pyi +++ b/faststream/kafka/shared/router.pyi @@ -5,7 +5,7 @@ from fast_depends.dependencies import Depends from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import CustomDecoder, CustomParser, Filter, T_HandlerReturn diff --git a/faststream/kafka/test.py b/faststream/kafka/test.py index 8a1d0df548..1c7f248430 100644 --- a/faststream/kafka/test.py +++ b/faststream/kafka/test.py @@ -5,9 +5,9 @@ from aiokafka import ConsumerRecord from typing_extensions import override +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.parsers import encode_message from faststream.broker.test import TestBroker, call_handler -from faststream.broker.wrapper import HandlerCallWrapper from faststream.kafka.asyncapi import Publisher from faststream.kafka.broker import KafkaBroker from faststream.kafka.producer import AioKafkaFastProducer diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index 305566fe92..3512db0d5b 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -30,7 +30,7 @@ ) from typing_extensions import TypeAlias, override -from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter +from faststream.broker.core.broker import BrokerUsecase, default_filter from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( CustomDecoder, @@ -53,12 +53,12 @@ Subject: TypeAlias = str if TYPE_CHECKING: - from faststream.broker.handler import WrapperProtocol + from faststream.broker.core.handler import WrapperProtocol class NatsBroker( NatsLoggingMixin, - BrokerAsyncUsecase[Msg, Client], + BrokerUsecase[Msg, Client], ): """A class to represent a NATS broker.""" @@ -371,6 +371,7 @@ def subscriber( # type: ignore[override] subject=subject, queue=queue, ), + producer=self, ), ) @@ -386,7 +387,6 @@ def subscriber( # type: ignore[override] # wrapper kwargs is_validate=self._is_validate, apply_types=self._is_apply_types, - producer=self, **wrapper_kwargs, ) diff --git a/faststream/nats/broker.pyi b/faststream/nats/broker.pyi index aa2ee045cb..aab6b43fdf 100644 --- a/faststream/nats/broker.pyi +++ b/faststream/nats/broker.pyi @@ -33,8 +33,8 @@ from nats.js.client import JetStreamContext from typing_extensions import override from faststream.asyncapi import schema as asyncapi -from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter -from faststream.broker.handler import WrapperProtocol +from faststream.broker.core.broker import BrokerUsecase, default_filter +from faststream.broker.core.handler import WrapperProtocol from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( @@ -57,7 +57,7 @@ Subject = str class NatsBroker( NatsLoggingMixin, - BrokerAsyncUsecase[Msg, Client], + BrokerUsecase[Msg, Client], ): stream: JetStreamContext | None diff --git a/faststream/nats/fastapi.pyi b/faststream/nats/fastapi.pyi index e7663e107d..c4bb643c3d 100644 --- a/faststream/nats/fastapi.pyi +++ b/faststream/nats/fastapi.pyi @@ -38,7 +38,8 @@ from starlette.types import ASGIApp, AppType, Lifespan from typing_extensions import override from faststream.asyncapi import schema as asyncapi -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.fastapi.router import StreamRouter from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( @@ -48,7 +49,6 @@ from faststream.broker.types import ( P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.nats.asyncapi import Publisher from faststream.nats.broker import NatsBroker from faststream.nats.js_stream import JStream diff --git a/faststream/nats/handler.py b/faststream/nats/handler.py index e201ce3e78..990b741703 100644 --- a/faststream/nats/handler.py +++ b/faststream/nats/handler.py @@ -26,29 +26,27 @@ from nats.js import JetStreamContext from typing_extensions import Annotated, Doc -from faststream.broker.handler import BaseHandler +from faststream.broker.core.handler import BaseHandler +from faststream.broker.core.publisher import FakePublisher from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.parsers import resolve_custom_func from faststream.broker.types import ( - AsyncPublisherProtocol, CustomDecoder, CustomParser, Filter, T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import FakePublisher from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage from faststream.nats.parser import JsParser, Parser -from faststream.nats.producer import NatsFastProducer from faststream.nats.pull_sub import PullSub from faststream.types import AnyDict, SendableMessage from faststream.utils.path import compile_path if TYPE_CHECKING: - from faststream.broker.handler import WrapperProtocol + from faststream.broker.core.handler import WrapperProtocol class LogicNatsHandler(BaseHandler[Msg]): @@ -75,6 +73,7 @@ def __init__( Callable[[StreamMessage[Any]], Dict[str, str]], Doc("Function to create log extra data by message"), ], + producer, logger: Annotated[ Optional[Logger], Doc("Logger to use with process message Watcher") ] = None, @@ -133,6 +132,7 @@ def __init__( self.path_regex = reg self.queue = queue + self.producer = producer self.stream = stream self.pull_sub = pull_sub @@ -181,7 +181,6 @@ def _process_message( self, func: Callable[[NatsMessage], Awaitable[T_HandlerReturn]], watcher: Callable[..., AsyncContextManager[None]], - producer: NatsFastProducer, ) -> Callable[ [NatsMessage], Awaitable[WrappedReturn[T_HandlerReturn]], @@ -192,21 +191,22 @@ async def process_wrapper( ) -> WrappedReturn[T_HandlerReturn]: async with watcher(message): r = await func(message) + return r, None - pub_response: Optional[AsyncPublisherProtocol] - if message.reply_to: - pub_response = FakePublisher( - partial( - producer.publish, - subject=message.reply_to, - ) - ) - else: - pub_response = None + return process_wrapper - return r, pub_response + def make_response_publisher(self, message: NatsMessage) -> Sequence[FakePublisher]: + if message.reply_to: + return ( + FakePublisher( + partial( + self.producer.publish, + subject=message.reply_to, + ) + ), + ) - return process_wrapper + return () async def start( self, diff --git a/faststream/nats/publisher.py b/faststream/nats/publisher.py index 2a9430eb35..20d880c821 100644 --- a/faststream/nats/publisher.py +++ b/faststream/nats/publisher.py @@ -4,7 +4,7 @@ from nats.aio.msg import Msg from typing_extensions import override -from faststream.broker.publisher import BasePublisher +from faststream.broker.core.publisher import BasePublisher from faststream.exceptions import NOT_CONNECTED_YET from faststream.nats.js_stream import JStream from faststream.nats.producer import NatsFastProducer, NatsJSFastProducer diff --git a/faststream/nats/router.pyi b/faststream/nats/router.pyi index 2cec360501..0e8f3e2b56 100644 --- a/faststream/nats/router.pyi +++ b/faststream/nats/router.pyi @@ -5,7 +5,8 @@ from nats.aio.msg import Msg from nats.js import api from typing_extensions import override -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( CustomDecoder, @@ -14,7 +15,6 @@ from faststream.broker.types import ( P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.nats.asyncapi import Publisher from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage diff --git a/faststream/nats/shared/logging.py b/faststream/nats/shared/logging.py index 7a586b2d45..46c04679b6 100644 --- a/faststream/nats/shared/logging.py +++ b/faststream/nats/shared/logging.py @@ -3,7 +3,7 @@ from typing_extensions import override -from faststream.broker.core.mixins import LoggingMixin +from faststream.broker.core.logging_mixin import LoggingMixin from faststream.broker.message import StreamMessage from faststream.log import access_logger from faststream.types import AnyDict diff --git a/faststream/nats/shared/router.py b/faststream/nats/shared/router.py index 4420368c4f..812f3e5f73 100644 --- a/faststream/nats/shared/router.py +++ b/faststream/nats/shared/router.py @@ -3,10 +3,10 @@ from nats.aio.msg import Msg from typing_extensions import override +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.router import BrokerRoute as NatsRoute from faststream.broker.router import BrokerRouter from faststream.broker.types import P_HandlerParams, T_HandlerReturn -from faststream.broker.wrapper import HandlerCallWrapper from faststream.types import SendableMessage __all__ = ( diff --git a/faststream/nats/shared/router.pyi b/faststream/nats/shared/router.pyi index 835c65252a..634d0ac826 100644 --- a/faststream/nats/shared/router.pyi +++ b/faststream/nats/shared/router.pyi @@ -6,7 +6,8 @@ from nats.aio.msg import Msg from nats.js import api from typing_extensions import override -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.middlewares import BaseMiddleware from faststream.broker.router import BrokerRouter from faststream.broker.types import ( @@ -16,7 +17,6 @@ from faststream.broker.types import ( P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage diff --git a/faststream/nats/test.py b/faststream/nats/test.py index 262186cba4..96fd276fe8 100644 --- a/faststream/nats/test.py +++ b/faststream/nats/test.py @@ -5,9 +5,9 @@ from nats.aio.msg import Msg from typing_extensions import override +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.parsers import encode_message from faststream.broker.test import TestBroker, call_handler -from faststream.broker.wrapper import HandlerCallWrapper from faststream.nats.asyncapi import Handler, Publisher from faststream.nats.broker import NatsBroker from faststream.nats.producer import NatsFastProducer diff --git a/faststream/rabbit/broker.py b/faststream/rabbit/broker.py index 382c44f685..85b2a0c0e9 100644 --- a/faststream/rabbit/broker.py +++ b/faststream/rabbit/broker.py @@ -23,19 +23,20 @@ from yarl import URL from faststream._compat import model_to_dict -from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter +from faststream.broker.core.broker import BrokerUsecase, default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.publisher import FakePublisher from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( - AsyncPublisherProtocol, CustomDecoder, CustomParser, Filter, P_HandlerParams, + PublisherProtocol, T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import FakePublisher, HandlerCallWrapper from faststream.exceptions import NOT_CONNECTED_YET from faststream.rabbit.asyncapi import Handler, Publisher from faststream.rabbit.helpers import RabbitDeclarer @@ -59,11 +60,11 @@ class RabbitBroker( RabbitLoggingMixin, - BrokerAsyncUsecase[aio_pika.IncomingMessage, aio_pika.RobustConnection], + BrokerUsecase[aio_pika.IncomingMessage, aio_pika.RobustConnection], ): """A RabbitMQ broker for FastAPI applications. - This class extends the base `BrokerAsyncUsecase` and provides asynchronous support for RabbitMQ as a message broker. + This class extends the base `BrokerUsecase` and provides asynchronous support for RabbitMQ as a message broker. Args: url (Union[str, URL, None], optional): The RabbitMQ connection URL. Defaults to "amqp://guest:guest@localhost:5672/". @@ -539,12 +540,12 @@ async def process_wrapper( message: The RabbitMessage to process. Returns: - A tuple containing the return value of the handler function and an optional AsyncPublisherProtocol. + A tuple containing the return value of the handler function and an optional PublisherProtocol. """ async with watcher(message): r = await func(message) - pub_response: Optional[AsyncPublisherProtocol] + pub_response: Optional[PublisherProtocol] if message.reply_to: pub_response = FakePublisher( partial( diff --git a/faststream/rabbit/broker.pyi b/faststream/rabbit/broker.pyi index 45891a3cf8..0633fd31ae 100644 --- a/faststream/rabbit/broker.pyi +++ b/faststream/rabbit/broker.pyi @@ -17,7 +17,8 @@ from typing_extensions import override from yarl import URL from faststream.asyncapi import schema as asyncapi -from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter +from faststream.broker.core.broker import BrokerUsecase, default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( @@ -28,7 +29,6 @@ from faststream.broker.types import ( T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.log import access_logger from faststream.rabbit.asyncapi import Handler, Publisher from faststream.rabbit.helpers import RabbitDeclarer @@ -43,7 +43,7 @@ from faststream.types import AnyDict, SendableMessage class RabbitBroker( RabbitLoggingMixin, - BrokerAsyncUsecase[aio_pika.IncomingMessage, aio_pika.RobustConnection], + BrokerUsecase[aio_pika.IncomingMessage, aio_pika.RobustConnection], ): handlers: dict[int, Handler] _publishers: dict[int, Publisher] diff --git a/faststream/rabbit/fastapi.pyi b/faststream/rabbit/fastapi.pyi index e78145b1dc..bb5f0f5340 100644 --- a/faststream/rabbit/fastapi.pyi +++ b/faststream/rabbit/fastapi.pyi @@ -20,7 +20,8 @@ from typing_extensions import override from yarl import URL from faststream.asyncapi import schema as asyncapi -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.fastapi.router import StreamRouter from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( @@ -30,7 +31,6 @@ from faststream.broker.types import ( P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.rabbit.asyncapi import Publisher from faststream.rabbit.broker import RabbitBroker from faststream.rabbit.message import RabbitMessage diff --git a/faststream/rabbit/handler.py b/faststream/rabbit/handler.py index 1fa278635b..48c30aae77 100644 --- a/faststream/rabbit/handler.py +++ b/faststream/rabbit/handler.py @@ -4,7 +4,8 @@ from fast_depends.core import CallModel from typing_extensions import override -from faststream.broker.handler import BaseHandler +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.handler import BaseHandler from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.parsers import resolve_custom_func @@ -15,7 +16,6 @@ P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.rabbit.helpers import RabbitDeclarer from faststream.rabbit.message import RabbitMessage from faststream.rabbit.parser import AioPikaParser diff --git a/faststream/rabbit/router.pyi b/faststream/rabbit/router.pyi index baa6aa3bf0..5419914314 100644 --- a/faststream/rabbit/router.pyi +++ b/faststream/rabbit/router.pyi @@ -4,7 +4,8 @@ import aio_pika from fast_depends.dependencies import Depends from typing_extensions import override -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.middlewares import BaseMiddleware from faststream.broker.router import BrokerRouter from faststream.broker.types import ( @@ -14,7 +15,6 @@ from faststream.broker.types import ( P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.rabbit.asyncapi import Publisher from faststream.rabbit.message import RabbitMessage from faststream.rabbit.shared.router import RabbitRoute diff --git a/faststream/rabbit/shared/logging.py b/faststream/rabbit/shared/logging.py index 21dc155d01..2bd7411f65 100644 --- a/faststream/rabbit/shared/logging.py +++ b/faststream/rabbit/shared/logging.py @@ -3,7 +3,7 @@ from typing_extensions import override -from faststream.broker.core.mixins import LoggingMixin +from faststream.broker.core.logging_mixin import LoggingMixin from faststream.broker.message import StreamMessage from faststream.log import access_logger from faststream.rabbit.shared.schemas import RabbitExchange, RabbitQueue diff --git a/faststream/rabbit/shared/publisher.py b/faststream/rabbit/shared/publisher.py index 7669f102cb..cc67aae310 100644 --- a/faststream/rabbit/shared/publisher.py +++ b/faststream/rabbit/shared/publisher.py @@ -4,7 +4,7 @@ from typing_extensions import TypeAlias -from faststream.broker.publisher import BasePublisher +from faststream.broker.core.publisher import BasePublisher from faststream.broker.types import MsgType from faststream.rabbit.shared.schemas import BaseRMQInformation from faststream.rabbit.shared.types import TimeoutType diff --git a/faststream/rabbit/shared/router.py b/faststream/rabbit/shared/router.py index 3471de1ade..6547285a64 100644 --- a/faststream/rabbit/shared/router.py +++ b/faststream/rabbit/shared/router.py @@ -4,10 +4,10 @@ from typing_extensions import override from faststream._compat import model_copy +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.router import BrokerRoute as RabbitRoute from faststream.broker.router import BrokerRouter from faststream.broker.types import P_HandlerParams, T_HandlerReturn -from faststream.broker.wrapper import HandlerCallWrapper from faststream.rabbit.shared.schemas import RabbitQueue from faststream.types import SendableMessage diff --git a/faststream/rabbit/shared/router.pyi b/faststream/rabbit/shared/router.pyi index 58307002cc..5dae0f23d3 100644 --- a/faststream/rabbit/shared/router.pyi +++ b/faststream/rabbit/shared/router.pyi @@ -3,7 +3,7 @@ from typing import Any, Callable, Sequence import aio_pika from fast_depends.dependencies import Depends -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( CustomDecoder, diff --git a/faststream/rabbit/test.py b/faststream/rabbit/test.py index 260fa78a6b..7609673602 100644 --- a/faststream/rabbit/test.py +++ b/faststream/rabbit/test.py @@ -8,8 +8,8 @@ from pamqp import commands as spec from pamqp.header import ContentHeader +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.test import TestBroker, call_handler -from faststream.broker.wrapper import HandlerCallWrapper from faststream.rabbit.asyncapi import Publisher from faststream.rabbit.broker import RabbitBroker from faststream.rabbit.parser import AioPikaParser diff --git a/faststream/redis/broker.py b/faststream/redis/broker.py index 45ab3d8634..80513151e4 100644 --- a/faststream/redis/broker.py +++ b/faststream/redis/broker.py @@ -19,19 +19,20 @@ from redis.exceptions import ResponseError from typing_extensions import TypeAlias, override -from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter +from faststream.broker.core.broker import BrokerUsecase, default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.publisher import FakePublisher from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( - AsyncPublisherProtocol, CustomDecoder, CustomParser, Filter, P_HandlerParams, + PublisherProtocol, T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import FakePublisher, HandlerCallWrapper from faststream.exceptions import NOT_CONNECTED_YET from faststream.redis.asyncapi import Handler, Publisher from faststream.redis.message import AnyRedisDict, RedisMessage @@ -48,7 +49,7 @@ class RedisBroker( RedisLoggingMixin, - BrokerAsyncUsecase[AnyRedisDict, "Redis[bytes]"], + BrokerUsecase[AnyRedisDict, "Redis[bytes]"], ): """Redis broker.""" @@ -181,7 +182,7 @@ async def process_wrapper( ): r = await func(message) - pub_response: Optional[AsyncPublisherProtocol] + pub_response: Optional[PublisherProtocol] if message.reply_to: pub_response = FakePublisher( partial(self.publish, channel=message.reply_to) diff --git a/faststream/redis/broker.pyi b/faststream/redis/broker.pyi index d41f563a6e..5c9ec75ae1 100644 --- a/faststream/redis/broker.pyi +++ b/faststream/redis/broker.pyi @@ -15,7 +15,8 @@ from redis.asyncio.connection import BaseParser, Connection, DefaultParser, Enco from typing_extensions import TypeAlias, override from faststream.asyncapi import schema as asyncapi -from faststream.broker.core.asynchronous import BrokerAsyncUsecase, default_filter +from faststream.broker.core.broker import BrokerUsecase, default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( @@ -26,7 +27,6 @@ from faststream.broker.types import ( T_HandlerReturn, WrappedReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.log import access_logger from faststream.redis.asyncapi import Handler, Publisher from faststream.redis.message import AnyRedisDict, RedisMessage @@ -40,7 +40,7 @@ Channel: TypeAlias = str class RedisBroker( RedisLoggingMixin, - BrokerAsyncUsecase[AnyRedisDict, "Redis[bytes]"], + BrokerUsecase[AnyRedisDict, "Redis[bytes]"], ): url: str handlers: dict[int, Handler] diff --git a/faststream/redis/fastapi.py b/faststream/redis/fastapi.py index 21f982f46f..159a3030c0 100644 --- a/faststream/redis/fastapi.py +++ b/faststream/redis/fastapi.py @@ -4,10 +4,10 @@ from redis.asyncio.client import Redis as RedisClient from typing_extensions import Annotated, override +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.fastapi.context import Context, ContextRepo, Logger from faststream.broker.fastapi.router import StreamRouter from faststream.broker.types import P_HandlerParams, T_HandlerReturn -from faststream.broker.wrapper import HandlerCallWrapper from faststream.redis.broker import RedisBroker as RB from faststream.redis.message import AnyRedisDict from faststream.redis.message import RedisMessage as RM diff --git a/faststream/redis/fastapi.pyi b/faststream/redis/fastapi.pyi index 3f88e73939..6dd8f9e071 100644 --- a/faststream/redis/fastapi.pyi +++ b/faststream/redis/fastapi.pyi @@ -19,7 +19,8 @@ from starlette.types import ASGIApp, Lifespan from typing_extensions import TypeAlias, override from faststream.asyncapi import schema as asyncapi -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.fastapi.router import StreamRouter from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( @@ -29,7 +30,6 @@ from faststream.broker.types import ( P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.log import access_logger from faststream.redis.asyncapi import Publisher from faststream.redis.broker import RedisBroker diff --git a/faststream/redis/handler.py b/faststream/redis/handler.py index c5e4fdafb4..4c71a6ba5a 100644 --- a/faststream/redis/handler.py +++ b/faststream/redis/handler.py @@ -23,7 +23,8 @@ from typing_extensions import override from faststream._compat import json_loads -from faststream.broker.handler import BaseHandler +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.core.handler import BaseHandler from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.parsers import resolve_custom_func @@ -34,7 +35,6 @@ P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.redis.message import ( AnyRedisDict, RedisMessage, diff --git a/faststream/redis/publisher.py b/faststream/redis/publisher.py index ab753aa355..bf3cd01a0b 100644 --- a/faststream/redis/publisher.py +++ b/faststream/redis/publisher.py @@ -3,7 +3,7 @@ from typing_extensions import override -from faststream.broker.publisher import BasePublisher +from faststream.broker.core.publisher import BasePublisher from faststream.exceptions import NOT_CONNECTED_YET from faststream.redis.message import AnyRedisDict from faststream.redis.producer import RedisFastProducer diff --git a/faststream/redis/router.pyi b/faststream/redis/router.pyi index 54d3652285..df55524aa9 100644 --- a/faststream/redis/router.pyi +++ b/faststream/redis/router.pyi @@ -7,7 +7,8 @@ from typing import ( from fast_depends.dependencies import Depends from typing_extensions import override -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.middlewares import BaseMiddleware from faststream.broker.router import BrokerRouter from faststream.broker.types import ( @@ -17,7 +18,6 @@ from faststream.broker.types import ( P_HandlerParams, T_HandlerReturn, ) -from faststream.broker.wrapper import HandlerCallWrapper from faststream.redis.asyncapi import Publisher from faststream.redis.message import AnyRedisDict, RedisMessage from faststream.redis.schemas import ListSub, PubSub, StreamSub diff --git a/faststream/redis/shared/logging.py b/faststream/redis/shared/logging.py index eee25a75fb..4f0cefcf3d 100644 --- a/faststream/redis/shared/logging.py +++ b/faststream/redis/shared/logging.py @@ -3,7 +3,7 @@ from typing_extensions import override -from faststream.broker.core.mixins import LoggingMixin +from faststream.broker.core.logging_mixin import LoggingMixin from faststream.broker.message import StreamMessage from faststream.log import access_logger from faststream.types import AnyDict diff --git a/faststream/redis/shared/router.py b/faststream/redis/shared/router.py index 72d7bddb10..b28fedbc8a 100644 --- a/faststream/redis/shared/router.py +++ b/faststream/redis/shared/router.py @@ -3,10 +3,10 @@ from typing_extensions import TypeAlias, override from faststream._compat import model_copy +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.router import BrokerRoute as RedisRoute from faststream.broker.router import BrokerRouter from faststream.broker.types import P_HandlerParams, T_HandlerReturn -from faststream.broker.wrapper import HandlerCallWrapper from faststream.redis.message import AnyRedisDict from faststream.redis.schemas import ListSub, PubSub, StreamSub from faststream.types import SendableMessage diff --git a/faststream/redis/shared/router.pyi b/faststream/redis/shared/router.pyi index 72c3b3daeb..efe0b08c4f 100644 --- a/faststream/redis/shared/router.pyi +++ b/faststream/redis/shared/router.pyi @@ -7,7 +7,7 @@ from typing import ( from fast_depends.dependencies import Depends from typing_extensions import TypeAlias -from faststream.broker.core.asynchronous import default_filter +from faststream.broker.core.broker import default_filter from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( CustomDecoder, diff --git a/faststream/redis/test.py b/faststream/redis/test.py index 59515846dc..aad86b24ab 100644 --- a/faststream/redis/test.py +++ b/faststream/redis/test.py @@ -1,8 +1,8 @@ import re from typing import Any, Optional, Sequence, Union +from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.test import TestBroker, call_handler -from faststream.broker.wrapper import HandlerCallWrapper from faststream.redis.asyncapi import Handler, Publisher from faststream.redis.broker import RedisBroker from faststream.redis.message import AnyRedisDict diff --git a/tests/asyncapi/base/arguments.py b/tests/asyncapi/base/arguments.py index 94e3dd706e..eb27ebdb44 100644 --- a/tests/asyncapi/base/arguments.py +++ b/tests/asyncapi/base/arguments.py @@ -7,7 +7,7 @@ from faststream import Context, FastStream from faststream._compat import PYDANTIC_V2 from faststream.asyncapi.generate import get_app_schema -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase class FastAPICompatible: # noqa: D101 diff --git a/tests/asyncapi/base/fastapi.py b/tests/asyncapi/base/fastapi.py index 934f6551dd..af722cfc6a 100644 --- a/tests/asyncapi/base/fastapi.py +++ b/tests/asyncapi/base/fastapi.py @@ -6,7 +6,7 @@ from fastapi.testclient import TestClient from faststream.asyncapi.generate import get_app_schema -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase from faststream.broker.fastapi.router import StreamRouter from faststream.broker.types import MsgType diff --git a/tests/asyncapi/base/naming.py b/tests/asyncapi/base/naming.py index 34c35dc7e3..dd9f35f2a3 100644 --- a/tests/asyncapi/base/naming.py +++ b/tests/asyncapi/base/naming.py @@ -5,7 +5,7 @@ from faststream import FastStream from faststream.asyncapi.generate import get_app_schema -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase class BaseNaming: # noqa: D101 diff --git a/tests/asyncapi/base/publisher.py b/tests/asyncapi/base/publisher.py index 9fd693143c..1b4e6c14df 100644 --- a/tests/asyncapi/base/publisher.py +++ b/tests/asyncapi/base/publisher.py @@ -4,7 +4,7 @@ from faststream import FastStream from faststream.asyncapi.generate import get_app_schema -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase class PublisherTestcase: # noqa: D101 diff --git a/tests/asyncapi/base/router.py b/tests/asyncapi/base/router.py index 3696871a07..a2b09dbcc1 100644 --- a/tests/asyncapi/base/router.py +++ b/tests/asyncapi/base/router.py @@ -2,7 +2,7 @@ from faststream import FastStream from faststream.asyncapi.generate import get_app_schema -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase from faststream.broker.router import BrokerRoute, BrokerRouter @@ -43,3 +43,54 @@ async def handle(msg): schema = get_app_schema(FastStream(broker)) assert schema.channels == {} + + def test_respect_subrouter(self): + broker = self.broker_class() + router = self.router_class() + router2 = self.router_class(include_in_schema=False) + + @router2.subscriber("test") + @router2.publisher("test") + async def handle(msg): + ... + + router.include_router(router2) + broker.include_router(router) + + schema = get_app_schema(FastStream(broker)) + + assert schema.channels == {} + + def test_not_include_subrouter(self): + broker = self.broker_class() + router = self.router_class(include_in_schema=False) + router2 = self.router_class() + + @router2.subscriber("test") + @router2.publisher("test") + async def handle(msg): + ... + + router.include_router(router2) + broker.include_router(router) + + schema = get_app_schema(FastStream(broker)) + + assert schema.channels == {} + + def test_include_subrouter(self): + broker = self.broker_class() + router = self.router_class() + router2 = self.router_class() + + @router2.subscriber("test") + @router2.publisher("test") + async def handle(msg): + ... + + router.include_router(router2) + broker.include_router(router) + + schema = get_app_schema(FastStream(broker)) + + assert len(schema.channels) == 2 diff --git a/tests/brokers/base/connection.py b/tests/brokers/base/connection.py index 4764b8f5fc..8ca8f314d5 100644 --- a/tests/brokers/base/connection.py +++ b/tests/brokers/base/connection.py @@ -3,7 +3,7 @@ import pytest -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase class BrokerConnectionTestcase: # noqa: D101 diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index efdf4982ff..66a6c3bcc4 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from faststream import Context, Depends -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase from faststream.exceptions import StopConsume diff --git a/tests/brokers/base/fastapi.py b/tests/brokers/base/fastapi.py index 80c792ac61..cef40e1857 100644 --- a/tests/brokers/base/fastapi.py +++ b/tests/brokers/base/fastapi.py @@ -9,17 +9,17 @@ from fastapi.testclient import TestClient from faststream import context -from faststream.broker.core.asynchronous import BrokerAsyncUsecase +from faststream.broker.core.broker import BrokerUsecase from faststream.broker.fastapi.context import Context from faststream.broker.fastapi.router import StreamRouter from faststream.types import AnyCallable -Broker = TypeVar("Broker", bound=BrokerAsyncUsecase) +Broker = TypeVar("Broker", bound=BrokerUsecase) @pytest.mark.asyncio() class FastAPITestcase: # noqa: D101 - router_class: Type[StreamRouter[BrokerAsyncUsecase]] + router_class: Type[StreamRouter[BrokerUsecase]] async def test_base_real(self, mock: Mock, queue: str, event: asyncio.Event): router = self.router_class() @@ -149,7 +149,7 @@ async def resp(msg): @pytest.mark.asyncio() class FastAPILocalTestcase: # noqa: D101 - router_class: Type[StreamRouter[BrokerAsyncUsecase]] + router_class: Type[StreamRouter[BrokerUsecase]] broker_test: Callable[[Broker], Broker] build_message: AnyCallable diff --git a/tests/brokers/base/middlewares.py b/tests/brokers/base/middlewares.py index 316df8f3ed..8f840256fc 100644 --- a/tests/brokers/base/middlewares.py +++ b/tests/brokers/base/middlewares.py @@ -4,7 +4,7 @@ import pytest -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase from faststream.broker.middlewares import BaseMiddleware diff --git a/tests/brokers/base/parser.py b/tests/brokers/base/parser.py index 615af53d6d..4c0a1b6f58 100644 --- a/tests/brokers/base/parser.py +++ b/tests/brokers/base/parser.py @@ -4,20 +4,20 @@ import pytest -from faststream.broker.core.asynchronous import BrokerAsyncUsecase +from faststream.broker.core.broker import BrokerUsecase @pytest.mark.asyncio() class LocalCustomParserTestcase: # noqa: D101 - broker_class: Type[BrokerAsyncUsecase] + broker_class: Type[BrokerUsecase] @pytest.fixture() def raw_broker(self): return None def patch_broker( - self, raw_broker: BrokerAsyncUsecase, broker: BrokerAsyncUsecase - ) -> BrokerAsyncUsecase: + self, raw_broker: BrokerUsecase, broker: BrokerUsecase + ) -> BrokerUsecase: return broker async def test_local_parser( diff --git a/tests/brokers/base/publish.py b/tests/brokers/base/publish.py index bcb247b84a..fcdf3ad59e 100644 --- a/tests/brokers/base/publish.py +++ b/tests/brokers/base/publish.py @@ -9,7 +9,7 @@ from faststream._compat import model_to_json from faststream.annotations import Logger -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase class SimpleModel(BaseModel): # noqa: D101 diff --git a/tests/brokers/base/router.py b/tests/brokers/base/router.py index 78dcbd283d..feeef4e7c6 100644 --- a/tests/brokers/base/router.py +++ b/tests/brokers/base/router.py @@ -5,7 +5,7 @@ import pytest from faststream import BaseMiddleware, Depends -from faststream.broker.core.asynchronous import BrokerAsyncUsecase +from faststream.broker.core.broker import BrokerUsecase from faststream.broker.router import BrokerRoute, BrokerRouter from faststream.types import AnyCallable from tests.brokers.base.middlewares import LocalMiddlewareTestcase @@ -17,9 +17,7 @@ class RouterTestcase(LocalMiddlewareTestcase, LocalCustomParserTestcase): # noq build_message: AnyCallable route_class: Type[BrokerRoute] - def patch_broker( - self, br: BrokerAsyncUsecase, router: BrokerRouter - ) -> BrokerAsyncUsecase: + def patch_broker(self, br: BrokerUsecase, router: BrokerRouter) -> BrokerUsecase: br.include_router(router) return br @@ -34,7 +32,7 @@ def raw_broker(self, pub_broker): async def test_empty_prefix( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, ): @@ -60,7 +58,7 @@ def subscriber(m): async def test_not_empty_prefix( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, ): @@ -88,7 +86,7 @@ def subscriber(m): async def test_empty_prefix_publisher( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, ): @@ -119,7 +117,7 @@ def response(m): async def test_not_empty_prefix_publisher( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, ): @@ -152,7 +150,7 @@ def response(m): async def test_manual_publisher( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, ): @@ -188,7 +186,7 @@ async def test_delayed_handlers( event: asyncio.Event, router: BrokerRouter, queue: str, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, ): def response(m): event.set() @@ -213,7 +211,7 @@ def response(m): async def test_nested_routers_sub( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, mock: Mock, @@ -249,7 +247,7 @@ def subscriber(m): async def test_nested_routers_pub( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, ): @@ -286,7 +284,7 @@ def response(m): async def test_router_dependencies( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, mock: Mock, @@ -325,7 +323,7 @@ def subscriber(s): async def test_router_middlewares( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, mock: Mock, @@ -365,7 +363,7 @@ def subscriber(s): async def test_router_parser( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, mock: Mock, @@ -407,7 +405,7 @@ def subscriber(s): async def test_router_parser_override( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, mock: Mock, @@ -465,7 +463,7 @@ def pub_broker(self, test_broker): async def test_publisher_mock( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, ): @@ -496,7 +494,7 @@ def subscriber(m): async def test_subscriber_mock( self, router: BrokerRouter, - pub_broker: BrokerAsyncUsecase, + pub_broker: BrokerUsecase, queue: str, event: asyncio.Event, ): @@ -522,7 +520,7 @@ def subscriber(m): subscriber.mock.assert_called_with("hello") async def test_manual_publisher_mock( - self, router: BrokerRouter, queue: str, pub_broker: BrokerAsyncUsecase + self, router: BrokerRouter, queue: str, pub_broker: BrokerUsecase ): publisher = router.publisher(queue + "resp") diff --git a/tests/brokers/base/rpc.py b/tests/brokers/base/rpc.py index 5e075ff19a..ade5c53856 100644 --- a/tests/brokers/base/rpc.py +++ b/tests/brokers/base/rpc.py @@ -4,7 +4,7 @@ import anyio import pytest -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase from faststream.utils.functions import timeout_scope diff --git a/tests/brokers/base/testclient.py b/tests/brokers/base/testclient.py index 6e2539eb73..7464278fb9 100644 --- a/tests/brokers/base/testclient.py +++ b/tests/brokers/base/testclient.py @@ -1,6 +1,6 @@ import pytest -from faststream.broker.core.abc import BrokerUsecase +from faststream.broker.core.broker import BrokerUsecase from faststream.types import AnyCallable from tests.brokers.base.consume import BrokerConsumeTestcase from tests.brokers.base.publish import BrokerPublishTestcase diff --git a/tests/utils/test_handler_lock.py b/tests/utils/test_handler_lock.py index 283d09a6a3..cc23514744 100644 --- a/tests/utils/test_handler_lock.py +++ b/tests/utils/test_handler_lock.py @@ -4,7 +4,7 @@ import pytest from anyio.abc import TaskStatus -from faststream.broker.handler import MultiLock +from faststream.broker.core.handler import MultiLock from tests.marks import python310 From 15cd9b6039508ad6b204f009b86302d3e50a3fbc Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Thu, 11 Jan 2024 21:49:33 +0300 Subject: [PATCH 10/87] refactore: new middlewares --- faststream/broker/core/broker.py | 2 +- faststream/broker/core/handler.py | 84 ++++++++++++------- .../broker/core/handler_wrapper_mixin.py | 39 ++++----- faststream/utils/context/repository.py | 80 +++++++++--------- tests/brokers/base/rpc.py | 4 +- 5 files changed, 113 insertions(+), 96 deletions(-) diff --git a/faststream/broker/core/broker.py b/faststream/broker/core/broker.py index 22a0807430..1fe3dd5550 100644 --- a/faststream/broker/core/broker.py +++ b/faststream/broker/core/broker.py @@ -381,7 +381,7 @@ def subscriber( decoder: Optional[CustomDecoder["StreamMessage[MsgType]"]] = None, parser: Optional[CustomParser[MsgType, "StreamMessage[MsgType]"]] = None, dependencies: Sequence["Depends"] = (), - middlewares: Sequence[Callable[[MsgType], "BaseMiddleware"]] = (), + middlewares: Sequence["BaseMiddleware"] = (), raw: bool = False, no_ack: bool = False, retry: Union[bool, int] = False, diff --git a/faststream/broker/core/handler.py b/faststream/broker/core/handler.py index d673a18888..78b6c92c4d 100644 --- a/faststream/broker/core/handler.py +++ b/faststream/broker/core/handler.py @@ -2,6 +2,7 @@ from contextlib import AsyncExitStack from dataclasses import dataclass from inspect import unwrap +from itertools import chain from logging import Logger from typing import ( TYPE_CHECKING, @@ -15,7 +16,9 @@ Optional, Sequence, Tuple, + Type, Union, + cast, ) from faststream.asyncapi.base import AsyncAPIOperation @@ -33,7 +36,6 @@ P_HandlerParams, PublisherProtocol, T_HandlerReturn, - WrappedReturn, ) from faststream.broker.utils import MultiLock from faststream.exceptions import HandlerException, StopConsume @@ -42,6 +44,8 @@ from faststream.utils.functions import to_async if TYPE_CHECKING: + from types import TracebackType + from fast_depends.core import CallModel from fast_depends.dependencies import Depends @@ -58,7 +62,7 @@ class HandlerItem(Generic[MsgType]): filter: Callable[["StreamMessage[MsgType]"], Awaitable[bool]] parser: AsyncParser[MsgType, Any] decoder: AsyncDecoder["StreamMessage[MsgType]"] - middlewares: Sequence[Callable[[Any], "BaseMiddleware"]] + middlewares: Sequence["BaseMiddleware"] dependant: "CallModel[Any, SendableMessage]" @property @@ -82,10 +86,13 @@ def description(self) -> Optional[str]: return description async def call( - self, msg: MsgType, cache: Dict[Any, Any] + self, + msg: MsgType, + cache: Dict[Any, Any], + extra_middlewares: Sequence["BaseMiddleware"], ) -> AsyncGenerator[ - Dict[str, str], - Optional["StreamMessage[MsgType]"], + Union["StreamMessage[MsgType]", None, SendableMessage], + None, ]: message = cache[self.parser] = cache.get( self.parser, @@ -97,14 +104,11 @@ async def call( ) if await self.filter(message): - log_context = yield message + yield message result = None async with AsyncExitStack() as consume_stack: - consume_stack.enter_context(context.scope("message", message)) - consume_stack.enter_context(context.scope("log_context", log_context)) - - for middleware in self.middlewares: + for middleware in chain(self.middlewares, extra_middlewares): message.decoded_body = await consume_stack.enter_async_context( middleware.consume_scope(message.decoded_body) ) @@ -225,7 +229,7 @@ def add_call( filter_: Filter["StreamMessage[MsgType]"], parser_: CustomParser[MsgType, Any], decoder_: CustomDecoder["StreamMessage[MsgType]"], - middlewares_: Sequence[Callable[[Any], "BaseMiddleware"]], + middlewares_: Sequence["BaseMiddleware"], dependencies_: Sequence["Depends"], **wrap_kwargs: Any, ) -> "WrapperProtocol[MsgType]": @@ -235,7 +239,7 @@ def wrapper( filter: Filter["StreamMessage[MsgType]"] = filter_, parser: CustomParser[MsgType, Any] = parser_, decoder: CustomDecoder["StreamMessage[MsgType]"] = decoder_, - middlewares: Sequence[Callable[[Any], "BaseMiddleware"]] = (), + middlewares: Sequence["BaseMiddleware"] = (), dependencies: Sequence["Depends"] = (), ) -> Union[ HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], @@ -245,7 +249,7 @@ def wrapper( ], ]: total_deps = (*dependencies_, *dependencies) - total_middlewares = (*self.middlewares, *middlewares_, *middlewares) + total_middlewares = (*middlewares_, *middlewares) def real_wrapper( func: Callable[P_HandlerParams, T_HandlerReturn], @@ -287,19 +291,22 @@ async def consume(self, msg: MsgType) -> SendableMessage: Returns: The sendable message. """ - result: Optional[WrappedReturn[SendableMessage]] = None + result: Optional[SendableMessage] = None result_msg: SendableMessage = None if not self.running: return result_msg + middlewares = [] async with AsyncExitStack() as stack: stack.enter_context(self.lock) stack.enter_context(context.scope("handler_", self)) for m in self.middlewares: - await stack.enter_async_context(m(msg)) + middleware = m(msg) + middlewares.append(middleware) + await middleware.__aenter__() cache = {} processed = False @@ -307,14 +314,29 @@ async def consume(self, msg: MsgType) -> SendableMessage: if processed: break - caller = h.call(msg, cache) + caller = h.call(msg, cache, middlewares) + + if ( + message := cast("StreamMessage[MsgType]", await caller.asend(None)) + ) is not None: + stack.enter_context(context.scope("message", message)) + stack.enter_context( + context.scope("log_context", self.log_context_builder(message)) + ) + + @stack.push_async_callback + async def close_middlewares( + exc_type: Optional[Type[BaseException]] = None, + exc_val: Optional[BaseException] = None, + exec_tb: Optional["TracebackType"] = None, + ) -> None: + for m in middlewares: + await m.__aexit__(exc_type, exc_val, exec_tb) - if (message := await caller.__anext__()) is not None: processed = True + try: - await caller.asend(self.log_context_builder(message)) - except StopAsyncIteration as e: - result = e.value + result = cast(SendableMessage, await caller.asend(None)) except StopConsume: await self.close() return @@ -324,18 +346,20 @@ async def consume(self, msg: MsgType) -> SendableMessage: *self.make_response_publisher(message), *h.handler._publishers, ): - if publisher is not None: - async with AsyncExitStack() as pub_stack: - for m_pub in h.middlewares: - result = await pub_stack.enter_async_context( - m_pub.publish_scope(result) - ) - - await publisher.publish( - message=result, - correlation_id=message.correlation_id, + async with AsyncExitStack() as pub_stack: + # TODO: need to test copy + result_to_send = result + + for m_pub in chain(middlewares, h.middlewares): + result_to_send = await pub_stack.enter_async_context( + m_pub.publish_scope(result_to_send) ) + await publisher.publish( + message=result_to_send, + correlation_id=message.correlation_id, + ) + assert not self.running or processed, "You have to consume message" # nosec B101 return result_msg diff --git a/faststream/broker/core/handler_wrapper_mixin.py b/faststream/broker/core/handler_wrapper_mixin.py index fd901991f6..bd124fba92 100644 --- a/faststream/broker/core/handler_wrapper_mixin.py +++ b/faststream/broker/core/handler_wrapper_mixin.py @@ -16,13 +16,11 @@ ) from fast_depends import inject -from fast_depends.core import CallModel, build_call_model -from fast_depends.dependencies import Depends +from fast_depends.core import build_call_model from pydantic import create_model from faststream._compat import PYDANTIC_V2 from faststream.broker.core.call_wrapper import HandlerCallWrapper -from faststream.broker.middlewares import BaseMiddleware from faststream.broker.push_back_watcher import WatcherContext from faststream.broker.types import ( CustomDecoder, @@ -40,7 +38,11 @@ if TYPE_CHECKING: from typing import Protocol, overload + from fast_depends.core import CallModel + from fast_depends.dependencies import Depends + from faststream.broker.message import StreamMessage + from faststream.broker.middlewares import BaseMiddleware class WrapperProtocol(Generic[MsgType], Protocol): """Annotation class to represent @subsriber return type.""" @@ -53,8 +55,8 @@ def __call__( filter: Filter["StreamMessage[MsgType]"], parser: CustomParser[MsgType, Any], decoder: CustomDecoder["StreamMessage[MsgType]"], - middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), - dependencies: Sequence[Depends] = (), + middlewares: Sequence["BaseMiddleware"] = (), + dependencies: Sequence["Depends"] = (), ) -> Callable[ [Callable[P_HandlerParams, T_HandlerReturn]], HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], @@ -69,8 +71,8 @@ def __call__( filter: Filter["StreamMessage[MsgType]"], parser: CustomParser[MsgType, Any], decoder: CustomDecoder["StreamMessage[MsgType]"], - middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), - dependencies: Sequence[Depends] = (), + middlewares: Sequence["BaseMiddleware"] = (), + dependencies: Sequence["Depends"] = (), ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: ... @@ -81,8 +83,8 @@ def __call__( filter: Filter["StreamMessage[MsgType]"], parser: CustomParser[MsgType, Any], decoder: CustomDecoder["StreamMessage[MsgType]"], - middlewares: Sequence[Callable[[Any], BaseMiddleware]] = (), - dependencies: Sequence[Depends] = (), + middlewares: Sequence["BaseMiddleware"] = (), + dependencies: Sequence["Depends"] = (), ) -> Union[ HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], Callable[ @@ -96,20 +98,11 @@ def __call__( class WrapHandlerMixin(Generic[MsgType]): """A class to patch original handle function.""" - def __init__( - self, - *, - middlewares: Sequence[Callable[[MsgType], BaseMiddleware]], - ) -> None: - """Initialize a new instance of the class.""" - self.calls = [] - self.middlewares = middlewares - def wrap_handler( self, *, func: Callable[P_HandlerParams, T_HandlerReturn], - dependencies: Sequence[Depends], + dependencies: Sequence["Depends"], logger: Optional[Logger], apply_types: bool, is_validate: bool, @@ -120,12 +113,12 @@ def wrap_handler( **process_kwargs: Any, ) -> Tuple[ HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - CallModel[P_HandlerParams, T_HandlerReturn], + "CallModel[P_HandlerParams, T_HandlerReturn]", ]: build_dep = build_dep = cast( Callable[ [Callable[F_Spec, F_Return]], - CallModel[F_Spec, F_Return], + "CallModel[F_Spec, F_Return]", ], get_dependant or partial( @@ -229,8 +222,8 @@ def _process_message( def _patch_fastapi_dependant( - dependant: CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]], -) -> CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]]: + dependant: "CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]]", +) -> "CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]]": """Patch FastAPI dependant. Args: diff --git a/faststream/utils/context/repository.py b/faststream/utils/context/repository.py index 7c3ea2f171..f4c5f1ec24 100644 --- a/faststream/utils/context/repository.py +++ b/faststream/utils/context/repository.py @@ -25,6 +25,13 @@ def __init__(self) -> None: self._global_context = {"context": self} self._scope_context = {} + @property + def context(self) -> AnyDict: + return { + **self._global_context, + **{i: j.get() for i, j in self._scope_context.items()}, + } + def set_global(self, key: str, v: Any) -> None: """Sets a value in the global context. @@ -92,9 +99,25 @@ def get_local(self, key: str, default: Any = None) -> Any: else: return default - def clear(self) -> None: - self._global_context = {"context": self} - self._scope_context.clear() + @contextmanager + def scope(self, key: str, value: Any) -> Iterator[None]: + """Sets a local variable and yields control to the caller. After the caller is done, the local variable is reset. + + Args: + key: The key of the local variable + value: The value to set the local variable to + + Yields: + None + + Returns: + An iterator that yields None + """ + token = self.set_local(key, value) + try: + yield + finally: + self.reset_local(key, token) def get(self, key: str, default: Any = None) -> Any: """Get the value associated with a key. @@ -108,6 +131,17 @@ def get(self, key: str, default: Any = None) -> Any: """ return self._global_context.get(key, self.get_local(key, default)) + def __getattr__(self, __name: str) -> Any: + """This is a function that is part of a class. It is used to get an attribute value using the `__getattr__` method. + + Args: + __name: The name of the attribute to get. + + Returns: + The value of the attribute. + """ + return self.get(__name) + def resolve(self, argument: str) -> Any: """Resolve the context of an argument. @@ -129,43 +163,9 @@ def resolve(self, argument: str) -> Any: v = v[i] if isinstance(v, Mapping) else getattr(v, i) return v - def __getattr__(self, __name: str) -> Any: - """This is a function that is part of a class. It is used to get an attribute value using the `__getattr__` method. - - Args: - __name: The name of the attribute to get. - - Returns: - The value of the attribute. - """ - return self.get(__name) - - @property - def context(self) -> AnyDict: - return { - **self._global_context, - **{i: j.get() for i, j in self._scope_context.items()}, - } - - @contextmanager - def scope(self, key: str, value: Any) -> Iterator[None]: - """Sets a local variable and yields control to the caller. After the caller is done, the local variable is reset. - - Args: - key: The key of the local variable - value: The value to set the local variable to - - Yields: - None - - Returns: - An iterator that yields None - """ - token = self.set_local(key, value) - try: - yield - finally: - self.reset_local(key, token) + def clear(self) -> None: + self._global_context = {"context": self} + self._scope_context.clear() context = ContextRepo() diff --git a/tests/brokers/base/rpc.py b/tests/brokers/base/rpc.py index ade5c53856..3b6acfd341 100644 --- a/tests/brokers/base/rpc.py +++ b/tests/brokers/base/rpc.py @@ -10,8 +10,8 @@ class BrokerRPCTestcase: # noqa: D101 @pytest.fixture() - def rpc_broker(self, full_broker): - return full_broker + def rpc_broker(self, broker): + return broker @pytest.mark.asyncio() async def test_rpc(self, queue: str, rpc_broker: BrokerUsecase): From 0b9d4f4046fe615448243fd269a83030652bba8a Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Thu, 11 Jan 2024 22:22:50 +0300 Subject: [PATCH 11/87] fix #690: ack message after response publish --- faststream/broker/core/call_wrapper.py | 6 +-- faststream/broker/core/handler.py | 8 ++-- .../broker/core/handler_wrapper_mixin.py | 32 +------------ faststream/broker/types.py | 8 ++-- faststream/broker/utils.py | 18 ++++++- faststream/kafka/broker.py | 31 ++---------- faststream/kafka/broker.pyi | 3 +- faststream/nats/broker.py | 10 ++-- faststream/nats/broker.pyi | 14 ------ faststream/nats/handler.py | 48 +++++++++---------- faststream/rabbit/broker.py | 27 +---------- faststream/rabbit/broker.pyi | 16 ------- faststream/redis/broker.py | 8 +--- faststream/redis/broker.pyi | 13 ----- 14 files changed, 65 insertions(+), 177 deletions(-) diff --git a/faststream/broker/core/call_wrapper.py b/faststream/broker/core/call_wrapper.py index a1aa78989f..30f6afca30 100644 --- a/faststream/broker/core/call_wrapper.py +++ b/faststream/broker/core/call_wrapper.py @@ -11,8 +11,8 @@ PublisherProtocol, T_HandlerReturn, WrappedHandlerCall, - WrappedReturn, ) +from faststream.types import SendableMessage class HandlerCallWrapper(Generic[MsgType, P_HandlerParams, T_HandlerReturn]): @@ -130,8 +130,8 @@ def call_wrapped( self, message: StreamMessage[MsgType], ) -> Union[ - Optional[WrappedReturn[T_HandlerReturn]], - Awaitable[Optional[WrappedReturn[T_HandlerReturn]]], + Optional[SendableMessage], + Awaitable[Optional[SendableMessage]], ]: """Calls the wrapped function with the given message. diff --git a/faststream/broker/core/handler.py b/faststream/broker/core/handler.py index 78b6c92c4d..196b951d4d 100644 --- a/faststream/broker/core/handler.py +++ b/faststream/broker/core/handler.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from inspect import unwrap from itertools import chain -from logging import Logger from typing import ( TYPE_CHECKING, Any, + AsyncContextManager, AsyncGenerator, Awaitable, Callable, @@ -153,21 +153,21 @@ def __init__( *, log_context_builder: Callable[["StreamMessage[Any]"], Dict[str, str]], middlewares: Sequence[Callable[[MsgType], "BaseMiddleware"]], - logger: Optional[Logger], description: Optional[str], title: Optional[str], include_in_schema: bool, graceful_timeout: Optional[float], + watcher: Callable[..., AsyncContextManager[None]], ) -> None: """Initialize a new instance of the class.""" self.calls = [] self.middlewares = middlewares self.log_context_builder = log_context_builder - self.logger = logger self.running = False self.lock = MultiLock() + self.watcher = watcher self.graceful_timeout = graceful_timeout # AsyncAPI information @@ -257,7 +257,6 @@ def real_wrapper( handler, dependant = self.wrap_handler( func=func, dependencies=total_deps, - logger=self.logger, **wrap_kwargs, ) @@ -319,6 +318,7 @@ async def consume(self, msg: MsgType) -> SendableMessage: if ( message := cast("StreamMessage[MsgType]", await caller.asend(None)) ) is not None: + await stack.enter_async_context(self.watcher(message)) stack.enter_context(context.scope("message", message)) stack.enter_context( context.scope("log_context", self.log_context_builder(message)) diff --git a/faststream/broker/core/handler_wrapper_mixin.py b/faststream/broker/core/handler_wrapper_mixin.py index bd124fba92..cddaf24384 100644 --- a/faststream/broker/core/handler_wrapper_mixin.py +++ b/faststream/broker/core/handler_wrapper_mixin.py @@ -1,9 +1,7 @@ from functools import partial, wraps -from logging import Logger from typing import ( TYPE_CHECKING, Any, - AsyncContextManager, Awaitable, Callable, Generic, @@ -21,7 +19,6 @@ from faststream._compat import PYDANTIC_V2 from faststream.broker.core.call_wrapper import HandlerCallWrapper -from faststream.broker.push_back_watcher import WatcherContext from faststream.broker.types import ( CustomDecoder, CustomParser, @@ -29,11 +26,9 @@ MsgType, P_HandlerParams, T_HandlerReturn, - WrappedReturn, ) -from faststream.broker.utils import get_watcher from faststream.types import F_Return, F_Spec -from faststream.utils.functions import fake_context, to_async +from faststream.utils.functions import to_async if TYPE_CHECKING: from typing import Protocol, overload @@ -103,14 +98,10 @@ def wrap_handler( *, func: Callable[P_HandlerParams, T_HandlerReturn], dependencies: Sequence["Depends"], - logger: Optional[Logger], apply_types: bool, is_validate: bool, raw: bool = False, - no_ack: bool = False, - retry: Union[bool, int] = False, get_dependant: Optional[Any] = None, - **process_kwargs: Any, ) -> Tuple[ HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], "CallModel[P_HandlerParams, T_HandlerReturn]", @@ -154,16 +145,6 @@ def wrap_handler( params_ln=len(dependant.flat_params), ) - f = self._process_message( - func=f, - watcher=( - partial(WatcherContext, watcher=get_watcher(logger, retry)) # type: ignore[arg-type] - if not no_ack - else fake_context - ), - **(process_kwargs or {}), - ) - handler_call.set_wrapped(f) return handler_call, dependant @@ -209,17 +190,6 @@ async def decode_wrapper(message: "StreamMessage[MsgType]") -> T_HandlerReturn: return decode_wrapper - def _process_message( - self, - func: Callable[[MsgType], Awaitable[T_HandlerReturn]], - watcher: Callable[..., AsyncContextManager[None]], - **kwargs: Any, - ) -> Callable[ - ["StreamMessage[MsgType]"], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: - raise NotImplementedError() - def _patch_fastapi_dependant( dependant: "CallModel[P_HandlerParams, Awaitable[T_HandlerReturn]]", diff --git a/faststream/broker/types.py b/faststream/broker/types.py index 9c2802ddf7..17c29de769 100644 --- a/faststream/broker/types.py +++ b/faststream/broker/types.py @@ -1,4 +1,4 @@ -from typing import Any, Awaitable, Callable, Optional, Protocol, Tuple, TypeVar, Union +from typing import Any, Awaitable, Callable, Optional, Protocol, TypeVar, Union from typing_extensions import ParamSpec, TypeAlias @@ -98,15 +98,13 @@ async def publish( ... -WrappedReturn: TypeAlias = Tuple[T_HandlerReturn, Optional[PublisherProtocol]] - AsyncWrappedHandlerCall: TypeAlias = Callable[ [StreamMessage[MsgType]], - Awaitable[Optional[WrappedReturn[T_HandlerReturn]]], + Awaitable[Optional[T_HandlerReturn]], ] SyncWrappedHandlerCall: TypeAlias = Callable[ [StreamMessage[MsgType]], - Optional[WrappedReturn[T_HandlerReturn]], + Optional[T_HandlerReturn], ] WrappedHandlerCall: TypeAlias = Union[ AsyncWrappedHandlerCall[MsgType, T_HandlerReturn], diff --git a/faststream/broker/utils.py b/faststream/broker/utils.py index 6c830f0b29..a5454ab36a 100644 --- a/faststream/broker/utils.py +++ b/faststream/broker/utils.py @@ -1,6 +1,7 @@ import asyncio from contextlib import suppress -from typing import TYPE_CHECKING, Optional, Type, Union +from functools import partial +from typing import TYPE_CHECKING, AsyncContextManager, Callable, Optional, Type, Union import anyio @@ -9,7 +10,9 @@ CounterWatcher, EndlessWatcher, OneTryWatcher, + WatcherContext, ) +from faststream.utils.functions import fake_context if TYPE_CHECKING: from logging import Logger @@ -42,7 +45,7 @@ def change_logger_handlers(logger: "Logger", fmt: str) -> None: def get_watcher( logger: Optional["Logger"], - try_number: Union[bool, int] = True, + try_number: Union[bool, int], ) -> BaseWatcher: """Get a watcher object based on the provided parameters. @@ -67,6 +70,17 @@ def get_watcher( return watcher +def get_watcher_context( + logger: Optional["Logger"], + no_ack: bool, + retry: Union[bool, int], +) -> Callable[..., AsyncContextManager[None]]: + if no_ack: + return fake_context + else: + return partial(WatcherContext, watcher=get_watcher(logger, retry)) + + class MultiLock: """A class representing a multi lock.""" diff --git a/faststream/kafka/broker.py b/faststream/kafka/broker.py index 6ce444ea51..45b8eb3306 100644 --- a/faststream/kafka/broker.py +++ b/faststream/kafka/broker.py @@ -1,4 +1,4 @@ -from functools import partial, wraps +from functools import partial from types import TracebackType from typing import ( Any, @@ -35,7 +35,6 @@ P_HandlerParams, PublisherProtocol, T_HandlerReturn, - WrappedReturn, ) from faststream.exceptions import NOT_CONNECTED_YET from faststream.kafka.asyncapi import Handler, Publisher @@ -203,34 +202,10 @@ def _process_message( func: Callable[[KafkaMessage], Awaitable[T_HandlerReturn]], watcher: Callable[..., AsyncContextManager[None]], **kwargs: Any, - ) -> Callable[ - [KafkaMessage], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: - """Wrap a message processing function with a watcher and publisher. - - Args: - func (Callable[[KafkaMessage], Awaitable[T_HandlerReturn]]): The message processing function. - watcher (BaseWatcher): The message watcher. - disable_watcher: Whether to use watcher context. - **kwargs: Additional keyword arguments. - - Returns: - Callable[[KafkaMessage], Awaitable[WrappedReturn[T_HandlerReturn]]]: The wrapped message processing function. - """ - - @wraps(func) + ): async def process_wrapper( message: KafkaMessage, - ) -> WrappedReturn[T_HandlerReturn]: - """Asynchronously process a Kafka message and wrap the return value. - - Args: - message (KafkaMessage): The Kafka message to process. - - Returns: - WrappedReturn[T_HandlerReturn]: The wrapped return value. - """ + ): async with watcher(message): r = await func(message) diff --git a/faststream/kafka/broker.pyi b/faststream/kafka/broker.pyi index b3584a6198..abfed3c776 100644 --- a/faststream/kafka/broker.pyi +++ b/faststream/kafka/broker.pyi @@ -34,7 +34,6 @@ from faststream.broker.types import ( Filter, P_HandlerParams, T_HandlerReturn, - WrappedReturn, ) from faststream.kafka.asyncapi import Handler, Publisher from faststream.kafka.message import KafkaMessage @@ -181,7 +180,7 @@ class KafkaBroker( **kwargs: Any, ) -> Callable[ [StreamMessage[aiokafka.ConsumerRecord]], - Awaitable[WrappedReturn[T_HandlerReturn]], + Awaitable[T_HandlerReturn], ]: ... @override # type: ignore[override] @overload diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index 3512db0d5b..7cceb7874f 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -37,6 +37,7 @@ CustomParser, Filter, ) +from faststream.broker.utils import get_watcher_context from faststream.exceptions import NOT_CONNECTED_YET from faststream.nats.asyncapi import Handler, Publisher from faststream.nats.helpers import stream_builder @@ -285,6 +286,8 @@ def subscriber( # type: ignore[override] middlewares: Sequence[Callable[[Msg], BaseMiddleware]] = (), filter: Filter[NatsMessage] = default_filter, max_workers: int = 1, + retry: bool = False, + no_ack: bool = False, # AsyncAPI information title: Optional[str] = None, description: Optional[str] = None, @@ -358,20 +361,21 @@ def subscriber( # type: ignore[override] stream=stream, pull_sub=pull_sub, extra_options=extra_options, + max_workers=max_workers, + producer=self, + # base options title=title, description=description, include_in_schema=include_in_schema, graceful_timeout=self.graceful_timeout, - max_workers=max_workers, middlewares=self.middlewares, - logger=self.logger, log_context_builder=partial( self._get_log_context, stream=stream.name if stream else "", subject=subject, queue=queue, ), - producer=self, + watcher=get_watcher_context(self.logger, no_ack, retry), ), ) diff --git a/faststream/nats/broker.pyi b/faststream/nats/broker.pyi index aab6b43fdf..5208090026 100644 --- a/faststream/nats/broker.pyi +++ b/faststream/nats/broker.pyi @@ -3,8 +3,6 @@ import ssl from types import TracebackType from typing import ( Any, - AsyncContextManager, - Awaitable, Callable, Sequence, ) @@ -35,14 +33,11 @@ from typing_extensions import override from faststream.asyncapi import schema as asyncapi from faststream.broker.core.broker import BrokerUsecase, default_filter from faststream.broker.core.handler import WrapperProtocol -from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( CustomDecoder, CustomParser, Filter, - T_HandlerReturn, - WrappedReturn, ) from faststream.log import access_logger from faststream.nats.asyncapi import Handler, Publisher @@ -197,15 +192,6 @@ class NatsBroker( exec_tb: TracebackType | None = None, ) -> None: ... async def start(self) -> None: ... - def _process_message( - self, - func: Callable[[StreamMessage[Msg]], Awaitable[T_HandlerReturn]], - watcher: Callable[..., AsyncContextManager[None]], - **kwargs: Any, - ) -> Callable[ - [StreamMessage[Msg]], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: ... def _log_connection_broken( self, error_cb: ErrorCallback | None = None, diff --git a/faststream/nats/handler.py b/faststream/nats/handler.py index 990b741703..b8f538dd70 100644 --- a/faststream/nats/handler.py +++ b/faststream/nats/handler.py @@ -1,7 +1,6 @@ import asyncio from contextlib import suppress -from functools import partial, wraps -from logging import Logger +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -35,8 +34,6 @@ CustomDecoder, CustomParser, Filter, - T_HandlerReturn, - WrappedReturn, ) from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage @@ -73,10 +70,11 @@ def __init__( Callable[[StreamMessage[Any]], Dict[str, str]], Doc("Function to create log extra data by message"), ], + watcher: Annotated[ + Callable[..., AsyncContextManager[None]], + Doc("Watcher to ack message"), + ], producer, - logger: Annotated[ - Optional[Logger], Doc("Logger to use with process message Watcher") - ] = None, queue: Annotated[ str, Doc("NATS queue name"), @@ -145,7 +143,7 @@ def __init__( title=title, middlewares=middlewares, graceful_timeout=graceful_timeout, - logger=logger, + watcher=watcher, ) self.max_workers = max_workers @@ -177,23 +175,23 @@ def add_call( **wrap_kwargs, ) - def _process_message( - self, - func: Callable[[NatsMessage], Awaitable[T_HandlerReturn]], - watcher: Callable[..., AsyncContextManager[None]], - ) -> Callable[ - [NatsMessage], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: - @wraps(func) - async def process_wrapper( - message: NatsMessage, - ) -> WrappedReturn[T_HandlerReturn]: - async with watcher(message): - r = await func(message) - return r, None - - return process_wrapper + # def _process_message( + # self, + # func: Callable[[NatsMessage], Awaitable[T_HandlerReturn]], + # watcher: Callable[..., AsyncContextManager[None]], + # ) -> Callable[ + # [NatsMessage], + # Awaitable[WrappedReturn[T_HandlerReturn]], + # ]: + # @wraps(func) + # async def process_wrapper( + # message: NatsMessage, + # ) -> WrappedReturn[T_HandlerReturn]: + # async with watcher(message): + # r = await func(message) + # return r, None + + # return process_wrapper def make_response_publisher(self, message: NatsMessage) -> Sequence[FakePublisher]: if message.reply_to: diff --git a/faststream/rabbit/broker.py b/faststream/rabbit/broker.py index 85b2a0c0e9..42fba8c917 100644 --- a/faststream/rabbit/broker.py +++ b/faststream/rabbit/broker.py @@ -35,7 +35,6 @@ P_HandlerParams, PublisherProtocol, T_HandlerReturn, - WrappedReturn, ) from faststream.exceptions import NOT_CONNECTED_YET from faststream.rabbit.asyncapi import Handler, Publisher @@ -513,35 +512,13 @@ def _process_message( watcher: Callable[..., AsyncContextManager[None]], reply_config: Optional[ReplyConfig] = None, **kwargs: Any, - ) -> Callable[ - [StreamMessage[aio_pika.IncomingMessage]], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: - """Process a message using the provided handler function. - - Args: - func (Callable): The handler function for processing the message. - watcher (BaseWatcher): The message watcher for tracking message processing. - reply_config (Optional[ReplyConfig], optional): The reply configuration for the message. - **kwargs (Any): Additional keyword arguments. - - Returns: - Callable: A wrapper function for processing messages. - """ + ): reply_kwargs = {} if reply_config is None else model_to_dict(reply_config) @wraps(func) async def process_wrapper( message: RabbitMessage, - ) -> WrappedReturn[T_HandlerReturn]: - """Asynchronously process a message and wrap the return value. - - Args: - message: The RabbitMessage to process. - - Returns: - A tuple containing the return value of the handler function and an optional PublisherProtocol. - """ + ): async with watcher(message): r = await func(message) diff --git a/faststream/rabbit/broker.pyi b/faststream/rabbit/broker.pyi index 0633fd31ae..a5d53b1394 100644 --- a/faststream/rabbit/broker.pyi +++ b/faststream/rabbit/broker.pyi @@ -3,8 +3,6 @@ from ssl import SSLContext from types import TracebackType from typing import ( Any, - AsyncContextManager, - Awaitable, Callable, Sequence, ) @@ -19,7 +17,6 @@ from yarl import URL from faststream.asyncapi import schema as asyncapi from faststream.broker.core.broker import BrokerUsecase, default_filter from faststream.broker.core.call_wrapper import HandlerCallWrapper -from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( CustomDecoder, @@ -27,7 +24,6 @@ from faststream.broker.types import ( Filter, P_HandlerParams, T_HandlerReturn, - WrappedReturn, ) from faststream.log import access_logger from faststream.rabbit.asyncapi import Handler, Publisher @@ -216,18 +212,6 @@ class RabbitBroker( user_id: str | None = None, app_id: str | None = None, ) -> aiormq.abc.ConfirmationFrameType | SendableMessage: ... - def _process_message( - self, - func: Callable[ - [StreamMessage[aio_pika.IncomingMessage]], Awaitable[T_HandlerReturn] - ], - watcher: Callable[..., AsyncContextManager[None]], - reply_config: ReplyConfig | None = None, - **kwargs: Any, - ) -> Callable[ - [StreamMessage[aio_pika.IncomingMessage]], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: ... async def declare_queue( self, queue: RabbitQueue, diff --git a/faststream/redis/broker.py b/faststream/redis/broker.py index 80513151e4..ba66cf7aaa 100644 --- a/faststream/redis/broker.py +++ b/faststream/redis/broker.py @@ -31,7 +31,6 @@ P_HandlerParams, PublisherProtocol, T_HandlerReturn, - WrappedReturn, ) from faststream.exceptions import NOT_CONNECTED_YET from faststream.redis.asyncapi import Handler, Publisher @@ -168,14 +167,11 @@ def _process_message( func: Callable[[StreamMessage[Any]], Awaitable[T_HandlerReturn]], watcher: Callable[..., AsyncContextManager[None]], **kwargs: Any, - ) -> Callable[ - [StreamMessage[Any]], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: + ): @wraps(func) async def process_wrapper( message: StreamMessage[Any], - ) -> WrappedReturn[T_HandlerReturn]: + ): async with watcher( message, redis=self._connection, diff --git a/faststream/redis/broker.pyi b/faststream/redis/broker.pyi index 5c9ec75ae1..5c12ba40a6 100644 --- a/faststream/redis/broker.pyi +++ b/faststream/redis/broker.pyi @@ -2,8 +2,6 @@ import logging from types import TracebackType from typing import ( Any, - AsyncContextManager, - Awaitable, Callable, Mapping, Sequence, @@ -17,7 +15,6 @@ from typing_extensions import TypeAlias, override from faststream.asyncapi import schema as asyncapi from faststream.broker.core.broker import BrokerUsecase, default_filter from faststream.broker.core.call_wrapper import HandlerCallWrapper -from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( CustomDecoder, @@ -25,7 +22,6 @@ from faststream.broker.types import ( Filter, P_HandlerParams, T_HandlerReturn, - WrappedReturn, ) from faststream.log import access_logger from faststream.redis.asyncapi import Handler, Publisher @@ -148,15 +144,6 @@ class RedisBroker( exec_tb: TracebackType | None = None, ) -> None: ... async def start(self) -> None: ... - def _process_message( - self, - func: Callable[[StreamMessage[Any]], Awaitable[T_HandlerReturn]], - watcher: Callable[..., AsyncContextManager[None]], - **kwargs: Any, - ) -> Callable[ - [StreamMessage[Any]], - Awaitable[WrappedReturn[T_HandlerReturn]], - ]: ... @override def subscriber( # type: ignore[override] self, From 0b8c7c938744bbe0bc444914bea9fa4cec394a7f Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Thu, 11 Jan 2024 22:36:07 +0300 Subject: [PATCH 12/87] fix #840: broker.publish respects middlewares --- faststream/broker/core/broker.py | 3 ++- faststream/broker/core/handler.py | 2 +- faststream/nats/broker.py | 22 ++++++++++++++-------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/faststream/broker/core/broker.py b/faststream/broker/core/broker.py index 1fe3dd5550..3badde9422 100644 --- a/faststream/broker/core/broker.py +++ b/faststream/broker/core/broker.py @@ -79,6 +79,7 @@ class BrokerUsecase( handlers: Mapping[Any, "BaseHandler[MsgType]"] _publishers: Mapping[Any, "BasePublisher[MsgType]"] + middlewares: Sequence[Callable[[Any], "BaseMiddleware"]] def __init__( self, @@ -120,7 +121,7 @@ def __init__( Doc("Dependencies to apply to all broker subscribers"), ] = (), middlewares: Annotated[ - Sequence[Callable[[MsgType], "BaseMiddleware"]], + Sequence[Callable[[Any], "BaseMiddleware"]], Doc("Middlewares to apply to all broker publishers/subscribers"), ] = (), graceful_timeout: Annotated[ diff --git a/faststream/broker/core/handler.py b/faststream/broker/core/handler.py index 196b951d4d..cd86803505 100644 --- a/faststream/broker/core/handler.py +++ b/faststream/broker/core/handler.py @@ -350,7 +350,7 @@ async def close_middlewares( # TODO: need to test copy result_to_send = result - for m_pub in chain(middlewares, h.middlewares): + for m_pub in middlewares: result_to_send = await pub_stack.enter_async_context( m_pub.publish_scope(result_to_send) ) diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index 7cceb7874f..ae3012ec02 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -1,5 +1,6 @@ import logging import warnings +from contextlib import AsyncExitStack from functools import partial from types import TracebackType from typing import ( @@ -48,7 +49,7 @@ from faststream.nats.security import parse_security from faststream.nats.shared.logging import NatsLoggingMixin from faststream.security import BaseSecurity -from faststream.types import AnyDict, DecodedMessage +from faststream.types import AnyDict, DecodedMessage, SendableMessage from faststream.utils.context.repository import context Subject: TypeAlias = str @@ -437,21 +438,26 @@ def publisher( # type: ignore[override] @override async def publish( # type: ignore[override] self, + message: SendableMessage, *args: Any, stream: Optional[str] = None, **kwargs: Any, ) -> Optional[DecodedMessage]: if stream is None: assert self._producer, NOT_CONNECTED_YET # nosec B101 - return await self._producer.publish(*args, **kwargs) - + publisher = self._producer else: assert self._js_producer, NOT_CONNECTED_YET # nosec B101 - return await self._js_producer.publish( - *args, - stream=stream, - **kwargs, # type: ignore[misc] - ) + publisher = self._js_producer + kwargs["stream"] = stream + + async with AsyncExitStack() as stack: + for m in self.middlewares: + message = await stack.enter_async_context( + m(None).publish_scope(message) + ) + + return await publisher.publish(message, *args, **kwargs) def __set_publisher_producer(self, publisher: Publisher) -> None: if publisher.stream is not None: From 2c4c28aece89f1e320985a39726057b1fa9b1a8e Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sat, 13 Jan 2024 19:29:20 +0300 Subject: [PATCH 13/87] fix: correct connection closing --- faststream/broker/core/broker.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/faststream/broker/core/broker.py b/faststream/broker/core/broker.py index 3badde9422..833a8fae98 100644 --- a/faststream/broker/core/broker.py +++ b/faststream/broker/core/broker.py @@ -342,7 +342,26 @@ async def close( await h.close() if self._connection is not None: - self._connection = None + await self._close(exc_type, exc_val, exec_tb) + + @abstractmethod + async def _close( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_val: Optional[BaseException] = None, + exec_tb: Optional["TracebackType"] = None, + ) -> None: + """Close the object. + + Args: + exc_type: Optional. The type of the exception. + exc_val: Optional. The exception value. + exec_tb: Optional. The traceback of the exception. + + Returns: + None + """ + self._connection = None @abstractmethod async def publish( From c12df70b5b13b8aefb99c380a47ae9c6310ba72b Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sat, 13 Jan 2024 19:56:06 +0300 Subject: [PATCH 14/87] fix: return published message in tests --- faststream/broker/core/handler.py | 28 ++++++++++++++-------------- tests/brokers/base/parser.py | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/faststream/broker/core/handler.py b/faststream/broker/core/handler.py index cd86803505..4e7d41bd92 100644 --- a/faststream/broker/core/handler.py +++ b/faststream/broker/core/handler.py @@ -336,27 +336,27 @@ async def close_middlewares( processed = True try: - result = cast(SendableMessage, await caller.asend(None)) + result_msg = cast(SendableMessage, await caller.asend(None)) except StopConsume: await self.close() return - # TODO: suppress all publishing errors and raise them after all publishers will be tried - for publisher in ( - *self.make_response_publisher(message), - *h.handler._publishers, - ): - async with AsyncExitStack() as pub_stack: - # TODO: need to test copy - result_to_send = result + async with AsyncExitStack() as pub_stack: + # TODO: need to test copy + result_msg = result_msg - for m_pub in middlewares: - result_to_send = await pub_stack.enter_async_context( - m_pub.publish_scope(result_to_send) - ) + for m_pub in middlewares: + result_msg = await pub_stack.enter_async_context( + m_pub.publish_scope(result_msg) + ) + # TODO: suppress all publishing errors and raise them after all publishers will be tried + for publisher in ( + *self.make_response_publisher(message), + *h.handler._publishers, + ): await publisher.publish( - message=result_to_send, + message=result_msg, correlation_id=message.correlation_id, ) diff --git a/tests/brokers/base/parser.py b/tests/brokers/base/parser.py index 4c0a1b6f58..44af434fc3 100644 --- a/tests/brokers/base/parser.py +++ b/tests/brokers/base/parser.py @@ -197,7 +197,7 @@ async def handle2(m): assert event.is_set() assert event2.is_set() - assert mock.call_count == 2 # instead 4 + assert mock.call_count == 1 class CustomParserTestcase(LocalCustomParserTestcase): # noqa: D101 From 1ee5d1e7452534142bee2801e7210b368cd38fb1 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sat, 13 Jan 2024 21:18:24 +0300 Subject: [PATCH 15/87] fix: correct middlewares --- faststream/broker/core/handler.py | 74 +++++++++++++++---------------- faststream/broker/middlewares.py | 10 ++--- tests/brokers/base/middlewares.py | 27 ++++++----- 3 files changed, 55 insertions(+), 56 deletions(-) diff --git a/faststream/broker/core/handler.py b/faststream/broker/core/handler.py index 4e7d41bd92..674fe99909 100644 --- a/faststream/broker/core/handler.py +++ b/faststream/broker/core/handler.py @@ -7,7 +7,6 @@ TYPE_CHECKING, Any, AsyncContextManager, - AsyncGenerator, Awaitable, Callable, Dict, @@ -85,15 +84,11 @@ def description(self) -> Optional[str]: description = getattr(caller, "__doc__", None) return description - async def call( + async def is_suitable( self, msg: MsgType, cache: Dict[Any, Any], - extra_middlewares: Sequence["BaseMiddleware"], - ) -> AsyncGenerator[ - Union["StreamMessage[MsgType]", None, SendableMessage], - None, - ]: + ) -> Optional["StreamMessage[MsgType]"]: message = cache[self.parser] = cache.get( self.parser, await self.parser(msg), @@ -104,36 +99,40 @@ async def call( ) if await self.filter(message): - yield message - - result = None - async with AsyncExitStack() as consume_stack: - for middleware in chain(self.middlewares, extra_middlewares): - message.decoded_body = await consume_stack.enter_async_context( - middleware.consume_scope(message.decoded_body) - ) + return message - try: - result = await self.handler.call_wrapped(message) + async def call( + self, + message: "StreamMessage[MsgType]", + extra_middlewares: Sequence["BaseMiddleware"], + ) -> Optional[SendableMessage]: + assert message.decoded_body + + result: SendableMessage = None + async with AsyncExitStack() as consume_stack: + for middleware in chain(self.middlewares, extra_middlewares): + message.decoded_body = await consume_stack.enter_async_context( + middleware.consume_scope(message.decoded_body) + ) - except StopConsume: - self.handler.trigger() - raise + try: + result = await self.handler.call_wrapped(message) - except HandlerException: - self.handler.trigger() - raise + except StopConsume: + self.handler.trigger() + raise - except Exception as e: - self.handler.trigger(error=e) - raise e + except HandlerException: + self.handler.trigger() + raise - else: - self.handler.trigger(result=result[0] if result else None) - yield result + except Exception as e: + self.handler.trigger(error=e) + raise e - else: - yield None + else: + self.handler.trigger(result=result) + return result class BaseHandler(AsyncAPIOperation, WrapHandlerMixin[MsgType]): @@ -290,8 +289,7 @@ async def consume(self, msg: MsgType) -> SendableMessage: Returns: The sendable message. """ - result: Optional[SendableMessage] = None - result_msg: SendableMessage = None + result_msg: Optional[SendableMessage] = None if not self.running: return result_msg @@ -313,10 +311,10 @@ async def consume(self, msg: MsgType) -> SendableMessage: if processed: break - caller = h.call(msg, cache, middlewares) - if ( - message := cast("StreamMessage[MsgType]", await caller.asend(None)) + message := cast( + "StreamMessage[MsgType]", await h.is_suitable(msg, cache) + ) ) is not None: await stack.enter_async_context(self.watcher(message)) stack.enter_context(context.scope("message", message)) @@ -336,7 +334,9 @@ async def close_middlewares( processed = True try: - result_msg = cast(SendableMessage, await caller.asend(None)) + result_msg = cast( + SendableMessage, await h.call(message, middlewares) + ) except StopConsume: await self.close() return diff --git a/faststream/broker/middlewares.py b/faststream/broker/middlewares.py index ca8f634720..6855f6c5d5 100644 --- a/faststream/broker/middlewares.py +++ b/faststream/broker/middlewares.py @@ -44,13 +44,9 @@ class BaseMiddleware: Asynchronous function to handle the after publish event. """ - def __init__(self, msg: Any) -> None: - """Initialize the class. - - Args: - msg: Any message to be stored. - """ - self.msg = msg + def __init__(self) -> None: + """Initialize the class.""" + pass async def on_receive(self) -> None: pass diff --git a/tests/brokers/base/middlewares.py b/tests/brokers/base/middlewares.py index 8f840256fc..5bb623d6aa 100644 --- a/tests/brokers/base/middlewares.py +++ b/tests/brokers/base/middlewares.py @@ -1,5 +1,5 @@ import asyncio -from typing import Type +from typing import Optional, Type from unittest.mock import Mock import pytest @@ -25,20 +25,21 @@ async def test_local_middleware( self, event: asyncio.Event, queue: str, mock: Mock, raw_broker ): class mid(BaseMiddleware): # noqa: N801 - async def on_receive(self): - mock.start(self.msg) - return await super().on_receive() + async def on_consume(self, msg): + mock.start(msg) + return await super().on_consume(msg) - async def after_processed(self, exc_type, exc_val, exec_tb): + async def after_consume(self, err: Optional[Exception]) -> None: mock.end() - return await super().after_processed(exc_type, exc_val, exec_tb) + event.set() + return await super().after_consume(err) broker = self.broker_class() - @broker.subscriber(queue, middlewares=(mid,)) + @broker.subscriber(queue, middlewares=(mid(),)) async def handler(m): - event.set() - return "" + mock.inner(m) + return "end" broker = self.patch_broker(raw_broker, broker) @@ -46,14 +47,16 @@ async def handler(m): await broker.start() await asyncio.wait( ( - asyncio.create_task(broker.publish("", queue)), + asyncio.create_task(broker.publish("start", queue)), asyncio.create_task(event.wait()), ), timeout=3, ) + mock.start.assert_called_once_with("start") + mock.inner.assert_called_once_with("start") + assert event.is_set() - mock.start.assert_called_once() mock.end.assert_called_once() async def test_local_middleware_not_shared_between_subscribers( @@ -236,7 +239,7 @@ async def after_processed(self, exc_type, exc_val, exec_tb): return await super().after_processed(exc_type, exc_val, exec_tb) broker = self.broker_class( - middlewares=(mid,), + middlewares=(mid(None),), ) @broker.subscriber(queue) From 7fab3f209e0f6c12e29c529cb6977916b7a27243 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sat, 13 Jan 2024 21:39:49 +0300 Subject: [PATCH 16/87] refactor: remove useless code --- faststream/broker/core/handler.py | 17 ++++------------- faststream/nats/broker.py | 2 +- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/faststream/broker/core/handler.py b/faststream/broker/core/handler.py index 674fe99909..4ac8f5a1bf 100644 --- a/faststream/broker/core/handler.py +++ b/faststream/broker/core/handler.py @@ -106,8 +106,6 @@ async def call( message: "StreamMessage[MsgType]", extra_middlewares: Sequence["BaseMiddleware"], ) -> Optional[SendableMessage]: - assert message.decoded_body - result: SendableMessage = None async with AsyncExitStack() as consume_stack: for middleware in chain(self.middlewares, extra_middlewares): @@ -306,14 +304,11 @@ async def consume(self, msg: MsgType) -> SendableMessage: await middleware.__aenter__() cache = {} - processed = False for h in self.calls: - if processed: - break - if ( message := cast( - "StreamMessage[MsgType]", await h.is_suitable(msg, cache) + "StreamMessage[MsgType]", + await h.is_suitable(msg, cache), ) ) is not None: await stack.enter_async_context(self.watcher(message)) @@ -331,8 +326,6 @@ async def close_middlewares( for m in middlewares: await m.__aexit__(exc_type, exc_val, exec_tb) - processed = True - try: result_msg = cast( SendableMessage, await h.call(message, middlewares) @@ -342,7 +335,6 @@ async def close_middlewares( return async with AsyncExitStack() as pub_stack: - # TODO: need to test copy result_msg = result_msg for m_pub in middlewares: @@ -355,14 +347,13 @@ async def close_middlewares( *self.make_response_publisher(message), *h.handler._publishers, ): + # add publishers middlewares await publisher.publish( message=result_msg, correlation_id=message.correlation_id, ) - assert not self.running or processed, "You have to consume message" # nosec B101 - - return result_msg + return result_msg def make_response_publisher( self, message: "StreamMessage[MsgType]" diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index ae3012ec02..14507283c9 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -454,7 +454,7 @@ async def publish( # type: ignore[override] async with AsyncExitStack() as stack: for m in self.middlewares: message = await stack.enter_async_context( - m(None).publish_scope(message) + m().publish_scope(message) ) return await publisher.publish(message, *args, **kwargs) From cc5c7fb47f08c89f5fed0eaf040c81f72f53f007 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sat, 13 Jan 2024 22:14:15 +0300 Subject: [PATCH 17/87] refactore: catch StopConsume at any level --- faststream/broker/core/handler.py | 98 +++++++++++++++---------------- faststream/nats/broker.py | 2 +- 2 files changed, 49 insertions(+), 51 deletions(-) diff --git a/faststream/broker/core/handler.py b/faststream/broker/core/handler.py index 4ac8f5a1bf..566499ef15 100644 --- a/faststream/broker/core/handler.py +++ b/faststream/broker/core/handler.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass from inspect import unwrap from itertools import chain @@ -7,6 +7,7 @@ TYPE_CHECKING, Any, AsyncContextManager, + AsyncIterator, Awaitable, Callable, Dict, @@ -17,7 +18,6 @@ Tuple, Type, Union, - cast, ) from faststream.asyncapi.base import AsyncAPIOperation @@ -186,40 +186,12 @@ async def close(self) -> None: self.running = False await self.lock.wait_release(self.graceful_timeout) - @property - def call_name(self) -> str: - """Returns the name of the handler call.""" - return to_camelcase(self.calls[0].call_name) - - @property - def description(self) -> Optional[str]: - """Returns the description of the handler.""" - if self._description: - return self._description - - if not self.calls: # pragma: no cover - return None - - else: - return self.calls[0].description - - def get_payloads(self) -> List[Tuple[AnyDict, str]]: - """Get the payloads of the handler.""" - payloads: List[Tuple[AnyDict, str]] = [] - - for h in self.calls: - body = parse_handler_params( - h.dependant, - prefix=f"{self._title or self.call_name}:Message", - ) - payloads.append( - ( - body, - to_camelcase(h.call_name), - ) - ) - - return payloads + @asynccontextmanager + async def stop_scope(self) -> AsyncIterator[None]: + try: + yield + except StopConsume: + await self.close() def add_call( self, @@ -295,8 +267,8 @@ async def consume(self, msg: MsgType) -> SendableMessage: middlewares = [] async with AsyncExitStack() as stack: stack.enter_context(self.lock) - stack.enter_context(context.scope("handler_", self)) + await stack.enter_async_context(self.stop_scope()) for m in self.middlewares: middleware = m(msg) @@ -305,12 +277,7 @@ async def consume(self, msg: MsgType) -> SendableMessage: cache = {} for h in self.calls: - if ( - message := cast( - "StreamMessage[MsgType]", - await h.is_suitable(msg, cache), - ) - ) is not None: + if (message := await h.is_suitable(msg, cache)) is not None: await stack.enter_async_context(self.watcher(message)) stack.enter_context(context.scope("message", message)) stack.enter_context( @@ -326,13 +293,7 @@ async def close_middlewares( for m in middlewares: await m.__aexit__(exc_type, exc_val, exec_tb) - try: - result_msg = cast( - SendableMessage, await h.call(message, middlewares) - ) - except StopConsume: - await self.close() - return + result_msg = await h.call(message, middlewares) async with AsyncExitStack() as pub_stack: result_msg = result_msg @@ -359,3 +320,40 @@ def make_response_publisher( self, message: "StreamMessage[MsgType]" ) -> Sequence[PublisherProtocol]: raise NotImplementedError() + + # AsyncAPI methods + + @property + def call_name(self) -> str: + """Returns the name of the handler call.""" + return to_camelcase(self.calls[0].call_name) + + @property + def description(self) -> Optional[str]: + """Returns the description of the handler.""" + if self._description: + return self._description + + if not self.calls: # pragma: no cover + return None + + else: + return self.calls[0].description + + def get_payloads(self) -> List[Tuple[AnyDict, str]]: + """Get the payloads of the handler.""" + payloads: List[Tuple[AnyDict, str]] = [] + + for h in self.calls: + body = parse_handler_params( + h.dependant, + prefix=f"{self._title or self.call_name}:Message", + ) + payloads.append( + ( + body, + to_camelcase(h.call_name), + ) + ) + + return payloads diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index 14507283c9..ae3012ec02 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -454,7 +454,7 @@ async def publish( # type: ignore[override] async with AsyncExitStack() as stack: for m in self.middlewares: message = await stack.enter_async_context( - m().publish_scope(message) + m(None).publish_scope(message) ) return await publisher.publish(message, *args, **kwargs) From 1d4526a52efeca52c46e5ca4eb576e1819f2a6e3 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 14 Jan 2024 14:32:00 +0300 Subject: [PATCH 18/87] refactor: new subscriber middlewares --- faststream/broker/core/broker.py | 10 +- faststream/broker/core/call_wrapper.py | 4 +- faststream/broker/core/handler.py | 30 +++-- .../broker/core/handler_wrapper_mixin.py | 9 +- faststream/broker/middlewares.py | 4 +- faststream/broker/types.py | 19 ++- faststream/kafka/broker.py | 11 +- faststream/nats/broker.py | 6 +- faststream/nats/broker.pyi | 4 +- .../nats/{fastapi.py => fastapi/__init__.py} | 38 ++---- faststream/nats/fastapi/fastapi.py | 24 ++++ faststream/nats/{ => fastapi}/fastapi.pyi | 34 ++---- faststream/nats/handler.py | 113 ++++++++---------- faststream/nats/router.pyi | 5 +- 14 files changed, 154 insertions(+), 157 deletions(-) rename faststream/nats/{fastapi.py => fastapi/__init__.py} (58%) create mode 100644 faststream/nats/fastapi/fastapi.py rename faststream/nats/{ => fastapi}/fastapi.pyi (88%) diff --git a/faststream/broker/core/broker.py b/faststream/broker/core/broker.py index 833a8fae98..84fc5e616c 100644 --- a/faststream/broker/core/broker.py +++ b/faststream/broker/core/broker.py @@ -4,8 +4,8 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Generic, + Iterable, List, Mapping, Optional, @@ -45,8 +45,8 @@ from faststream.broker.core.handler_wrapper_mixin import WrapperProtocol from faststream.broker.core.publisher import BasePublisher from faststream.broker.message import StreamMessage - from faststream.broker.middlewares import BaseMiddleware from faststream.broker.router import BrokerRouter + from faststream.broker.types import BrokerMiddleware, SubscriberMiddleware from faststream.security import BaseSecurity from faststream.types import AnyDict, SendableMessage @@ -79,7 +79,7 @@ class BrokerUsecase( handlers: Mapping[Any, "BaseHandler[MsgType]"] _publishers: Mapping[Any, "BasePublisher[MsgType]"] - middlewares: Sequence[Callable[[Any], "BaseMiddleware"]] + middlewares: Sequence["BrokerMiddleware[MsgType]"] def __init__( self, @@ -121,7 +121,7 @@ def __init__( Doc("Dependencies to apply to all broker subscribers"), ] = (), middlewares: Annotated[ - Sequence[Callable[[Any], "BaseMiddleware"]], + Iterable["BrokerMiddleware[MsgType]"], Doc("Middlewares to apply to all broker publishers/subscribers"), ] = (), graceful_timeout: Annotated[ @@ -401,7 +401,7 @@ def subscriber( decoder: Optional[CustomDecoder["StreamMessage[MsgType]"]] = None, parser: Optional[CustomParser[MsgType, "StreamMessage[MsgType]"]] = None, dependencies: Sequence["Depends"] = (), - middlewares: Sequence["BaseMiddleware"] = (), + middlewares: Iterable["SubscriberMiddleware"] = (), raw: bool = False, no_ack: bool = False, retry: Union[bool, int] = False, diff --git a/faststream/broker/core/call_wrapper.py b/faststream/broker/core/call_wrapper.py index 30f6afca30..8e75783d5f 100644 --- a/faststream/broker/core/call_wrapper.py +++ b/faststream/broker/core/call_wrapper.py @@ -38,7 +38,7 @@ class HandlerCallWrapper(Generic[MsgType, P_HandlerParams, T_HandlerReturn]): future: Optional["asyncio.Future[Any]"] is_test: bool - _wrapped_call: Optional[WrappedHandlerCall[MsgType, T_HandlerReturn]] + _wrapped_call: Optional[WrappedHandlerCall[MsgType, SendableMessage]] _original_call: Callable[P_HandlerParams, T_HandlerReturn] _publishers: List[PublisherProtocol] @@ -116,7 +116,7 @@ def __call__( return self._original_call(*args, **kwargs) def set_wrapped( - self, wrapped: WrappedHandlerCall[MsgType, T_HandlerReturn] + self, wrapped: WrappedHandlerCall[MsgType, SendableMessage] ) -> None: """Set the wrapped handler call. diff --git a/faststream/broker/core/handler.py b/faststream/broker/core/handler.py index 566499ef15..15e36c9264 100644 --- a/faststream/broker/core/handler.py +++ b/faststream/broker/core/handler.py @@ -12,6 +12,7 @@ Callable, Dict, Generic, + Iterable, List, Optional, Sequence, @@ -51,6 +52,7 @@ from faststream.broker.core.handler_wrapper_mixin import WrapperProtocol from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware + from faststream.broker.types import BrokerMiddleware, SubscriberMiddleware @dataclass(slots=True) @@ -61,7 +63,7 @@ class HandlerItem(Generic[MsgType]): filter: Callable[["StreamMessage[MsgType]"], Awaitable[bool]] parser: AsyncParser[MsgType, Any] decoder: AsyncDecoder["StreamMessage[MsgType]"] - middlewares: Sequence["BaseMiddleware"] + middlewares: Iterable["SubscriberMiddleware"] dependant: "CallModel[Any, SendableMessage]" @property @@ -104,13 +106,13 @@ async def is_suitable( async def call( self, message: "StreamMessage[MsgType]", - extra_middlewares: Sequence["BaseMiddleware"], + extra_middlewares: Iterable["SubscriberMiddleware"], ) -> Optional[SendableMessage]: result: SendableMessage = None async with AsyncExitStack() as consume_stack: for middleware in chain(self.middlewares, extra_middlewares): message.decoded_body = await consume_stack.enter_async_context( - middleware.consume_scope(message.decoded_body) + middleware(message.decoded_body) ) try: @@ -149,7 +151,7 @@ def __init__( self, *, log_context_builder: Callable[["StreamMessage[Any]"], Dict[str, str]], - middlewares: Sequence[Callable[[MsgType], "BaseMiddleware"]], + middlewares: Iterable["BrokerMiddleware[MsgType]"], description: Optional[str], title: Optional[str], include_in_schema: bool, @@ -198,17 +200,18 @@ def add_call( filter_: Filter["StreamMessage[MsgType]"], parser_: CustomParser[MsgType, Any], decoder_: CustomDecoder["StreamMessage[MsgType]"], - middlewares_: Sequence["BaseMiddleware"], + middlewares_: Iterable["SubscriberMiddleware"], dependencies_: Sequence["Depends"], **wrap_kwargs: Any, ) -> "WrapperProtocol[MsgType]": + # TODO: should return SELF? def wrapper( func: Optional[Callable[P_HandlerParams, T_HandlerReturn]] = None, *, filter: Filter["StreamMessage[MsgType]"] = filter_, parser: CustomParser[MsgType, Any] = parser_, decoder: CustomDecoder["StreamMessage[MsgType]"] = decoder_, - middlewares: Sequence["BaseMiddleware"] = (), + middlewares: Iterable["SubscriberMiddleware"] = (), dependencies: Sequence["Depends"] = (), ) -> Union[ HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], @@ -264,12 +267,13 @@ async def consume(self, msg: MsgType) -> SendableMessage: if not self.running: return result_msg - middlewares = [] async with AsyncExitStack() as stack: stack.enter_context(self.lock) stack.enter_context(context.scope("handler_", self)) await stack.enter_async_context(self.stop_scope()) + # enter all middlewares + middlewares: List["BaseMiddleware"] = [] for m in self.middlewares: middleware = m(msg) middlewares.append(middleware) @@ -279,11 +283,15 @@ async def consume(self, msg: MsgType) -> SendableMessage: for h in self.calls: if (message := await h.is_suitable(msg, cache)) is not None: await stack.enter_async_context(self.watcher(message)) - stack.enter_context(context.scope("message", message)) stack.enter_context( - context.scope("log_context", self.log_context_builder(message)) + context.scope( + "log_context", + self.log_context_builder(message), + ) ) + stack.enter_context(context.scope("message", message)) + # middlewares should be exited before scope does @stack.push_async_callback async def close_middlewares( exc_type: Optional[Type[BaseException]] = None, @@ -293,7 +301,9 @@ async def close_middlewares( for m in middlewares: await m.__aexit__(exc_type, exc_val, exec_tb) - result_msg = await h.call(message, middlewares) + result_msg = await h.call( + message, (m.consume_scope for m in middlewares) + ) async with AsyncExitStack() as pub_stack: result_msg = result_msg diff --git a/faststream/broker/core/handler_wrapper_mixin.py b/faststream/broker/core/handler_wrapper_mixin.py index cddaf24384..dfcbfcd452 100644 --- a/faststream/broker/core/handler_wrapper_mixin.py +++ b/faststream/broker/core/handler_wrapper_mixin.py @@ -5,6 +5,7 @@ Awaitable, Callable, Generic, + Iterable, Mapping, Optional, Sequence, @@ -37,7 +38,7 @@ from fast_depends.dependencies import Depends from faststream.broker.message import StreamMessage - from faststream.broker.middlewares import BaseMiddleware + from faststream.broker.types import SubscriberMiddleware class WrapperProtocol(Generic[MsgType], Protocol): """Annotation class to represent @subsriber return type.""" @@ -50,7 +51,7 @@ def __call__( filter: Filter["StreamMessage[MsgType]"], parser: CustomParser[MsgType, Any], decoder: CustomDecoder["StreamMessage[MsgType]"], - middlewares: Sequence["BaseMiddleware"] = (), + middlewares: Iterable["SubscriberMiddleware"] = (), dependencies: Sequence["Depends"] = (), ) -> Callable[ [Callable[P_HandlerParams, T_HandlerReturn]], @@ -66,7 +67,7 @@ def __call__( filter: Filter["StreamMessage[MsgType]"], parser: CustomParser[MsgType, Any], decoder: CustomDecoder["StreamMessage[MsgType]"], - middlewares: Sequence["BaseMiddleware"] = (), + middlewares: Iterable["SubscriberMiddleware"] = (), dependencies: Sequence["Depends"] = (), ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: ... @@ -78,7 +79,7 @@ def __call__( filter: Filter["StreamMessage[MsgType]"], parser: CustomParser[MsgType, Any], decoder: CustomDecoder["StreamMessage[MsgType]"], - middlewares: Sequence["BaseMiddleware"] = (), + middlewares: Iterable["SubscriberMiddleware"] = (), dependencies: Sequence["Depends"] = (), ) -> Union[ HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], diff --git a/faststream/broker/middlewares.py b/faststream/broker/middlewares.py index 6855f6c5d5..65174c8177 100644 --- a/faststream/broker/middlewares.py +++ b/faststream/broker/middlewares.py @@ -115,7 +115,9 @@ async def after_consume(self, err: Optional[Exception]) -> None: raise err @asynccontextmanager - async def consume_scope(self, msg: DecodedMessage) -> AsyncIterator[DecodedMessage]: + async def consume_scope( + self, msg: Optional[DecodedMessage] + ) -> AsyncIterator[Optional[DecodedMessage]]: """Asynchronously consumes a message and returns an asynchronous iterator of decoded messages. Args: diff --git a/faststream/broker/types.py b/faststream/broker/types.py index 17c29de769..98154b8260 100644 --- a/faststream/broker/types.py +++ b/faststream/broker/types.py @@ -1,8 +1,18 @@ -from typing import Any, Awaitable, Callable, Optional, Protocol, TypeVar, Union +from typing import ( + Any, + AsyncContextManager, + Awaitable, + Callable, + Optional, + Protocol, + TypeVar, + Union, +) from typing_extensions import ParamSpec, TypeAlias from faststream.broker.message import StreamMessage +from faststream.broker.middlewares import BaseMiddleware from faststream.types import DecodedMessage, SendableMessage Decoded = TypeVar("Decoded", bound=DecodedMessage) @@ -110,3 +120,10 @@ async def publish( AsyncWrappedHandlerCall[MsgType, T_HandlerReturn], SyncWrappedHandlerCall[MsgType, T_HandlerReturn], ] + + +BrokerMiddleware: TypeAlias = Callable[[MsgType], BaseMiddleware] +SubscriberMiddleware: TypeAlias = Callable[ + [Optional[DecodedMessage]], + AsyncContextManager[DecodedMessage], +] diff --git a/faststream/kafka/broker.py b/faststream/kafka/broker.py index 45b8eb3306..294be5d5a1 100644 --- a/faststream/kafka/broker.py +++ b/faststream/kafka/broker.py @@ -27,8 +27,8 @@ from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.core.publisher import FakePublisher from faststream.broker.message import StreamMessage -from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( + BrokerMiddleware, CustomDecoder, CustomParser, Filter, @@ -263,14 +263,7 @@ def subscriber( # type: ignore[override] ] ] = None, decoder: Optional[CustomDecoder] = None, - middlewares: Optional[ - Sequence[ - Callable[ - [aiokafka.ConsumerRecord], - BaseMiddleware, - ] - ] - ] = None, + middlewares: Iterable["BrokerMiddleware[aiokafka.ConsumerRecord]"] = (), filter: Union[ Filter[KafkaMessage], Filter[StreamMessage[Tuple[aiokafka.ConsumerRecord, ...]]], diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index ae3012ec02..46ecec1168 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -6,8 +6,8 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Dict, + Iterable, List, Optional, Sequence, @@ -32,11 +32,11 @@ from typing_extensions import TypeAlias, override from faststream.broker.core.broker import BrokerUsecase, default_filter -from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( CustomDecoder, CustomParser, Filter, + SubscriberMiddleware, ) from faststream.broker.utils import get_watcher_context from faststream.exceptions import NOT_CONNECTED_YET @@ -284,7 +284,7 @@ def subscriber( # type: ignore[override] dependencies: Sequence[Depends] = (), parser: Optional[CustomParser[Msg, NatsMessage]] = None, decoder: Optional[CustomDecoder[NatsMessage]] = None, - middlewares: Sequence[Callable[[Msg], BaseMiddleware]] = (), + middlewares: Iterable[SubscriberMiddleware] = (), filter: Filter[NatsMessage] = default_filter, max_workers: int = 1, retry: bool = False, diff --git a/faststream/nats/broker.pyi b/faststream/nats/broker.pyi index 5208090026..cd48b7d6d0 100644 --- a/faststream/nats/broker.pyi +++ b/faststream/nats/broker.pyi @@ -4,6 +4,7 @@ from types import TracebackType from typing import ( Any, Callable, + Iterable, Sequence, ) @@ -38,6 +39,7 @@ from faststream.broker.types import ( CustomDecoder, CustomParser, Filter, + SubscriberMiddleware, ) from faststream.log import access_logger from faststream.nats.asyncapi import Handler, Publisher @@ -226,7 +228,7 @@ class NatsBroker( dependencies: Sequence[Depends] = (), parser: CustomParser[Msg, NatsMessage] | None = None, decoder: CustomDecoder[NatsMessage] | None = None, - middlewares: Sequence[Callable[[Msg], BaseMiddleware]] | None = None, + middlewares: Iterable[SubscriberMiddleware] = (), filter: Filter[NatsMessage] = default_filter, retry: bool = False, no_ack: bool = False, diff --git a/faststream/nats/fastapi.py b/faststream/nats/fastapi/__init__.py similarity index 58% rename from faststream/nats/fastapi.py rename to faststream/nats/fastapi/__init__.py index 2bcb1a074c..bc75535e66 100644 --- a/faststream/nats/fastapi.py +++ b/faststream/nats/fastapi/__init__.py @@ -1,14 +1,20 @@ from nats.aio.client import Client as NatsClient -from nats.aio.msg import Msg from nats.js.client import JetStreamContext -from typing_extensions import Annotated, override +from typing_extensions import Annotated from faststream.broker.fastapi.context import Context, ContextRepo, Logger -from faststream.broker.fastapi.router import StreamRouter from faststream.nats.broker import NatsBroker as NB +from faststream.nats.fastapi.fastapi import NatsRouter from faststream.nats.message import NatsMessage as NM from faststream.nats.producer import NatsFastProducer, NatsJSFastProducer +NatsMessage = Annotated[NM, Context("message")] +NatsBroker = Annotated[NB, Context("broker")] +Client = Annotated[NatsClient, Context("broker._connection")] +JsClient = Annotated[JetStreamContext, Context("broker._stream")] +NatsProducer = Annotated[NatsFastProducer, Context("broker._producer")] +NatsJsProducer = Annotated[NatsJSFastProducer, Context("broker._js_producer")] + __all__ = ( "Context", "Logger", @@ -21,29 +27,3 @@ "NatsProducer", "NatsJsProducer", ) - -NatsMessage = Annotated[NM, Context("message")] -NatsBroker = Annotated[NB, Context("broker")] -Client = Annotated[NatsClient, Context("broker._connection")] -JsClient = Annotated[JetStreamContext, Context("broker._stream")] -NatsProducer = Annotated[NatsFastProducer, Context("broker._producer")] -NatsJsProducer = Annotated[NatsJSFastProducer, Context("broker._js_producer")] - - -class NatsRouter(StreamRouter[Msg]): - """A class to represent a NATS router.""" - - broker_class = NB - - @override - @staticmethod - def _setup_log_context( # type: ignore[override] - main_broker: NB, - including_broker: NB, - ) -> None: - for h in including_broker.handlers.values(): - main_broker._setup_log_context( - queue=h.queue, - subject=h.subject, - stream=h.stream.name if h.stream else None, - ) diff --git a/faststream/nats/fastapi/fastapi.py b/faststream/nats/fastapi/fastapi.py new file mode 100644 index 0000000000..3ec79a75c9 --- /dev/null +++ b/faststream/nats/fastapi/fastapi.py @@ -0,0 +1,24 @@ +from nats.aio.msg import Msg +from typing_extensions import override + +from faststream.broker.fastapi.router import StreamRouter +from faststream.nats.broker import NatsBroker + + +class NatsRouter(StreamRouter[Msg]): + """A class to represent a NATS router.""" + + broker_class = NatsBroker + + @override + @staticmethod + def _setup_log_context( # type: ignore[override] + main_broker: NatsBroker, + including_broker: NatsBroker, + ) -> None: + for h in including_broker.handlers.values(): + main_broker._setup_log_context( + queue=h.queue, + subject=h.subject, + stream=h.stream.name if h.stream else None, + ) diff --git a/faststream/nats/fastapi.pyi b/faststream/nats/fastapi/fastapi.pyi similarity index 88% rename from faststream/nats/fastapi.pyi rename to faststream/nats/fastapi/fastapi.pyi index c4bb643c3d..27e3c0b2f5 100644 --- a/faststream/nats/fastapi.pyi +++ b/faststream/nats/fastapi/fastapi.pyi @@ -4,9 +4,8 @@ from typing import ( Any, Awaitable, Callable, - Mapping, + Iterable, Sequence, - overload, ) from fast_depends.dependencies import Depends @@ -34,19 +33,20 @@ from nats.aio.msg import Msg from nats.js import api from starlette import routing from starlette.responses import JSONResponse, Response -from starlette.types import ASGIApp, AppType, Lifespan +from starlette.types import ASGIApp, Lifespan from typing_extensions import override from faststream.asyncapi import schema as asyncapi from faststream.broker.core.broker import default_filter from faststream.broker.core.call_wrapper import HandlerCallWrapper from faststream.broker.fastapi.router import StreamRouter -from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( + BrokerMiddleware, CustomDecoder, CustomParser, Filter, P_HandlerParams, + SubscriberMiddleware, T_HandlerReturn, ) from faststream.nats.asyncapi import Publisher @@ -105,7 +105,7 @@ class NatsRouter(StreamRouter[Msg]): graceful_timeout: float | None = None, decoder: CustomDecoder[NatsMessage] | None = None, parser: CustomParser[Msg, NatsMessage] | None = None, - middlewares: Sequence[Callable[[Msg], BaseMiddleware]] | None = None, + middlewares: Iterable[BrokerMiddleware[Msg]] = (), # AsyncAPI args asyncapi_url: str | list[str] | None = None, protocol: str = "nats", @@ -159,7 +159,7 @@ class NatsRouter(StreamRouter[Msg]): dependencies: Sequence[Depends] = (), parser: CustomParser[Msg, NatsMessage] | None = None, decoder: CustomDecoder[NatsMessage] | None = None, - middlewares: Sequence[Callable[[Msg], BaseMiddleware]] | None = None, + middlewares: Iterable[SubscriberMiddleware] = (), filter: Filter[NatsMessage] = default_filter, retry: bool = False, # AsyncAPI information @@ -193,7 +193,7 @@ class NatsRouter(StreamRouter[Msg]): dependencies: Sequence[Depends] = (), parser: CustomParser[Msg, NatsMessage] | None = None, decoder: CustomDecoder[NatsMessage] | None = None, - middlewares: Sequence[Callable[[Msg], BaseMiddleware]] | None = None, + middlewares: Iterable[SubscriberMiddleware] = (), filter: Filter[NatsMessage] = default_filter, retry: bool = False, no_ack: bool = False, @@ -223,23 +223,3 @@ class NatsRouter(StreamRouter[Msg]): schema: Any | None = None, include_in_schema: bool = True, ) -> Publisher: ... - @overload - def after_startup( - self, - func: Callable[[AppType], Mapping[str, Any]], - ) -> Callable[[AppType], Mapping[str, Any]]: ... - @overload - def after_startup( - self, - func: Callable[[AppType], Awaitable[Mapping[str, Any]]], - ) -> Callable[[AppType], Awaitable[Mapping[str, Any]]]: ... - @overload - def after_startup( - self, - func: Callable[[AppType], None], - ) -> Callable[[AppType], None]: ... - @overload - def after_startup( - self, - func: Callable[[AppType], Awaitable[None]], - ) -> Callable[[AppType], Awaitable[None]]: ... diff --git a/faststream/nats/handler.py b/faststream/nats/handler.py index b8f538dd70..93ecb744a7 100644 --- a/faststream/nats/handler.py +++ b/faststream/nats/handler.py @@ -8,6 +8,7 @@ Awaitable, Callable, Dict, + Iterable, Optional, Sequence, Union, @@ -15,50 +16,52 @@ ) import anyio -from anyio.abc import TaskGroup, TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from fast_depends.dependencies import Depends -from nats.aio.client import Client -from nats.aio.msg import Msg -from nats.aio.subscription import Subscription from nats.errors import TimeoutError -from nats.js import JetStreamContext from typing_extensions import Annotated, Doc from faststream.broker.core.handler import BaseHandler from faststream.broker.core.publisher import FakePublisher -from faststream.broker.message import StreamMessage -from faststream.broker.middlewares import BaseMiddleware from faststream.broker.parsers import resolve_custom_func -from faststream.broker.types import ( - CustomDecoder, - CustomParser, - Filter, -) -from faststream.nats.js_stream import JStream -from faststream.nats.message import NatsMessage from faststream.nats.parser import JsParser, Parser -from faststream.nats.pull_sub import PullSub from faststream.types import AnyDict, SendableMessage from faststream.utils.path import compile_path if TYPE_CHECKING: - from faststream.broker.core.handler import WrapperProtocol + from anyio.abc import TaskGroup, TaskStatus + from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + from fast_depends.dependencies import Depends + from nats.aio.client import Client + from nats.aio.msg import Msg + from nats.aio.subscription import Subscription + from nats.js import JetStreamContext - -class LogicNatsHandler(BaseHandler[Msg]): + from faststream.broker.core.handler import WrapperProtocol + from faststream.broker.message import StreamMessage + from faststream.broker.types import ( + BrokerMiddleware, + CustomDecoder, + CustomParser, + Filter, + SubscriberMiddleware, + ) + from faststream.nats.js_stream import JStream + from faststream.nats.message import NatsMessage + from faststream.nats.pull_sub import PullSub + + +class LogicNatsHandler(BaseHandler["Msg"]): """A class to represent a NATS handler.""" subscription: Union[ None, - Subscription, - JetStreamContext.PushSubscription, - JetStreamContext.PullSubscription, + "Subscription", + "JetStreamContext.PushSubscription", + "JetStreamContext.PullSubscription", ] - task_group: Optional[TaskGroup] + task_group: Optional["TaskGroup"] task: Optional["asyncio.Task[Any]"] - send_stream: MemoryObjectSendStream[Msg] - receive_stream: MemoryObjectReceiveStream[Msg] + send_stream: "MemoryObjectSendStream[Msg]" + receive_stream: "MemoryObjectReceiveStream[Msg]" def __init__( self, @@ -67,7 +70,7 @@ def __init__( Doc("NATS subject to subscribe"), ], log_context_builder: Annotated[ - Callable[[StreamMessage[Any]], Dict[str, str]], + Callable[["StreamMessage[Any]"], Dict[str, str]], Doc("Function to create log extra data by message"), ], watcher: Annotated[ @@ -80,11 +83,11 @@ def __init__( Doc("NATS queue name"), ] = "", stream: Annotated[ - Optional[JStream], + Optional["JStream"], Doc("NATS Stream object"), ] = None, pull_sub: Annotated[ - Optional[PullSub], + Optional["PullSub"], Doc("NATS Pull consumer parameters container"), ] = None, extra_options: Annotated[ @@ -103,7 +106,7 @@ def __init__( Doc("Process up to this parameter messages concurrently"), ] = 1, middlewares: Annotated[ - Sequence[Callable[[Msg], BaseMiddleware]], + Iterable["BrokerMiddleware[Msg]"], Doc("Global middleware to use `on_receive`, `after_processed`"), ] = (), # AsyncAPI information @@ -158,11 +161,11 @@ def __init__( def add_call( self, *, - parser: Optional[CustomParser[Msg, NatsMessage]], - decoder: Optional[CustomDecoder[NatsMessage]], - filter: Filter[NatsMessage], - middlewares: Sequence[Callable[[Msg], BaseMiddleware]], - dependencies: Sequence[Depends], + parser: Optional["CustomParser[Msg, NatsMessage]"], + decoder: Optional["CustomDecoder[NatsMessage]"], + filter: "Filter[NatsMessage]", + middlewares: Iterable["SubscriberMiddleware"], + dependencies: Sequence["Depends"], **wrap_kwargs: Any, ) -> "WrapperProtocol[Msg]": parser_ = Parser if self.stream is None else JsParser @@ -175,25 +178,9 @@ def add_call( **wrap_kwargs, ) - # def _process_message( - # self, - # func: Callable[[NatsMessage], Awaitable[T_HandlerReturn]], - # watcher: Callable[..., AsyncContextManager[None]], - # ) -> Callable[ - # [NatsMessage], - # Awaitable[WrappedReturn[T_HandlerReturn]], - # ]: - # @wraps(func) - # async def process_wrapper( - # message: NatsMessage, - # ) -> WrappedReturn[T_HandlerReturn]: - # async with watcher(message): - # r = await func(message) - # return r, None - - # return process_wrapper - - def make_response_publisher(self, message: NatsMessage) -> Sequence[FakePublisher]: + def make_response_publisher( + self, message: "NatsMessage" + ) -> Sequence[FakePublisher]: if message.reply_to: return ( FakePublisher( @@ -209,12 +196,12 @@ def make_response_publisher(self, message: NatsMessage) -> Sequence[FakePublishe async def start( self, connection: Annotated[ - Union[Client, JetStreamContext], + Union["Client", "JetStreamContext"], Doc("NATS client or JS Context object using to create subscription"), ], ) -> None: """Create NATS subscription and start consume task.""" - cb: Callable[[Msg], Awaitable[SendableMessage]] + cb: Callable[["Msg"], Awaitable[SendableMessage]] if self.max_workers > 1: self.task = asyncio.create_task(self._serve_consume_queue()) cb = self.__put_msg @@ -222,7 +209,7 @@ async def start( cb = self.consume if self.pull_sub is not None: - connection = cast(JetStreamContext, connection) + connection = cast("JetStreamContext", connection) if self.stream is None: raise ValueError("Pull subscriber can be used only with a stream") @@ -257,14 +244,14 @@ async def close(self) -> None: async def _consume_pull( self, - cb: Callable[[Msg], Awaitable[SendableMessage]], + cb: Callable[["Msg"], Awaitable[SendableMessage]], *, - task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + task_status: "TaskStatus[None]" = anyio.TASK_STATUS_IGNORED, ) -> None: """Endless task consuming messages using NATS Pull subscriber.""" assert self.pull_sub # nosec B101 - sub = cast(JetStreamContext.PullSubscription, self.subscription) + sub = cast("JetStreamContext.PullSubscription", self.subscription) task_status.started() @@ -287,7 +274,7 @@ async def _consume_pull( async def _serve_consume_queue( self, *, - task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + task_status: "TaskStatus[None]" = anyio.TASK_STATUS_IGNORED, ) -> None: """Endless task consuming messages from in-memory queue. @@ -301,13 +288,13 @@ async def _serve_consume_queue( async def __consume_msg( self, - msg: Msg, + msg: "Msg", ) -> None: """Proxy method to call `self.consume` with semaphore block.""" async with self.limiter: await self.consume(msg) - async def __put_msg(self, msg: Msg) -> None: + async def __put_msg(self, msg: "Msg") -> None: """Proxy method to put msg into in-memory queue with semaphore block.""" async with self.limiter: await self.send_stream.send(msg) diff --git a/faststream/nats/router.pyi b/faststream/nats/router.pyi index 0e8f3e2b56..4c89b56d03 100644 --- a/faststream/nats/router.pyi +++ b/faststream/nats/router.pyi @@ -1,4 +1,4 @@ -from typing import Any, Callable, Sequence +from typing import Any, Callable, Iterable, Sequence from fast_depends.dependencies import Depends from nats.aio.msg import Msg @@ -13,6 +13,7 @@ from faststream.broker.types import ( CustomParser, Filter, P_HandlerParams, + SubscriberMiddleware, T_HandlerReturn, ) from faststream.nats.asyncapi import Publisher @@ -83,7 +84,7 @@ class NatsRouter(BaseRouter): dependencies: Sequence[Depends] = (), parser: CustomParser[Msg, NatsMessage] | None = None, decoder: CustomDecoder[NatsMessage] | None = None, - middlewares: Sequence[Callable[[Msg], BaseMiddleware]] | None = None, + middlewares: Iterable[SubscriberMiddleware] = (), filter: Filter[NatsMessage] = default_filter, retry: bool = False, no_ack: bool = False, From fd486438b576781e94a952eacae14afd4b9463b7 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 14 Jan 2024 14:53:16 +0300 Subject: [PATCH 19/87] refactor: nats new structure --- faststream/nats/__init__.py | 6 +- faststream/nats/broker.py | 3 +- faststream/nats/broker.pyi | 8 +-- faststream/nats/handler.py | 3 +- faststream/nats/helpers.py | 2 +- faststream/nats/publisher.py | 2 +- faststream/nats/router.py | 46 +++++++++++-- faststream/nats/router.pyi | 50 ++++++++++++-- faststream/nats/schemas/__init__.py | 7 ++ faststream/nats/{ => schemas}/js_stream.py | 0 faststream/nats/{ => schemas}/js_stream.pyi | 0 faststream/nats/{ => schemas}/pull_sub.py | 0 faststream/nats/shared/router.py | 52 -------------- faststream/nats/shared/router.pyi | 76 --------------------- 14 files changed, 101 insertions(+), 154 deletions(-) create mode 100644 faststream/nats/schemas/__init__.py rename faststream/nats/{ => schemas}/js_stream.py (100%) rename faststream/nats/{ => schemas}/js_stream.pyi (100%) rename faststream/nats/{ => schemas}/pull_sub.py (100%) delete mode 100644 faststream/nats/shared/router.py delete mode 100644 faststream/nats/shared/router.pyi diff --git a/faststream/nats/__init__.py b/faststream/nats/__init__.py index 9763b2226f..fb51594fb0 100644 --- a/faststream/nats/__init__.py +++ b/faststream/nats/__init__.py @@ -16,10 +16,8 @@ from faststream.broker.test import TestApp from faststream.nats.annotations import NatsMessage from faststream.nats.broker import NatsBroker -from faststream.nats.js_stream import JStream -from faststream.nats.pull_sub import PullSub -from faststream.nats.router import NatsRouter -from faststream.nats.shared.router import NatsRoute +from faststream.nats.router import NatsRoute, NatsRouter +from faststream.nats.schemas import JStream, PullSub from faststream.nats.test import TestNatsBroker __all__ = ( diff --git a/faststream/nats/broker.py b/faststream/nats/broker.py index 46ecec1168..171d3c78a6 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker.py @@ -42,10 +42,9 @@ from faststream.exceptions import NOT_CONNECTED_YET from faststream.nats.asyncapi import Handler, Publisher from faststream.nats.helpers import stream_builder -from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage from faststream.nats.producer import NatsFastProducer, NatsJSFastProducer -from faststream.nats.pull_sub import PullSub +from faststream.nats.schemas import JStream, PullSub from faststream.nats.security import parse_security from faststream.nats.shared.logging import NatsLoggingMixin from faststream.security import BaseSecurity diff --git a/faststream/nats/broker.pyi b/faststream/nats/broker.pyi index cd48b7d6d0..5fe1b9ce43 100644 --- a/faststream/nats/broker.pyi +++ b/faststream/nats/broker.pyi @@ -3,7 +3,6 @@ import ssl from types import TracebackType from typing import ( Any, - Callable, Iterable, Sequence, ) @@ -34,8 +33,8 @@ from typing_extensions import override from faststream.asyncapi import schema as asyncapi from faststream.broker.core.broker import BrokerUsecase, default_filter from faststream.broker.core.handler import WrapperProtocol -from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( + BrokerMiddleware, CustomDecoder, CustomParser, Filter, @@ -43,10 +42,9 @@ from faststream.broker.types import ( ) from faststream.log import access_logger from faststream.nats.asyncapi import Handler, Publisher -from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage from faststream.nats.producer import NatsFastProducer, NatsJSFastProducer -from faststream.nats.pull_sub import PullSub +from faststream.nats.schemas import JStream, PullSub from faststream.nats.shared.logging import NatsLoggingMixin from faststream.types import DecodedMessage, SendableMessage @@ -104,7 +102,7 @@ class NatsBroker( dependencies: Sequence[Depends] = (), decoder: CustomDecoder[NatsMessage] | None = None, parser: CustomParser[Msg, NatsMessage] | None = None, - middlewares: Sequence[Callable[[Msg], BaseMiddleware]] | None = None, + middlewares: Iterable[BrokerMiddleware[Msg]] = (), # AsyncAPI args asyncapi_url: str | list[str] | None = None, protocol: str = "nats", diff --git a/faststream/nats/handler.py b/faststream/nats/handler.py index 93ecb744a7..197689280d 100644 --- a/faststream/nats/handler.py +++ b/faststream/nats/handler.py @@ -44,9 +44,8 @@ Filter, SubscriberMiddleware, ) - from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage - from faststream.nats.pull_sub import PullSub + from faststream.nats.schemas import JStream, PullSub class LogicNatsHandler(BaseHandler["Msg"]): diff --git a/faststream/nats/helpers.py b/faststream/nats/helpers.py index b1ed65c2c4..9c6f578b14 100644 --- a/faststream/nats/helpers.py +++ b/faststream/nats/helpers.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Optional, Union -from faststream.nats.js_stream import JStream +from faststream.nats.schemas.js_stream import JStream class StreamBuilder: diff --git a/faststream/nats/publisher.py b/faststream/nats/publisher.py index 20d880c821..6c66a18fa2 100644 --- a/faststream/nats/publisher.py +++ b/faststream/nats/publisher.py @@ -6,8 +6,8 @@ from faststream.broker.core.publisher import BasePublisher from faststream.exceptions import NOT_CONNECTED_YET -from faststream.nats.js_stream import JStream from faststream.nats.producer import NatsFastProducer, NatsJSFastProducer +from faststream.nats.schemas import JStream from faststream.types import AnyDict, DecodedMessage, SendableMessage diff --git a/faststream/nats/router.py b/faststream/nats/router.py index 59e280a6a4..f76cc0e92b 100644 --- a/faststream/nats/router.py +++ b/faststream/nats/router.py @@ -1,15 +1,53 @@ -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional, Sequence +from nats.aio.msg import Msg from typing_extensions import override +from faststream.broker.core.call_wrapper import HandlerCallWrapper +from faststream.broker.router import BrokerRoute as NatsRoute +from faststream.broker.router import BrokerRouter +from faststream.broker.types import P_HandlerParams, T_HandlerReturn from faststream.nats.asyncapi import Publisher -from faststream.nats.shared.router import NatsRouter as BaseRouter +from faststream.types import SendableMessage -class NatsRouter(BaseRouter): +class NatsRouter(BrokerRouter[str, Msg]): """A class to represent a NATS router.""" - _publishers: Dict[str, Publisher] # type: ignore[assignment] + _publishers: Dict[str, Publisher] + + def __init__( + self, + prefix: str = "", + handlers: Sequence[NatsRoute[Msg, SendableMessage]] = (), + **kwargs: Any, + ) -> None: + """Initialize the NATS router. + + Args: + prefix: The prefix. + handlers: The handlers. + **kwargs: The keyword arguments. + """ + for h in handlers: + if not (subj := h.kwargs.pop("subject", None)): + subj, h.args = h.args[0], h.args[1:] + h.args = (prefix + subj, *h.args) + super().__init__(prefix, handlers, **kwargs) + + @override + def subscriber( # type: ignore[override] + self, + subject: str, + **broker_kwargs: Any, + ) -> Callable[ + [Callable[P_HandlerParams, T_HandlerReturn]], + HandlerCallWrapper[Msg, P_HandlerParams, T_HandlerReturn], + ]: + return self._wrap_subscriber( + self.prefix + subject, + **broker_kwargs, + ) @override @staticmethod diff --git a/faststream/nats/router.pyi b/faststream/nats/router.pyi index 4c89b56d03..749dc03f08 100644 --- a/faststream/nats/router.pyi +++ b/faststream/nats/router.pyi @@ -7,8 +7,9 @@ from typing_extensions import override from faststream.broker.core.broker import default_filter from faststream.broker.core.call_wrapper import HandlerCallWrapper -from faststream.broker.middlewares import BaseMiddleware +from faststream.broker.router import BrokerRouter from faststream.broker.types import ( + BrokerMiddleware, CustomDecoder, CustomParser, Filter, @@ -17,13 +18,48 @@ from faststream.broker.types import ( T_HandlerReturn, ) from faststream.nats.asyncapi import Publisher -from faststream.nats.js_stream import JStream from faststream.nats.message import NatsMessage -from faststream.nats.pull_sub import PullSub -from faststream.nats.shared.router import NatsRoute -from faststream.nats.shared.router import NatsRouter as BaseRouter +from faststream.nats.schemas import JStream, PullSub -class NatsRouter(BaseRouter): +class NatsRoute: + """Delayed `NatsBroker.subscriber()` registration object.""" + + def __init__( + self, + call: Callable[..., T_HandlerReturn], + subject: str, + queue: str = "", + pending_msgs_limit: int | None = None, + pending_bytes_limit: int | None = None, + # Core arguments + max_msgs: int = 0, + ack_first: bool = False, + # JS arguments + stream: str | JStream | None = None, + durable: str | None = None, + config: api.ConsumerConfig | None = None, + ordered_consumer: bool = False, + idle_heartbeat: float | None = None, + flow_control: bool = False, + deliver_policy: api.DeliverPolicy | None = None, + headers_only: bool | None = None, + # broker arguments + dependencies: Sequence[Depends] = (), + parser: CustomParser[Msg, NatsMessage] | None = None, + decoder: CustomDecoder[NatsMessage] | None = None, + middlewares: Iterable[BrokerMiddleware[Msg]] = (), + filter: Filter[NatsMessage] = default_filter, + retry: bool = False, + no_ack: bool = False, + max_workers: int = 1, + # AsyncAPI information + title: str | None = None, + description: str | None = None, + include_in_schema: bool = True, + **__service_kwargs: Any, + ) -> None: ... + +class NatsRouter(BrokerRouter[str, Msg]): _publishers: dict[str, Publisher] # type: ignore[assignment] def __init__( @@ -32,7 +68,7 @@ class NatsRouter(BaseRouter): handlers: Sequence[NatsRoute] = (), *, dependencies: Sequence[Depends] = (), - middlewares: Sequence[Callable[[Msg], BaseMiddleware]] | None = None, + middlewares: Iterable[BrokerMiddleware[Msg]] = (), parser: CustomParser[Msg, NatsMessage] | None = None, decoder: CustomDecoder[NatsMessage] | None = None, include_in_schema: bool = True, diff --git a/faststream/nats/schemas/__init__.py b/faststream/nats/schemas/__init__.py new file mode 100644 index 0000000000..24ca18db99 --- /dev/null +++ b/faststream/nats/schemas/__init__.py @@ -0,0 +1,7 @@ +from faststream.nats.schemas.js_stream import JStream +from faststream.nats.schemas.pull_sub import PullSub + +__all__ = ( + "JStream", + "PullSub", +) diff --git a/faststream/nats/js_stream.py b/faststream/nats/schemas/js_stream.py similarity index 100% rename from faststream/nats/js_stream.py rename to faststream/nats/schemas/js_stream.py diff --git a/faststream/nats/js_stream.pyi b/faststream/nats/schemas/js_stream.pyi similarity index 100% rename from faststream/nats/js_stream.pyi rename to faststream/nats/schemas/js_stream.pyi diff --git a/faststream/nats/pull_sub.py b/faststream/nats/schemas/pull_sub.py similarity index 100% rename from faststream/nats/pull_sub.py rename to faststream/nats/schemas/pull_sub.py diff --git a/faststream/nats/shared/router.py b/faststream/nats/shared/router.py deleted file mode 100644 index 812f3e5f73..0000000000 --- a/faststream/nats/shared/router.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Any, Callable, Sequence - -from nats.aio.msg import Msg -from typing_extensions import override - -from faststream.broker.core.call_wrapper import HandlerCallWrapper -from faststream.broker.router import BrokerRoute as NatsRoute -from faststream.broker.router import BrokerRouter -from faststream.broker.types import P_HandlerParams, T_HandlerReturn -from faststream.types import SendableMessage - -__all__ = ( - "NatsRouter", - "NatsRoute", -) - - -class NatsRouter(BrokerRouter[str, Msg]): - """A class to represent a NATS router.""" - - def __init__( - self, - prefix: str = "", - handlers: Sequence[NatsRoute[Msg, SendableMessage]] = (), - **kwargs: Any, - ) -> None: - """Initialize the NATS router. - - Args: - prefix: The prefix. - handlers: The handlers. - **kwargs: The keyword arguments. - """ - for h in handlers: - if not (subj := h.kwargs.pop("subject", None)): - subj, h.args = h.args[0], h.args[1:] - h.args = (prefix + subj, *h.args) - super().__init__(prefix, handlers, **kwargs) - - @override - def subscriber( # type: ignore[override] - self, - subject: str, - **broker_kwargs: Any, - ) -> Callable[ - [Callable[P_HandlerParams, T_HandlerReturn]], - HandlerCallWrapper[Msg, P_HandlerParams, T_HandlerReturn], - ]: - return self._wrap_subscriber( - self.prefix + subject, - **broker_kwargs, - ) diff --git a/faststream/nats/shared/router.pyi b/faststream/nats/shared/router.pyi deleted file mode 100644 index 634d0ac826..0000000000 --- a/faststream/nats/shared/router.pyi +++ /dev/null @@ -1,76 +0,0 @@ -from abc import ABCMeta -from typing import Any, Callable, Sequence - -from fast_depends.dependencies import Depends -from nats.aio.msg import Msg -from nats.js import api -from typing_extensions import override - -from faststream.broker.core.broker import default_filter -from faststream.broker.core.call_wrapper import HandlerCallWrapper -from faststream.broker.middlewares import BaseMiddleware -from faststream.broker.router import BrokerRouter -from faststream.broker.types import ( - CustomDecoder, - CustomParser, - Filter, - P_HandlerParams, - T_HandlerReturn, -) -from faststream.nats.js_stream import JStream -from faststream.nats.message import NatsMessage - -class NatsRoute: - """Delayed `NatsBroker.subscriber()` registration object.""" - - def __init__( - self, - call: Callable[..., T_HandlerReturn], - subject: str, - queue: str = "", - pending_msgs_limit: int | None = None, - pending_bytes_limit: int | None = None, - # Core arguments - max_msgs: int = 0, - ack_first: bool = False, - # JS arguments - stream: str | JStream | None = None, - durable: str | None = None, - config: api.ConsumerConfig | None = None, - ordered_consumer: bool = False, - idle_heartbeat: float | None = None, - flow_control: bool = False, - deliver_policy: api.DeliverPolicy | None = None, - headers_only: bool | None = None, - # broker arguments - dependencies: Sequence[Depends] = (), - parser: CustomParser[Msg, NatsMessage] | None = None, - decoder: CustomDecoder[NatsMessage] | None = None, - middlewares: Sequence[Callable[[Msg], BaseMiddleware]] | None = None, - filter: Filter[NatsMessage] = default_filter, - retry: bool = False, - no_ack: bool = False, - max_workers: int = 1, - # AsyncAPI information - title: str | None = None, - description: str | None = None, - include_in_schema: bool = True, - **__service_kwargs: Any, - ) -> None: ... - -class NatsRouter(BrokerRouter[str, Msg], metaclass=ABCMeta): - def __init__( - self, - prefix: str = "", - handlers: Sequence[NatsRoute] = (), - **kwargs: Any, - ) -> None: ... - @override - def subscriber( # type: ignore[override] - self, - subject: str, - **broker_kwargs: Any, - ) -> Callable[ - [Callable[P_HandlerParams, T_HandlerReturn]], - HandlerCallWrapper[Msg, P_HandlerParams, T_HandlerReturn], - ]: ... From 3984422da2d7070b44a1550cc840ae844c38a60f Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 14 Jan 2024 14:57:35 +0300 Subject: [PATCH 20/87] refactor: nats new structure --- faststream/nats/broker/__init__.py | 3 +++ faststream/nats/{ => broker}/broker.py | 2 +- faststream/nats/{ => broker}/broker.pyi | 2 +- faststream/nats/{shared => broker}/logging.py | 0 faststream/nats/router/__init__.py | 6 ++++++ faststream/nats/{ => router}/router.py | 0 faststream/nats/{ => router}/router.pyi | 0 faststream/nats/shared/__init__.py | 0 8 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 faststream/nats/broker/__init__.py rename faststream/nats/{ => broker}/broker.py (99%) rename faststream/nats/{ => broker}/broker.pyi (99%) rename faststream/nats/{shared => broker}/logging.py (100%) create mode 100644 faststream/nats/router/__init__.py rename faststream/nats/{ => router}/router.py (100%) rename faststream/nats/{ => router}/router.pyi (100%) delete mode 100644 faststream/nats/shared/__init__.py diff --git a/faststream/nats/broker/__init__.py b/faststream/nats/broker/__init__.py new file mode 100644 index 0000000000..68408b4233 --- /dev/null +++ b/faststream/nats/broker/__init__.py @@ -0,0 +1,3 @@ +from faststream.nats.broker.broker import NatsBroker + +__all__ = ("NatsBroker",) diff --git a/faststream/nats/broker.py b/faststream/nats/broker/broker.py similarity index 99% rename from faststream/nats/broker.py rename to faststream/nats/broker/broker.py index 171d3c78a6..1065a0c4b3 100644 --- a/faststream/nats/broker.py +++ b/faststream/nats/broker/broker.py @@ -41,12 +41,12 @@ from faststream.broker.utils import get_watcher_context from faststream.exceptions import NOT_CONNECTED_YET from faststream.nats.asyncapi import Handler, Publisher +from faststream.nats.broker.logging import NatsLoggingMixin from faststream.nats.helpers import stream_builder from faststream.nats.message import NatsMessage from faststream.nats.producer import NatsFastProducer, NatsJSFastProducer from faststream.nats.schemas import JStream, PullSub from faststream.nats.security import parse_security -from faststream.nats.shared.logging import NatsLoggingMixin from faststream.security import BaseSecurity from faststream.types import AnyDict, DecodedMessage, SendableMessage from faststream.utils.context.repository import context diff --git a/faststream/nats/broker.pyi b/faststream/nats/broker/broker.pyi similarity index 99% rename from faststream/nats/broker.pyi rename to faststream/nats/broker/broker.pyi index 5fe1b9ce43..5a4d80a740 100644 --- a/faststream/nats/broker.pyi +++ b/faststream/nats/broker/broker.pyi @@ -42,10 +42,10 @@ from faststream.broker.types import ( ) from faststream.log import access_logger from faststream.nats.asyncapi import Handler, Publisher +from faststream.nats.broker.logging import NatsLoggingMixin from faststream.nats.message import NatsMessage from faststream.nats.producer import NatsFastProducer, NatsJSFastProducer from faststream.nats.schemas import JStream, PullSub -from faststream.nats.shared.logging import NatsLoggingMixin from faststream.types import DecodedMessage, SendableMessage Subject = str diff --git a/faststream/nats/shared/logging.py b/faststream/nats/broker/logging.py similarity index 100% rename from faststream/nats/shared/logging.py rename to faststream/nats/broker/logging.py diff --git a/faststream/nats/router/__init__.py b/faststream/nats/router/__init__.py new file mode 100644 index 0000000000..b59722cf5c --- /dev/null +++ b/faststream/nats/router/__init__.py @@ -0,0 +1,6 @@ +from faststream.nats.router.router import NatsRoute, NatsRouter + +__all__ = ( + "NatsRoute", + "NatsRouter", +) diff --git a/faststream/nats/router.py b/faststream/nats/router/router.py similarity index 100% rename from faststream/nats/router.py rename to faststream/nats/router/router.py diff --git a/faststream/nats/router.pyi b/faststream/nats/router/router.pyi similarity index 100% rename from faststream/nats/router.pyi rename to faststream/nats/router/router.pyi diff --git a/faststream/nats/shared/__init__.py b/faststream/nats/shared/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 From ca7d7d7fca5f4c297641d4f7327f97b386bdb9ad Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 14 Jan 2024 17:03:58 +0300 Subject: [PATCH 21/87] feat: publisher middlewares --- faststream/broker/core/handler.py | 36 +++++++-------- faststream/broker/core/publisher.py | 70 +++++++++++++++++++++++------ faststream/broker/middlewares.py | 21 +++++---- faststream/broker/types.py | 6 ++- faststream/nats/broker/broker.py | 15 ++++--- faststream/nats/handler.py | 29 +++++++----- faststream/nats/publisher.py | 2 +- 7 files changed, 119 insertions(+), 60 deletions(-) diff --git a/faststream/broker/core/handler.py b/faststream/broker/core/handler.py index 15e36c9264..cb1acb24bf 100644 --- a/faststream/broker/core/handler.py +++ b/faststream/broker/core/handler.py @@ -305,27 +305,27 @@ async def close_middlewares( message, (m.consume_scope for m in middlewares) ) - async with AsyncExitStack() as pub_stack: - result_msg = result_msg - - for m_pub in middlewares: - result_msg = await pub_stack.enter_async_context( - m_pub.publish_scope(result_msg) - ) - - # TODO: suppress all publishing errors and raise them after all publishers will be tried - for publisher in ( - *self.make_response_publisher(message), - *h.handler._publishers, - ): - # add publishers middlewares - await publisher.publish( - message=result_msg, - correlation_id=message.correlation_id, - ) + if publishers := ( + *self.make_response_publisher(message), + *h.handler._publishers, + ): + async with AsyncExitStack() as pub_stack: + for m_pub in middlewares: + result_msg = await pub_stack.enter_async_context( + m_pub.publish_scope(result_msg) + ) + + # TODO: suppress all publishing errors and raise them after all publishers will be tried + for p in publishers: + await p.publish( + message=result_msg, + correlation_id=message.correlation_id, + ) return result_msg + raise AssertionError(f"Where is not suitable handler for {msg=}") + def make_response_publisher( self, message: "StreamMessage[MsgType]" ) -> Sequence[PublisherProtocol]: diff --git a/faststream/broker/core/publisher.py b/faststream/broker/core/publisher.py index e59ce8358b..b8978dce0f 100644 --- a/faststream/broker/core/publisher.py +++ b/faststream/broker/core/publisher.py @@ -1,7 +1,18 @@ from abc import abstractmethod +from contextlib import AsyncExitStack from dataclasses import dataclass, field from inspect import unwrap -from typing import Any, Awaitable, Callable, Generic, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Iterable, + List, + Optional, + Tuple, +) from unittest.mock import MagicMock from fast_depends._compat import create_model, get_config_base @@ -14,6 +25,9 @@ from faststream.broker.types import MsgType, P_HandlerParams, T_HandlerReturn from faststream.types import AnyDict, SendableMessage +if TYPE_CHECKING: + from faststream.broker.types import PublisherMiddleware + class FakePublisher: """A class to represent a fake publisher. @@ -25,13 +39,18 @@ class FakePublisher: publish : asynchronously publishes a message with optional correlation ID and additional keyword arguments """ - def __init__(self, method: Callable[..., Awaitable[SendableMessage]]) -> None: + def __init__( + self, + method: Callable[..., Awaitable[SendableMessage]], + middlewares: Iterable["PublisherMiddleware"] = (), + ) -> None: """Initialize an object. Args: method: A callable that takes any number of arguments and returns an awaitable sendable message. """ self.method = method + self.middlewares = middlewares async def publish( self, @@ -51,12 +70,16 @@ async def publish( Returns: The published message. """ - return await self.method( - message, - *args, - correlation_id=correlation_id, - **kwargs, - ) + async with AsyncExitStack() as stack: + for m in self.middlewares: + message = await stack.enter_async_context(m(message)) + + return await self.method( + message, + *args, + correlation_id=correlation_id, + **kwargs, + ) @dataclass @@ -67,17 +90,13 @@ class BasePublisher(AsyncAPIOperation, Generic[MsgType]): title : optional title of the publisher _description : optional description of the publisher _fake_handler : boolean indicating if a fake handler is used - calls : list of callable objects + calls : list of callable objects to generate AsyncAPI mock : MagicMock object for mocking purposes Methods: description() : returns the description of the publisher __call__(func) : decorator to register a function as a handler for the publisher publish(message, correlation_id, **kwargs) : publishes a message with optional correlation ID - - Raises: - NotImplementedError: if the publish method is not implemented. - """ title: Optional[str] = field(default=None) @@ -87,6 +106,10 @@ class BasePublisher(AsyncAPIOperation, Generic[MsgType]): calls: List[Callable[..., Any]] = field( init=False, default_factory=list, repr=False ) + middlewares: Iterable["PublisherMiddleware"] = field( + default_factory=tuple, repr=False + ) + _fake_handler: bool = field(default=False, repr=False) mock: Optional[MagicMock] = field(init=False, default=None, repr=False) @@ -135,7 +158,26 @@ async def publish( *args: Any, correlation_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[SendableMessage]: + ) -> Any: + async with AsyncExitStack() as stack: + for m in self.middlewares: + message = await stack.enter_async_context(m(message)) + + return await self._publish( + message, + *args, + correlation_id=correlation_id, + **kwargs, + ) + + @abstractmethod + async def _publish( + self, + message: SendableMessage, + *args: Any, + correlation_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: """Publish a message. Args: diff --git a/faststream/broker/middlewares.py b/faststream/broker/middlewares.py index 65174c8177..835bdbe148 100644 --- a/faststream/broker/middlewares.py +++ b/faststream/broker/middlewares.py @@ -44,7 +44,7 @@ class BaseMiddleware: Asynchronous function to handle the after publish event. """ - def __init__(self) -> None: + def __init__(self, msg: Optional[Any] = None) -> None: """Initialize the class.""" pass @@ -91,7 +91,9 @@ async def __aexit__( """ return await self.after_processed(exc_type, exc_val, exec_tb) - async def on_consume(self, msg: DecodedMessage) -> DecodedMessage: + async def on_consume( + self, msg: Optional[DecodedMessage] + ) -> Optional[DecodedMessage]: """Asynchronously consumes a message. Args: @@ -147,7 +149,7 @@ async def consume_scope( err = None await self.after_consume(err) - async def on_publish(self, msg: SendableMessage) -> SendableMessage: + async def on_publish(self, msg: Any) -> SendableMessage: """Asynchronously handle a publish event. Args: @@ -174,9 +176,7 @@ async def after_publish(self, err: Optional[Exception]) -> None: raise err @asynccontextmanager - async def publish_scope( - self, msg: SendableMessage - ) -> AsyncIterator[SendableMessage]: + async def publish_scope(self, msg: Any) -> AsyncIterator[SendableMessage]: """Publish a message and return an async iterator. Args: @@ -230,18 +230,17 @@ def __init__( self.logger = logger self.log_level = log_level - def __call__(self, msg: Any) -> Self: + def __call__(self, *args: Any) -> Self: """Call the object with a message. - Args: - msg: Any message to be passed to the object. - Returns: The object itself. """ return self - async def on_consume(self, msg: DecodedMessage) -> DecodedMessage: + async def on_consume( + self, msg: Optional[DecodedMessage] + ) -> Optional[DecodedMessage]: if self.logger is not None: c = context.get_local("log_context") self.logger.log(self.log_level, "Received", extra=c) diff --git a/faststream/broker/types.py b/faststream/broker/types.py index 98154b8260..731304a752 100644 --- a/faststream/broker/types.py +++ b/faststream/broker/types.py @@ -103,7 +103,7 @@ async def publish( **kwargs: Additional keyword arguments. Returns: - The published message, or None if the message was not published. + The response message or None. """ ... @@ -127,3 +127,7 @@ async def publish( [Optional[DecodedMessage]], AsyncContextManager[DecodedMessage], ] +PublisherMiddleware: TypeAlias = Callable[ + [Any], + AsyncContextManager[SendableMessage], +] diff --git a/faststream/nats/broker/broker.py b/faststream/nats/broker/broker.py index 1065a0c4b3..a69ad9e197 100644 --- a/faststream/nats/broker/broker.py +++ b/faststream/nats/broker/broker.py @@ -36,6 +36,7 @@ CustomDecoder, CustomParser, Filter, + PublisherMiddleware, SubscriberMiddleware, ) from faststream.broker.utils import get_watcher_context @@ -181,6 +182,7 @@ async def start(self) -> None: await super().start() assert self._connection # nosec B101 assert self.stream, "Broker should be started already" # nosec B101 + assert self._producer, "Broker should be started already" # nosec B101 for handler in self.handlers.values(): stream = handler.stream @@ -222,7 +224,9 @@ async def start(self) -> None: stream=stream.name if stream else "", ) self._log(f"`{handler.call_name}` waiting for messages", extra=c) - await handler.start(self.stream if is_js else self._connection) + await handler.start( + self.stream if is_js else self._connection, self._producer + ) def _log_connection_broken( self, @@ -362,7 +366,6 @@ def subscriber( # type: ignore[override] pull_sub=pull_sub, extra_options=extra_options, max_workers=max_workers, - producer=self, # base options title=title, description=description, @@ -404,6 +407,8 @@ def publisher( # type: ignore[override] # JS stream: Union[str, JStream, None] = None, timeout: Optional[float] = None, + # specific + middlewares: Iterable[PublisherMiddleware] = (), # AsyncAPI information title: Optional[str] = None, description: Optional[str] = None, @@ -423,6 +428,8 @@ def publisher( # type: ignore[override] # JS timeout=timeout, stream=stream, + # Specific + middlewares=middlewares, # AsyncAPI title=title, _description=description, @@ -452,9 +459,7 @@ async def publish( # type: ignore[override] async with AsyncExitStack() as stack: for m in self.middlewares: - message = await stack.enter_async_context( - m(None).publish_scope(message) - ) + message = await stack.enter_async_context(m().publish_scope(message)) return await publisher.publish(message, *args, **kwargs) diff --git a/faststream/nats/handler.py b/faststream/nats/handler.py index 197689280d..4e087c2ff0 100644 --- a/faststream/nats/handler.py +++ b/faststream/nats/handler.py @@ -9,6 +9,7 @@ Callable, Dict, Iterable, + List, Optional, Sequence, Union, @@ -42,6 +43,7 @@ CustomDecoder, CustomParser, Filter, + PublisherProtocol, SubscriberMiddleware, ) from faststream.nats.message import NatsMessage @@ -58,9 +60,10 @@ class LogicNatsHandler(BaseHandler["Msg"]): "JetStreamContext.PullSubscription", ] task_group: Optional["TaskGroup"] - task: Optional["asyncio.Task[Any]"] + tasks: List["asyncio.Task[Any]"] send_stream: "MemoryObjectSendStream[Msg]" receive_stream: "MemoryObjectReceiveStream[Msg]" + producer: Optional["PublisherProtocol"] def __init__( self, @@ -76,7 +79,6 @@ def __init__( Callable[..., AsyncContextManager[None]], Doc("Watcher to ack message"), ], - producer, queue: Annotated[ str, Doc("NATS queue name"), @@ -130,9 +132,7 @@ def __init__( ) self.subject = path self.path_regex = reg - self.queue = queue - self.producer = producer self.stream = stream self.pull_sub = pull_sub @@ -150,12 +150,13 @@ def __init__( self.max_workers = max_workers self.subscription = None + self.producer = None self.send_stream, self.receive_stream = anyio.create_memory_object_stream( max_buffer_size=max_workers ) self.limiter = anyio.Semaphore(max_workers) - self.task = None + self.tasks = [] def add_call( self, @@ -180,6 +181,8 @@ def add_call( def make_response_publisher( self, message: "NatsMessage" ) -> Sequence[FakePublisher]: + assert self.producer, "You should setup producer first" + if message.reply_to: return ( FakePublisher( @@ -198,11 +201,17 @@ async def start( Union["Client", "JetStreamContext"], Doc("NATS client or JS Context object using to create subscription"), ], + producer: Annotated[ + Optional["PublisherProtocol"], + Doc("Publisher to response RPC"), + ], ) -> None: """Create NATS subscription and start consume task.""" + self.producer = producer + cb: Callable[["Msg"], Awaitable[SendableMessage]] if self.max_workers > 1: - self.task = asyncio.create_task(self._serve_consume_queue()) + self.tasks.append(asyncio.create_task(self._serve_consume_queue())) cb = self.__put_msg else: cb = self.consume @@ -217,7 +226,7 @@ async def start( subject=self.subject, **self.extra_options, ) - self.task = asyncio.create_task(self._consume_pull(cb)) + self.tasks.append(asyncio.create_task(self._consume_pull(cb))) else: self.subscription = await connection.subscribe( @@ -237,9 +246,9 @@ async def close(self) -> None: await self.subscription.unsubscribe() self.subscription = None - if self.task is not None: - self.task.cancel() - self.task = None + for t in self.tasks: + t.cancel() + self.tasks = [] async def _consume_pull( self, diff --git a/faststream/nats/publisher.py b/faststream/nats/publisher.py index 6c66a18fa2..4b4292ec34 100644 --- a/faststream/nats/publisher.py +++ b/faststream/nats/publisher.py @@ -26,7 +26,7 @@ class LogicPublisher(BasePublisher[Msg]): ) @override - async def publish( # type: ignore[override] + async def _publish( # type: ignore[override] self, message: SendableMessage = "", reply_to: str = "", From d072a00b6231185b5d56cbc3446d3b4c879b1545 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 14 Jan 2024 17:06:30 +0300 Subject: [PATCH 22/87] lint: publisher middlewares hint --- faststream/nats/broker/broker.pyi | 3 +++ faststream/nats/fastapi/fastapi.pyi | 3 +++ faststream/nats/router/router.pyi | 7 +++++++ 3 files changed, 13 insertions(+) diff --git a/faststream/nats/broker/broker.pyi b/faststream/nats/broker/broker.pyi index 5a4d80a740..a5cb7c3c88 100644 --- a/faststream/nats/broker/broker.pyi +++ b/faststream/nats/broker/broker.pyi @@ -38,6 +38,7 @@ from faststream.broker.types import ( CustomDecoder, CustomParser, Filter, + PublisherMiddleware, SubscriberMiddleware, ) from faststream.log import access_logger @@ -247,6 +248,8 @@ class NatsBroker( # JS stream: str | JStream | None = None, timeout: float | None = None, + # specific + middlewares: Iterable[PublisherMiddleware] = (), # AsyncAPI information title: str | None = None, description: str | None = None, diff --git a/faststream/nats/fastapi/fastapi.pyi b/faststream/nats/fastapi/fastapi.pyi index 27e3c0b2f5..662c5fe270 100644 --- a/faststream/nats/fastapi/fastapi.pyi +++ b/faststream/nats/fastapi/fastapi.pyi @@ -46,6 +46,7 @@ from faststream.broker.types import ( CustomParser, Filter, P_HandlerParams, + PublisherMiddleware, SubscriberMiddleware, T_HandlerReturn, ) @@ -217,6 +218,8 @@ class NatsRouter(StreamRouter[Msg]): # JS stream: str | JStream | None = None, timeout: float | None = None, + # specific + middlewares: Iterable[PublisherMiddleware] = (), # AsyncAPI information title: str | None = None, description: str | None = None, diff --git a/faststream/nats/router/router.pyi b/faststream/nats/router/router.pyi index 749dc03f08..54363df72b 100644 --- a/faststream/nats/router/router.pyi +++ b/faststream/nats/router/router.pyi @@ -14,6 +14,7 @@ from faststream.broker.types import ( CustomParser, Filter, P_HandlerParams, + PublisherMiddleware, SubscriberMiddleware, T_HandlerReturn, ) @@ -87,7 +88,13 @@ class NatsRouter(BrokerRouter[str, Msg]): self, subject: str, headers: dict[str, str] | None = None, + # Core reply_to: str = "", + # JS + stream: str | JStream | None = None, + timeout: float | None = None, + # specific + middlewares: Iterable[PublisherMiddleware] = (), # AsyncAPI information title: str | None = None, description: str | None = None, From 827c31649e6c3f7ac5e8876feebec4f3eb2fb174 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Sun, 14 Jan 2024 21:34:00 +0300 Subject: [PATCH 23/87] tests: router middlewares & depends tests --- tests/brokers/base/router.py | 74 +++++++++--------------------------- 1 file changed, 19 insertions(+), 55 deletions(-) diff --git a/tests/brokers/base/router.py b/tests/brokers/base/router.py index feeef4e7c6..d8e3b60550 100644 --- a/tests/brokers/base/router.py +++ b/tests/brokers/base/router.py @@ -4,7 +4,7 @@ import pytest -from faststream import BaseMiddleware, Depends +from faststream import Depends from faststream.broker.core.broker import BrokerUsecase from faststream.broker.router import BrokerRoute, BrokerRouter from faststream.types import AnyCallable @@ -13,7 +13,10 @@ @pytest.mark.asyncio() -class RouterTestcase(LocalMiddlewareTestcase, LocalCustomParserTestcase): # noqa: D101 +class RouterTestcase( # noqa: D101 + LocalMiddlewareTestcase, + LocalCustomParserTestcase, +): build_message: AnyCallable route_class: Type[BrokerRoute] @@ -289,36 +292,17 @@ async def test_router_dependencies( event: asyncio.Event, mock: Mock, ): - pub_broker._is_apply_types = True + router = type(router)(dependencies=(Depends(lambda: 1),)) + router2 = type(router)(dependencies=(Depends(lambda: 2),)) - async def dep1(s): - mock.dep1() - - async def dep2(s): - mock.dep2() - - router = type(router)(dependencies=(Depends(dep1),)) - - @router.subscriber(queue, dependencies=(Depends(dep2),)) - def subscriber(s): - event.set() + @router2.subscriber(queue, dependencies=(Depends(lambda: 3),)) + def subscriber(): + ... + router.include_router(router2) pub_broker.include_routers(router) - async with pub_broker: - await pub_broker.start() - - await asyncio.wait( - ( - asyncio.create_task(pub_broker.publish("hello", queue)), - asyncio.create_task(event.wait()), - ), - timeout=3, - ) - - assert event.is_set() - mock.dep1.assert_called_once() - mock.dep2.assert_called_once() + assert len(list(pub_broker.handlers.values())[0].calls[0].dependant.extra_dependencies) == 3 async def test_router_middlewares( self, @@ -328,37 +312,17 @@ async def test_router_middlewares( event: asyncio.Event, mock: Mock, ): - class mid1(BaseMiddleware): # noqa: N801 - async def on_receive(self) -> None: - mock.mid1() + router = type(router)(middlewares=(1,)) + router2 = type(router)(middlewares=(2,)) - class mid2(BaseMiddleware): # noqa: N801 - async def on_receive(self) -> None: - mock.mid1.assert_called_once() - mock.mid2() - - router = type(router)(middlewares=(mid1,)) - - @router.subscriber(queue, middlewares=(mid2,)) - def subscriber(s): - event.set() + @router2.subscriber(queue, middlewares=(3,)) + def subscriber(): + ... + router.include_router(router2) pub_broker.include_routers(router) - async with pub_broker: - await pub_broker.start() - - await asyncio.wait( - ( - asyncio.create_task(pub_broker.publish("hello", queue)), - asyncio.create_task(event.wait()), - ), - timeout=3, - ) - - assert event.is_set() - mock.mid1.assert_called_once() - mock.mid2.assert_called_once() + assert list(pub_broker.handlers.values())[0].calls[0].middlewares == (1, 2, 3) async def test_router_parser( self, From 2c9896b908b19e302e117f0b4effe54bfa4abe37 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Tue, 16 Jan 2024 19:03:27 +0300 Subject: [PATCH 24/87] refactor: new logging logic --- faststream/app.py | 4 +- faststream/asyncapi/__init__.py | 6 +- faststream/asyncapi/generate.py | 1 - faststream/asyncapi/site.py | 13 ++- faststream/broker/core/broker.py | 45 ++++---- faststream/broker/core/handler.py | 17 ++- faststream/broker/core/logging_mixin.py | 42 +++----- faststream/broker/utils.py | 22 ---- faststream/cli/utils/logs.py | 1 - faststream/kafka/broker.py | 17 +-- faststream/kafka/shared/logging.py | 43 +------- faststream/log/__init__.py | 7 +- faststream/log/formatter.py | 134 +++--------------------- faststream/log/logging.py | 114 ++++++++++---------- faststream/nats/broker/broker.py | 42 +++----- faststream/nats/broker/logging.py | 43 ++++---- faststream/nats/handler.py | 36 +++++-- faststream/rabbit/shared/logging.py | 3 +- faststream/redis/shared/logging.py | 3 +- faststream/utils/context/repository.py | 3 +- 20 files changed, 216 insertions(+), 380 deletions(-) diff --git a/faststream/app.py b/faststream/app.py index e778f43080..070e000076 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -1,4 +1,5 @@ import logging +import logging.config from typing import ( TYPE_CHECKING, Any, @@ -16,7 +17,7 @@ from faststream._compat import ExceptionGroup from faststream.cli.supervisors.utils import HANDLED_SIGNALS -from faststream.log import logger +from faststream.log.logging import logger from faststream.types import AnyDict, AsyncFunc, Lifespan, SettingField from faststream.utils import apply_types, context from faststream.utils.functions import drop_response_type, fake_context, to_async @@ -108,7 +109,6 @@ def __init__( self.broker = broker self.logger = logger self.context = context - context.set_global("app", self) self._on_startup_calling = [] self._after_startup_calling = [] diff --git a/faststream/asyncapi/__init__.py b/faststream/asyncapi/__init__.py index adee11dca5..07827b8101 100644 --- a/faststream/asyncapi/__init__.py +++ b/faststream/asyncapi/__init__.py @@ -1,5 +1,9 @@ """AsyncAPI related functions.""" +from faststream.asyncapi.generate import get_app_schema from faststream.asyncapi.site import get_asyncapi_html -__all__ = ("get_asyncapi_html",) +__all__ = ( + "get_asyncapi_html", + "get_app_schema", +) diff --git a/faststream/asyncapi/generate.py b/faststream/asyncapi/generate.py index dc4da6cc0a..1b827af953 100644 --- a/faststream/asyncapi/generate.py +++ b/faststream/asyncapi/generate.py @@ -159,7 +159,6 @@ def get_app_broker_channels( Raises: AssertionError: If the app does not have a broker. - """ channels = {} assert app.broker # nosec B101 diff --git a/faststream/asyncapi/site.py b/faststream/asyncapi/site.py index 2286d2c26b..803218529c 100644 --- a/faststream/asyncapi/site.py +++ b/faststream/asyncapi/site.py @@ -17,6 +17,8 @@ def get_asyncapi_html( errors: bool = True, expand_message_examples: bool = True, title: str = "FastStream", + asyncapi_js_url: str = "https://unpkg.com/@asyncapi/react-component@1.0.0-next.47/browser/standalone/index.js", + asyncapi_css_url: str = "https://unpkg.com/@asyncapi/react-component@1.0.0-next.46/styles/default.min.css", ) -> str: """Generate HTML for displaying an AsyncAPI document. @@ -77,7 +79,11 @@ def get_asyncapi_html( - + """ + f""" + + """ + """