Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

975 use aio pika pool #1401

Merged
merged 6 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 71 additions & 44 deletions faststream/rabbit/broker/broker.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import logging
from inspect import Parameter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Optional,
Type,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Type, Union, cast
from urllib.parse import urlparse

from aio_pika import connect_robust
from aio_pika.pool import Pool
from typing_extensions import Annotated, Doc, override

from faststream.__about__ import SERVICE_NAME
Expand All @@ -22,11 +14,7 @@
from faststream.rabbit.broker.registrator import RabbitRegistrator
from faststream.rabbit.helpers.declarer import RabbitDeclarer
from faststream.rabbit.publisher.producer import AioPikaFastProducer
from faststream.rabbit.schemas import (
RABBIT_REPLY,
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import RABBIT_REPLY, RabbitExchange, RabbitQueue
from faststream.rabbit.security import parse_security
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber
from faststream.rabbit.utils import build_url
Expand All @@ -48,15 +36,35 @@
from yarl import URL

from faststream.asyncapi import schema as asyncapi
from faststream.broker.types import (
BrokerMiddleware,
CustomCallable,
)
from faststream.broker.types import BrokerMiddleware, CustomCallable
from faststream.rabbit.types import AioPikaSendableMessage
from faststream.security import BaseSecurity
from faststream.types import AnyDict, Decorator, LoggerProto


async def get_connection(
url: str,
timeout: "TimeoutType",
ssl_context: Optional["SSLContext"],
) -> "RobustConnection":
return cast(
"RobustConnection",
await connect_robust(
url,
timeout=timeout,
ssl_context=ssl_context,
),
)


async def get_channel(connection_pool: "Pool[RobustConnection]") -> "RobustChannel":
async with connection_pool.acquire() as connection:
return cast(
"RobustChannel",
await connection.channel(),
)


class RabbitBroker(
RabbitRegistrator,
RabbitLoggingBroker,
Expand All @@ -67,7 +75,8 @@ class RabbitBroker(
_producer: Optional["AioPikaFastProducer"]

declarer: Optional[RabbitDeclarer]
_channel: Optional["RobustChannel"]
_channel_pool: Optional["Pool[RobustChannel]"]
_connection_pool: Optional["Pool[RobustConnection]"]

def __init__(
self,
Expand Down Expand Up @@ -213,6 +222,14 @@ def __init__(
Iterable["Decorator"],
Doc("Any custom decorator to apply to wrapped functions."),
] = (),
max_connection_pool_size: Annotated[
int,
Doc("Max connection pool size"),
] = 1,
max_channel_pool_size: Annotated[
int,
Doc("Max channel pool size"),
] = 1,
) -> None:
security_args = parse_security(security)

Expand All @@ -234,6 +251,8 @@ def __init__(
# respect ascynapi_url argument scheme
builded_asyncapi_url = urlparse(asyncapi_url)
self.virtual_host = builded_asyncapi_url.path
self.max_connection_pool_size = max_connection_pool_size
self.max_channel_pool_size = max_channel_pool_size
if protocol is None:
protocol = builded_asyncapi_url.scheme

Expand Down Expand Up @@ -274,6 +293,8 @@ def __init__(
self.app_id = app_id

self._channel = None
self._channel_pool = None
self._connection_pool = None
self.declarer = None

@property
Expand Down Expand Up @@ -406,26 +427,31 @@ async def _connect( # type: ignore[override]
publisher_confirms: bool,
on_return_raises: bool,
) -> "RobustConnection":
connection = cast(
"RobustConnection",
await connect_robust(
url,
timeout=timeout,
ssl_context=ssl_context,
),
)
if self._connection_pool is None:
self._connection_pool = Pool(
lambda: get_connection(
url=url,
timeout=timeout,
ssl_context=ssl_context,
),
max_size=self.max_connection_pool_size,
)

if self._channel is None: # pragma: no branch
max_consumers = self._max_consumers
channel = self._channel = cast(
"RobustChannel",
await connection.channel(
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
if self._channel_pool is None:
assert self._connection_pool is not None
self._channel_pool = cast(
Pool["RobustChannel"],
Pool(
lambda: get_channel(
cast("Pool[RobustConnection]", self._connection_pool)
),
max_size=self.max_channel_pool_size,
),
)

async with self._channel_pool.acquire() as channel:
max_consumers = self._max_consumers

declarer = self.declarer = RabbitDeclarer(channel)
await declarer.declare_queue(RABBIT_REPLY)

Expand All @@ -444,26 +470,27 @@ async def _connect( # type: ignore[override]
self._log(f"Set max consumers to {max_consumers}", extra=c)
await channel.set_qos(prefetch_count=int(max_consumers))

return connection
return cast("RobustConnection", channel._connection)

async def _close(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional["TracebackType"] = None,
) -> None:
if self._channel is not None:
if not self._channel.is_closed:
await self._channel.close()
if self._channel_pool is not None:
if not self._channel_pool.is_closed:
await self._channel_pool.close()
self._channel_pool = None

self._channel = None
if self._connection_pool is not None:
if not self._connection_pool.is_closed:
await self._connection_pool.close()
self._connection_pool = None

self.declarer = None
self._producer = None

if self._connection is not None:
await self._connection.close()

await super()._close(exc_type, exc_val, exc_tb)

async def start(self) -> None:
Expand Down
40 changes: 16 additions & 24 deletions faststream/rabbit/fastapi/router.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
import logging
from inspect import Parameter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Type,
Union,
cast,
)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Optional, Sequence, Type, Union, cast)

from fastapi.datastructures import Default
from fastapi.routing import APIRoute
Expand All @@ -26,10 +15,7 @@
from faststream.broker.utils import default_filter
from faststream.rabbit.broker.broker import RabbitBroker as RB
from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher
from faststream.rabbit.schemas import (
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import RabbitExchange, RabbitQueue
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber

if TYPE_CHECKING:
Expand All @@ -45,13 +31,9 @@
from yarl import URL

from faststream.asyncapi import schema as asyncapi
from faststream.broker.types import (
BrokerMiddleware,
CustomCallable,
Filter,
PublisherMiddleware,
SubscriberMiddleware,
)
from faststream.broker.types import (BrokerMiddleware, CustomCallable,
Filter, PublisherMiddleware,
SubscriberMiddleware)
from faststream.rabbit.message import RabbitMessage
from faststream.rabbit.schemas.reply import ReplyConfig
from faststream.security import BaseSecurity
Expand Down Expand Up @@ -414,6 +396,14 @@ def __init__(
"""
),
] = Default(generate_unique_id),
max_connection_pool_size: Annotated[
int,
Doc("Max connection pool size"),
] = 1,
max_channel_pool_size: Annotated[
int,
Doc("Max channel pool size"),
] = 1,
) -> None:
super().__init__(
url,
Expand All @@ -424,6 +414,8 @@ def __init__(
client_properties=client_properties,
timeout=timeout,
max_consumers=max_consumers,
max_connection_pool_size=max_connection_pool_size,
max_channel_pool_size=max_channel_pool_size,
app_id=app_id,
graceful_timeout=graceful_timeout,
decoder=decoder,
Expand Down
9 changes: 3 additions & 6 deletions faststream/rabbit/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import aiormq
from aio_pika.message import IncomingMessage
from aio_pika.pool import Pool
from pamqp import commands as spec
from pamqp.header import ContentHeader
from typing_extensions import override
Expand All @@ -13,11 +14,7 @@
from faststream.rabbit.parser import AioPikaParser
from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher
from faststream.rabbit.publisher.producer import AioPikaFastProducer
from faststream.rabbit.schemas import (
ExchangeType,
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import ExchangeType, RabbitExchange, RabbitQueue
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber
from faststream.testing.broker import TestBroker, call_handler

Expand All @@ -35,7 +32,7 @@ class TestRabbitBroker(TestBroker[RabbitBroker]):

@classmethod
def _patch_test_broker(cls, broker: RabbitBroker) -> None:
broker._channel = AsyncMock()
broker._channel_pool = AsyncMock(Pool)
broker.declarer = AsyncMock()
super()._patch_test_broker(broker)

Expand Down
4 changes: 3 additions & 1 deletion tests/brokers/rabbit/specific/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@
async def test_set_max():
broker = RabbitBroker(logger=None, max_consumers=10)
await broker.start()
assert broker._channel._prefetch_count == 10
assert broker._channel_pool
async with broker._channel_pool.acquire() as channel:
assert channel._prefetch_count == 10
await broker.close()
Loading