From 11d70bf6c0d945c364d813e7234604ea9d0465ec Mon Sep 17 00:00:00 2001 From: Francis Santos Date: Thu, 18 Jul 2024 18:45:27 +0200 Subject: [PATCH 1/2] feat: init concurrent rpc for rabbit broker --- .secrets.baseline | 4 +- CODE_OF_CONDUCT.md | 2 +- docs/docs/en/release.md | 31 ++++++- faststream/rabbit/broker/connection.py | 4 +- faststream/rabbit/publisher/producer.py | 111 ++++++++++++++++-------- tests/brokers/rabbit/test_rpc.py | 83 ++++++++++++++++++ tests/conftest.py | 8 ++ 7 files changed, 199 insertions(+), 44 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index 3fd4156bb0..59625efc63 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -128,7 +128,7 @@ "filename": "docs/docs/en/release.md", "hashed_secret": "35675e68f4b5af7b995d9205ad0fc43842f16450", "is_verified": false, - "line_number": 1325, + "line_number": 1376, "is_secret": false } ], @@ -163,5 +163,5 @@ } ] }, - "generated_at": "2024-06-10T09:56:52Z" + "generated_at": "2024-07-18T16:41:25Z" } diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 35445a790b..6b6ac2f90b 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -6,7 +6,7 @@ We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, +identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. diff --git a/docs/docs/en/release.md b/docs/docs/en/release.md index feadf44839..5a48192f49 100644 --- a/docs/docs/en/release.md +++ b/docs/docs/en/release.md @@ -12,6 +12,35 @@ hide: --- # Release Notes +## 0.5.14 + +### What's Changed +* Update Release Notes for 0.5.13 by @faststream-release-notes-updater in [#1548](https://github.com/airtai/faststream/pull/1548){.external-link target="_blank"} +* Add allow_auto_create_topics to make automatic topic creation configurable by [@kumaranvpl](https://github.com/kumaranvpl){.external-link target="_blank"} in [#1556](https://github.com/airtai/faststream/pull/1556){.external-link target="_blank"} + + +**Full Changelog**: [#0.5.13...0.5.14](https://github.com/airtai/faststream/compare/0.5.13...0.5.14){.external-link target="_blank"} + +## 0.5.13 + +### What's Changed + +* feat: nats filter JS subscription support by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1519](https://github.com/airtai/faststream/pull/1519){.external-link target="_blank"} +* fix: correct RabbitExchange processing by OTEL in broker.publish case by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1521](https://github.com/airtai/faststream/pull/1521){.external-link target="_blank"} +* fix: correct Nats ObjectStorage get file behavior inside watch subscriber by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1523](https://github.com/airtai/faststream/pull/1523){.external-link target="_blank"} +* Resolve Issue 1386, Add rpc_prefix by [@aKardasz](https://github.com/aKardasz){.external-link target="_blank"} in [#1484](https://github.com/airtai/faststream/pull/1484){.external-link target="_blank"} +* fix: correct spans linking in batches case by [@draincoder](https://github.com/draincoder){.external-link target="_blank"} in [#1532](https://github.com/airtai/faststream/pull/1532){.external-link target="_blank"} +* fix (#1539): correct anyio.create_memory_object_stream annotation by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1541](https://github.com/airtai/faststream/pull/1541){.external-link target="_blank"} +* fix: correct publish_coverage CI by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1536](https://github.com/airtai/faststream/pull/1536){.external-link target="_blank"} +* Add NatsBroker.new_inbox() by [@maxalbert](https://github.com/maxalbert){.external-link target="_blank"} in [#1543](https://github.com/airtai/faststream/pull/1543){.external-link target="_blank"} +* fix (#1544): correct Redis message nack & reject signature by [@Lancetnik](https://github.com/Lancetnik){.external-link target="_blank"} in [#1546](https://github.com/airtai/faststream/pull/1546){.external-link target="_blank"} + +### New Contributors +* [@aKardasz](https://github.com/aKardasz){.external-link target="_blank"} made their first contribution in [#1484](https://github.com/airtai/faststream/pull/1484){.external-link target="_blank"} +* [@maxalbert](https://github.com/maxalbert){.external-link target="_blank"} made their first contribution in [#1543](https://github.com/airtai/faststream/pull/1543){.external-link target="_blank"} + +**Full Changelog**: [#0.5.12...0.5.13](https://github.com/airtai/faststream/compare/0.5.12...0.5.13){.external-link target="_blank"} + ## 0.5.12 ### What's Changed @@ -324,7 +353,7 @@ You can find more information about it in the official [**aiokafka** doc](https: `pattern` option was added too, but it is still experimental and does not support `Path` -3. [`Path`](https://faststream.airt.ai/latest/nats/message/#subject-pattern-access) feature performance was increased. Also, `Path` is suitable for NATS `PullSub` batch subscribtion as well now. +3. [`Path`](https://faststream.airt.ai/latest/nats/message/#subject-pattern-access) feature performance was increased. Also, `Path` is suitable for NATS `PullSub` batch subscription as well now. ```python from faststream import NatsBroker, PullSub diff --git a/faststream/rabbit/broker/connection.py b/faststream/rabbit/broker/connection.py index b332eb8a42..636b2eb8a7 100644 --- a/faststream/rabbit/broker/connection.py +++ b/faststream/rabbit/broker/connection.py @@ -27,7 +27,7 @@ def __init__( publisher_confirms: bool, on_return_raises: bool, ) -> None: - self._connection_pool: "Pool[RobustConnection]" = Pool( + self._connection_pool: Pool[RobustConnection] = Pool( lambda: connect_robust( url=url, timeout=timeout, @@ -36,7 +36,7 @@ def __init__( max_size=connection_pool_size, ) - self._channel_pool: "Pool[RobustChannel]" = Pool( + self._channel_pool: Pool[RobustChannel] = Pool( lambda: self._get_channel( channel_number=channel_number, publisher_confirms=publisher_confirms, diff --git a/faststream/rabbit/publisher/producer.py b/faststream/rabbit/publisher/producer.py index 4d1a6acd23..38092c7d11 100644 --- a/faststream/rabbit/publisher/producer.py +++ b/faststream/rabbit/publisher/producer.py @@ -1,7 +1,11 @@ +from asyncio import Future +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, AsyncContextManager, + ClassVar, + Dict, Optional, Type, Union, @@ -9,7 +13,6 @@ ) import anyio -from aio_pika.abc import AbstractIncomingMessage from typing_extensions import override from faststream.broker.publisher.proto import ProducerProto @@ -17,6 +20,7 @@ from faststream.exceptions import WRONG_PUBLISH_ARGS from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.schemas import RABBIT_REPLY, RabbitExchange +from faststream.utils.classes import Singleton from faststream.utils.functions import fake_context, timeout_scope if TYPE_CHECKING: @@ -25,7 +29,6 @@ import aiormq from aio_pika import IncomingMessage, RobustChannel, RobustQueue from aio_pika.abc import DateType, HeadersType, TimeoutType - from anyio.streams.memory import MemoryObjectReceiveStream from faststream.broker.types import ( AsyncCallable, @@ -50,8 +53,7 @@ def __init__( decoder: Optional["CustomCallable"], ) -> None: self.declarer = declarer - - self._rpc_lock = anyio.Lock() + self.rpc_manager = _RPCManager(declarer=declarer) default_parser = AioPikaParser() self._parser = resolve_custom_func(parser, default_parser.parse_message) @@ -85,25 +87,23 @@ async def publish( # type: ignore[override] app_id: Optional[str] = None, ) -> Optional[Any]: """Publish a message to a RabbitMQ queue.""" - context: AsyncContextManager[ - Optional[MemoryObjectReceiveStream[IncomingMessage]] - ] - channel: Optional["RobustChannel"] + context: AsyncContextManager[Optional[RobustChannel]] + channel: Optional[RobustChannel] + response: Optional[Future[IncomingMessage]] if rpc: if reply_to is not None: raise WRONG_PUBLISH_ARGS - rmq_queue = await self.declarer.declare_queue(RABBIT_REPLY) - channel = cast("RobustChannel", rmq_queue.channel) - context = _RPCCallback(self._rpc_lock, rmq_queue) - reply_to = RABBIT_REPLY.name + context = await self.rpc_manager(correlation_id=correlation_id) + response = self.rpc_manager.result + reply_to = self.rpc_manager.queue.name else: - channel = None + response = None context = fake_context() - async with context as response_queue: + async with context as channel: r = await self._publish( message=message, exchange=exchange, @@ -127,13 +127,13 @@ async def publish( # type: ignore[override] channel=channel, ) - if response_queue is None: + if response is None: return r else: msg: Optional[IncomingMessage] = None with timeout_scope(rpc_timeout, raise_timeout): - msg = await response_queue.receive() + msg = await response if msg: # pragma: no branch return await self._decoder(await self._parser(msg)) @@ -197,31 +197,62 @@ async def _publish( ) -class _RPCCallback: +class _RPCManager(Singleton, AsyncContextManager["RobustChannel"]): """A class provides an RPC lock.""" - def __init__(self, lock: "anyio.Lock", callback_queue: "RobustQueue") -> None: - self.lock = lock - self.queue = callback_queue + # Singleton entities (all async tasks share the same context) + __lock = anyio.Lock() + __rpc_messages: ClassVar[Dict[str, Future]] = {} # type: ignore[type-arg] + __consumer_tags: ClassVar[Dict["RobustChannel", str]] = {} - async def __aenter__(self) -> "MemoryObjectReceiveStream[IncomingMessage]": - ( - send_response_stream, - receive_response_stream, - ) = anyio.create_memory_object_stream[AbstractIncomingMessage]( - max_buffer_size=1 - ) - await self.lock.acquire() + # Context variables (async tasks have their own context) + __current_correlation_id: ContextVar[str] = ContextVar("rpc_correlation_id") + __current_queue: ContextVar["RobustQueue"] = ContextVar("rpc_queue") + __current_result: ContextVar[Future] = ContextVar("rpc_result") # type: ignore[type-arg] - self.consumer_tag = await self.queue.consume( - callback=send_response_stream.send, - no_ack=True, - ) + def __init__(self, declarer: "RabbitDeclarer") -> None: + self.declarer = declarer - return cast( - "MemoryObjectReceiveStream[IncomingMessage]", - receive_response_stream, - ) + @property + def queue(self) -> "RobustQueue": + return self.__current_queue.get() + + @property + def correlation_id(self) -> str: + return self.__current_correlation_id.get() + + @property + def result(self) -> Future: # type: ignore[type-arg] + return self.__current_result.get() + + async def __rpc_callback(self, msg: "IncomingMessage") -> None: + """A callback function to handle RPC messages.""" + if msg.correlation_id in self.__rpc_messages: + self.__rpc_messages[msg.correlation_id].set_result(msg) + + async def __call__(self, correlation_id: str) -> "_RPCManager": + """Sets the current RPC context.""" + async with self.__lock: + if correlation_id in self.__rpc_messages: + raise RuntimeError("The correlation ID is already in use.") + self.__current_result.set(Future()) + self.__rpc_messages[correlation_id] = self.result + + self.__current_queue.set(await self.declarer.declare_queue(RABBIT_REPLY)) + self.__current_correlation_id.set(correlation_id) + + return self + + async def __aenter__(self) -> "RobustChannel": + async with self.__lock: + if self.queue.channel not in self.__consumer_tags: + consumer_tag = await self.queue.consume( + callback=self.__rpc_callback, # type: ignore[arg-type] + no_ack=True, + ) + self.__consumer_tags[self.queue.channel] = consumer_tag # type: ignore[index] + + return cast("RobustChannel", self.queue.channel) async def __aexit__( self, @@ -229,5 +260,9 @@ async def __aexit__( exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> None: - self.lock.release() - await self.queue.cancel(self.consumer_tag) + self.__rpc_messages.pop(self.__current_correlation_id.get()) # type: ignore[call-overload] + if exc_tb: + async with self.__lock: + if self.queue.channel in self.__consumer_tags: + await self.queue.cancel(self.__consumer_tags[self.queue.channel]) # type: ignore[index] + self.__consumer_tags.pop(self.queue.channel) # type: ignore[call-overload] diff --git a/tests/brokers/rabbit/test_rpc.py b/tests/brokers/rabbit/test_rpc.py index d0bd80cab7..709ecdb3c4 100644 --- a/tests/brokers/rabbit/test_rpc.py +++ b/tests/brokers/rabbit/test_rpc.py @@ -1,6 +1,11 @@ +import asyncio +import uuid + +import anyio import pytest from faststream.rabbit import RabbitBroker +from faststream.rabbit.publisher.producer import _RPCManager from tests.brokers.base.rpc import BrokerRPCTestcase, ReplyAndConsumeForbidden @@ -8,3 +13,81 @@ class TestRPC(BrokerRPCTestcase, ReplyAndConsumeForbidden): def get_broker(self, apply_types: bool = False) -> RabbitBroker: return RabbitBroker(apply_types=apply_types) + + @pytest.mark.asyncio() + async def test_rpc_with_concurrency(self, queue: str): + rpc_broker = self.get_broker() + + @rpc_broker.subscriber(queue) + async def m(m): # pragma: no cover + await asyncio.sleep(1) + return m + + async with self.patch_broker(rpc_broker) as br: + await br.start() + + with anyio.fail_after(3): + results = await asyncio.gather( + *[ + br.publish( + f"hello {i}", + queue, + rpc=True, + ) + for i in range(10) + ] + ) + + for i, r in enumerate(results): + assert r == f"hello {i}" + + +class TestRPCManager: + @pytest.mark.asyncio() + async def test_context_variables_per_concurrent_task(self): + rpc_broker = RabbitBroker() + rpc_manager = _RPCManager(declarer=rpc_broker.declarer) + results = set() + correlation_ids = set() + channels = set() + + async def run_operation(): + context = await rpc_manager(correlation_id=uuid.uuid4().hex) + async with context: + results.add(rpc_manager.result) + correlation_ids.add(rpc_manager.correlation_id) + channels.add(rpc_manager.queue.channel) + + await rpc_broker.start() + await asyncio.gather(*[run_operation() for _ in range(10)]) + assert len(results) == 10 + assert len(correlation_ids) == 10 + assert len(channels) == 1 + + @pytest.mark.asyncio() + async def test_one_queue_per_channel(self): + rpc_broker = RabbitBroker(max_channel_pool_size=10) + rpc_manager = _RPCManager(declarer=rpc_broker.declarer) + channels = set() + + async def run_operation(): + context = await rpc_manager(correlation_id=uuid.uuid4().hex) + async with context: + channels.add(rpc_manager.queue) + + await rpc_broker.start() + await asyncio.gather(*[run_operation() for _ in range(10)]) + assert len(channels) == 10 + + @pytest.mark.asyncio() + async def test_clean_up_after_exception(self): + rpc_broker = RabbitBroker() + rpc_manager = _RPCManager(declarer=rpc_broker.declarer) + + await rpc_broker.start() + with pytest.raises(ValueError): # noqa: PT011 + async with await rpc_manager(correlation_id=uuid.uuid4().hex): + raise ValueError("test") + + assert len(rpc_manager._RPCManager__rpc_messages) == 0 + assert not rpc_manager.queue._consumers diff --git a/tests/conftest.py b/tests/conftest.py index 92778c660a..1caa0b07ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,3 +62,11 @@ def context(): @pytest.fixture() def kafka_basic_project(): return "docs.docs_src.kafka.basic.basic:app" + + +@pytest.fixture(scope="session", autouse=True) +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() From 6406019861e887023e677582aa8886bb4235001e Mon Sep 17 00:00:00 2001 From: Francis Santos Date: Fri, 19 Jul 2024 12:03:24 +0200 Subject: [PATCH 2/2] ref: use a single channel for rpc --- faststream/rabbit/broker/broker.py | 2 +- faststream/rabbit/helpers/declarer.py | 48 ++++++-- faststream/rabbit/publisher/producer.py | 141 ++++++++++-------------- faststream/utils/functions.py | 9 ++ tests/brokers/rabbit/test_rpc.py | 119 ++++++++++++-------- 5 files changed, 182 insertions(+), 137 deletions(-) diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index 013914b9d5..2f54a0f5bd 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -201,7 +201,7 @@ def __init__( max_channel_pool_size: Annotated[ int, Doc("Max channel pool size"), - ] = 1, + ] = 2, # NOTE: because we're sharing channels between consumers and producers ) -> None: security_args = parse_security(security) diff --git a/faststream/rabbit/helpers/declarer.py b/faststream/rabbit/helpers/declarer.py index a5f38f8deb..1302a02035 100644 --- a/faststream/rabbit/helpers/declarer.py +++ b/faststream/rabbit/helpers/declarer.py @@ -1,5 +1,5 @@ -from contextlib import AsyncExitStack -from typing import TYPE_CHECKING, Dict, Optional, cast +from contextlib import AsyncExitStack, asynccontextmanager +from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, Tuple, cast if TYPE_CHECKING: import aio_pika @@ -12,8 +12,13 @@ class RabbitDeclarer: """An utility class to declare RabbitMQ queues and exchanges.""" __connection_manager: "ConnectionManager" - __queues: Dict["RabbitQueue", "aio_pika.RobustQueue"] - __exchanges: Dict["RabbitExchange", "aio_pika.RobustExchange"] + __queues: Dict[ + Tuple[Optional["aio_pika.RobustChannel"], "RabbitQueue"], "aio_pika.RobustQueue" + ] + __exchanges: Dict[ + Tuple[Optional["aio_pika.RobustChannel"], "RabbitExchange"], + "aio_pika.RobustExchange", + ] def __init__(self, connection_manager: "ConnectionManager") -> None: self.__connection_manager = connection_manager @@ -28,14 +33,18 @@ async def declare_queue( channel: Optional["aio_pika.RobustChannel"] = None, ) -> "aio_pika.RobustQueue": """Declare a queue.""" - if (queue_obj := self.__queues.get(queue)) is None: + # NOTE: It would return the queue linked to another channel if it was already declared + # unless the channel is part of the key + if (queue_obj := self.__queues.get((channel, queue))) is None: async with AsyncExitStack() as stack: if channel is None: channel = await stack.enter_async_context( self.__connection_manager.acquire_channel() ) + if (channel, queue) in self.__queues: + return self.__queues[(channel, queue)] - self.__queues[queue] = queue_obj = cast( + self.__queues[(channel, queue)] = queue_obj = cast( "aio_pika.RobustQueue", await channel.declare_queue( name=queue.name, @@ -59,7 +68,9 @@ async def declare_exchange( channel: Optional["aio_pika.RobustChannel"] = None, ) -> "aio_pika.RobustExchange": """Declare an exchange, parent exchanges and bind them each other.""" - if exch := self.__exchanges.get(exchange): + # NOTE: It would return the queue linked to another channel if it was already declared + # unless the channel is part of the key + if exch := self.__exchanges.get((channel, exchange)): return exch async with AsyncExitStack() as stack: @@ -67,12 +78,14 @@ async def declare_exchange( channel = await stack.enter_async_context( self.__connection_manager.acquire_channel() ) + if (channel, exchange) in self.__exchanges: + return self.__exchanges[(channel, exchange)] if not exchange.name: return channel.default_exchange else: - self.__exchanges[exchange] = exch = cast( + self.__exchanges[(channel, exchange)] = exch = cast( "aio_pika.RobustExchange", await channel.declare_exchange( name=exchange.name, @@ -102,3 +115,22 @@ async def declare_exchange( ) return exch # type: ignore[return-value] + + @asynccontextmanager + async def declare_queue_scope( + self, + queue: "RabbitQueue", + passive: bool = False, + *, + channel: Optional["aio_pika.RobustChannel"] = None, + ) -> AsyncGenerator["aio_pika.RobustQueue", None]: + """Declare a queue and return it with a context manager.""" + async with AsyncExitStack() as stack: + if channel is None: + channel = await stack.enter_async_context( + self.__connection_manager.acquire_channel() + ) + + yield await self.declare_queue( + queue=queue, passive=passive, channel=channel + ) diff --git a/faststream/rabbit/publisher/producer.py b/faststream/rabbit/publisher/producer.py index 38092c7d11..5f3976e98a 100644 --- a/faststream/rabbit/publisher/producer.py +++ b/faststream/rabbit/publisher/producer.py @@ -1,18 +1,17 @@ -from asyncio import Future -from contextvars import ContextVar +from contextlib import asynccontextmanager from typing import ( TYPE_CHECKING, Any, AsyncContextManager, - ClassVar, - Dict, + AsyncGenerator, Optional, - Type, + Tuple, Union, cast, ) import anyio +from aio_pika.abc import AbstractIncomingMessage from typing_extensions import override from faststream.broker.publisher.proto import ProducerProto @@ -21,14 +20,16 @@ from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.schemas import RABBIT_REPLY, RabbitExchange from faststream.utils.classes import Singleton -from faststream.utils.functions import fake_context, timeout_scope +from faststream.utils.functions import ( + fake_context_yielding, + timeout_scope, +) if TYPE_CHECKING: - from types import TracebackType - import aiormq - from aio_pika import IncomingMessage, RobustChannel, RobustQueue + from aio_pika import IncomingMessage, RobustChannel from aio_pika.abc import DateType, HeadersType, TimeoutType + from anyio.streams.memory import MemoryObjectReceiveStream from faststream.broker.types import ( AsyncCallable, @@ -87,23 +88,26 @@ async def publish( # type: ignore[override] app_id: Optional[str] = None, ) -> Optional[Any]: """Publish a message to a RabbitMQ queue.""" - context: AsyncContextManager[Optional[RobustChannel]] + context: AsyncContextManager[ + Union[ + Tuple[MemoryObjectReceiveStream[IncomingMessage], RobustChannel], + Tuple[None, None], + ] + ] channel: Optional[RobustChannel] - response: Optional[Future[IncomingMessage]] + response_queue: Optional[MemoryObjectReceiveStream[IncomingMessage]] if rpc: if reply_to is not None: raise WRONG_PUBLISH_ARGS - context = await self.rpc_manager(correlation_id=correlation_id) - response = self.rpc_manager.result - reply_to = self.rpc_manager.queue.name + context = self.rpc_manager() + reply_to = RABBIT_REPLY.name else: - response = None - context = fake_context() + context = fake_context_yielding(with_yield=(None, None)) - async with context as channel: + async with context as (response_queue, channel): r = await self._publish( message=message, exchange=exchange, @@ -127,13 +131,13 @@ async def publish( # type: ignore[override] channel=channel, ) - if response is None: + if response_queue is None: return r else: msg: Optional[IncomingMessage] = None with timeout_scope(rpc_timeout, raise_timeout): - msg = await response + msg = await response_queue.receive() if msg: # pragma: no branch return await self._decoder(await self._parser(msg)) @@ -197,72 +201,45 @@ async def _publish( ) -class _RPCManager(Singleton, AsyncContextManager["RobustChannel"]): - """A class provides an RPC lock.""" - - # Singleton entities (all async tasks share the same context) - __lock = anyio.Lock() - __rpc_messages: ClassVar[Dict[str, Future]] = {} # type: ignore[type-arg] - __consumer_tags: ClassVar[Dict["RobustChannel", str]] = {} - - # Context variables (async tasks have their own context) - __current_correlation_id: ContextVar[str] = ContextVar("rpc_correlation_id") - __current_queue: ContextVar["RobustQueue"] = ContextVar("rpc_queue") - __current_result: ContextVar[Future] = ContextVar("rpc_result") # type: ignore[type-arg] +class _RPCManager(Singleton): + """A class that provides an RPC lock.""" def __init__(self, declarer: "RabbitDeclarer") -> None: self.declarer = declarer - @property - def queue(self) -> "RobustQueue": - return self.__current_queue.get() - - @property - def correlation_id(self) -> str: - return self.__current_correlation_id.get() - - @property - def result(self) -> Future: # type: ignore[type-arg] - return self.__current_result.get() - - async def __rpc_callback(self, msg: "IncomingMessage") -> None: - """A callback function to handle RPC messages.""" - if msg.correlation_id in self.__rpc_messages: - self.__rpc_messages[msg.correlation_id].set_result(msg) - - async def __call__(self, correlation_id: str) -> "_RPCManager": - """Sets the current RPC context.""" - async with self.__lock: - if correlation_id in self.__rpc_messages: - raise RuntimeError("The correlation ID is already in use.") - self.__current_result.set(Future()) - self.__rpc_messages[correlation_id] = self.result - - self.__current_queue.set(await self.declarer.declare_queue(RABBIT_REPLY)) - self.__current_correlation_id.set(correlation_id) - - return self - - async def __aenter__(self) -> "RobustChannel": - async with self.__lock: - if self.queue.channel not in self.__consumer_tags: - consumer_tag = await self.queue.consume( - callback=self.__rpc_callback, # type: ignore[arg-type] + @asynccontextmanager + async def __call__( + self, + ) -> AsyncGenerator[ + Tuple[ + "MemoryObjectReceiveStream[IncomingMessage]", + "RobustChannel", + ], + None, + ]: + # NOTE: this allows us to make sure the channel is only used by a single + # RPC call at a time, however, if the channel pool is used for both consuming + # and producing, they will be blocked by each other + async with self.declarer.declare_queue_scope(RABBIT_REPLY) as queue: + consumer_tag = None + try: + ( + send_response_stream, + receive_response_stream, + ) = anyio.create_memory_object_stream[AbstractIncomingMessage]( + max_buffer_size=1 + ) + consumer_tag = await queue.consume( + callback=send_response_stream.send, # type: ignore[arg-type] no_ack=True, ) - self.__consumer_tags[self.queue.channel] = consumer_tag # type: ignore[index] - - return cast("RobustChannel", self.queue.channel) - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]] = None, - exc_val: Optional[BaseException] = None, - exc_tb: Optional["TracebackType"] = None, - ) -> None: - self.__rpc_messages.pop(self.__current_correlation_id.get()) # type: ignore[call-overload] - if exc_tb: - async with self.__lock: - if self.queue.channel in self.__consumer_tags: - await self.queue.cancel(self.__consumer_tags[self.queue.channel]) # type: ignore[index] - self.__consumer_tags.pop(self.queue.channel) # type: ignore[call-overload] + yield ( + cast( + "MemoryObjectReceiveStream[IncomingMessage]", + receive_response_stream, + ), + cast("RobustChannel", queue.channel), + ) + finally: + if consumer_tag is not None: + await queue.cancel(consumer_tag) # type: ignore[index] diff --git a/faststream/utils/functions.py b/faststream/utils/functions.py index 81b1b06db9..81ff8ce02e 100644 --- a/faststream/utils/functions.py +++ b/faststream/utils/functions.py @@ -70,6 +70,15 @@ async def fake_context(*args: Any, **kwargs: Any) -> AsyncIterator[None]: yield None +@asynccontextmanager +async def fake_context_yielding( + *args: Any, + with_yield: F_Return = None, # type: ignore[assignment] + **kwargs: Any, +) -> AsyncIterator[F_Return]: + yield with_yield + + @contextmanager def sync_fake_context(*args: Any, **kwargs: Any) -> Iterator[None]: yield None diff --git a/tests/brokers/rabbit/test_rpc.py b/tests/brokers/rabbit/test_rpc.py index 709ecdb3c4..a0e7cebd09 100644 --- a/tests/brokers/rabbit/test_rpc.py +++ b/tests/brokers/rabbit/test_rpc.py @@ -1,5 +1,4 @@ import asyncio -import uuid import anyio import pytest @@ -11,22 +10,26 @@ @pytest.mark.rabbit() class TestRPC(BrokerRPCTestcase, ReplyAndConsumeForbidden): - def get_broker(self, apply_types: bool = False) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types) + def get_broker( + self, apply_types: bool = False, max_channel_pool_size: int = 2 + ) -> RabbitBroker: + return RabbitBroker( + apply_types=apply_types, max_channel_pool_size=max_channel_pool_size + ) @pytest.mark.asyncio() async def test_rpc_with_concurrency(self, queue: str): - rpc_broker = self.get_broker() + rpc_broker = self.get_broker(max_channel_pool_size=20) @rpc_broker.subscriber(queue) async def m(m): # pragma: no cover - await asyncio.sleep(1) + await asyncio.sleep(0.1) return m async with self.patch_broker(rpc_broker) as br: await br.start() - with anyio.fail_after(3): + with anyio.fail_after(1): results = await asyncio.gather( *[ br.publish( @@ -41,53 +44,77 @@ async def m(m): # pragma: no cover for i, r in enumerate(results): assert r == f"hello {i}" - -class TestRPCManager: @pytest.mark.asyncio() - async def test_context_variables_per_concurrent_task(self): - rpc_broker = RabbitBroker() - rpc_manager = _RPCManager(declarer=rpc_broker.declarer) - results = set() - correlation_ids = set() - channels = set() + async def test_rpc_with_concurrency_equal_consumers_channels(self, queue: str): + rpc_broker = self.get_broker(max_channel_pool_size=9) - async def run_operation(): - context = await rpc_manager(correlation_id=uuid.uuid4().hex) - async with context: - results.add(rpc_manager.result) - correlation_ids.add(rpc_manager.correlation_id) - channels.add(rpc_manager.queue.channel) - - await rpc_broker.start() - await asyncio.gather(*[run_operation() for _ in range(10)]) - assert len(results) == 10 - assert len(correlation_ids) == 10 - assert len(channels) == 1 + @rpc_broker.subscriber(queue) + async def m(m): # pragma: no cover + await asyncio.sleep(0.1) + return m + + async with self.patch_broker(rpc_broker) as br: + await br.start() + + with anyio.fail_after(1): + results = await asyncio.gather( + *[ + br.publish( + f"hello {i}", + queue, + rpc=True, + ) + for i in range(10) + ] + ) + + for i, r in enumerate(results): + assert r == f"hello {i}" @pytest.mark.asyncio() - async def test_one_queue_per_channel(self): - rpc_broker = RabbitBroker(max_channel_pool_size=10) - rpc_manager = _RPCManager(declarer=rpc_broker.declarer) - channels = set() + async def test_rpc_recovers_after_timeout(self, queue: str): + rpc_broker = self.get_broker() - async def run_operation(): - context = await rpc_manager(correlation_id=uuid.uuid4().hex) - async with context: - channels.add(rpc_manager.queue) + @rpc_broker.subscriber(queue) + async def m(m): # pragma: no cover + await anyio.sleep(0.1) + return m + + async with self.patch_broker(rpc_broker) as br: + await br.start() + + with pytest.raises(TimeoutError): # pragma: no branch + await br.publish( + "hello", + queue, + rpc=True, + rpc_timeout=0, + raise_timeout=True, + ) + assert ( + await br.publish( + "hello", + queue, + rpc=True, + ) + ) == "hello" - await rpc_broker.start() - await asyncio.gather(*[run_operation() for _ in range(10)]) - assert len(channels) == 10 +class TestRPCManager: @pytest.mark.asyncio() - async def test_clean_up_after_exception(self): - rpc_broker = RabbitBroker() + async def test_context_variables_per_concurrent_task(self): + rpc_broker = RabbitBroker(max_channel_pool_size=10) rpc_manager = _RPCManager(declarer=rpc_broker.declarer) + receive_streams = set() + channels = set() - await rpc_broker.start() - with pytest.raises(ValueError): # noqa: PT011 - async with await rpc_manager(correlation_id=uuid.uuid4().hex): - raise ValueError("test") - - assert len(rpc_manager._RPCManager__rpc_messages) == 0 - assert not rpc_manager.queue._consumers + async def run_operation(): + async with rpc_manager() as (receive_stream, channel): + receive_streams.add(receive_stream) + channels.add(channel) + await asyncio.sleep(0.1) + + async with rpc_broker: + await asyncio.gather(*[run_operation() for _ in range(10)]) + assert len(receive_streams) == 10 + assert len(channels) == 10