Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

975 use aio pika pool #1401

Merged
merged 6 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 64 additions & 41 deletions faststream/rabbit/broker/broker.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import logging
from inspect import Parameter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Optional,
Type,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Type, Union, cast
from urllib.parse import urlparse

from aio_pika import connect_robust
from aio_pika.pool import Pool
from typing_extensions import Annotated, Doc, override

from faststream.__about__ import SERVICE_NAME
Expand All @@ -21,11 +13,7 @@
from faststream.rabbit.broker.logging import RabbitLoggingBroker
from faststream.rabbit.broker.registrator import RabbitRegistrator
from faststream.rabbit.publisher.producer import AioPikaFastProducer
from faststream.rabbit.schemas import (
RABBIT_REPLY,
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import RABBIT_REPLY, RabbitExchange, RabbitQueue
from faststream.rabbit.security import parse_security
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber
from faststream.rabbit.utils import RabbitDeclarer, build_url
Expand All @@ -47,15 +35,35 @@
from yarl import URL

from faststream.asyncapi import schema as asyncapi
from faststream.broker.types import (
BrokerMiddleware,
CustomCallable,
)
from faststream.broker.types import BrokerMiddleware, CustomCallable
from faststream.rabbit.types import AioPikaSendableMessage
from faststream.security import BaseSecurity
from faststream.types import AnyDict, Decorator, LoggerProto


async def get_connection(
url: str,
timeout: "TimeoutType",
ssl_context: Optional["SSLContext"],
) -> "RobustConnection":
return cast(
"RobustConnection",
await connect_robust(
url,
timeout=timeout,
ssl_context=ssl_context,
),
)


async def get_channel(connection_pool: "Pool[RobustConnection]") -> "RobustChannel":
async with connection_pool.acquire() as connection:
return cast(
"RobustChannel",
await connection.channel(),
)


class RabbitBroker(
RabbitRegistrator,
RabbitLoggingBroker,
Expand All @@ -66,7 +74,8 @@ class RabbitBroker(
_producer: Optional["AioPikaFastProducer"]

declarer: Optional[RabbitDeclarer]
_channel: Optional["RobustChannel"]
_channel_pool: Optional["Pool[RobustChannel]"]
_connection_pool: Optional["Pool[RobustConnection]"]

def __init__(
self,
Expand Down Expand Up @@ -249,6 +258,8 @@ def __init__(
self.app_id = app_id

self._channel = None
self._channel_pool = None
self._connection_pool = None
self.declarer = None

@property
Expand Down Expand Up @@ -346,23 +357,34 @@ async def _connect( # type: ignore[override]
*,
timeout: "TimeoutType",
ssl_context: Optional["SSLContext"],
max_connection_pool_size: int = 1,
max_channel_pool_size: int = 10,
) -> "RobustConnection":
connection = cast(
"RobustConnection",
await connect_robust(
url,
timeout=timeout,
ssl_context=ssl_context,
),
)
if self._connection_pool is None:
self._connection_pool = Pool(
lambda: get_connection(
url=url,
timeout=timeout,
ssl_context=ssl_context,
),
max_size=max_connection_pool_size,
)

if self._channel is None: # pragma: no branch
max_consumers = self._max_consumers
channel = self._channel = cast(
"RobustChannel",
await connection.channel(),
if self._channel_pool is None:
assert self._connection_pool is not None
self._channel_pool = cast(
Pool["RobustChannel"],
Pool(
lambda: get_channel(
cast("Pool[RobustConnection]", self._connection_pool)
),
max_size=max_channel_pool_size,
),
)

async with self._channel_pool.acquire() as channel:
max_consumers = self._max_consumers

declarer = self.declarer = RabbitDeclarer(channel)
await declarer.declare_queue(RABBIT_REPLY)

Expand All @@ -380,26 +402,27 @@ async def _connect( # type: ignore[override]
self._log(f"Set max consumers to {max_consumers}", extra=c)
await channel.set_qos(prefetch_count=int(max_consumers))

return connection
return cast("RobustConnection", channel._connection)

async def _close(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional["TracebackType"] = None,
) -> None:
if self._channel is not None:
if not self._channel.is_closed:
await self._channel.close()
if self._channel_pool is not None:
if not self._channel_pool.is_closed:
await self._channel_pool.close()
self._channel_pool = None

self._channel = None
if self._connection_pool is not None:
if not self._connection_pool.is_closed:
await self._connection_pool.close()
self._connection_pool = None

self.declarer = None
self._producer = None

if self._connection is not None:
await self._connection.close()

await super()._close(exc_type, exc_val, exc_tb)

async def start(self) -> None:
Expand Down
9 changes: 3 additions & 6 deletions faststream/rabbit/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import aiormq
from aio_pika.message import IncomingMessage
from aio_pika.pool import Pool
from pamqp import commands as spec
from pamqp.header import ContentHeader
from typing_extensions import override
Expand All @@ -12,11 +13,7 @@
from faststream.rabbit.parser import AioPikaParser
from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher
from faststream.rabbit.publisher.producer import AioPikaFastProducer
from faststream.rabbit.schemas import (
ExchangeType,
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import ExchangeType, RabbitExchange, RabbitQueue
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber
from faststream.testing.broker import TestBroker, call_handler

Expand All @@ -34,7 +31,7 @@ class TestRabbitBroker(TestBroker[RabbitBroker]):

@classmethod
def _patch_test_broker(cls, broker: RabbitBroker) -> None:
broker._channel = AsyncMock()
broker._channel_pool = AsyncMock(Pool)
broker.declarer = AsyncMock()
super()._patch_test_broker(broker)

Expand Down
4 changes: 3 additions & 1 deletion tests/brokers/rabbit/specific/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@
async def test_set_max():
broker = RabbitBroker(logger=None, max_consumers=10)
await broker.start()
assert broker._channel._prefetch_count == 10
assert broker._channel_pool
async with broker._channel_pool.acquire() as channel:
assert channel._prefetch_count == 10
await broker.close()
Loading