diff --git a/faststream/__about__.py b/faststream/__about__.py index efa54afcc2..7aaf590027 100644 --- a/faststream/__about__.py +++ b/faststream/__about__.py @@ -1,6 +1,6 @@ """Simple and fast framework to create message brokers based microservices.""" -__version__ = "0.5.5" +__version__ = "0.5.6" SERVICE_NAME = f"faststream-{__version__}" diff --git a/faststream/broker/core/abc.py b/faststream/broker/core/abc.py index 1a49e26843..e89f2d5144 100644 --- a/faststream/broker/core/abc.py +++ b/faststream/broker/core/abc.py @@ -46,6 +46,19 @@ def __init__( self._parser = parser self._decoder = decoder + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + """Append BrokerMiddleware to the end of middlewares list. + + Current middleware will be used as a most inner of already existed ones. + """ + self._middlewares = (*self._middlewares, middleware) + + for sub in self._subscribers.values(): + sub.add_middleware(middleware) + + for pub in self._publishers.values(): + pub.add_middleware(middleware) + @abstractmethod def subscriber( self, diff --git a/faststream/broker/message.py b/faststream/broker/message.py index a692f12d4f..beec9fe555 100644 --- a/faststream/broker/message.py +++ b/faststream/broker/message.py @@ -6,6 +6,7 @@ TYPE_CHECKING, Any, Generic, + List, Optional, Sequence, Tuple, @@ -38,6 +39,7 @@ class StreamMessage(Generic[MsgType]): body: Union[bytes, Any] headers: "AnyDict" = field(default_factory=dict) + batch_headers: List["AnyDict"] = field(default_factory=list) path: "AnyDict" = field(default_factory=dict) content_type: Optional[str] = None diff --git a/faststream/broker/publisher/proto.py b/faststream/broker/publisher/proto.py index 2233739252..747b29b048 100644 --- a/faststream/broker/publisher/proto.py +++ b/faststream/broker/publisher/proto.py @@ -56,6 +56,9 @@ class PublisherProto( _middlewares: Iterable["PublisherMiddleware"] _producer: Optional["ProducerProto"] + @abstractmethod + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: ... + @staticmethod @abstractmethod def create() -> "PublisherProto[MsgType]": diff --git a/faststream/broker/publisher/usecase.py b/faststream/broker/publisher/usecase.py index 23e8c5586e..46bb96ef2a 100644 --- a/faststream/broker/publisher/usecase.py +++ b/faststream/broker/publisher/usecase.py @@ -19,7 +19,12 @@ from faststream.asyncapi.message import get_response_schema from faststream.asyncapi.utils import to_camelcase from faststream.broker.publisher.proto import PublisherProto -from faststream.broker.types import MsgType, P_HandlerParams, T_HandlerReturn +from faststream.broker.types import ( + BrokerMiddleware, + MsgType, + P_HandlerParams, + T_HandlerReturn, +) from faststream.broker.wrapper.call import HandlerCallWrapper if TYPE_CHECKING: @@ -87,6 +92,9 @@ def __init__( self.include_in_schema = include_in_schema self.schema_ = schema_ + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + self._broker_middlewares = (*self._broker_middlewares, middleware) + @override def setup( # type: ignore[override] self, diff --git a/faststream/broker/subscriber/proto.py b/faststream/broker/subscriber/proto.py index 534c795b95..fa19428fde 100644 --- a/faststream/broker/subscriber/proto.py +++ b/faststream/broker/subscriber/proto.py @@ -35,6 +35,9 @@ class SubscriberProto( _broker_middlewares: Iterable["BrokerMiddleware[MsgType]"] _producer: Optional["ProducerProto"] + @abstractmethod + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: ... + @staticmethod @abstractmethod def create() -> "SubscriberProto[MsgType]": diff --git a/faststream/broker/subscriber/usecase.py b/faststream/broker/subscriber/usecase.py index 82e6ebce8c..5d0dd886dd 100644 --- a/faststream/broker/subscriber/usecase.py +++ b/faststream/broker/subscriber/usecase.py @@ -131,6 +131,9 @@ def __init__( self.description_ = description_ self.include_in_schema = include_in_schema + def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None: + self._broker_middlewares = (*self._broker_middlewares, middleware) + @override def setup( # type: ignore[override] self, diff --git a/faststream/cli/docs/app.py b/faststream/cli/docs/app.py index 450abdb061..c8066a8b9e 100644 --- a/faststream/cli/docs/app.py +++ b/faststream/cli/docs/app.py @@ -45,8 +45,7 @@ def serve( ), ), is_factory: bool = typer.Option( - False, - "--factory", help="Treat APP as an application factory" + False, "--factory", help="Treat APP as an application factory" ), ) -> None: """Serve project AsyncAPI schema.""" @@ -110,7 +109,8 @@ def gen( ), is_factory: bool = typer.Option( False, - "--factory", help="Treat APP as an application factory" + "--factory", + help="Treat APP as an application factory", ), ) -> None: """Generate project AsyncAPI schema.""" diff --git a/faststream/cli/main.py b/faststream/cli/main.py index 3ed7afa3e6..bbbe99aa33 100644 --- a/faststream/cli/main.py +++ b/faststream/cli/main.py @@ -211,7 +211,9 @@ def publish( rpc: bool = typer.Option(False, help="Enable RPC mode and system output"), is_factory: bool = typer.Option( False, - "--factory", help="Treat APP as an application factory" + "--factory", + is_flag=True, + help="Treat APP as an application factory", ), ) -> None: """Publish a message using the specified broker in a FastStream application. diff --git a/faststream/confluent/parser.py b/faststream/confluent/parser.py index a4858247ac..8541ceb4f0 100644 --- a/faststream/confluent/parser.py +++ b/faststream/confluent/parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union from faststream.broker.message import decode_message, gen_cor_id from faststream.confluent.message import FAKE_CONSUMER, KafkaMessage @@ -20,18 +20,14 @@ async def parse_message( message: "Message", ) -> "StreamMessage[Message]": """Parses a Kafka message.""" - headers = {} - if message.headers() is not None: - for i, j in message.headers(): # type: ignore[union-attr] - if isinstance(j, str): - headers[i] = j - else: - headers[i] = j.decode() + headers = _parse_msg_headers(message.headers()) + body = message.value() offset = message.offset() _, timestamp = message.timestamp() handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_") + return KafkaMessage( body=body, headers=headers, @@ -49,28 +45,29 @@ async def parse_message_batch( message: Tuple["Message", ...], ) -> "StreamMessage[Tuple[Message, ...]]": """Parses a batch of messages from a Kafka consumer.""" + body: List[Any] = [] + batch_headers: List[Dict[str, str]] = [] + first = message[0] last = message[-1] - headers = {} - if first.headers() is not None: - for i, j in first.headers(): # type: ignore[union-attr] - if isinstance(j, str): - headers[i] = j - else: - headers[i] = j.decode() - body = [m.value() for m in message] - first_offset = first.offset() - last_offset = last.offset() + for m in message: + body.append(m.value) + batch_headers.append(_parse_msg_headers(m.headers())) + + headers = next(iter(batch_headers), {}) + _, first_timestamp = first.timestamp() handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_") + return KafkaMessage( body=body, headers=headers, + batch_headers=batch_headers, reply_to=headers.get("reply_to", ""), content_type=headers.get("content-type"), - message_id=f"{first_offset}-{last_offset}-{first_timestamp}", + message_id=f"{first.offset()}-{last.offset()}-{first_timestamp}", correlation_id=headers.get("correlation_id", gen_cor_id()), raw_message=message, consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER, @@ -91,3 +88,9 @@ async def decode_message_batch( ) -> "DecodedMessage": """Decode a batch of messages.""" return [decode_message(await cls.parse_message(m)) for m in msg.raw_message] + + +def _parse_msg_headers( + headers: Sequence[Tuple[str, Union[bytes, str]]], +) -> Dict[str, str]: + return {i: j if isinstance(j, str) else j.decode() for i, j in headers} diff --git a/faststream/kafka/parser.py b/faststream/kafka/parser.py index c99bc31c33..8487eb3d0b 100644 --- a/faststream/kafka/parser.py +++ b/faststream/kafka/parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from faststream.broker.message import decode_message, gen_cor_id from faststream.kafka.message import FAKE_CONSUMER, KafkaMessage @@ -39,13 +39,24 @@ async def parse_message_batch( message: Tuple["ConsumerRecord", ...], ) -> "StreamMessage[Tuple[ConsumerRecord, ...]]": """Parses a batch of messages from a Kafka consumer.""" + body: List[Any] = [] + batch_headers: List[Dict[str, str]] = [] + first = message[0] last = message[-1] - headers = {i: j.decode() for i, j in first.headers} + + for m in message: + body.append(m.value) + batch_headers.append({i: j.decode() for i, j in m.headers}) + + headers = next(iter(batch_headers), {}) + handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_") + return KafkaMessage( - body=[m.value for m in message], + body=body, headers=headers, + batch_headers=batch_headers, reply_to=headers.get("reply_to", ""), content_type=headers.get("content-type"), message_id=f"{first.offset}-{last.offset}-{first.timestamp}", diff --git a/faststream/nats/broker/broker.py b/faststream/nats/broker/broker.py index 2c6265ac1c..2ccbe47bad 100644 --- a/faststream/nats/broker/broker.py +++ b/faststream/nats/broker/broker.py @@ -623,12 +623,12 @@ async def start(self) -> None: ) except BadRequestError as e: - old_config = (await self.stream.stream_info(stream.name)).config - if ( e.description == "stream name already in use with a different configuration" ): + old_config = (await self.stream.stream_info(stream.name)).config + self._log(str(e), logging.WARNING, log_context) await self.stream.update_stream( config=stream.config, diff --git a/faststream/nats/parser.py b/faststream/nats/parser.py index d843f13f99..940ae70426 100644 --- a/faststream/nats/parser.py +++ b/faststream/nats/parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from faststream.broker.message import StreamMessage, decode_message, gen_cor_id from faststream.nats.message import NatsBatchMessage, NatsMessage @@ -102,15 +102,27 @@ async def parse_batch( self, message: List["Msg"], ) -> "StreamMessage[List[Msg]]": - if first_msg := next(iter(message), None): - path = self.get_path(first_msg.subject) + body: List[bytes] = [] + batch_headers: List[Dict[str, str]] = [] + + if message: + path = self.get_path(message[0].subject) + + for m in message: + batch_headers.append(m.headers or {}) + body.append(m.data) + else: path = None + headers = next(iter(batch_headers), {}) + return NatsBatchMessage( raw_message=message, - body=[m.data for m in message], + body=body, path=path or {}, + headers=headers, + batch_headers=batch_headers, ) async def decode_batch( diff --git a/faststream/rabbit/__init__.py b/faststream/rabbit/__init__.py index 11ca1a9373..7c05cb70c8 100644 --- a/faststream/rabbit/__init__.py +++ b/faststream/rabbit/__init__.py @@ -21,5 +21,6 @@ "ReplyConfig", "RabbitExchange", "RabbitQueue", + # Annotations "RabbitMessage", ) diff --git a/faststream/rabbit/annotations.py b/faststream/rabbit/annotations.py index bfb78c6af9..f32654d2cc 100644 --- a/faststream/rabbit/annotations.py +++ b/faststream/rabbit/annotations.py @@ -1,3 +1,4 @@ +from aio_pika import RobustChannel, RobustConnection from typing_extensions import Annotated from faststream.annotations import ContextRepo, Logger, NoCast @@ -13,8 +14,20 @@ "RabbitMessage", "RabbitBroker", "RabbitProducer", + "Channel", + "Connection", ) RabbitMessage = Annotated[RM, Context("message")] RabbitBroker = Annotated[RB, Context("broker")] RabbitProducer = Annotated[AioPikaFastProducer, Context("broker._producer")] + +Channel = Annotated[RobustChannel, Context("broker._channel")] +Connection = Annotated[RobustConnection, Context("broker._connection")] + +# NOTE: transaction is not for the public usage yet +# async def _get_transaction(connection: Connection) -> RabbitTransaction: +# async with connection.channel(publisher_confirms=False) as channel: +# yield channel.transaction() + +# Transaction = Annotated[RabbitTransaction, Depends(_get_transaction)] diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index fd4ca30d84..b0bc98c42f 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -100,6 +100,26 @@ def __init__( "TimeoutType", Doc("Connection establishement timeout."), ] = None, + # channel args + channel_number: Annotated[ + Optional[int], + Doc("Specify the channel number explicit."), + ] = None, + publisher_confirms: Annotated[ + bool, + Doc( + "if `True` the `publish` method will " + "return `bool` type after publish is complete." + "Otherwise it will returns `None`." + ), + ] = True, + on_return_raises: Annotated[ + bool, + Doc( + "raise an :class:`aio_pika.exceptions.DeliveryError`" + "when mandatory message will be returned" + ), + ] = False, # broker args max_consumers: Annotated[ Optional[int], @@ -220,6 +240,10 @@ def __init__( url=str(amqp_url), ssl_context=security_args.get("ssl_context"), timeout=timeout, + # channel args + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, # Basic args graceful_timeout=graceful_timeout, dependencies=dependencies, @@ -303,6 +327,26 @@ async def connect( # type: ignore[override] "TimeoutType", Doc("Connection establishement timeout."), ] = None, + # channel args + channel_number: Annotated[ + Union[int, None, object], + Doc("Specify the channel number explicit."), + ] = Parameter.empty, + publisher_confirms: Annotated[ + Union[bool, object], + Doc( + "if `True` the `publish` method will " + "return `bool` type after publish is complete." + "Otherwise it will returns `None`." + ), + ] = Parameter.empty, + on_return_raises: Annotated[ + Union[bool, object], + Doc( + "raise an :class:`aio_pika.exceptions.DeliveryError`" + "when mandatory message will be returned" + ), + ] = Parameter.empty, ) -> "RobustConnection": """Connect broker object to RabbitMQ. @@ -310,6 +354,15 @@ async def connect( # type: ignore[override] """ kwargs: AnyDict = {} + if channel_number is not Parameter.empty: + kwargs["channel_number"] = channel_number + + if publisher_confirms is not Parameter.empty: + kwargs["publisher_confirms"] = publisher_confirms + + if on_return_raises is not Parameter.empty: + kwargs["on_return_raises"] = on_return_raises + if timeout: kwargs["timeout"] = timeout @@ -346,6 +399,9 @@ async def _connect( # type: ignore[override] *, timeout: "TimeoutType", ssl_context: Optional["SSLContext"], + channel_number: Optional[int], + publisher_confirms: bool, + on_return_raises: bool, ) -> "RobustConnection": connection = cast( "RobustConnection", @@ -360,7 +416,11 @@ async def _connect( # type: ignore[override] max_consumers = self._max_consumers channel = self._channel = cast( "RobustChannel", - await connection.channel(), + await connection.channel( + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, + ), ) declarer = self.declarer = RabbitDeclarer(channel) diff --git a/faststream/rabbit/fastapi/router.py b/faststream/rabbit/fastapi/router.py index 4cc90b25d9..6d13beabae 100644 --- a/faststream/rabbit/fastapi/router.py +++ b/faststream/rabbit/fastapi/router.py @@ -96,6 +96,26 @@ def __init__( "TimeoutType", Doc("Connection establishement timeout."), ] = None, + # channel args + channel_number: Annotated[ + Optional[int], + Doc("Specify the channel number explicit."), + ] = None, + publisher_confirms: Annotated[ + bool, + Doc( + "if `True` the `publish` method will " + "return `bool` type after publish is complete." + "Otherwise it will returns `None`." + ), + ] = True, + on_return_raises: Annotated[ + bool, + Doc( + "raise an :class:`aio_pika.exceptions.DeliveryError`" + "when mandatory message will be returned" + ), + ] = False, # broker args max_consumers: Annotated[ Optional[int], @@ -408,6 +428,9 @@ def __init__( graceful_timeout=graceful_timeout, decoder=decoder, parser=parser, + channel_number=channel_number, + publisher_confirms=publisher_confirms, + on_return_raises=on_return_raises, middlewares=middlewares, security=security, asyncapi_url=asyncapi_url, diff --git a/faststream/rabbit/schemas/queue.py b/faststream/rabbit/schemas/queue.py index b63685d1a5..a9bccf013d 100644 --- a/faststream/rabbit/schemas/queue.py +++ b/faststream/rabbit/schemas/queue.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import TYPE_CHECKING, Optional from typing_extensions import Annotated, Doc @@ -115,3 +116,13 @@ def __init__( self.auto_delete = auto_delete self.arguments = arguments self.timeout = timeout + + def add_prefix(self, prefix: str) -> "RabbitQueue": + new_q: RabbitQueue = deepcopy(self) + + new_q.name = "".join((prefix, new_q.name)) + + if new_q.routing_key: + new_q.routing_key = "".join((prefix, new_q.routing_key)) + + return new_q diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index aecac22384..d2ca4480a2 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import ( TYPE_CHECKING, Any, @@ -223,6 +222,4 @@ def get_log_context( def add_prefix(self, prefix: str) -> None: """Include Subscriber in router.""" - new_q = deepcopy(self.queue) - new_q.name = "".join((prefix, new_q.name)) - self.queue = new_q + self.queue = self.queue.add_prefix(prefix) diff --git a/faststream/redis/parser.py b/faststream/redis/parser.py index d47dae603d..52806b7fbd 100644 --- a/faststream/redis/parser.py +++ b/faststream/redis/parser.py @@ -1,6 +1,7 @@ from typing import ( TYPE_CHECKING, Any, + List, Mapping, Optional, Sequence, @@ -135,13 +136,16 @@ async def parse_message( self, message: Mapping[str, Any], ) -> "StreamMessage[Mapping[str, Any]]": - data, headers = self._parse_data(message) + data, headers, batch_headers = self._parse_data(message) + id_ = gen_cor_id() + return self.msg_class( raw_message=message, body=data, path=self.get_path(message), headers=headers, + batch_headers=batch_headers, reply_to=headers.get("reply_to", ""), content_type=headers.get("content-type"), message_id=headers.get("message_id", id_), @@ -149,8 +153,10 @@ async def parse_message( ) @staticmethod - def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: - return RawMessage.parse(message["data"]) + def _parse_data( + message: Mapping[str, Any], + ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: + return (*RawMessage.parse(message["data"]), []) def get_path(self, message: Mapping[str, Any]) -> "AnyDict": if ( @@ -182,10 +188,26 @@ class RedisBatchListParser(SimpleParser): msg_class = RedisBatchListMessage @staticmethod - def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: + def _parse_data( + message: Mapping[str, Any], + ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: + body: List[Any] = [] + batch_headers: List["AnyDict"] = [] + + for x in message["data"]: + msg_data, msg_headers = _decode_batch_body_item(x) + body.append(msg_data) + batch_headers.append(msg_headers) + + first_msg_headers = next(iter(batch_headers), {}) + return ( - dump_json(_decode_batch_body_item(x) for x in message["data"]), - {"content-type": ContentTypes.json}, + dump_json(body), + { + **first_msg_headers, + "content-type": ContentTypes.json.value, + }, + batch_headers, ) @@ -193,27 +215,43 @@ class RedisStreamParser(SimpleParser): msg_class = RedisStreamMessage @classmethod - def _parse_data(cls, message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: + def _parse_data( + cls, message: Mapping[str, Any] + ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: data = message["data"] - return RawMessage.parse(data.get(bDATA_KEY) or dump_json(data)) + return (*RawMessage.parse(data.get(bDATA_KEY) or dump_json(data)), []) class RedisBatchStreamParser(SimpleParser): msg_class = RedisBatchStreamMessage @staticmethod - def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]: + def _parse_data( + message: Mapping[str, Any], + ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: + body: List[Any] = [] + batch_headers: List["AnyDict"] = [] + + for x in message["data"]: + msg_data, msg_headers = _decode_batch_body_item(x.get(bDATA_KEY, x)) + body.append(msg_data) + batch_headers.append(msg_headers) + + first_msg_headers = next(iter(batch_headers), {}) + return ( - dump_json( - _decode_batch_body_item(x.get(bDATA_KEY, x)) for x in message["data"] - ), - {"content-type": ContentTypes.json}, + dump_json(body), + { + **first_msg_headers, + "content-type": ContentTypes.json.value, + }, + batch_headers, ) -def _decode_batch_body_item(msg_content: bytes) -> Any: - msg_body, _ = RawMessage.parse(msg_content) +def _decode_batch_body_item(msg_content: bytes) -> Tuple[Any, "AnyDict"]: + msg_body, headers = RawMessage.parse(msg_content) try: - return json_loads(msg_body) + return json_loads(msg_body), headers except Exception: - return msg_body + return msg_body, headers diff --git a/faststream/types.py b/faststream/types.py index 9f12fb9d57..681a7a3b18 100644 --- a/faststream/types.py +++ b/faststream/types.py @@ -63,22 +63,16 @@ class StandardDataclass(Protocol): """Protocol to check type is dataclass.""" __dataclass_fields__: ClassVar[Dict[str, Any]] - __dataclass_params__: ClassVar[Any] - __post_init__: ClassVar[Callable[..., None]] - - def __init__(self, *args: object, **kwargs: object) -> None: - """Interface method.""" - ... BaseSendableMessage: TypeAlias = Union[ JsonDecodable, Decimal, datetime, - None, StandardDataclass, SendableTable, SendableArray, + None, ] try: diff --git a/pyproject.toml b/pyproject.toml index c4f419a0a6..505e0de0dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,9 @@ redis = ["redis>=5.0.0,<6.0.0"] # dev dependencies devdocs = [ "mkdocs-material==9.5.21", - "mkdocs-static-i18n==1.2.2", + "mkdocs-static-i18n==1.2.3", "mdx-include==1.4.2", - "mkdocstrings[python]==0.25.0", + "mkdocstrings[python]==0.25.1", "mkdocs-literate-nav==0.6.1", "mkdocs-git-revision-date-localized-plugin==1.2.5", "mike==2.1.1", # versioning @@ -106,14 +106,14 @@ types = [ lint = [ "faststream[types]", - "ruff==0.4.3", + "ruff==0.4.4", "bandit==1.7.8", "semgrep==1.70.0", "codespell==2.2.6", ] test-core = [ - "coverage[toml]==7.4.4", + "coverage[toml]==7.5.1", "pytest==8.2.0", "pytest-asyncio==0.23.6", "dirty-equals==0.7.1.post0", @@ -133,7 +133,7 @@ dev = [ "faststream[rabbit,kafka,confluent,nats,redis,lint,testing,devdocs]", "pre-commit==3.5.0; python_version < '3.9'", "pre-commit==3.7.0; python_version >= '3.9'", - "detect-secrets==1.4.0", + "detect-secrets==1.5.0", ] [project.urls] diff --git a/tests/asyncapi/rabbit/test_router.py b/tests/asyncapi/rabbit/test_router.py index b878eac005..386f4960f5 100644 --- a/tests/asyncapi/rabbit/test_router.py +++ b/tests/asyncapi/rabbit/test_router.py @@ -63,7 +63,7 @@ async def handle(msg): ... "subscribe": { "bindings": { "amqp": { - "cc": "key", + "cc": "test_key", "ack": True, "bindingVersion": "0.2.0", } @@ -91,7 +91,7 @@ async def handle(msg): ... }, }, } - ) + ), schema class TestRouterArguments(ArgumentsTestcase): diff --git a/tests/brokers/base/middlewares.py b/tests/brokers/base/middlewares.py index 4f89f08411..7ed74522d8 100644 --- a/tests/brokers/base/middlewares.py +++ b/tests/brokers/base/middlewares.py @@ -270,6 +270,59 @@ async def handler(m): mock.start.assert_called_once() mock.end.assert_called_once() + async def test_add_global_middleware( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + raw_broker, + ): + class mid(BaseMiddleware): # noqa: N801 + async def on_receive(self): + mock.start(self.msg) + return await super().on_receive() + + async def after_processed(self, exc_type, exc_val, exc_tb): + mock.end() + return await super().after_processed(exc_type, exc_val, exc_tb) + + broker = self.broker_class() + + # already registered subscriber + @broker.subscriber(queue, **self.subscriber_kwargs) + async def handler(m): + event.set() + return "" + + # should affect to already registered and a new subscriber both + broker.add_middleware(mid) + + event2 = asyncio.Event() + + # new subscriber + @broker.subscriber(f"{queue}1", **self.subscriber_kwargs) + async def handler2(m): + event2.set() + return "" + + broker = self.patch_broker(raw_broker, broker) + + async with broker: + await broker.start() + await asyncio.wait( + ( + asyncio.create_task(broker.publish("", queue)), + asyncio.create_task(broker.publish("", f"{queue}1")), + asyncio.create_task(event.wait()), + asyncio.create_task(event2.wait()), + ), + timeout=self.timeout, + ) + + assert event.is_set() + assert mock.start.call_count == 2 + assert mock.end.call_count == 2 + async def test_patch_publish(self, queue: str, mock: Mock, event, raw_broker): class Mid(BaseMiddleware): async def on_publish(self, msg: str, *args, **kwargs) -> str: diff --git a/tests/brokers/base/publish.py b/tests/brokers/base/publish.py index 2ed026c9f7..4deb2a50ae 100644 --- a/tests/brokers/base/publish.py +++ b/tests/brokers/base/publish.py @@ -1,4 +1,5 @@ import asyncio +from dataclasses import asdict, dataclass from datetime import datetime from typing import Any, ClassVar, Dict, List, Tuple from unittest.mock import Mock @@ -7,7 +8,7 @@ import pytest from pydantic import BaseModel -from faststream._compat import model_to_json +from faststream._compat import dump_json, model_to_json from faststream.annotations import Logger from faststream.broker.core.usecase import BrokerUsecase @@ -16,6 +17,11 @@ class SimpleModel(BaseModel): r: str +@dataclass +class SimpleDataclass: + r: str + + now = datetime.now() @@ -55,6 +61,12 @@ def pub_broker(self, full_broker): 1.0, id="float->float", ), + pytest.param( + 1, + float, + 1.0, + id="int->float", + ), pytest.param( False, bool, @@ -103,6 +115,30 @@ def pub_broker(self, full_broker): SimpleModel(r="hello!"), id="dict->model", ), + pytest.param( + dump_json(asdict(SimpleDataclass(r="hello!"))), + SimpleDataclass, + SimpleDataclass(r="hello!"), + id="bytes->dataclass", + ), + pytest.param( + SimpleDataclass(r="hello!"), + SimpleDataclass, + SimpleDataclass(r="hello!"), + id="dataclass->dataclass", + ), + pytest.param( + SimpleDataclass(r="hello!"), + dict, + {"r": "hello!"}, + id="dataclass->dict", + ), + pytest.param( + {"r": "hello!"}, + SimpleDataclass, + SimpleDataclass(r="hello!"), + id="dict->dataclass", + ), ), ) async def test_serialize( diff --git a/tests/brokers/confluent/test_consume.py b/tests/brokers/confluent/test_consume.py index fb612d66d0..2c471c6e73 100644 --- a/tests/brokers/confluent/test_consume.py +++ b/tests/brokers/confluent/test_consume.py @@ -39,6 +39,38 @@ async def handler(msg): assert [{1, "hi"}] == [set(r.result()) for r in result] + @pytest.mark.asyncio() + async def test_consume_batch_headers( + self, mock, event: asyncio.Event, queue: str, full_broker: KafkaBroker + ): + @full_broker.subscriber(queue, batch=True, **self.subscriber_kwargs) + def subscriber(m, msg: KafkaMessage): + check = all( + ( + msg.headers, + [msg.headers] == msg.batch_headers, + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + async with full_broker: + await full_broker.start() + + await asyncio.wait( + ( + asyncio.create_task( + full_broker.publish("", queue, headers={"custom": "1"}) + ), + asyncio.create_task(event.wait()), + ), + timeout=self.timeout, + ) + + assert event.is_set() + mock.assert_called_once_with(True) + @pytest.mark.asyncio() @pytest.mark.slow() async def test_consume_ack( diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index a50c06d8c4..82c3a7d0b8 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -33,6 +33,38 @@ async def handler(msg): assert [{1, "hi"}] == [set(r.result()) for r in result] + @pytest.mark.asyncio() + async def test_consume_batch_headers( + self, mock, event: asyncio.Event, queue: str, full_broker: KafkaBroker + ): + @full_broker.subscriber(queue, batch=True) + def subscriber(m, msg: KafkaMessage): + check = all( + ( + msg.headers, + [msg.headers] == msg.batch_headers, + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + async with full_broker: + await full_broker.start() + + await asyncio.wait( + ( + asyncio.create_task( + full_broker.publish("", queue, headers={"custom": "1"}) + ), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(True) + @pytest.mark.asyncio() @pytest.mark.slow() async def test_consume_ack( diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 23a1576287..1e7997526e 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -96,7 +96,6 @@ def subscriber(m): assert event.is_set() mock.assert_called_once_with([b"hello"]) - @pytest.mark.asyncio() async def test_consume_ack( self, queue: str, @@ -127,7 +126,6 @@ async def handler(msg: NatsMessage): assert event.is_set() - @pytest.mark.asyncio() async def test_consume_ack_manual( self, queue: str, @@ -159,7 +157,6 @@ async def handler(msg: NatsMessage): assert event.is_set() - @pytest.mark.asyncio() async def test_consume_ack_raise( self, queue: str, @@ -191,7 +188,6 @@ async def handler(msg: NatsMessage): assert event.is_set() - @pytest.mark.asyncio() async def test_nack( self, queue: str, @@ -223,7 +219,6 @@ async def handler(msg: NatsMessage): assert event.is_set() - @pytest.mark.asyncio() async def test_consume_no_ack( self, queue: str, full_broker: NatsBroker, event: asyncio.Event ): @@ -248,3 +243,41 @@ async def handler(msg: NatsMessage): m.mock.assert_not_called() assert event.is_set() + + async def test_consume_batch_headers( + self, + queue: str, + full_broker: NatsBroker, + stream: JStream, + event: asyncio.Event, + mock, + ): + @full_broker.subscriber( + queue, + stream=stream, + pull_sub=PullSub(1, batch=True), + ) + def subscriber(m, msg: NatsMessage): + check = all( + ( + msg.headers, + [msg.headers] == msg.batch_headers, + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + await full_broker.start() + await asyncio.wait( + ( + asyncio.create_task( + full_broker.publish("", queue, headers={"custom": "1"}) + ), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(True) diff --git a/tests/brokers/rabbit/test_router.py b/tests/brokers/rabbit/test_router.py index 50f81a636d..ac14d3372d 100644 --- a/tests/brokers/rabbit/test_router.py +++ b/tests/brokers/rabbit/test_router.py @@ -139,6 +139,39 @@ def subscriber(m): assert event.is_set() + async def test_queue_obj_with_routing_key( + self, + router: RabbitRouter, + broker: RabbitBroker, + queue: str, + event: asyncio.Event, + ): + router.prefix = "test/" + + r_queue = RabbitQueue("useless", routing_key=f"{queue}1") + exchange = RabbitExchange(f"{queue}exch") + + @router.subscriber(r_queue, exchange=exchange) + def subscriber(m): + event.set() + + broker.include_router(router) + + async with broker: + await broker.start() + + await asyncio.wait( + ( + asyncio.create_task( + broker.publish("hello", f"test/{queue}1", exchange=exchange) + ), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + async def test_delayed_handlers_with_queue( self, event: asyncio.Event, diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index 176fd0965f..8ddad852c8 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -138,30 +138,70 @@ async def handler(msg): mock.assert_called_once_with(b"hello") @pytest.mark.slow() - async def test_consume_list_batch_with_one(self, queue: str, broker: RedisBroker): - msgs_queue = asyncio.Queue(maxsize=1) - - @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=1)) + async def test_consume_list_batch_with_one( + self, event: asyncio.Event, mock, queue: str, broker: RedisBroker + ): + @broker.subscriber( + list=ListSub(queue, batch=True, max_records=1, polling_interval=0.01) + ) async def handler(msg): - await msgs_queue.put(msg) + mock(msg) + event.set() async with broker: await broker.start() - await broker.publish("hi", list=queue) - - result, _ = await asyncio.wait( - (asyncio.create_task(msgs_queue.get()),), + await asyncio.wait( + ( + asyncio.create_task(broker.publish("hi", list=queue)), + asyncio.create_task(event.wait()), + ), timeout=3, ) - assert ["hi"] == [r.result()[0] for r in result] + assert event.is_set() + mock.assert_called_once_with(["hi"]) + + @pytest.mark.slow() + async def test_consume_list_batch_headers( + self, + queue: str, + full_broker: RedisBroker, + event: asyncio.Event, + mock, + ): + @full_broker.subscriber(list=ListSub(queue, batch=True, polling_interval=0.01)) + def subscriber(m, msg: RedisMessage): + check = all( + ( + msg.headers, + msg.headers["correlation_id"] + == msg.batch_headers[0]["correlation_id"], + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + await full_broker.start() + await asyncio.wait( + ( + asyncio.create_task( + full_broker.publish("", list=queue, headers={"custom": "1"}) + ), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(True) @pytest.mark.slow() async def test_consume_list_batch(self, queue: str, broker: RedisBroker): msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=1)) + @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=0.01)) async def handler(msg): await msgs_queue.put(msg) @@ -189,7 +229,7 @@ def __hash__(self): msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=1)) + @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=0.01)) async def handler(msg: List[Data]): await msgs_queue.put(msg) @@ -210,7 +250,7 @@ async def handler(msg: List[Data]): async def test_consume_list_batch_native(self, queue: str, broker: RedisBroker): msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=1)) + @broker.subscriber(list=ListSub(queue, batch=True, polling_interval=0.01)) async def handler(msg): await msgs_queue.put(msg) @@ -238,7 +278,7 @@ async def test_consume_stream( mock: MagicMock, queue, ): - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000)) + @broker.subscriber(stream=StreamSub(queue, polling_interval=10)) async def handler(msg): mock(msg) event.set() @@ -264,7 +304,7 @@ async def test_consume_stream_native( mock: MagicMock, queue, ): - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000)) + @broker.subscriber(stream=StreamSub(queue, polling_interval=10)) async def handler(msg): mock(msg) event.set() @@ -292,7 +332,7 @@ async def test_consume_stream_batch( mock: MagicMock, queue, ): - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000, batch=True)) + @broker.subscriber(stream=StreamSub(queue, polling_interval=10, batch=True)) async def handler(msg): mock(msg) event.set() @@ -310,6 +350,43 @@ async def handler(msg): mock.assert_called_once_with(["hello"]) + @pytest.mark.slow() + async def test_consume_stream_batch_headers( + self, + queue: str, + full_broker: RedisBroker, + event: asyncio.Event, + mock, + ): + @full_broker.subscriber( + stream=StreamSub(queue, polling_interval=10, batch=True) + ) + def subscriber(m, msg: RedisMessage): + check = all( + ( + msg.headers, + msg.headers["correlation_id"] + == msg.batch_headers[0]["correlation_id"], + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + await full_broker.start() + await asyncio.wait( + ( + asyncio.create_task( + full_broker.publish("", stream=queue, headers={"custom": "1"}) + ), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(True) + @pytest.mark.slow() async def test_consume_stream_batch_complex( self, @@ -323,7 +400,7 @@ class Data(BaseModel): msgs_queue = asyncio.Queue(maxsize=1) - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000, batch=True)) + @broker.subscriber(stream=StreamSub(queue, polling_interval=10, batch=True)) async def handler(msg: List[Data]): await msgs_queue.put(msg) @@ -348,7 +425,7 @@ async def test_consume_stream_batch_native( mock: MagicMock, queue, ): - @broker.subscriber(stream=StreamSub(queue, polling_interval=3000, batch=True)) + @broker.subscriber(stream=StreamSub(queue, polling_interval=10, batch=True)) async def handler(msg): mock(msg) event.set() diff --git a/tests/brokers/redis/test_fastapi.py b/tests/brokers/redis/test_fastapi.py index c61f88614d..36e95d1a29 100644 --- a/tests/brokers/redis/test_fastapi.py +++ b/tests/brokers/redis/test_fastapi.py @@ -86,7 +86,7 @@ async def test_consume_stream( ): router = RedisRouter() - @router.subscriber(stream=StreamSub(queue, polling_interval=3000)) + @router.subscriber(stream=StreamSub(queue, polling_interval=10)) async def handler(msg): mock(msg) event.set() @@ -114,7 +114,7 @@ async def test_consume_stream_batch( ): router = RedisRouter() - @router.subscriber(stream=StreamSub(queue, polling_interval=3000, batch=True)) + @router.subscriber(stream=StreamSub(queue, polling_interval=10, batch=True)) async def handler(msg: List[str]): mock(msg) event.set()