Skip to content

Commit

Permalink
Merge pull request #517 from Ostorlab/fix_mq_connection_errors
Browse files Browse the repository at this point in the history
Catch CONNECTION_EXCEPTIONS with  MQ
  • Loading branch information
3asm authored Nov 2, 2023
2 parents b0260a0 + 18d7547 commit ebd8f78
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ exclude =
# Add here agents requirements (semicolon/line-separated)
agent =
Werkzeug==2.3.7
aio-pika==6.8.1
aio-pika
flask
importlib-metadata; python_version<"3.8"
jsonschema>=4.4.0
Expand Down
29 changes: 24 additions & 5 deletions src/ostorlab/agent/mixins/agent_mq_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Optional

import aio_pika
import tenacity

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,15 +51,19 @@ def __init__(
self._get_channel, max_size=64, loop=self._loop
)

async def _get_connection(self) -> aio_pika.Connection:
return await aio_pika.connect_robust(url=self._url, loop=self._loop)
async def _get_connection(self) -> aio_pika.abc.AbstractRobustConnection:
return await aio_pika.connect_robust(
url=self._url, loop=self._loop, fail_fast=False
)

async def _get_channel(self) -> aio_pika.Channel:
async with self._connection_pool.acquire() as connection:
channel: aio_pika.Channel = await connection.channel()
return channel

async def _get_exchange(self, channel: aio_pika.Channel) -> aio_pika.Exchange:
async def _get_exchange(
self, channel: aio_pika.abc.AbstractChannel
) -> aio_pika.abc.AbstractExchange:
return await channel.declare_exchange(
self._topic,
type=aio_pika.ExchangeType.TOPIC,
Expand All @@ -76,6 +81,11 @@ async def mq_init(self, delete_queue_first: bool = False) -> None:
channel = await connection.channel()
await self._declare_mq_queue(channel, delete_queue_first)

@tenacity.retry(
retry=tenacity.retry_if_exception_type(
aio_pika.exceptions.ChannelInvalidStateError
),
)
async def mq_run(self, delete_queue_first: bool = False) -> None:
"""Use a channel to declare the queue, set the listener on the selectors and consume the received messaged.
Args:
Expand All @@ -91,7 +101,9 @@ async def mq_run(self, delete_queue_first: bool = False) -> None:
await self._queue.consume(self._mq_process_message, no_ack=False)

async def _declare_mq_queue(
self, channel: aio_pika.Channel, delete_queue_first: bool = False
self,
channel: aio_pika.abc.AbstractRobustChannel,
delete_queue_first: bool = False,
) -> None:
"""Declare the MQ queue on a given channel.
The queue is durable, re-declaring the queue will return the same queue
Expand Down Expand Up @@ -122,7 +134,9 @@ async def _declare_mq_queue(
for k in self._keys:
await self._queue.bind(exchange, k)

async def _mq_process_message(self, message: aio_pika.IncomingMessage) -> None:
async def _mq_process_message(
self, message: aio_pika.abc.AbstractIncomingMessage
) -> None:
"""Consumes the MQ messages and calls the process message callback."""
logger.debug("incoming pika message received")
async with message.process(requeue=True, reject_on_redelivered=True):
Expand Down Expand Up @@ -151,6 +165,11 @@ async def async_mq_send_message(
)
await exchange.publish(routing_key=key, message=pika_message)

@tenacity.retry(
retry=tenacity.retry_if_exception_type(
aio_pika.exceptions.ChannelInvalidStateError
),
)
def mq_send_message(
self, key: str, message: bytes, message_priority: Optional[int] = None
) -> None:
Expand Down
25 changes: 21 additions & 4 deletions tests/agent/mixins/agent_mq_mixin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
class Agent(agent_mq_mixin.AgentMQMixin):
"""Helper class to test MQ implementation of send and process messages."""

def __init__(self, name="test1", keys=("a.#",)):
url = "amqp://guest:guest@localhost:5672/"
def __init__(
self, name="test1", keys=("a.#",), url="amqp://guest:guest@localhost:5672/"
):
topic = "test_topic"
super().__init__(name=name, keys=keys, url=url, topic=topic)
self.stub = None
Expand All @@ -24,8 +25,10 @@ def process_message(self, selector, message):
self.stub(message)

@classmethod
def create(cls, stub, name="test1", keys=("a.#",)):
instance = cls(name=name, keys=keys)
def create(
cls, stub, name="test1", keys=("a.#",), url="amqp://guest:guest@localhost:5672/"
):
instance = cls(name=name, keys=keys, url=url)
instance.stub = stub
return instance

Expand All @@ -44,6 +47,20 @@ async def testClient_whenMessageIsSent_processMessageIsCalled(mocker, mq_service
assert stub.call_count == 1


@pytest.mark.asyncio
async def testConnection_whenConnectionException_reconnectIsCalled(mocker):
stub = mocker.stub(name="test1")
client = Agent.create(
stub, name="test1", keys=["d.#"], url="amqp://wrong:wrong@localhost:5672/"
)
task = asyncio.create_task(client.mq_init())
try:
await asyncio.wait_for(task, timeout=10)
except asyncio.TimeoutError:
pass
assert task.done() is True


@pytest.mark.skip(reason="Needs debugging why MQ is not resending the message")
@pytest.mark.asyncio
@pytest.mark.docker
Expand Down

0 comments on commit ebd8f78

Please sign in to comment.