Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle connection exceptions with MQ #526

Merged
merged 13 commits into from
Nov 9, 2023
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ exclude =
# Add here agents requirements (semicolon/line-separated)
agent =
Werkzeug==2.3.7
aio-pika==6.8.1
aio-pika
flask
importlib-metadata; python_version<"3.8"
jsonschema>=4.4.0
Expand Down
37 changes: 28 additions & 9 deletions src/ostorlab/agent/mixins/agent_mq_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Optional

import aio_pika
import tenacity

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,15 +51,19 @@ def __init__(
self._get_channel, max_size=64, loop=self._loop
)

async def _get_connection(self) -> aio_pika.Connection:
return await aio_pika.connect_robust(url=self._url, loop=self._loop)
async def _get_connection(self) -> aio_pika.abc.AbstractRobustConnection:
return await aio_pika.connect_robust(
url=self._url, loop=self._loop, fail_fast=False
)

async def _get_channel(self) -> aio_pika.Channel:
async with self._connection_pool.acquire() as connection:
channel: aio_pika.Channel = await connection.channel()
return channel

async def _get_exchange(self, channel: aio_pika.Channel) -> aio_pika.Exchange:
async def _get_exchange(
self, channel: aio_pika.abc.AbstractChannel
) -> aio_pika.abc.AbstractExchange:
return await channel.declare_exchange(
self._topic,
type=aio_pika.ExchangeType.TOPIC,
Expand Down Expand Up @@ -91,7 +96,9 @@ async def mq_run(self, delete_queue_first: bool = False) -> None:
await self._queue.consume(self._mq_process_message, no_ack=False)

async def _declare_mq_queue(
self, channel: aio_pika.Channel, delete_queue_first: bool = False
self,
channel: aio_pika.abc.AbstractRobustChannel,
delete_queue_first: bool = False,
) -> None:
"""Declare the MQ queue on a given channel.
The queue is durable, re-declaring the queue will return the same queue
Expand Down Expand Up @@ -122,13 +129,21 @@ async def _declare_mq_queue(
for k in self._keys:
await self._queue.bind(exchange, k)

async def _mq_process_message(self, message: aio_pika.IncomingMessage) -> None:
async def _mq_process_message(
self, message: aio_pika.abc.AbstractIncomingMessage
) -> None:
"""Consumes the MQ messages and calls the process message callback."""
logger.debug("incoming pika message received")
async with message.process(requeue=True, reject_on_redelivered=True):
await self._loop.run_in_executor(
self._executor, self.process_message, message.routing_key, message.body
)
try:
async with message.process(requeue=True, reject_on_redelivered=True):
await self._loop.run_in_executor(
self._executor,
self.process_message,
message.routing_key,
message.body,
)
except aio_pika.exceptions.ChannelInvalidStateError:
await self.mq_run()
benyissa marked this conversation as resolved.
Show resolved Hide resolved

def process_message(self, selector: str, message: bytes) -> None:
"""Callback to implement to process the MQ messages received."""
Expand All @@ -151,6 +166,9 @@ async def async_mq_send_message(
)
await exchange.publish(routing_key=key, message=pika_message)

@tenacity.retry(
retry=tenacity.retry_if_exception_type(aio_pika.exceptions.ConnectionClosed),
)
def mq_send_message(
self, key: str, message: bytes, message_priority: Optional[int] = None
) -> None:
Expand All @@ -161,6 +179,7 @@ def mq_send_message(
message_priority: the priority to use for the message default is 0.
"""
logger.debug("sending %s to %s", message, key)

if not self._loop.is_running():
self._loop.run_until_complete(
self.async_mq_send_message(key, message, message_priority)
Expand Down
27 changes: 23 additions & 4 deletions tests/agent/mixins/agent_mq_mixin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
class Agent(agent_mq_mixin.AgentMQMixin):
"""Helper class to test MQ implementation of send and process messages."""

def __init__(self, name="test1", keys=("a.#",)):
url = "amqp://guest:guest@localhost:5672/"
def __init__(
self, name="test1", keys=("a.#",), url="amqp://guest:guest@localhost:5672/"
):
topic = "test_topic"
super().__init__(name=name, keys=keys, url=url, topic=topic)
self.stub = None
Expand All @@ -24,8 +25,10 @@ def process_message(self, selector, message):
self.stub(message)

@classmethod
def create(cls, stub, name="test1", keys=("a.#",)):
instance = cls(name=name, keys=keys)
def create(
cls, stub, name="test1", keys=("a.#",), url="amqp://guest:guest@localhost:5672/"
):
instance = cls(name=name, keys=keys, url=url)
instance.stub = stub
return instance

Expand All @@ -44,6 +47,22 @@ async def testClient_whenMessageIsSent_processMessageIsCalled(mocker, mq_service
assert stub.call_count == 1


@pytest.mark.asyncio
async def testConnection_whenConnectionException_reconnectIsCalled(mocker):
stub = mocker.stub(name="test1")
client = Agent.create(
stub, name="test1", keys=["d.#"], url="amqp://wrong:wrong@localhost:5672/"
)
task = asyncio.create_task(client.mq_init())

try:
await asyncio.wait_for(task, timeout=10)
benyissa marked this conversation as resolved.
Show resolved Hide resolved
except asyncio.TimeoutError:
pass
benyissa marked this conversation as resolved.
Show resolved Hide resolved

assert task.done() is True


@pytest.mark.skip(reason="Needs debugging why MQ is not resending the message")
@pytest.mark.asyncio
@pytest.mark.docker
Expand Down
Loading