diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 214fc18f..407420ab 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -1,21 +1,18 @@ -from typing import Any, Callable, Protocol +from typing import Any, Callable, Pattern, Protocol RpcSubscriber = Callable[['Communicator', Any], Any] BroadcastSubscriber = Callable[['Communicator', Any, Any, Any, Any], Any] -class Communicator(Protocol): - def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: - ... +class Communicator(Protocol): + def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ... - def add_broadcast_subscriber(self, subscriber: BroadcastSubscriber, identifier=None) -> Any: - ... + def add_broadcast_subscriber( + self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None + ) -> Any: ... - def remove_rpc_subscriber(self, identifier): - ... + def remove_rpc_subscriber(self, identifier): ... - def remove_broadcast_subscriber(self, identifier): - ... + def remove_broadcast_subscriber(self, identifier): ... - def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: - ... + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 02dd123b..be7c5dd3 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -38,7 +38,6 @@ except ModuleNotFoundError: from contextvars import ContextVar -import kiwipy import yaml from . import events, exceptions, message, persistence, ports, process_states, utils @@ -313,9 +312,9 @@ def init(self) -> None: try: # filter out state change broadcasts - # TODO: pattern filter should be moved to add_broadcast_subscriber. - subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) - identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) + identifier = self._communicator.add_broadcast_subscriber( + self.broadcast_receive, subject_filter=re.compile(r'^(?!state_changed).*'), identifier=str(self.pid) + ) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index 9dbafbed..6d1f337c 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -131,10 +131,10 @@ def remove_task_subscriber(self, identifier: 'ID_TYPE') -> None: return self._communicator.remove_task_subscriber(identifier) def add_broadcast_subscriber( - self, subscriber: 'BroadcastSubscriber', identifier: Optional['ID_TYPE'] = None + self, subscriber: 'BroadcastSubscriber', subject_filter=None, identifier: Optional['ID_TYPE'] = None ) -> 'ID_TYPE': converted = convert_to_comm(subscriber, self._loop) - return self._communicator.add_broadcast_subscriber(converted, identifier) + return self._communicator.add_broadcast_subscriber(converted, subject_filter, identifier) def remove_broadcast_subscriber(self, identifier: 'ID_TYPE') -> None: return self._communicator.remove_broadcast_subscriber(identifier) diff --git a/tests/rmq/test_communications.py b/tests/rmq/test_communications.py index 63813bdc..00b7f1c6 100644 --- a/tests/rmq/test_communications.py +++ b/tests/rmq/test_communications.py @@ -56,7 +56,7 @@ def test_add_broadcast_subscriber(loop_communicator, subscriber): assert loop_communicator.add_broadcast_subscriber(subscriber) is not None identifier = 'identifier' - assert loop_communicator.add_broadcast_subscriber(subscriber, identifier) == identifier + assert loop_communicator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier def test_remove_broadcast_subscriber(loop_communicator, subscriber): diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 26c9a852..80c1ac71 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -7,6 +7,7 @@ import tempfile import uuid +from kiwipy.rmq.communicator import kiwipy import pytest import shortuuid import yaml @@ -81,7 +82,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): assert result == BROADCAST @pytest.mark.asyncio - async def test_broadcast_filter(self, loop_communicator): + async def test_broadcast_filter(self, loop_communicator: kiwipy.Communicator): broadcast_future = asyncio.Future() def ignore_broadcast(_comm, body, sender, subject, correlation_id): @@ -90,7 +91,7 @@ def ignore_broadcast(_comm, body, sender, subject, correlation_id): def get_broadcast(_comm, body, sender, subject, correlation_id): broadcast_future.set_result(True) - loop_communicator.add_broadcast_subscriber(BroadcastFilter(ignore_broadcast, subject='other')) + loop_communicator.add_broadcast_subscriber(ignore_broadcast, subject_filter='other') loop_communicator.add_broadcast_subscriber(get_broadcast) loop_communicator.broadcast_send( **{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420}