diff --git a/src/ostorlab/agent/mixins/agent_mq_mixin.py b/src/ostorlab/agent/mixins/agent_mq_mixin.py index 5fca41087..fdf34735f 100644 --- a/src/ostorlab/agent/mixins/agent_mq_mixin.py +++ b/src/ostorlab/agent/mixins/agent_mq_mixin.py @@ -9,6 +9,7 @@ from typing import List, Optional import aio_pika +import tenacity logger = logging.getLogger(__name__) @@ -50,15 +51,19 @@ def __init__( self._get_channel, max_size=64, loop=self._loop ) - async def _get_connection(self) -> aio_pika.Connection: - return await aio_pika.connect_robust(url=self._url, loop=self._loop) + async def _get_connection(self) -> aio_pika.abc.AbstractRobustConnection: + return await aio_pika.connect_robust( + url=self._url, loop=self._loop, fail_fast=False + ) async def _get_channel(self) -> aio_pika.Channel: async with self._connection_pool.acquire() as connection: channel: aio_pika.Channel = await connection.channel() return channel - async def _get_exchange(self, channel: aio_pika.Channel) -> aio_pika.Exchange: + async def _get_exchange( + self, channel: aio_pika.abc.AbstractChannel + ) -> aio_pika.abc.AbstractExchange: return await channel.declare_exchange( self._topic, type=aio_pika.ExchangeType.TOPIC, @@ -76,6 +81,11 @@ async def mq_init(self, delete_queue_first: bool = False) -> None: channel = await connection.channel() await self._declare_mq_queue(channel, delete_queue_first) + @tenacity.retry( + retry=tenacity.retry_if_exception_type( + aio_pika.exceptions.ChannelInvalidStateError + ), + ) async def mq_run(self, delete_queue_first: bool = False) -> None: """Use a channel to declare the queue, set the listener on the selectors and consume the received messaged. Args: @@ -91,7 +101,9 @@ async def mq_run(self, delete_queue_first: bool = False) -> None: await self._queue.consume(self._mq_process_message, no_ack=False) async def _declare_mq_queue( - self, channel: aio_pika.Channel, delete_queue_first: bool = False + self, + channel: aio_pika.abc.AbstractRobustChannel, + delete_queue_first: bool = False, ) -> None: """Declare the MQ queue on a given channel. The queue is durable, re-declaring the queue will return the same queue @@ -122,7 +134,14 @@ async def _declare_mq_queue( for k in self._keys: await self._queue.bind(exchange, k) - async def _mq_process_message(self, message: aio_pika.IncomingMessage) -> None: + @tenacity.retry( + retry=tenacity.retry_if_exception_type( + aio_pika.exceptions.ChannelInvalidStateError + ), + ) + async def _mq_process_message( + self, message: aio_pika.abc.AbstractIncomingMessage + ) -> None: """Consumes the MQ messages and calls the process message callback.""" logger.debug("incoming pika message received") async with message.process(requeue=True, reject_on_redelivered=True): @@ -151,6 +170,11 @@ async def async_mq_send_message( ) await exchange.publish(routing_key=key, message=pika_message) + @tenacity.retry( + retry=tenacity.retry_if_exception_type( + aio_pika.exceptions.ChannelInvalidStateError + ), + ) def mq_send_message( self, key: str, message: bytes, message_priority: Optional[int] = None ) -> None: