diff --git a/faststream/__about__.py b/faststream/__about__.py index efa54afcc2..7aaf590027 100644 --- a/faststream/__about__.py +++ b/faststream/__about__.py @@ -1,6 +1,6 @@ """Simple and fast framework to create message brokers based microservices.""" -__version__ = "0.5.5" +__version__ = "0.5.6" SERVICE_NAME = f"faststream-{__version__}" diff --git a/faststream/broker/core/abc.py b/faststream/broker/core/abc.py index 1a49e26843..e89f2d5144 100644 --- a/faststream/broker/core/abc.py +++ b/faststream/broker/core/abc.py @@ -46,6 +46,19 @@ def __init__( self._parser = parser self._decoder = decoder + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + """Append BrokerMiddleware to the end of middlewares list. + + Current middleware will be used as a most inner of already existed ones. + """ + self._middlewares = (*self._middlewares, middleware) + + for sub in self._subscribers.values(): + sub.add_middleware(middleware) + + for pub in self._publishers.values(): + pub.add_middleware(middleware) + @abstractmethod def subscriber( self, diff --git a/faststream/broker/publisher/proto.py b/faststream/broker/publisher/proto.py index 2233739252..4a916ed6a0 100644 --- a/faststream/broker/publisher/proto.py +++ b/faststream/broker/publisher/proto.py @@ -56,6 +56,10 @@ class PublisherProto( _middlewares: Iterable["PublisherMiddleware"] _producer: Optional["ProducerProto"] + @abstractmethod + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + ... + @staticmethod @abstractmethod def create() -> "PublisherProto[MsgType]": diff --git a/faststream/broker/publisher/usecase.py b/faststream/broker/publisher/usecase.py index 23e8c5586e..46bb96ef2a 100644 --- a/faststream/broker/publisher/usecase.py +++ b/faststream/broker/publisher/usecase.py @@ -19,7 +19,12 @@ from faststream.asyncapi.message import get_response_schema from faststream.asyncapi.utils import to_camelcase from faststream.broker.publisher.proto import PublisherProto -from faststream.broker.types import MsgType, P_HandlerParams, T_HandlerReturn +from faststream.broker.types import ( + BrokerMiddleware, + MsgType, + P_HandlerParams, + T_HandlerReturn, +) from faststream.broker.wrapper.call import HandlerCallWrapper if TYPE_CHECKING: @@ -87,6 +92,9 @@ def __init__( self.include_in_schema = include_in_schema self.schema_ = schema_ + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + self._broker_middlewares = (*self._broker_middlewares, middleware) + @override def setup( # type: ignore[override] self, diff --git a/faststream/broker/subscriber/proto.py b/faststream/broker/subscriber/proto.py index 534c795b95..bb0ae5fb2d 100644 --- a/faststream/broker/subscriber/proto.py +++ b/faststream/broker/subscriber/proto.py @@ -35,6 +35,10 @@ class SubscriberProto( _broker_middlewares: Iterable["BrokerMiddleware[MsgType]"] _producer: Optional["ProducerProto"] + @abstractmethod + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + ... + @staticmethod @abstractmethod def create() -> "SubscriberProto[MsgType]": diff --git a/faststream/broker/subscriber/usecase.py b/faststream/broker/subscriber/usecase.py index 82e6ebce8c..5d0dd886dd 100644 --- a/faststream/broker/subscriber/usecase.py +++ b/faststream/broker/subscriber/usecase.py @@ -131,6 +131,9 @@ def __init__( self.description_ = description_ self.include_in_schema = include_in_schema + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + self._broker_middlewares = (*self._broker_middlewares, middleware) + @override def setup( # type: ignore[override] self, diff --git a/faststream/nats/parser.py b/faststream/nats/parser.py index d843f13f99..db4caba54b 100644 --- a/faststream/nats/parser.py +++ b/faststream/nats/parser.py @@ -104,13 +104,17 @@ async def parse_batch( ) -> "StreamMessage[List[Msg]]": if first_msg := next(iter(message), None): path = self.get_path(first_msg.subject) + headers = first_msg.headers + else: path = None + headers = None return NatsBatchMessage( raw_message=message, body=[m.data for m in message], path=path or {}, + headers=headers or {}, ) async def decode_batch( diff --git a/faststream/rabbit/__init__.py b/faststream/rabbit/__init__.py index 11ca1a9373..7c05cb70c8 100644 --- a/faststream/rabbit/__init__.py +++ b/faststream/rabbit/__init__.py @@ -21,5 +21,6 @@ "ReplyConfig", "RabbitExchange", "RabbitQueue", + # Annotations "RabbitMessage", ) diff --git a/faststream/rabbit/annotations.py b/faststream/rabbit/annotations.py index bfb78c6af9..f32654d2cc 100644 --- a/faststream/rabbit/annotations.py +++ b/faststream/rabbit/annotations.py @@ -1,3 +1,4 @@ +from aio_pika import RobustChannel, RobustConnection from typing_extensions import Annotated from faststream.annotations import ContextRepo, Logger, NoCast @@ -13,8 +14,20 @@ "RabbitMessage", "RabbitBroker", "RabbitProducer", + "Channel", + "Connection", ) RabbitMessage = Annotated[RM, Context("message")] RabbitBroker = Annotated[RB, Context("broker")] RabbitProducer = Annotated[AioPikaFastProducer, Context("broker._producer")] + +Channel = Annotated[RobustChannel, Context("broker._channel")] +Connection = Annotated[RobustConnection, Context("broker._connection")] + +# NOTE: transaction is not for the public usage yet +# async def _get_transaction(connection: Connection) -> RabbitTransaction: +# async with connection.channel(publisher_confirms=False) as channel: +# yield channel.transaction() + +# Transaction = Annotated[RabbitTransaction, Depends(_get_transaction)] diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index fd4ca30d84..30a6cce711 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -100,6 +100,26 @@ def __init__( "TimeoutType", Doc("Connection establishement timeout."), ] = None, + # channel args + channel_number: Annotated[ + Optional[int], + Doc("Specify the channel number explicit."), + ] = None, + publisher_confirms: Annotated[ + bool, + Doc( + "if `True` the `publish` method will " + "return `bool` type after publish is complete." + "Otherwise it will returns `None`." + ), + ] = True, + on_return_raises: Annotated[ + bool, + Doc( + "raise an :class:`aio_pika.exceptions.DeliveryError`" + "when mandatory message will be returned" + ) + ] = False, # broker args max_consumers: Annotated[ Optional[int], @@ -220,6 +240,10 @@ def __init__( url=str(amqp_url), ssl_context=security_args.get("ssl_context"), timeout=timeout, + # channel args + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, # Basic args graceful_timeout=graceful_timeout, dependencies=dependencies, @@ -303,6 +327,26 @@ async def connect( # type: ignore[override] "TimeoutType", Doc("Connection establishement timeout."), ] = None, + # channel args + channel_number: Annotated[ + Union[int, None, object], + Doc("Specify the channel number explicit."), + ] = Parameter.empty, + publisher_confirms: Annotated[ + Union[bool, object], + Doc( + "if `True` the `publish` method will " + "return `bool` type after publish is complete." + "Otherwise it will returns `None`." + ), + ] = Parameter.empty, + on_return_raises: Annotated[ + Union[bool, object], + Doc( + "raise an :class:`aio_pika.exceptions.DeliveryError`" + "when mandatory message will be returned" + ) + ] = Parameter.empty, ) -> "RobustConnection": """Connect broker object to RabbitMQ. @@ -310,6 +354,15 @@ async def connect( # type: ignore[override] """ kwargs: AnyDict = {} + if channel_number is not Parameter.empty: + kwargs["channel_number"] = channel_number + + if publisher_confirms is not Parameter.empty: + kwargs["publisher_confirms"] = publisher_confirms + + if on_return_raises is not Parameter.empty: + kwargs["on_return_raises"] = on_return_raises + if timeout: kwargs["timeout"] = timeout @@ -346,6 +399,9 @@ async def _connect( # type: ignore[override] *, timeout: "TimeoutType", ssl_context: Optional["SSLContext"], + channel_number: Optional[int], + publisher_confirms: bool, + on_return_raises: bool, ) -> "RobustConnection": connection = cast( "RobustConnection", @@ -360,7 +416,11 @@ async def _connect( # type: ignore[override] max_consumers = self._max_consumers channel = self._channel = cast( "RobustChannel", - await connection.channel(), + await connection.channel( + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, + ), ) declarer = self.declarer = RabbitDeclarer(channel) diff --git a/faststream/rabbit/fastapi/router.py b/faststream/rabbit/fastapi/router.py index 4cc90b25d9..fcbc27320d 100644 --- a/faststream/rabbit/fastapi/router.py +++ b/faststream/rabbit/fastapi/router.py @@ -96,6 +96,26 @@ def __init__( "TimeoutType", Doc("Connection establishement timeout."), ] = None, + # channel args + channel_number: Annotated[ + Optional[int], + Doc("Specify the channel number explicit."), + ] = None, + publisher_confirms: Annotated[ + bool, + Doc( + "if `True` the `publish` method will " + "return `bool` type after publish is complete." + "Otherwise it will returns `None`." + ), + ] = True, + on_return_raises: Annotated[ + bool, + Doc( + "raise an :class:`aio_pika.exceptions.DeliveryError`" + "when mandatory message will be returned" + ) + ] = False, # broker args max_consumers: Annotated[ Optional[int], @@ -408,6 +428,9 @@ def __init__( graceful_timeout=graceful_timeout, decoder=decoder, parser=parser, + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, middlewares=middlewares, security=security, asyncapi_url=asyncapi_url, diff --git a/faststream/rabbit/schemas/queue.py b/faststream/rabbit/schemas/queue.py index b63685d1a5..a9bccf013d 100644 --- a/faststream/rabbit/schemas/queue.py +++ b/faststream/rabbit/schemas/queue.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import TYPE_CHECKING, Optional from typing_extensions import Annotated, Doc @@ -115,3 +116,13 @@ def __init__( self.auto_delete = auto_delete self.arguments = arguments self.timeout = timeout + + def add_prefix(self, prefix: str) -> "RabbitQueue": + new_q: RabbitQueue = deepcopy(self) + + new_q.name = "".join((prefix, new_q.name)) + + if new_q.routing_key: + new_q.routing_key = "".join((prefix, new_q.routing_key)) + + return new_q diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index aecac22384..d2ca4480a2 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import ( TYPE_CHECKING, Any, @@ -223,6 +222,4 @@ def get_log_context( def add_prefix(self, prefix: str) -> None: """Include Subscriber in router.""" - new_q = deepcopy(self.queue) - new_q.name = "".join((prefix, new_q.name)) - self.queue = new_q + self.queue = self.queue.add_prefix(prefix) diff --git a/faststream/redis/parser.py b/faststream/redis/parser.py index d47dae603d..8daa0f92cf 100644 --- a/faststream/redis/parser.py +++ b/faststream/redis/parser.py @@ -1,6 +1,7 @@ from typing import ( TYPE_CHECKING, Any, + Dict, Mapping, Optional, Sequence, @@ -183,9 +184,13 @@ class RedisBatchListParser(SimpleParser): @staticmethod def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: + data = [_decode_batch_body_item(x) for x in message["data"]] return ( - dump_json(_decode_batch_body_item(x) for x in message["data"]), - {"content-type": ContentTypes.json}, + dump_json(i[0] for i in data), + { + **data[0][1], + "content-type": ContentTypes.json, + }, ) @@ -203,17 +208,19 @@ class RedisBatchStreamParser(SimpleParser): @staticmethod def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: + data = [_decode_batch_body_item(x.get(bDATA_KEY, x)) for x in message["data"]] return ( - dump_json( - _decode_batch_body_item(x.get(bDATA_KEY, x)) for x in message["data"] - ), - {"content-type": ContentTypes.json}, + dump_json(i[0] for i in data), + { + **data[0][1], + "content-type": ContentTypes.json, + }, ) -def _decode_batch_body_item(msg_content: bytes) -> Any: - msg_body, _ = RawMessage.parse(msg_content) +def _decode_batch_body_item(msg_content: bytes) -> Tuple[Any, Dict[str, str]]: + msg_body, headers = RawMessage.parse(msg_content) try: - return json_loads(msg_body) + return json_loads(msg_body), headers except Exception: - return msg_body + return msg_body, headers diff --git a/tests/brokers/base/middlewares.py b/tests/brokers/base/middlewares.py index 4f89f08411..21b4931af5 100644 --- a/tests/brokers/base/middlewares.py +++ b/tests/brokers/base/middlewares.py @@ -270,6 +270,55 @@ async def handler(m): mock.start.assert_called_once() mock.end.assert_called_once() + async def test_add_global_middleware( + self, event: asyncio.Event, queue: str, mock: Mock, raw_broker, + ): + class mid(BaseMiddleware): # noqa: N801 + async def on_receive(self): + mock.start(self.msg) + return await super().on_receive() + + async def after_processed(self, exc_type, exc_val, exc_tb): + mock.end() + return await super().after_processed(exc_type, exc_val, exc_tb) + + broker = self.broker_class() + + # already registered subscriber + @broker.subscriber(queue, **self.subscriber_kwargs) + async def handler(m): + event.set() + return "" + + # should affect to already registered and a new subscriber both + broker.add_middleware(mid) + + event2 = asyncio.Event() + + # new subscriber + @broker.subscriber(f"{queue}1", **self.subscriber_kwargs) + async def handler2(m): + event2.set() + return "" + + broker = self.patch_broker(raw_broker, broker) + + async with broker: + await broker.start() + await asyncio.wait( + ( + asyncio.create_task(broker.publish("", queue)), + asyncio.create_task(broker.publish("", f"{queue}1")), + asyncio.create_task(event.wait()), + asyncio.create_task(event2.wait()), + ), + timeout=self.timeout, + ) + + assert event.is_set() + assert mock.start.call_count == 2 + assert mock.end.call_count == 2 + async def test_patch_publish(self, queue: str, mock: Mock, event, raw_broker): class Mid(BaseMiddleware): async def on_publish(self, msg: str, *args, **kwargs) -> str: diff --git a/tests/brokers/rabbit/test_router.py b/tests/brokers/rabbit/test_router.py index 50f81a636d..ac14d3372d 100644 --- a/tests/brokers/rabbit/test_router.py +++ b/tests/brokers/rabbit/test_router.py @@ -139,6 +139,39 @@ def subscriber(m): assert event.is_set() + async def test_queue_obj_with_routing_key( + self, + router: RabbitRouter, + broker: RabbitBroker, + queue: str, + event: asyncio.Event, + ): + router.prefix = "test/" + + r_queue = RabbitQueue("useless", routing_key=f"{queue}1") + exchange = RabbitExchange(f"{queue}exch") + + @router.subscriber(r_queue, exchange=exchange) + def subscriber(m): + event.set() + + broker.include_router(router) + + async with broker: + await broker.start() + + await asyncio.wait( + ( + asyncio.create_task( + broker.publish("hello", f"test/{queue}1", exchange=exchange) + ), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + async def test_delayed_handlers_with_queue( self, event: asyncio.Event,