Skip to content

Commit

Permalink
Merge pull request #526 from Ostorlab/refix_mq_connection_errors
Browse files Browse the repository at this point in the history
Handle connection exceptions with MQ
  • Loading branch information
benyissa authored Nov 9, 2023
2 parents d37ad79 + 9f6022c commit fc2713b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ exclude =
# Add here agents requirements (semicolon/line-separated)
agent =
Werkzeug==2.3.7
aio-pika==6.8.1
aio-pika
flask
importlib-metadata; python_version<"3.8"
jsonschema>=4.4.0
Expand Down
38 changes: 29 additions & 9 deletions src/ostorlab/agent/mixins/agent_mq_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Optional

import aio_pika
import tenacity

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,15 +51,19 @@ def __init__(
self._get_channel, max_size=64, loop=self._loop
)

async def _get_connection(self) -> aio_pika.Connection:
return await aio_pika.connect_robust(url=self._url, loop=self._loop)
async def _get_connection(self) -> aio_pika.abc.AbstractRobustConnection:
return await aio_pika.connect_robust(
url=self._url, loop=self._loop, fail_fast=False
)

async def _get_channel(self) -> aio_pika.Channel:
async with self._connection_pool.acquire() as connection:
channel: aio_pika.Channel = await connection.channel()
return channel

async def _get_exchange(self, channel: aio_pika.Channel) -> aio_pika.Exchange:
async def _get_exchange(
self, channel: aio_pika.abc.AbstractChannel
) -> aio_pika.abc.AbstractExchange:
return await channel.declare_exchange(
self._topic,
type=aio_pika.ExchangeType.TOPIC,
Expand Down Expand Up @@ -91,7 +96,9 @@ async def mq_run(self, delete_queue_first: bool = False) -> None:
await self._queue.consume(self._mq_process_message, no_ack=False)

async def _declare_mq_queue(
self, channel: aio_pika.Channel, delete_queue_first: bool = False
self,
channel: aio_pika.abc.AbstractRobustChannel,
delete_queue_first: bool = False,
) -> None:
"""Declare the MQ queue on a given channel.
The queue is durable, re-declaring the queue will return the same queue
Expand Down Expand Up @@ -122,13 +129,22 @@ async def _declare_mq_queue(
for k in self._keys:
await self._queue.bind(exchange, k)

async def _mq_process_message(self, message: aio_pika.IncomingMessage) -> None:
async def _mq_process_message(
self, message: aio_pika.abc.AbstractIncomingMessage
) -> None:
"""Consumes the MQ messages and calls the process message callback."""
logger.debug("incoming pika message received")
async with message.process(requeue=True, reject_on_redelivered=True):
await self._loop.run_in_executor(
self._executor, self.process_message, message.routing_key, message.body
)
try:
async with message.process(requeue=True, reject_on_redelivered=True):
await self._loop.run_in_executor(
self._executor,
self.process_message,
message.routing_key,
message.body,
)
except aio_pika.exceptions.ChannelInvalidStateError:
logger.warning("The channel is closed unexpectedly.")
await self.mq_run()

def process_message(self, selector: str, message: bytes) -> None:
"""Callback to implement to process the MQ messages received."""
Expand All @@ -151,6 +167,9 @@ async def async_mq_send_message(
)
await exchange.publish(routing_key=key, message=pika_message)

@tenacity.retry(
retry=tenacity.retry_if_exception_type(aio_pika.exceptions.ConnectionClosed),
)
def mq_send_message(
self, key: str, message: bytes, message_priority: Optional[int] = None
) -> None:
Expand All @@ -161,6 +180,7 @@ def mq_send_message(
message_priority: the priority to use for the message default is 0.
"""
logger.debug("sending %s to %s", message, key)

if not self._loop.is_running():
self._loop.run_until_complete(
self.async_mq_send_message(key, message, message_priority)
Expand Down
27 changes: 23 additions & 4 deletions tests/agent/mixins/agent_mq_mixin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
class Agent(agent_mq_mixin.AgentMQMixin):
"""Helper class to test MQ implementation of send and process messages."""

def __init__(self, name="test1", keys=("a.#",)):
url = "amqp://guest:guest@localhost:5672/"
def __init__(
self, name="test1", keys=("a.#",), url="amqp://guest:guest@localhost:5672/"
):
topic = "test_topic"
super().__init__(name=name, keys=keys, url=url, topic=topic)
self.stub = None
Expand All @@ -24,8 +25,10 @@ def process_message(self, selector, message):
self.stub(message)

@classmethod
def create(cls, stub, name="test1", keys=("a.#",)):
instance = cls(name=name, keys=keys)
def create(
cls, stub, name="test1", keys=("a.#",), url="amqp://guest:guest@localhost:5672/"
):
instance = cls(name=name, keys=keys, url=url)
instance.stub = stub
return instance

Expand All @@ -44,6 +47,22 @@ async def testClient_whenMessageIsSent_processMessageIsCalled(mocker, mq_service
assert stub.call_count == 1


@pytest.mark.asyncio
async def testConnection_whenConnectionException_reconnectIsCalled(mocker):
stub = mocker.stub(name="test1")
client = Agent.create(
stub, name="test1", keys=["d.#"], url="amqp://wrong:wrong@localhost:5672/"
)
task = asyncio.create_task(client.mq_init())

try:
await asyncio.wait_for(task, timeout=10)
except asyncio.TimeoutError:
pass

assert task.done() is True


@pytest.mark.skip(reason="Needs debugging why MQ is not resending the message")
@pytest.mark.asyncio
@pytest.mark.docker
Expand Down

0 comments on commit fc2713b

Please sign in to comment.