diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index 6cb357fef7..b6e3f09a4d 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -1,18 +1,10 @@ import logging from inspect import Parameter -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - Optional, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Type, Union, cast from urllib.parse import urlparse from aio_pika import connect_robust +from aio_pika.pool import Pool from typing_extensions import Annotated, Doc, override from faststream.__about__ import SERVICE_NAME @@ -22,11 +14,7 @@ from faststream.rabbit.broker.registrator import RabbitRegistrator from faststream.rabbit.helpers.declarer import RabbitDeclarer from faststream.rabbit.publisher.producer import AioPikaFastProducer -from faststream.rabbit.schemas import ( - RABBIT_REPLY, - RabbitExchange, - RabbitQueue, -) +from faststream.rabbit.schemas import RABBIT_REPLY, RabbitExchange, RabbitQueue from faststream.rabbit.security import parse_security from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber from faststream.rabbit.utils import build_url @@ -48,15 +36,35 @@ from yarl import URL from faststream.asyncapi import schema as asyncapi - from faststream.broker.types import ( - BrokerMiddleware, - CustomCallable, - ) + from faststream.broker.types import BrokerMiddleware, CustomCallable from faststream.rabbit.types import AioPikaSendableMessage from faststream.security import BaseSecurity from faststream.types import AnyDict, Decorator, LoggerProto +async def get_connection( + url: str, + timeout: "TimeoutType", + ssl_context: Optional["SSLContext"], +) -> "RobustConnection": + return cast( + "RobustConnection", + await connect_robust( + url, + timeout=timeout, + ssl_context=ssl_context, + ), + ) + + +async def get_channel(connection_pool: "Pool[RobustConnection]") -> "RobustChannel": + async with connection_pool.acquire() as connection: + return cast( + "RobustChannel", + await connection.channel(), + ) + + class RabbitBroker( RabbitRegistrator, RabbitLoggingBroker, @@ -67,7 +75,8 @@ class RabbitBroker( _producer: Optional["AioPikaFastProducer"] declarer: Optional[RabbitDeclarer] - _channel: Optional["RobustChannel"] + _channel_pool: Optional["Pool[RobustChannel]"] + _connection_pool: Optional["Pool[RobustConnection]"] def __init__( self, @@ -213,6 +222,14 @@ def __init__( Iterable["Decorator"], Doc("Any custom decorator to apply to wrapped functions."), ] = (), + max_connection_pool_size: Annotated[ + int, + Doc("Max connection pool size"), + ] = 1, + max_channel_pool_size: Annotated[ + int, + Doc("Max channel pool size"), + ] = 1, ) -> None: security_args = parse_security(security) @@ -234,6 +251,8 @@ def __init__( # respect ascynapi_url argument scheme builded_asyncapi_url = urlparse(asyncapi_url) self.virtual_host = builded_asyncapi_url.path + self.max_connection_pool_size = max_connection_pool_size + self.max_channel_pool_size = max_channel_pool_size if protocol is None: protocol = builded_asyncapi_url.scheme @@ -274,6 +293,8 @@ def __init__( self.app_id = app_id self._channel = None + self._channel_pool = None + self._connection_pool = None self.declarer = None @property @@ -406,26 +427,31 @@ async def _connect( # type: ignore[override] publisher_confirms: bool, on_return_raises: bool, ) -> "RobustConnection": - connection = cast( - "RobustConnection", - await connect_robust( - url, - timeout=timeout, - ssl_context=ssl_context, - ), - ) + if self._connection_pool is None: + self._connection_pool = Pool( + lambda: get_connection( + url=url, + timeout=timeout, + ssl_context=ssl_context, + ), + max_size=self.max_connection_pool_size, + ) - if self._channel is None: # pragma: no branch - max_consumers = self._max_consumers - channel = self._channel = cast( - "RobustChannel", - await connection.channel( - channel_number=channel_number, - publisher_confirms=publisher_confirms, - on_return_raises=on_return_raises, + if self._channel_pool is None: + assert self._connection_pool is not None + self._channel_pool = cast( + Pool["RobustChannel"], + Pool( + lambda: get_channel( + cast("Pool[RobustConnection]", self._connection_pool) + ), + max_size=self.max_channel_pool_size, ), ) + async with self._channel_pool.acquire() as channel: + max_consumers = self._max_consumers + declarer = self.declarer = RabbitDeclarer(channel) await declarer.declare_queue(RABBIT_REPLY) @@ -444,7 +470,7 @@ async def _connect( # type: ignore[override] self._log(f"Set max consumers to {max_consumers}", extra=c) await channel.set_qos(prefetch_count=int(max_consumers)) - return connection + return cast("RobustConnection", channel._connection) async def _close( self, @@ -452,18 +478,19 @@ async def _close( exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> None: - if self._channel is not None: - if not self._channel.is_closed: - await self._channel.close() + if self._channel_pool is not None: + if not self._channel_pool.is_closed: + await self._channel_pool.close() + self._channel_pool = None - self._channel = None + if self._connection_pool is not None: + if not self._connection_pool.is_closed: + await self._connection_pool.close() + self._connection_pool = None self.declarer = None self._producer = None - if self._connection is not None: - await self._connection.close() - await super()._close(exc_type, exc_val, exc_tb) async def start(self) -> None: diff --git a/faststream/rabbit/fastapi/router.py b/faststream/rabbit/fastapi/router.py index d0445badfb..dad339387b 100644 --- a/faststream/rabbit/fastapi/router.py +++ b/faststream/rabbit/fastapi/router.py @@ -1,18 +1,7 @@ import logging from inspect import Parameter -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Sequence, - Type, - Union, - cast, -) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List, + Optional, Sequence, Type, Union, cast) from fastapi.datastructures import Default from fastapi.routing import APIRoute @@ -26,10 +15,7 @@ from faststream.broker.utils import default_filter from faststream.rabbit.broker.broker import RabbitBroker as RB from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher -from faststream.rabbit.schemas import ( - RabbitExchange, - RabbitQueue, -) +from faststream.rabbit.schemas import RabbitExchange, RabbitQueue from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber if TYPE_CHECKING: @@ -45,13 +31,9 @@ from yarl import URL from faststream.asyncapi import schema as asyncapi - from faststream.broker.types import ( - BrokerMiddleware, - CustomCallable, - Filter, - PublisherMiddleware, - SubscriberMiddleware, - ) + from faststream.broker.types import (BrokerMiddleware, CustomCallable, + Filter, PublisherMiddleware, + SubscriberMiddleware) from faststream.rabbit.message import RabbitMessage from faststream.rabbit.schemas.reply import ReplyConfig from faststream.security import BaseSecurity @@ -414,6 +396,14 @@ def __init__( """ ), ] = Default(generate_unique_id), + max_connection_pool_size: Annotated[ + int, + Doc("Max connection pool size"), + ] = 1, + max_channel_pool_size: Annotated[ + int, + Doc("Max channel pool size"), + ] = 1, ) -> None: super().__init__( url, @@ -424,6 +414,8 @@ def __init__( client_properties=client_properties, timeout=timeout, max_consumers=max_consumers, + max_connection_pool_size=max_connection_pool_size, + max_channel_pool_size=max_channel_pool_size, app_id=app_id, graceful_timeout=graceful_timeout, decoder=decoder, diff --git a/faststream/rabbit/testing.py b/faststream/rabbit/testing.py index 3d3a274418..db5cb9d5d4 100644 --- a/faststream/rabbit/testing.py +++ b/faststream/rabbit/testing.py @@ -3,6 +3,7 @@ import aiormq from aio_pika.message import IncomingMessage +from aio_pika.pool import Pool from pamqp import commands as spec from pamqp.header import ContentHeader from typing_extensions import override @@ -13,11 +14,7 @@ from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher from faststream.rabbit.publisher.producer import AioPikaFastProducer -from faststream.rabbit.schemas import ( - ExchangeType, - RabbitExchange, - RabbitQueue, -) +from faststream.rabbit.schemas import ExchangeType, RabbitExchange, RabbitQueue from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber from faststream.testing.broker import TestBroker, call_handler @@ -35,7 +32,7 @@ class TestRabbitBroker(TestBroker[RabbitBroker]): @classmethod def _patch_test_broker(cls, broker: RabbitBroker) -> None: - broker._channel = AsyncMock() + broker._channel_pool = AsyncMock(Pool) broker.declarer = AsyncMock() super()._patch_test_broker(broker) diff --git a/tests/brokers/rabbit/specific/test_init.py b/tests/brokers/rabbit/specific/test_init.py index e87b71b466..38d156a7e9 100644 --- a/tests/brokers/rabbit/specific/test_init.py +++ b/tests/brokers/rabbit/specific/test_init.py @@ -8,5 +8,7 @@ async def test_set_max(): broker = RabbitBroker(logger=None, max_consumers=10) await broker.start() - assert broker._channel._prefetch_count == 10 + assert broker._channel_pool + async with broker._channel_pool.acquire() as channel: + assert channel._prefetch_count == 10 await broker.close()