From 908a7eb7d38299dbbdb51cdda170ecffe068c354 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Mon, 8 May 2023 11:05:21 +0200 Subject: [PATCH 1/2] Add stop methods for subscriber and publisher, plus some refactoring --- posttroll/ns.py | 2 +- posttroll/publisher.py | 63 +++++++++++++++++++------------- posttroll/subscriber.py | 67 +++++++++++++++++++--------------- posttroll/tests/test_pubsub.py | 53 ++++++++++----------------- 4 files changed, 95 insertions(+), 90 deletions(-) diff --git a/posttroll/ns.py b/posttroll/ns.py index 5baf54a..12a1bd0 100644 --- a/posttroll/ns.py +++ b/posttroll/ns.py @@ -113,7 +113,7 @@ def get_active_address(name, arec): return Message("/oper/ns", "info", "") -class NameServer(object): +class NameServer: """The name server.""" def __init__(self, max_age=timedelta(minutes=10), multicast_enabled=True, restrict_to_localhost=False): diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 05ffd21..befa5e3 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -95,53 +95,60 @@ def __init__(self, address, name="", min_port=None, max_port=None): """Bind the publisher class to a port.""" self.name = name self.destination = address - self.publish = get_context().socket(zmq.PUB) - _set_tcp_keepalive(self.publish) - + self.publish_socket = None # Limit port range or use the defaults when no port is defined # by the user - min_port = min_port or int(config.get('pub_min_port', 49152)) - max_port = max_port or int(config.get('pub_max_port', 65536)) + self.min_port = min_port or int(config.get('pub_min_port', 49152)) + self.max_port = max_port or int(config.get('pub_max_port', 65536)) + self.port_number = None + + # Initialize no heartbeat + self._heartbeat = None + self._pub_lock = Lock() + + + + def start(self): + """Start the publisher. + """ + self.publish_socket = get_context().socket(zmq.PUB) + _set_tcp_keepalive(self.publish_socket) + + self.bind() + LOGGER.info("publisher started on port %s", str(self.port_number)) + return self + def bind(self): # Check for port 0 (random port) u__ = urlsplit(self.destination) port = u__.port if port == 0: dest = urlunsplit((u__.scheme, u__.hostname, u__.path, u__.query, u__.fragment)) - self.port_number = self.publish.bind_to_random_port( + self.port_number = self.publish_socket.bind_to_random_port( dest, - min_port=min_port, - max_port=max_port) + min_port=self.min_port, + max_port=self.max_port) netloc = u__.hostname + ":" + str(self.port_number) self.destination = urlunsplit((u__.scheme, netloc, u__.path, u__.query, u__.fragment)) else: - self.publish.bind(self.destination) + self.publish_socket.bind(self.destination) self.port_number = port - LOGGER.info("publisher started on port %s", str(self.port_number)) - - # Initialize no heartbeat - self._heartbeat = None - self._pub_lock = Lock() - - def start(self): - """Start the publisher. - - Actually just returns *self*, but needed for consistent use from context manager. - """ - return self - def send(self, msg): """Send the given message.""" with self._pub_lock: - self.publish.send_string(msg) + self.publish_socket.send_string(msg) def stop(self): """Stop the publisher.""" - self.publish.setsockopt(zmq.LINGER, 1) - self.publish.close() + self.publish_socket.setsockopt(zmq.LINGER, 1) + self.publish_socket.close() + + def close(self): + """Alias for stop.""" + self.stop() def heartbeat(self, min_interval=0): """Send a heartbeat ... but only if *min_interval* seconds has passed since last beat.""" @@ -210,7 +217,7 @@ def start(self): pub_addr = _get_publish_address(self._port) self._publisher = self._publisher_class(pub_addr, self._name, min_port=self.min_port, - max_port=self.max_port) + max_port=self.max_port).start() LOGGER.debug("entering publish %s", str(self._publisher.destination)) addr = _get_publish_address(self._publisher.port_number, str(get_own_ip())) self._broadcaster = sendaddressservice(self._name, addr, @@ -233,6 +240,10 @@ def stop(self): self._broadcaster.stop() self._broadcaster = None + def close(self): + """Alias for stop.""" + self.stop() + def _get_publish_address(port, ip_address="*"): return "tcp://" + ip_address + ":" + str(port) diff --git a/posttroll/subscriber.py b/posttroll/subscriber.py index c63bec5..7ea5d80 100644 --- a/posttroll/subscriber.py +++ b/posttroll/subscriber.py @@ -75,7 +75,6 @@ def __init__(self, addresses, topics='', message_filter=None, translate=False): self.sub_addr = {} self.addr_sub = {} - self.poller = None self._hooks = [] self._hooks_cb = {} @@ -85,7 +84,7 @@ def __init__(self, addresses, topics='', message_filter=None, translate=False): self.update(addresses) - self._loop = True + self._loop = None def add(self, address, topics=None): """Add *address* to the subscribing list for *topics*. @@ -94,21 +93,25 @@ def add(self, address, topics=None): """ with self._lock: if address in self.addresses: - return False + return topics = self._magickfy_topics(topics) or self._topics LOGGER.info("Subscriber adding address %s with topics %s", str(address), str(topics)) - subscriber = get_context().socket(SUB) - _set_tcp_keepalive(subscriber) - for t__ in topics: - subscriber.setsockopt_string(SUBSCRIBE, str(t__)) - subscriber.connect(address) + subscriber = self._add_sub_socket(address, topics) self.sub_addr[subscriber] = address self.addr_sub[address] = subscriber - if self.poller: - self.poller.register(subscriber, POLLIN) - return True + + def _add_sub_socket(self, address, topics): + subscriber = get_context().socket(SUB) + _set_tcp_keepalive(subscriber) + for t__ in topics: + subscriber.setsockopt_string(SUBSCRIBE, str(t__)) + subscriber.connect(address) + + if self.poller: + self.poller.register(subscriber, POLLIN) + return subscriber def remove(self, address): """Remove *address* from the subscribing list for *topics*.""" @@ -116,26 +119,29 @@ def remove(self, address): try: subscriber = self.addr_sub[address] except KeyError: - return False + return LOGGER.info("Subscriber removing address %s", str(address)) - if self.poller: - self.poller.unregister(subscriber) del self.addr_sub[address] del self.sub_addr[subscriber] - subscriber.close() - return True + self._remove_sub_socket(subscriber) + + def _remove_sub_socket(self, subscriber): + if self.poller: + self.poller.unregister(subscriber) + subscriber.close() def update(self, addresses): """Update with a set of addresses.""" if isinstance(addresses, str): addresses = [addresses, ] - s0_, s1_ = set(self.addresses), set(addresses) - sr_, sa_ = s0_.difference(s1_), s1_.difference(s0_) - for a__ in sr_: - self.remove(a__) - for a__ in sa_: - self.add(a__) - return bool(sr_ or sa_) + current_addresses, new_addresses = set(self.addresses), set(addresses) + addresses_to_remove = current_addresses.difference(new_addresses) + addresses_to_add = new_addresses.difference(current_addresses) + for addr in addresses_to_remove: + self.remove(addr) + for addr in addresses_to_add: + self.add(addr) + return bool(addresses_to_remove or addresses_to_add) def add_hook_sub(self, address, topics, callback): """Specify a SUB *callback* in the same stream (thread) as the main receive loop. @@ -146,12 +152,10 @@ def add_hook_sub(self, address, topics, callback): Good for operations, which is required to be done in the same thread as the main recieve loop (e.q operations on the underlying sockets). """ + topics = self._magickfy_topics(topics) LOGGER.info("Subscriber adding SUB hook %s for topics %s", str(address), str(topics)) - socket = get_context().socket(SUB) - for t__ in self._magickfy_topics(topics): - socket.setsockopt_string(SUBSCRIBE, str(t__)) - socket.connect(address) + socket = self._add_sub_socket(address, topics) self._add_hook(socket, callback) def add_hook_pull(self, address, callback): @@ -163,14 +167,15 @@ def add_hook_pull(self, address, callback): LOGGER.info("Subscriber adding PULL hook %s", str(address)) socket = get_context().socket(PULL) socket.connect(address) + if self.poller: + self.poller.register(socket, POLLIN) self._add_hook(socket, callback) def _add_hook(self, socket, callback): """Add a generic hook. The passed socket has to be "receive only".""" self._hooks.append(socket) self._hooks_cb[socket] = callback - if self.poller: - self.poller.register(socket, POLLIN) + @property def addresses(self): @@ -346,6 +351,10 @@ def stop(self): self._subscriber.close() self._subscriber = None + def close(self): + """Alias for stop.""" + return self.stop() + class Subscribe: """Subscriber context. diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 67bdd47..2da79b9 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -63,23 +63,23 @@ def test_pub_addresses(self): with Publish(str("data_provider"), 0, ["this_data"], broadcast_interval=0.1): time.sleep(.3) res = get_pub_addresses(["this_data"], timeout=.5) - self.assertEqual(len(res), 1) + assert len(res) == 1 expected = {u'status': True, u'service': [u'data_provider', u'this_data'], u'name': u'address'} for key, val in expected.items(): - self.assertEqual(res[0][key], val) - self.assertTrue("receive_time" in res[0]) - self.assertTrue("URI" in res[0]) + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] res = get_pub_addresses([str("data_provider")]) - self.assertEqual(len(res), 1) + assert len(res) == 1 expected = {u'status': True, u'service': [u'data_provider', u'this_data'], u'name': u'address'} for key, val in expected.items(): - self.assertEqual(res[0][key], val) - self.assertTrue("receive_time" in res[0]) - self.assertTrue("URI" in res[0]) + assert res[0][key] == val + assert "receive_time" in res[0] + assert "URI" in res[0] def test_pub_sub_ctx(self): """Test publish and subscribe.""" @@ -95,10 +95,10 @@ def test_pub_sub_ctx(self): time.sleep(1) msg = next(sub.recv(2)) if msg is not None: - self.assertEqual(str(msg), str(message)) + assert str(msg) == str(message) tested = True sub.close() - self.assertTrue(tested) + assert tested def test_pub_sub_add_rm(self): """Test adding and removing publishers.""" @@ -107,21 +107,21 @@ def test_pub_sub_add_rm(self): time.sleep(4) with Subscribe("this_data", "counter", True) as sub: - self.assertEqual(len(sub.sub_addr), 0) + assert len(sub.sub_addr) == 0 with Publish("data_provider", 0, ["this_data"]): time.sleep(4) next(sub.recv(2)) - self.assertEqual(len(sub.sub_addr), 1) + assert len(sub.sub_addr) == 1 time.sleep(3) for msg in sub.recv(2): if msg is None: break time.sleep(3) - self.assertEqual(len(sub.sub_addr), 0) + assert len(sub.sub_addr) == 0 with Publish("data_provider_2", 0, ["another_data"]): time.sleep(4) next(sub.recv(2)) - self.assertEqual(len(sub.sub_addr), 0) + assert len(sub.sub_addr) == 0 sub.close() @@ -246,7 +246,7 @@ def test_pub_suber(self): from posttroll.subscriber import Subscriber pub_address = "tcp://" + str(get_own_ip()) + ":0" - pub = Publisher(pub_address) + pub = Publisher(pub_address).start() addr = pub_address[:-1] + str(pub.port_number) sub = Subscriber([addr], '/counter') tested = False @@ -257,9 +257,9 @@ def test_pub_suber(self): msg = next(sub.recv(2)) if msg is not None: - self.assertEqual(str(msg), str(message)) + assert str(msg) == str(message) tested = True - self.assertTrue(tested) + assert tested pub.stop() def test_pub_sub_ctx_no_nameserver(self): @@ -669,7 +669,7 @@ def test_publisher_tcp_keepalive(tcp_keepalive_settings): get_context.return_value.socket.return_value = socket from posttroll.publisher import Publisher - _ = Publisher("tcp://127.0.0.1:9000") + _ = Publisher("tcp://127.0.0.1:9000").start() _assert_tcp_keepalive(socket) @@ -681,7 +681,7 @@ def test_publisher_tcp_keepalive_not_set(tcp_keepalive_no_settings): get_context.return_value.socket.return_value = socket from posttroll.publisher import Publisher - _ = Publisher("tcp://127.0.0.1:9000") + _ = Publisher("tcp://127.0.0.1:9000").start() _assert_no_tcp_keepalive(socket) @@ -720,18 +720,3 @@ def _assert_tcp_keepalive(socket): def _assert_no_tcp_keepalive(socket): assert "TCP_KEEPALIVE" not in str(socket.setsockopt.mock_calls) - - -def suite(): - """Collect the test suite for publisher and subsciber tests.""" - loader = unittest.TestLoader() - mysuite = unittest.TestSuite() - mysuite.addTest(loader.loadTestsFromTestCase(TestPubSub)) - mysuite.addTest(loader.loadTestsFromTestCase(TestNS)) - mysuite.addTest(loader.loadTestsFromTestCase(TestNSWithoutMulticasting)) - mysuite.addTest(loader.loadTestsFromTestCase(TestListenerContainer)) - mysuite.addTest(loader.loadTestsFromTestCase(TestPub)) - mysuite.addTest(loader.loadTestsFromTestCase(TestAddressReceiver)) - mysuite.addTest(loader.loadTestsFromTestCase(TestPublisherDictConfig)) - - return mysuite From ddfa066de1c876111c8682f2f43d93fc5326b997 Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Mon, 8 May 2023 11:06:40 +0200 Subject: [PATCH 2/2] Add two testing utilities --- posttroll/testing.py | 24 ++++++++++++++++++++++++ posttroll/tests/test_testing.py | 11 +++++++++++ 2 files changed, 35 insertions(+) create mode 100644 posttroll/testing.py create mode 100644 posttroll/tests/test_testing.py diff --git a/posttroll/testing.py b/posttroll/testing.py new file mode 100644 index 0000000..501aa53 --- /dev/null +++ b/posttroll/testing.py @@ -0,0 +1,24 @@ +"""Testing utilities.""" +from contextlib import contextmanager + +@contextmanager +def patched_subscriber_recv(messages): + """Patch the Subscriber object to return given messages.""" + from unittest import mock + with mock.patch("posttroll.subscriber.Subscriber.recv", mock.Mock(return_value=messages)): + yield + +@contextmanager +def patched_publisher(): + """Patch the Subscriber object to return given messages.""" + from unittest import mock + published = [] + + def fake_send(self, message): + published.append(message) + + def noop(self, *args, **kwargs): + pass + + with mock.patch.multiple("posttroll.publisher.Publisher", send=fake_send, start=noop, stop=noop): + yield published diff --git a/posttroll/tests/test_testing.py b/posttroll/tests/test_testing.py new file mode 100644 index 0000000..c8909ff --- /dev/null +++ b/posttroll/tests/test_testing.py @@ -0,0 +1,11 @@ +from posttroll.testing import patched_publisher + +def test_fake_publisher(): + from posttroll.publisher import create_publisher_from_dict_config + + with patched_publisher() as messages: + pub = create_publisher_from_dict_config(dict(port=1979, nameservers=False)) + pub.start() + pub.send("bla") + pub.stop() + assert "bla" in messages