From de79a6572da454335a0182b77c90c36fbb69c3ba Mon Sep 17 00:00:00 2001 From: Gustavo Santos Date: Thu, 25 Apr 2024 16:08:23 -0300 Subject: [PATCH 1/5] feat: use aio pika pool for connections and channels --- faststream/rabbit/broker/broker.py | 110 ++++++++++++++++++----------- faststream/rabbit/testing.py | 9 +-- 2 files changed, 72 insertions(+), 47 deletions(-) diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index fd4ca30d84..efbecf736a 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -1,18 +1,10 @@ import logging from inspect import Parameter -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - Optional, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Type, Union, cast from urllib.parse import urlparse from aio_pika import connect_robust +from aio_pika.pool import Pool from typing_extensions import Annotated, Doc, override from faststream.__about__ import SERVICE_NAME @@ -21,11 +13,7 @@ from faststream.rabbit.broker.logging import RabbitLoggingBroker from faststream.rabbit.broker.registrator import RabbitRegistrator from faststream.rabbit.publisher.producer import AioPikaFastProducer -from faststream.rabbit.schemas import ( - RABBIT_REPLY, - RabbitExchange, - RabbitQueue, -) +from faststream.rabbit.schemas import RABBIT_REPLY, RabbitExchange, RabbitQueue from faststream.rabbit.security import parse_security from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber from faststream.rabbit.utils import RabbitDeclarer, build_url @@ -47,15 +35,35 @@ from yarl import URL from faststream.asyncapi import schema as asyncapi - from faststream.broker.types import ( - BrokerMiddleware, - CustomCallable, - ) + from faststream.broker.types import BrokerMiddleware, CustomCallable from faststream.rabbit.types import AioPikaSendableMessage from faststream.security import BaseSecurity from faststream.types import AnyDict, Decorator, LoggerProto +async def get_connection( + url: str, + timeout: "TimeoutType", + ssl_context: Optional["SSLContext"], +) -> "RobustConnection": + return cast( + "RobustConnection", + await connect_robust( + url, + timeout=timeout, + ssl_context=ssl_context, + ), + ) + + +async def get_channel(connection_pool: "Pool[RobustConnection]") -> "RobustChannel": + async with connection_pool.acquire() as connection: + return cast( + "RobustChannel", + await connection.channel(), + ) + + class RabbitBroker( RabbitRegistrator, RabbitLoggingBroker, @@ -66,7 +74,8 @@ class RabbitBroker( _producer: Optional["AioPikaFastProducer"] declarer: Optional[RabbitDeclarer] - _channel: Optional["RobustChannel"] + _channel_pool: Optional["Pool[RobustChannel]"] + _connection_pool: Optional["Pool[RobustConnection]"] def __init__( self, @@ -249,6 +258,8 @@ def __init__( self.app_id = app_id self._channel = None + self._channel_pool = None + self._connection_pool = None self.declarer = None @property @@ -346,23 +357,34 @@ async def _connect( # type: ignore[override] *, timeout: "TimeoutType", ssl_context: Optional["SSLContext"], + max_connection_pool_size: int = 1, + max_channel_pool_size: int = 10, ) -> "RobustConnection": - connection = cast( - "RobustConnection", - await connect_robust( - url, - timeout=timeout, - ssl_context=ssl_context, - ), - ) + if self._connection_pool is None: + self._connection_pool = Pool( + lambda: get_connection( + url=url, + timeout=timeout, + ssl_context=ssl_context, + ), + max_size=max_connection_pool_size, + ) - if self._channel is None: # pragma: no branch - max_consumers = self._max_consumers - channel = self._channel = cast( - "RobustChannel", - await connection.channel(), + if self._channel_pool is None: + assert self._connection_pool is not None + self._channel_pool = cast( + Pool["RobustChannel"], + Pool( + lambda: get_channel( + cast("Pool[RobustConnection]", self._connection_pool) + ), + max_size=max_channel_pool_size, + ), ) + async with self._channel_pool.acquire() as channel: + max_consumers = self._max_consumers + declarer = self.declarer = RabbitDeclarer(channel) await declarer.declare_queue(RABBIT_REPLY) @@ -380,7 +402,7 @@ async def _connect( # type: ignore[override] self._log(f"Set max consumers to {max_consumers}", extra=c) await channel.set_qos(prefetch_count=int(max_consumers)) - return connection + return cast("RobustConnection", channel._connection) async def _close( self, @@ -388,18 +410,24 @@ async def _close( exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> None: - if self._channel is not None: - if not self._channel.is_closed: - await self._channel.close() + if self._channel_pool is not None: + if not self._channel_pool.is_closed: + await self._channel_pool.close() + self._channel_pool = None + + if self._connection is not None: + if not self._connection.is_closed: + await self._connection.close() + self._connection = None - self._channel = None + if self._connection_pool is not None: + if not self._connection_pool.is_closed: + await self._connection_pool.close() + self._connection_pool = None self.declarer = None self._producer = None - if self._connection is not None: - await self._connection.close() - await super()._close(exc_type, exc_val, exc_tb) async def start(self) -> None: diff --git a/faststream/rabbit/testing.py b/faststream/rabbit/testing.py index e425ed02d6..13a877407d 100644 --- a/faststream/rabbit/testing.py +++ b/faststream/rabbit/testing.py @@ -3,6 +3,7 @@ import aiormq from aio_pika.message import IncomingMessage +from aio_pika.pool import Pool from pamqp import commands as spec from pamqp.header import ContentHeader from typing_extensions import override @@ -12,11 +13,7 @@ from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher from faststream.rabbit.publisher.producer import AioPikaFastProducer -from faststream.rabbit.schemas import ( - ExchangeType, - RabbitExchange, - RabbitQueue, -) +from faststream.rabbit.schemas import ExchangeType, RabbitExchange, RabbitQueue from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber from faststream.testing.broker import TestBroker, call_handler @@ -34,7 +31,7 @@ class TestRabbitBroker(TestBroker[RabbitBroker]): @classmethod def _patch_test_broker(cls, broker: RabbitBroker) -> None: - broker._channel = AsyncMock() + broker._channel_pool = AsyncMock(Pool) broker.declarer = AsyncMock() super()._patch_test_broker(broker) From 38a1a15f1f0a6b84a8ea7620c77b756ec53354d0 Mon Sep 17 00:00:00 2001 From: Gustavo Santos Date: Thu, 25 Apr 2024 16:08:50 -0300 Subject: [PATCH 2/5] test: change case to use the pool instead of _channel --- tests/brokers/rabbit/specific/test_init.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/brokers/rabbit/specific/test_init.py b/tests/brokers/rabbit/specific/test_init.py index e87b71b466..38d156a7e9 100644 --- a/tests/brokers/rabbit/specific/test_init.py +++ b/tests/brokers/rabbit/specific/test_init.py @@ -8,5 +8,7 @@ async def test_set_max(): broker = RabbitBroker(logger=None, max_consumers=10) await broker.start() - assert broker._channel._prefetch_count == 10 + assert broker._channel_pool + async with broker._channel_pool.acquire() as channel: + assert channel._prefetch_count == 10 await broker.close() From 6b0fc2f3428482bcaf3d7b92b42200fc53f349f9 Mon Sep 17 00:00:00 2001 From: Gustavo Santos Date: Thu, 25 Apr 2024 16:20:25 -0300 Subject: [PATCH 3/5] chore: remove not needed _connection.close --- faststream/rabbit/broker/broker.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index efbecf736a..30f80ff73c 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -415,11 +415,6 @@ async def _close( await self._channel_pool.close() self._channel_pool = None - if self._connection is not None: - if not self._connection.is_closed: - await self._connection.close() - self._connection = None - if self._connection_pool is not None: if not self._connection_pool.is_closed: await self._connection_pool.close() From e193dbcd20a8b53158eb7a84a8b57f52f32af59e Mon Sep 17 00:00:00 2001 From: Gustavo Santos Date: Tue, 30 Apr 2024 17:12:26 -0300 Subject: [PATCH 4/5] chore: change default max channel pool to 1 --- faststream/rabbit/broker/broker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index 30f80ff73c..b5b5b47a52 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -358,7 +358,7 @@ async def _connect( # type: ignore[override] timeout: "TimeoutType", ssl_context: Optional["SSLContext"], max_connection_pool_size: int = 1, - max_channel_pool_size: int = 10, + max_channel_pool_size: int = 1, ) -> "RobustConnection": if self._connection_pool is None: self._connection_pool = Pool( From b688c37822b45a4b96c8b5ec2b01eeb2c32f9124 Mon Sep 17 00:00:00 2001 From: Gustavo Santos Date: Tue, 21 May 2024 16:30:46 -0300 Subject: [PATCH 5/5] feat: add arguments with docs --- faststream/rabbit/broker/broker.py | 16 +++++++++--- faststream/rabbit/fastapi/router.py | 40 ++++++++++++----------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index b5b5b47a52..6d1d8419e7 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -201,6 +201,14 @@ def __init__( Iterable["Decorator"], Doc("Any custom decorator to apply to wrapped functions."), ] = (), + max_connection_pool_size: Annotated[ + int, + Doc("Max connection pool size"), + ] = 1, + max_channel_pool_size: Annotated[ + int, + Doc("Max channel pool size"), + ] = 1, ) -> None: security_args = parse_security(security) @@ -222,6 +230,8 @@ def __init__( # respect ascynapi_url argument scheme builded_asyncapi_url = urlparse(asyncapi_url) self.virtual_host = builded_asyncapi_url.path + self.max_connection_pool_size = max_connection_pool_size + self.max_channel_pool_size = max_channel_pool_size if protocol is None: protocol = builded_asyncapi_url.scheme @@ -357,8 +367,6 @@ async def _connect( # type: ignore[override] *, timeout: "TimeoutType", ssl_context: Optional["SSLContext"], - max_connection_pool_size: int = 1, - max_channel_pool_size: int = 1, ) -> "RobustConnection": if self._connection_pool is None: self._connection_pool = Pool( @@ -367,7 +375,7 @@ async def _connect( # type: ignore[override] timeout=timeout, ssl_context=ssl_context, ), - max_size=max_connection_pool_size, + max_size=self.max_connection_pool_size, ) if self._channel_pool is None: @@ -378,7 +386,7 @@ async def _connect( # type: ignore[override] lambda: get_channel( cast("Pool[RobustConnection]", self._connection_pool) ), - max_size=max_channel_pool_size, + max_size=self.max_channel_pool_size, ), ) diff --git a/faststream/rabbit/fastapi/router.py b/faststream/rabbit/fastapi/router.py index 4cc90b25d9..c45a80aeba 100644 --- a/faststream/rabbit/fastapi/router.py +++ b/faststream/rabbit/fastapi/router.py @@ -1,18 +1,7 @@ import logging from inspect import Parameter -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Sequence, - Type, - Union, - cast, -) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List, + Optional, Sequence, Type, Union, cast) from fastapi.datastructures import Default from fastapi.routing import APIRoute @@ -26,10 +15,7 @@ from faststream.broker.utils import default_filter from faststream.rabbit.broker.broker import RabbitBroker as RB from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher -from faststream.rabbit.schemas import ( - RabbitExchange, - RabbitQueue, -) +from faststream.rabbit.schemas import RabbitExchange, RabbitQueue from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber if TYPE_CHECKING: @@ -45,13 +31,9 @@ from yarl import URL from faststream.asyncapi import schema as asyncapi - from faststream.broker.types import ( - BrokerMiddleware, - CustomCallable, - Filter, - PublisherMiddleware, - SubscriberMiddleware, - ) + from faststream.broker.types import (BrokerMiddleware, CustomCallable, + Filter, PublisherMiddleware, + SubscriberMiddleware) from faststream.rabbit.message import RabbitMessage from faststream.rabbit.schemas.reply import ReplyConfig from faststream.security import BaseSecurity @@ -394,6 +376,14 @@ def __init__( """ ), ] = Default(generate_unique_id), + max_connection_pool_size: Annotated[ + int, + Doc("Max connection pool size"), + ] = 1, + max_channel_pool_size: Annotated[ + int, + Doc("Max channel pool size"), + ] = 1, ) -> None: super().__init__( url, @@ -404,6 +394,8 @@ def __init__( client_properties=client_properties, timeout=timeout, max_consumers=max_consumers, + max_connection_pool_size=max_connection_pool_size, + max_channel_pool_size=max_channel_pool_size, app_id=app_id, graceful_timeout=graceful_timeout, decoder=decoder,