Skip to content

Commit

Permalink
ensure unsubscribe consumer is only unsubscribing the right consumer …
Browse files Browse the repository at this point in the history
…and allow multiple consumers
  • Loading branch information
sanderegg committed Nov 26, 2024
1 parent 865a757 commit bf472ba
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
3 changes: 3 additions & 0 deletions packages/service-library/src/servicelib/rabbitmq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
RPCNotInitializedError,
RPCServerError,
)
from ._models import ConsumerTag, QueueName
from ._rpc_router import RPCRouter
from ._utils import is_rabbitmq_responsive, wait_till_rabbitmq_responsive

Expand All @@ -23,6 +24,8 @@
"RPCRouter",
"RPCServerError",
"wait_till_rabbitmq_responsive",
"QueueName",
"ConsumerTag",
)

# nopycln: file
46 changes: 27 additions & 19 deletions packages/service-library/src/servicelib/rabbitmq/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from dataclasses import dataclass, field
from functools import partial
from typing import Final
from uuid import uuid4

import aio_pika
from pydantic import NonNegativeInt

from ..logging_utils import log_catch, log_context
from ._client_base import RabbitMQClientBase
from ._models import MessageHandler, RabbitMessage
from ._models import ConsumerTag, MessageHandler, QueueName, RabbitMessage
from ._utils import (
RABBIT_QUEUE_MESSAGE_DEFAULT_TTL_MS,
declare_queue,
Expand Down Expand Up @@ -139,7 +140,7 @@ async def _get_channel(self) -> aio_pika.abc.AbstractChannel:
return channel

async def _get_consumer_tag(self, exchange_name) -> str:
return f"{get_rabbitmq_client_unique_name(self.client_name)}_{exchange_name}"
return f"{get_rabbitmq_client_unique_name(self.client_name)}_{exchange_name}_{uuid4()}"

async def subscribe(
self,
Expand All @@ -151,7 +152,7 @@ async def subscribe(
message_ttl: NonNegativeInt = RABBIT_QUEUE_MESSAGE_DEFAULT_TTL_MS,
unexpected_error_retry_delay_s: float = _DEFAULT_UNEXPECTED_ERROR_RETRY_DELAY_S,
unexpected_error_max_attempts: int = _DEFAULT_UNEXPECTED_ERROR_MAX_ATTEMPTS,
) -> str:
) -> tuple[QueueName, ConsumerTag]:
"""subscribe to exchange_name calling ``message_handler`` for every incoming message
- exclusive_queue: True means that every instance of this application will
receive the incoming messages
Expand Down Expand Up @@ -238,14 +239,14 @@ async def subscribe(
)
await delayed_queue.bind(delayed_exchange)

_consumer_tag = await self._get_consumer_tag(exchange_name)
consumer_tag = await self._get_consumer_tag(exchange_name)
await queue.consume(
partial(_on_message, message_handler, unexpected_error_max_attempts),
exclusive=exclusive_queue,
consumer_tag=_consumer_tag,
consumer_tag=consumer_tag,
)
output: str = queue.name
return output
return output, consumer_tag

async def add_topics(
self,
Expand Down Expand Up @@ -300,13 +301,16 @@ async def remove_topics(

async def unsubscribe(
self,
queue_name: str,
queue_name: QueueName,
) -> None:
assert self._channel_pool # nosec
async with self._channel_pool.acquire() as channel:
queue = await channel.get_queue(queue_name)
# NOTE: we force delete here
await queue.delete(if_unused=False, if_empty=False)
"""This will delete the queue if there are no consumers left"""
assert self._connection_pool # nosec
if not self._connection_pool.is_closed:
assert self._channel_pool # nosec
async with self._channel_pool.acquire() as channel:
queue = await channel.get_queue(queue_name)
# NOTE: we force delete here
await queue.delete(if_unused=False, if_empty=False)

async def publish(self, exchange_name: str, message: RabbitMessage) -> None:
"""publish message in the exchange exchange_name.
Expand All @@ -333,10 +337,14 @@ async def publish(self, exchange_name: str, message: RabbitMessage) -> None:
routing_key=message.routing_key() or "",
)

async def unsubscribe_consumer(self, exchange_name: str):
assert self._channel_pool # nosec
async with self._channel_pool.acquire() as channel:
queue_name = exchange_name
queue = await channel.get_queue(queue_name)
_consumer_tag = await self._get_consumer_tag(exchange_name)
await queue.cancel(_consumer_tag)
async def unsubscribe_consumer(
self, queue_name: QueueName, consumer_tag: ConsumerTag
) -> None:
"""This will only remove the consumers without deleting the queue"""
assert self._connection_pool # nosec
if not self._connection_pool.is_closed:
assert self._channel_pool # nosec
async with self._channel_pool.acquire() as channel:
assert isinstance(channel, aio_pika.RobustChannel) # nosec
queue = await channel.get_queue(queue_name)
await queue.cancel(consumer_tag)
5 changes: 4 additions & 1 deletion packages/service-library/src/servicelib/rabbitmq/_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Awaitable, Callable
from typing import Any, Protocol
from typing import Any, Protocol, TypeAlias

from models_library.basic_types import ConstrainedStr
from models_library.rabbitmq_basic_types import (
Expand All @@ -11,6 +11,9 @@

MessageHandler = Callable[[Any], Awaitable[bool]]

QueueName: TypeAlias = str
ConsumerTag: TypeAlias = str


class RabbitMessage(Protocol):
def body(self) -> bytes:
Expand Down

0 comments on commit bf472ba

Please sign in to comment.