Skip to content

Commit

Permalink
Merge branch 'main' into 66-add-timedelta-as-a-params_dict-workflow-d…
Browse files Browse the repository at this point in the history
…efinition-parameter
  • Loading branch information
lfse-slafleur committed Sep 10, 2024
2 parents 41073fb + 38715d2 commit 74d86fb
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 9 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ classifiers = [

dependencies = [
"aio-pika ~= 9.4.2",
"omotes-sdk-protocol ~= 0.1.3",
"omotes-sdk-protocol ~= 0.1.4",
"pamqp ~= 3.3.0",
"celery ~= 5.3.6",
"typing-extensions ~= 4.11.0",
"streamcapture ~= 1.2.4",
Expand Down
102 changes: 98 additions & 4 deletions src/omotes_sdk/internal/common/broker_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from functools import partial
import threading
from types import TracebackType
from typing import Callable, Optional, Dict, Type, TypedDict
from typing import Callable, Optional, Dict, Type, TypedDict, cast
from datetime import timedelta

from aio_pika import connect_robust, Message, DeliveryMode
from aio_pika.abc import (
Expand All @@ -17,6 +18,8 @@
AbstractIncomingMessage,
AbstractExchange,
)
from aio_pika.exceptions import ChannelClosed
from pamqp.common import Arguments

from omotes_sdk.config import RabbitMQConfig

Expand Down Expand Up @@ -113,6 +116,55 @@ def to_argument(self) -> AioPikaQueueTypeArguments:
return result


@dataclass()
class QueueMessageTTLArguments():
"""Construct additional time-to-live arguments when declaring a queue."""

queue_ttl: Optional[timedelta] = None
"""Expires and deletes the queue after a period of time when it is not used.
The timedelta must be convertible into a positive integer.
Ref: https://www.rabbitmq.com/docs/ttl#queue-ttl"""
message_ttl: Optional[timedelta] = None
"""Expires and deletes the message within the queue after the defined TTL.
The timedelta must be convertible into a non-negative integer.
Ref: https://www.rabbitmq.com/docs/ttl#per-queue-message-ttl"""
dead_letter_routing_key: Optional[str] = None
"""When specified, the expired message is republished to the designated dead letter queue.
If not set, the message's own routing key is used.
Ref: https://www.rabbitmq.com/docs/dlx#routing"""
dead_letter_exchange: Optional[str] = None
"""Dead letter exchange name.
Ref: https://www.rabbitmq.com/docs/dlx"""

def to_argument(self) -> Arguments:
"""Convert the time-to-live variables to the aio-pika `declare_queue` keyword arguments.
:return: The time-to-live keyword arguments in AMQP method arguments data type.
"""
arguments: Arguments = {}
# Ensure this is not None to avoid typecheck error.
arguments = cast(dict, arguments)

if self.queue_ttl is not None:
if self.queue_ttl <= timedelta(0):
raise ValueError("queue_ttl must be a positive value, "
+ f"{self.queue_ttl} received.")
arguments["x-expires"] = int(self.queue_ttl.total_seconds() * 1000)
if self.message_ttl is not None:
if self.message_ttl < timedelta(0):
raise ValueError("message_ttl can not be a negative value, "
+ f"{self.message_ttl} received.")
if self.queue_ttl is not None and self.message_ttl > self.queue_ttl:
# Raise an error as it serves no purpose.
raise ValueError("message_ttl shall be smaller or equal to queue_ttl.")
arguments["x-message-ttl"] = int(self.message_ttl.total_seconds() * 1000)
if self.dead_letter_routing_key is not None:
arguments["x-dead-letter-routing-key"] = str(self.dead_letter_routing_key)
if self.dead_letter_exchange is not None:
arguments["x-dead-letter-exchange"] = str(self.dead_letter_exchange)
return arguments


class BrokerInterface(threading.Thread):
"""Interface to RabbitMQ using aiopika."""

Expand Down Expand Up @@ -222,6 +274,7 @@ async def _declare_queue(
queue_type: AMQPQueueType,
bind_to_routing_key: Optional[str] = None,
exchange_name: Optional[str] = None,
queue_message_ttl: Optional[QueueMessageTTLArguments] = None
) -> AbstractQueue:
"""Declare an AMQP queue.
Expand All @@ -231,15 +284,26 @@ async def _declare_queue(
key of the queue name. If none, the queue is only bound to the name of the queue.
If not none, then the exchange_name must be set as well.
:param exchange_name: Name of the exchange on which the messages will be published.
:param queue_message_ttl: Additional arguments to specify queue or message TTL.
"""
if bind_to_routing_key is not None and exchange_name is None:
raise RuntimeError(
f"Routing key for binding was set to {bind_to_routing_key} but no "
f"exchange name was provided."
)

logger.info("Declaring queue %s as %s", queue_name, queue_type)
queue = await self._channel.declare_queue(queue_name, **queue_type.to_argument())
if queue_message_ttl is not None:
ttl_arguments = queue_message_ttl.to_argument()
else:
ttl_arguments = None

logger.info("Declaring queue %s as %s with arguments as %s",
queue_name,
queue_type,
ttl_arguments)
queue = await self._channel.declare_queue(queue_name,
**queue_type.to_argument(),
arguments=ttl_arguments)

if exchange_name is not None:
if exchange_name not in self._exchanges:
Expand All @@ -260,6 +324,7 @@ async def _declare_queue_and_add_subscription(
bind_to_routing_key: Optional[str] = None,
exchange_name: Optional[str] = None,
delete_after_messages: Optional[int] = None,
queue_message_ttl: Optional[QueueMessageTTLArguments] = None
) -> None:
"""Declare an AMQP queue and subscribe to the messages.
Expand All @@ -273,6 +338,7 @@ async def _declare_queue_and_add_subscription(
:param exchange_name: Name of the exchange on which the messages will be published.
:param delete_after_messages: Delete the subscription & queue after this limit of messages
have been successfully processed.
:param queue_message_ttl: Additional arguments to specify queue or message TTL.
"""
if queue_name in self._queue_subscription_consumer_by_name:
logger.error(
Expand All @@ -282,7 +348,7 @@ async def _declare_queue_and_add_subscription(
raise RuntimeError(f"Queue subscription for {queue_name} already exists.")

queue = await self._declare_queue(
queue_name, queue_type, bind_to_routing_key, exchange_name
queue_name, queue_type, bind_to_routing_key, exchange_name, queue_message_ttl
)

queue_consumer = QueueSubscriptionConsumer(
Expand All @@ -296,6 +362,19 @@ async def _declare_queue_and_add_subscription(
)
self._queue_subscription_tasks[queue_name] = queue_subscription_task

async def _queue_exists(self, queue_name: str) -> bool:
"""Check if the queue exists.
:param queue_name: Name of the queue to be checked.
"""
try:
await self._channel.get_queue(queue_name, ensure=True)
logger.info("The %s queue exists", queue_name)
return True
except ChannelClosed as err:
logger.warning(err)
return False

async def _remove_queue_subscription(self, queue_name: str) -> None:
"""Remove subscription from queue and delete the queue if one exists.
Expand Down Expand Up @@ -393,6 +472,7 @@ def declare_queue(
queue_type: AMQPQueueType,
bind_to_routing_key: Optional[str] = None,
exchange_name: Optional[str] = None,
queue_message_ttl: Optional[QueueMessageTTLArguments] = None
) -> None:
"""Declare an AMQP queue.
Expand All @@ -402,13 +482,15 @@ def declare_queue(
key of the queue name. If none, the queue is only bound to the name of the queue.
If not none, then the exchange_name must be set as well.
:param exchange_name: Name of the exchange on which the messages will be published.
:param queue_message_ttl: Additional arguments to specify queue or message TTL.
"""
asyncio.run_coroutine_threadsafe(
self._declare_queue(
queue_name=queue_name,
queue_type=queue_type,
bind_to_routing_key=bind_to_routing_key,
exchange_name=exchange_name,
queue_message_ttl=queue_message_ttl,
),
self._loop,
).result()
Expand All @@ -421,6 +503,7 @@ def declare_queue_and_add_subscription(
bind_to_routing_key: Optional[str] = None,
exchange_name: Optional[str] = None,
delete_after_messages: Optional[int] = None,
queue_message_ttl: Optional[QueueMessageTTLArguments] = None
) -> None:
"""Declare an AMQP queue and subscribe to the messages.
Expand All @@ -433,6 +516,7 @@ def declare_queue_and_add_subscription(
:param exchange_name: Name of the exchange on which the messages will be published.
:param delete_after_messages: Delete the subscription & queue after this limit of messages
have been successfully processed.
:param queue_message_ttl: Additional arguments to specify queue or message TTL.
"""
asyncio.run_coroutine_threadsafe(
self._declare_queue_and_add_subscription(
Expand All @@ -442,10 +526,20 @@ def declare_queue_and_add_subscription(
bind_to_routing_key=bind_to_routing_key,
exchange_name=exchange_name,
delete_after_messages=delete_after_messages,
queue_message_ttl=queue_message_ttl,
),
self._loop,
).result()

def queue_exists(self, queue_name: str) -> bool:
"""Check if the queue exists.
:param queue_name: Name of the queue to be checked.
"""
return asyncio.run_coroutine_threadsafe(
self._queue_exists(queue_name=queue_name), self._loop
).result()

def remove_queue_subscription(self, queue_name: str) -> None:
"""Remove subscription from queue and delete the queue if one exists.
Expand Down
84 changes: 80 additions & 4 deletions src/omotes_sdk/omotes_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from datetime import timedelta
from typing import Callable, Optional, Union

from omotes_sdk.internal.common.broker_interface import BrokerInterface, AMQPQueueType
from omotes_sdk.internal.common.broker_interface import (
BrokerInterface,
AMQPQueueType,
QueueMessageTTLArguments
)
from omotes_sdk.config import RabbitMQConfig
from omotes_sdk_protocol.job_pb2 import (
JobResult,
Expand Down Expand Up @@ -104,6 +108,9 @@ class OmotesInterface:
"""How long the SDK should wait for the first reply when requesting the current workflow
definitions from the orchestrator."""

JOB_RESULT_MESSAGE_TTL: timedelta = timedelta(hours=48)
"""Default value of job result message TTL."""

def __init__(
self,
rabbitmq_config: RabbitMQConfig,
Expand Down Expand Up @@ -174,6 +181,8 @@ def connect_to_submitted_job(
callback_on_progress_update: Optional[Callable[[Job, JobProgressUpdate], None]],
callback_on_status_update: Optional[Callable[[Job, JobStatusUpdate], None]],
auto_disconnect_on_result: bool,
auto_dead_letter_after_ttl: Optional[timedelta] = JOB_RESULT_MESSAGE_TTL,
reconnect: bool = True
) -> None:
"""(Re)connect to the running job.
Expand All @@ -187,14 +196,68 @@ def connect_to_submitted_job(
:param auto_disconnect_on_result: Remove/disconnect from all queues pertaining to this job
once the result is received and handled without exceptions through
`callback_on_finished`.
:param auto_dead_letter_after_ttl: When erroneous situations occur (e.g. client is offline),
the job result message (if available) will be dead lettered after the given TTL,
and all queues of this job will be removed subsequently. Default to 48 hours if unset.
Set to `None` to turn off auto dead letter and clean up, but be aware this may lead to
messages and queues to be stored in RabbitMQ indefinitely
(which uses up memory & disk space).
:param reconnect: When True, first check the job queues status and raise an error if not
exist. Default to True.
"""
job_results_queue_name = OmotesQueueNames.job_results_queue_name(job.id)
job_progress_queue_name = OmotesQueueNames.job_progress_queue_name(job.id)
job_status_queue_name = OmotesQueueNames.job_status_queue_name(job.id)

if reconnect:
logger.info("Reconnect to the submitted job %s is set to True. "
+ "Checking job queues status...", job.id)
if not self.broker_if.queue_exists(job_results_queue_name):
raise RuntimeError(
f"The {job_results_queue_name} queue does not exist or is removed. "
"Abort reconnecting to the queue."
)
if (callback_on_progress_update
and not self.broker_if.queue_exists(job_progress_queue_name)):
raise RuntimeError(
f"The {job_progress_queue_name} queue does not exist or is removed. "
"Abort reconnecting to the queue."
)
if (callback_on_status_update
and not self.broker_if.queue_exists(job_status_queue_name)):
raise RuntimeError(
f"The {job_status_queue_name} queue does not exist or is removed. "
"Abort reconnecting to the queue."
)

if auto_disconnect_on_result:
logger.info("Connecting to update for job %s with auto disconnect on result", job.id)
auto_disconnect_handler = self._autodelete_progres_status_queues_on_result
else:
logger.info("Connecting to update for job %s and expect manual disconnect", job.id)
auto_disconnect_handler = None

# TODO: handle reconnection after the message is dead lettered but queue still exists.

if auto_dead_letter_after_ttl is not None:
message_ttl = auto_dead_letter_after_ttl
queue_ttl = auto_dead_letter_after_ttl * 2
logger.info("Auto dead letter and cleanup on error after TTL is set. "
+ "The leftover job result message will be dead lettered after %s, "
+ "and leftover job queues will be discarded after %s.",
message_ttl, queue_ttl)
job_result_queue_message_ttl = QueueMessageTTLArguments(
queue_ttl=queue_ttl,
message_ttl=message_ttl,
dead_letter_routing_key=OmotesQueueNames.job_result_dead_letter_queue_name(),
dead_letter_exchange=OmotesQueueNames.omotes_exchange_name())
job_progress_status_queue_ttl = QueueMessageTTLArguments(queue_ttl=queue_ttl)
else:
logger.info("Auto dead letter and cleanup on error after TTL is not set. "
+ "Manual cleanup on leftover job queues and messages might be required.")
job_result_queue_message_ttl = None
job_progress_status_queue_ttl = None

callback_handler = JobSubmissionCallbackHandler(
job,
callback_on_finished,
Expand All @@ -204,25 +267,28 @@ def connect_to_submitted_job(
)

self.broker_if.declare_queue_and_add_subscription(
queue_name=OmotesQueueNames.job_results_queue_name(job.id),
queue_name=job_results_queue_name,
callback_on_message=callback_handler.callback_on_finished_wrapped,
queue_type=AMQPQueueType.DURABLE,
exchange_name=OmotesQueueNames.omotes_exchange_name(),
delete_after_messages=1,
queue_message_ttl=job_result_queue_message_ttl
)
if callback_on_progress_update:
self.broker_if.declare_queue_and_add_subscription(
queue_name=OmotesQueueNames.job_progress_queue_name(job.id),
queue_name=job_progress_queue_name,
callback_on_message=callback_handler.callback_on_progress_update_wrapped,
queue_type=AMQPQueueType.DURABLE,
exchange_name=OmotesQueueNames.omotes_exchange_name(),
queue_message_ttl=job_progress_status_queue_ttl
)
if callback_on_status_update:
self.broker_if.declare_queue_and_add_subscription(
queue_name=OmotesQueueNames.job_status_queue_name(job.id),
queue_name=job_status_queue_name,
callback_on_message=callback_handler.callback_on_status_update_wrapped,
queue_type=AMQPQueueType.DURABLE,
exchange_name=OmotesQueueNames.omotes_exchange_name(),
queue_message_ttl=job_progress_status_queue_ttl
)

def submit_job(
Expand All @@ -235,6 +301,7 @@ def submit_job(
callback_on_progress_update: Optional[Callable[[Job, JobProgressUpdate], None]],
callback_on_status_update: Optional[Callable[[Job, JobStatusUpdate], None]],
auto_disconnect_on_result: bool,
auto_dead_letter_after_ttl: Optional[timedelta] = JOB_RESULT_MESSAGE_TTL
) -> Job:
"""Submit a new job and connect to progress and status updates and the job result.
Expand All @@ -252,6 +319,12 @@ def submit_job(
:param auto_disconnect_on_result: Remove/disconnect from all queues pertaining to this job
once the result is received and handled without exceptions through
`callback_on_finished`.
:param auto_dead_letter_after_ttl: When erroneous situations occur (e.g. client is offline),
the job result message (if available) will be dead lettered after the given TTL,
and all queues of this job will be removed subsequently. Default to 48 hours if unset.
Set to `None` to turn off auto dead letter and clean up, but be aware this may lead to
messages and queues to be stored in RabbitMQ indefinitely
(which uses up memory & disk space).
:raises UnknownWorkflowException: If `workflow_type` is unknown as a possible workflow in
this interface.
:return: The job handle which is created. This object needs to be saved persistently by the
Expand All @@ -263,13 +336,16 @@ def submit_job(
raise UnknownWorkflowException()

job = Job(id=uuid.uuid4(), workflow_type=workflow_type)
reconnect = False
logger.info("Submitting job %s", job.id)
self.connect_to_submitted_job(
job,
callback_on_finished,
callback_on_progress_update,
callback_on_status_update,
auto_disconnect_on_result,
auto_dead_letter_after_ttl,
reconnect
)

if job_timeout is not None:
Expand Down
Loading

0 comments on commit 74d86fb

Please sign in to comment.