diff --git a/packages/service-library/src/servicelib/rabbitmq/__init__.py b/packages/service-library/src/servicelib/rabbitmq/__init__.py index 488ab99833c2..b0c2a66fea43 100644 --- a/packages/service-library/src/servicelib/rabbitmq/__init__.py +++ b/packages/service-library/src/servicelib/rabbitmq/__init__.py @@ -8,6 +8,7 @@ RPCNotInitializedError, RPCServerError, ) +from ._models import ConsumerTag, QueueName from ._rpc_router import RPCRouter from ._utils import is_rabbitmq_responsive, wait_till_rabbitmq_responsive @@ -23,6 +24,8 @@ "RPCRouter", "RPCServerError", "wait_till_rabbitmq_responsive", + "QueueName", + "ConsumerTag", ) # nopycln: file diff --git a/packages/service-library/src/servicelib/rabbitmq/_client.py b/packages/service-library/src/servicelib/rabbitmq/_client.py index dc70aa03ffaf..661f2c7f69b7 100644 --- a/packages/service-library/src/servicelib/rabbitmq/_client.py +++ b/packages/service-library/src/servicelib/rabbitmq/_client.py @@ -3,13 +3,14 @@ from dataclasses import dataclass, field from functools import partial from typing import Final +from uuid import uuid4 import aio_pika from pydantic import NonNegativeInt from ..logging_utils import log_catch, log_context from ._client_base import RabbitMQClientBase -from ._models import MessageHandler, RabbitMessage +from ._models import ConsumerTag, MessageHandler, QueueName, RabbitMessage from ._utils import ( RABBIT_QUEUE_MESSAGE_DEFAULT_TTL_MS, declare_queue, @@ -139,7 +140,7 @@ async def _get_channel(self) -> aio_pika.abc.AbstractChannel: return channel async def _get_consumer_tag(self, exchange_name) -> str: - return f"{get_rabbitmq_client_unique_name(self.client_name)}_{exchange_name}" + return f"{get_rabbitmq_client_unique_name(self.client_name)}_{exchange_name}_{uuid4()}" async def subscribe( self, @@ -151,7 +152,7 @@ async def subscribe( message_ttl: NonNegativeInt = RABBIT_QUEUE_MESSAGE_DEFAULT_TTL_MS, unexpected_error_retry_delay_s: float = _DEFAULT_UNEXPECTED_ERROR_RETRY_DELAY_S, unexpected_error_max_attempts: int = _DEFAULT_UNEXPECTED_ERROR_MAX_ATTEMPTS, - ) -> str: + ) -> tuple[QueueName, ConsumerTag]: """subscribe to exchange_name calling ``message_handler`` for every incoming message - exclusive_queue: True means that every instance of this application will receive the incoming messages @@ -238,14 +239,14 @@ async def subscribe( ) await delayed_queue.bind(delayed_exchange) - _consumer_tag = await self._get_consumer_tag(exchange_name) + consumer_tag = await self._get_consumer_tag(exchange_name) await queue.consume( partial(_on_message, message_handler, unexpected_error_max_attempts), exclusive=exclusive_queue, - consumer_tag=_consumer_tag, + consumer_tag=consumer_tag, ) output: str = queue.name - return output + return output, consumer_tag async def add_topics( self, @@ -300,13 +301,16 @@ async def remove_topics( async def unsubscribe( self, - queue_name: str, + queue_name: QueueName, ) -> None: - assert self._channel_pool # nosec - async with self._channel_pool.acquire() as channel: - queue = await channel.get_queue(queue_name) - # NOTE: we force delete here - await queue.delete(if_unused=False, if_empty=False) + """This will delete the queue if there are no consumers left""" + assert self._connection_pool # nosec + if not self._connection_pool.is_closed: + assert self._channel_pool # nosec + async with self._channel_pool.acquire() as channel: + queue = await channel.get_queue(queue_name) + # NOTE: we force delete here + await queue.delete(if_unused=False, if_empty=False) async def publish(self, exchange_name: str, message: RabbitMessage) -> None: """publish message in the exchange exchange_name. @@ -333,10 +337,14 @@ async def publish(self, exchange_name: str, message: RabbitMessage) -> None: routing_key=message.routing_key() or "", ) - async def unsubscribe_consumer(self, exchange_name: str): - assert self._channel_pool # nosec - async with self._channel_pool.acquire() as channel: - queue_name = exchange_name - queue = await channel.get_queue(queue_name) - _consumer_tag = await self._get_consumer_tag(exchange_name) - await queue.cancel(_consumer_tag) + async def unsubscribe_consumer( + self, queue_name: QueueName, consumer_tag: ConsumerTag + ) -> None: + """This will only remove the consumers without deleting the queue""" + assert self._connection_pool # nosec + if not self._connection_pool.is_closed: + assert self._channel_pool # nosec + async with self._channel_pool.acquire() as channel: + assert isinstance(channel, aio_pika.RobustChannel) # nosec + queue = await channel.get_queue(queue_name) + await queue.cancel(consumer_tag) diff --git a/packages/service-library/src/servicelib/rabbitmq/_models.py b/packages/service-library/src/servicelib/rabbitmq/_models.py index c76800a4d8a8..beda88fc1cc0 100644 --- a/packages/service-library/src/servicelib/rabbitmq/_models.py +++ b/packages/service-library/src/servicelib/rabbitmq/_models.py @@ -1,5 +1,5 @@ from collections.abc import Awaitable, Callable -from typing import Any, Protocol +from typing import Any, Protocol, TypeAlias from models_library.basic_types import ConstrainedStr from models_library.rabbitmq_basic_types import ( @@ -11,6 +11,9 @@ MessageHandler = Callable[[Any], Awaitable[bool]] +QueueName: TypeAlias = str +ConsumerTag: TypeAlias = str + class RabbitMessage(Protocol): def body(self) -> bytes: