diff --git a/setup.cfg b/setup.cfg index a76dc86d9..b2f6504c2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -81,7 +81,7 @@ exclude = # Add here agents requirements (semicolon/line-separated) agent = Werkzeug==2.3.7 - aio-pika==6.8.1 + aio-pika flask importlib-metadata; python_version<"3.8" jsonschema>=4.4.0 diff --git a/src/ostorlab/agent/mixins/agent_mq_mixin.py b/src/ostorlab/agent/mixins/agent_mq_mixin.py index 5fca41087..eccb2908d 100644 --- a/src/ostorlab/agent/mixins/agent_mq_mixin.py +++ b/src/ostorlab/agent/mixins/agent_mq_mixin.py @@ -9,6 +9,7 @@ from typing import List, Optional import aio_pika +import tenacity logger = logging.getLogger(__name__) @@ -50,15 +51,19 @@ def __init__( self._get_channel, max_size=64, loop=self._loop ) - async def _get_connection(self) -> aio_pika.Connection: - return await aio_pika.connect_robust(url=self._url, loop=self._loop) + async def _get_connection(self) -> aio_pika.abc.AbstractRobustConnection: + return await aio_pika.connect_robust( + url=self._url, loop=self._loop, fail_fast=False + ) async def _get_channel(self) -> aio_pika.Channel: async with self._connection_pool.acquire() as connection: channel: aio_pika.Channel = await connection.channel() return channel - async def _get_exchange(self, channel: aio_pika.Channel) -> aio_pika.Exchange: + async def _get_exchange( + self, channel: aio_pika.abc.AbstractChannel + ) -> aio_pika.abc.AbstractExchange: return await channel.declare_exchange( self._topic, type=aio_pika.ExchangeType.TOPIC, @@ -91,7 +96,9 @@ async def mq_run(self, delete_queue_first: bool = False) -> None: await self._queue.consume(self._mq_process_message, no_ack=False) async def _declare_mq_queue( - self, channel: aio_pika.Channel, delete_queue_first: bool = False + self, + channel: aio_pika.abc.AbstractRobustChannel, + delete_queue_first: bool = False, ) -> None: """Declare the MQ queue on a given channel. The queue is durable, re-declaring the queue will return the same queue @@ -122,13 +129,22 @@ async def _declare_mq_queue( for k in self._keys: await self._queue.bind(exchange, k) - async def _mq_process_message(self, message: aio_pika.IncomingMessage) -> None: + async def _mq_process_message( + self, message: aio_pika.abc.AbstractIncomingMessage + ) -> None: """Consumes the MQ messages and calls the process message callback.""" logger.debug("incoming pika message received") - async with message.process(requeue=True, reject_on_redelivered=True): - await self._loop.run_in_executor( - self._executor, self.process_message, message.routing_key, message.body - ) + try: + async with message.process(requeue=True, reject_on_redelivered=True): + await self._loop.run_in_executor( + self._executor, + self.process_message, + message.routing_key, + message.body, + ) + except aio_pika.exceptions.ChannelInvalidStateError: + logger.warning("The channel is closed unexpectedly.") + await self.mq_run() def process_message(self, selector: str, message: bytes) -> None: """Callback to implement to process the MQ messages received.""" @@ -151,6 +167,9 @@ async def async_mq_send_message( ) await exchange.publish(routing_key=key, message=pika_message) + @tenacity.retry( + retry=tenacity.retry_if_exception_type(aio_pika.exceptions.ConnectionClosed), + ) def mq_send_message( self, key: str, message: bytes, message_priority: Optional[int] = None ) -> None: @@ -161,6 +180,7 @@ def mq_send_message( message_priority: the priority to use for the message default is 0. """ logger.debug("sending %s to %s", message, key) + if not self._loop.is_running(): self._loop.run_until_complete( self.async_mq_send_message(key, message, message_priority) diff --git a/tests/agent/mixins/agent_mq_mixin_test.py b/tests/agent/mixins/agent_mq_mixin_test.py index 25982a02a..0319b2fb9 100644 --- a/tests/agent/mixins/agent_mq_mixin_test.py +++ b/tests/agent/mixins/agent_mq_mixin_test.py @@ -12,8 +12,9 @@ class Agent(agent_mq_mixin.AgentMQMixin): """Helper class to test MQ implementation of send and process messages.""" - def __init__(self, name="test1", keys=("a.#",)): - url = "amqp://guest:guest@localhost:5672/" + def __init__( + self, name="test1", keys=("a.#",), url="amqp://guest:guest@localhost:5672/" + ): topic = "test_topic" super().__init__(name=name, keys=keys, url=url, topic=topic) self.stub = None @@ -24,8 +25,10 @@ def process_message(self, selector, message): self.stub(message) @classmethod - def create(cls, stub, name="test1", keys=("a.#",)): - instance = cls(name=name, keys=keys) + def create( + cls, stub, name="test1", keys=("a.#",), url="amqp://guest:guest@localhost:5672/" + ): + instance = cls(name=name, keys=keys, url=url) instance.stub = stub return instance @@ -44,6 +47,22 @@ async def testClient_whenMessageIsSent_processMessageIsCalled(mocker, mq_service assert stub.call_count == 1 +@pytest.mark.asyncio +async def testConnection_whenConnectionException_reconnectIsCalled(mocker): + stub = mocker.stub(name="test1") + client = Agent.create( + stub, name="test1", keys=["d.#"], url="amqp://wrong:wrong@localhost:5672/" + ) + task = asyncio.create_task(client.mq_init()) + + try: + await asyncio.wait_for(task, timeout=10) + except asyncio.TimeoutError: + pass + + assert task.done() is True + + @pytest.mark.skip(reason="Needs debugging why MQ is not resending the message") @pytest.mark.asyncio @pytest.mark.docker