Skip to content

Commit

Permalink
Merge pull request #1401 from Focadecombate/975-use-aio-pika-pool
Browse files Browse the repository at this point in the history
975 use aio pika pool
  • Loading branch information
Lancetnik authored Jun 12, 2024
2 parents bb6997f + 6257cfa commit d12abef
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 75 deletions.
115 changes: 71 additions & 44 deletions faststream/rabbit/broker/broker.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import logging
from inspect import Parameter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Optional,
Type,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Type, Union, cast
from urllib.parse import urlparse

from aio_pika import connect_robust
from aio_pika.pool import Pool
from typing_extensions import Annotated, Doc, override

from faststream.__about__ import SERVICE_NAME
Expand All @@ -22,11 +14,7 @@
from faststream.rabbit.broker.registrator import RabbitRegistrator
from faststream.rabbit.helpers.declarer import RabbitDeclarer
from faststream.rabbit.publisher.producer import AioPikaFastProducer
from faststream.rabbit.schemas import (
RABBIT_REPLY,
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import RABBIT_REPLY, RabbitExchange, RabbitQueue
from faststream.rabbit.security import parse_security
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber
from faststream.rabbit.utils import build_url
Expand All @@ -48,15 +36,35 @@
from yarl import URL

from faststream.asyncapi import schema as asyncapi
from faststream.broker.types import (
BrokerMiddleware,
CustomCallable,
)
from faststream.broker.types import BrokerMiddleware, CustomCallable
from faststream.rabbit.types import AioPikaSendableMessage
from faststream.security import BaseSecurity
from faststream.types import AnyDict, Decorator, LoggerProto


async def get_connection(
url: str,
timeout: "TimeoutType",
ssl_context: Optional["SSLContext"],
) -> "RobustConnection":
return cast(
"RobustConnection",
await connect_robust(
url,
timeout=timeout,
ssl_context=ssl_context,
),
)


async def get_channel(connection_pool: "Pool[RobustConnection]") -> "RobustChannel":
async with connection_pool.acquire() as connection:
return cast(
"RobustChannel",
await connection.channel(),
)


class RabbitBroker(
RabbitRegistrator,
RabbitLoggingBroker,
Expand All @@ -67,7 +75,8 @@ class RabbitBroker(
_producer: Optional["AioPikaFastProducer"]

declarer: Optional[RabbitDeclarer]
_channel: Optional["RobustChannel"]
_channel_pool: Optional["Pool[RobustChannel]"]
_connection_pool: Optional["Pool[RobustConnection]"]

def __init__(
self,
Expand Down Expand Up @@ -213,6 +222,14 @@ def __init__(
Iterable["Decorator"],
Doc("Any custom decorator to apply to wrapped functions."),
] = (),
max_connection_pool_size: Annotated[
int,
Doc("Max connection pool size"),
] = 1,
max_channel_pool_size: Annotated[
int,
Doc("Max channel pool size"),
] = 1,
) -> None:
security_args = parse_security(security)

Expand All @@ -234,6 +251,8 @@ def __init__(
# respect ascynapi_url argument scheme
builded_asyncapi_url = urlparse(asyncapi_url)
self.virtual_host = builded_asyncapi_url.path
self.max_connection_pool_size = max_connection_pool_size
self.max_channel_pool_size = max_channel_pool_size
if protocol is None:
protocol = builded_asyncapi_url.scheme

Expand Down Expand Up @@ -274,6 +293,8 @@ def __init__(
self.app_id = app_id

self._channel = None
self._channel_pool = None
self._connection_pool = None
self.declarer = None

@property
Expand Down Expand Up @@ -406,26 +427,31 @@ async def _connect( # type: ignore[override]
publisher_confirms: bool,
on_return_raises: bool,
) -> "RobustConnection":
connection = cast(
"RobustConnection",
await connect_robust(
url,
timeout=timeout,
ssl_context=ssl_context,
),
)
if self._connection_pool is None:
self._connection_pool = Pool(
lambda: get_connection(
url=url,
timeout=timeout,
ssl_context=ssl_context,
),
max_size=self.max_connection_pool_size,
)

if self._channel is None: # pragma: no branch
max_consumers = self._max_consumers
channel = self._channel = cast(
"RobustChannel",
await connection.channel(
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
if self._channel_pool is None:
assert self._connection_pool is not None
self._channel_pool = cast(
Pool["RobustChannel"],
Pool(
lambda: get_channel(
cast("Pool[RobustConnection]", self._connection_pool)
),
max_size=self.max_channel_pool_size,
),
)

async with self._channel_pool.acquire() as channel:
max_consumers = self._max_consumers

declarer = self.declarer = RabbitDeclarer(channel)
await declarer.declare_queue(RABBIT_REPLY)

Expand All @@ -444,26 +470,27 @@ async def _connect( # type: ignore[override]
self._log(f"Set max consumers to {max_consumers}", extra=c)
await channel.set_qos(prefetch_count=int(max_consumers))

return connection
return cast("RobustConnection", channel._connection)

async def _close(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional["TracebackType"] = None,
) -> None:
if self._channel is not None:
if not self._channel.is_closed:
await self._channel.close()
if self._channel_pool is not None:
if not self._channel_pool.is_closed:
await self._channel_pool.close()
self._channel_pool = None

self._channel = None
if self._connection_pool is not None:
if not self._connection_pool.is_closed:
await self._connection_pool.close()
self._connection_pool = None

self.declarer = None
self._producer = None

if self._connection is not None:
await self._connection.close()

await super()._close(exc_type, exc_val, exc_tb)

async def start(self) -> None:
Expand Down
40 changes: 16 additions & 24 deletions faststream/rabbit/fastapi/router.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
import logging
from inspect import Parameter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Type,
Union,
cast,
)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Optional, Sequence, Type, Union, cast)

from fastapi.datastructures import Default
from fastapi.routing import APIRoute
Expand All @@ -26,10 +15,7 @@
from faststream.broker.utils import default_filter
from faststream.rabbit.broker.broker import RabbitBroker as RB
from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher
from faststream.rabbit.schemas import (
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import RabbitExchange, RabbitQueue
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber

if TYPE_CHECKING:
Expand All @@ -45,13 +31,9 @@
from yarl import URL

from faststream.asyncapi import schema as asyncapi
from faststream.broker.types import (
BrokerMiddleware,
CustomCallable,
Filter,
PublisherMiddleware,
SubscriberMiddleware,
)
from faststream.broker.types import (BrokerMiddleware, CustomCallable,
Filter, PublisherMiddleware,
SubscriberMiddleware)
from faststream.rabbit.message import RabbitMessage
from faststream.rabbit.schemas.reply import ReplyConfig
from faststream.security import BaseSecurity
Expand Down Expand Up @@ -414,6 +396,14 @@ def __init__(
"""
),
] = Default(generate_unique_id),
max_connection_pool_size: Annotated[
int,
Doc("Max connection pool size"),
] = 1,
max_channel_pool_size: Annotated[
int,
Doc("Max channel pool size"),
] = 1,
) -> None:
super().__init__(
url,
Expand All @@ -424,6 +414,8 @@ def __init__(
client_properties=client_properties,
timeout=timeout,
max_consumers=max_consumers,
max_connection_pool_size=max_connection_pool_size,
max_channel_pool_size=max_channel_pool_size,
app_id=app_id,
graceful_timeout=graceful_timeout,
decoder=decoder,
Expand Down
9 changes: 3 additions & 6 deletions faststream/rabbit/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import aiormq
from aio_pika.message import IncomingMessage
from aio_pika.pool import Pool
from pamqp import commands as spec
from pamqp.header import ContentHeader
from typing_extensions import override
Expand All @@ -13,11 +14,7 @@
from faststream.rabbit.parser import AioPikaParser
from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher
from faststream.rabbit.publisher.producer import AioPikaFastProducer
from faststream.rabbit.schemas import (
ExchangeType,
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import ExchangeType, RabbitExchange, RabbitQueue
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber
from faststream.testing.broker import TestBroker, call_handler

Expand All @@ -35,7 +32,7 @@ class TestRabbitBroker(TestBroker[RabbitBroker]):

@classmethod
def _patch_test_broker(cls, broker: RabbitBroker) -> None:
broker._channel = AsyncMock()
broker._channel_pool = AsyncMock(Pool)
broker.declarer = AsyncMock()
super()._patch_test_broker(broker)

Expand Down
4 changes: 3 additions & 1 deletion tests/brokers/rabbit/specific/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@
async def test_set_max():
broker = RabbitBroker(logger=None, max_consumers=10)
await broker.start()
assert broker._channel._prefetch_count == 10
assert broker._channel_pool
async with broker._channel_pool.acquire() as channel:
assert channel._prefetch_count == 10
await broker.close()

0 comments on commit d12abef

Please sign in to comment.