Skip to content

Commit

Permalink
Merge pull request #53 from mraspaud/feature-testing-utilities
Browse files Browse the repository at this point in the history
Add close methods for subscriber and publisher and testing utilities
  • Loading branch information
mraspaud authored May 10, 2023
2 parents 14d84af + ddfa066 commit b210496
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 90 deletions.
2 changes: 1 addition & 1 deletion posttroll/ns.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_active_address(name, arec):
return Message("/oper/ns", "info", "")


class NameServer(object):
class NameServer:
"""The name server."""

def __init__(self, max_age=timedelta(minutes=10), multicast_enabled=True, restrict_to_localhost=False):
Expand Down
63 changes: 37 additions & 26 deletions posttroll/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,53 +95,60 @@ def __init__(self, address, name="", min_port=None, max_port=None):
"""Bind the publisher class to a port."""
self.name = name
self.destination = address
self.publish = get_context().socket(zmq.PUB)
_set_tcp_keepalive(self.publish)

self.publish_socket = None
# Limit port range or use the defaults when no port is defined
# by the user
min_port = min_port or int(config.get('pub_min_port', 49152))
max_port = max_port or int(config.get('pub_max_port', 65536))
self.min_port = min_port or int(config.get('pub_min_port', 49152))
self.max_port = max_port or int(config.get('pub_max_port', 65536))
self.port_number = None

# Initialize no heartbeat
self._heartbeat = None
self._pub_lock = Lock()



def start(self):
"""Start the publisher.
"""
self.publish_socket = get_context().socket(zmq.PUB)
_set_tcp_keepalive(self.publish_socket)

self.bind()
LOGGER.info("publisher started on port %s", str(self.port_number))
return self

def bind(self):
# Check for port 0 (random port)
u__ = urlsplit(self.destination)
port = u__.port
if port == 0:
dest = urlunsplit((u__.scheme, u__.hostname,
u__.path, u__.query, u__.fragment))
self.port_number = self.publish.bind_to_random_port(
self.port_number = self.publish_socket.bind_to_random_port(
dest,
min_port=min_port,
max_port=max_port)
min_port=self.min_port,
max_port=self.max_port)
netloc = u__.hostname + ":" + str(self.port_number)
self.destination = urlunsplit((u__.scheme, netloc, u__.path,
u__.query, u__.fragment))
else:
self.publish.bind(self.destination)
self.publish_socket.bind(self.destination)
self.port_number = port

LOGGER.info("publisher started on port %s", str(self.port_number))

# Initialize no heartbeat
self._heartbeat = None
self._pub_lock = Lock()

def start(self):
"""Start the publisher.
Actually just returns *self*, but needed for consistent use from context manager.
"""
return self

def send(self, msg):
"""Send the given message."""
with self._pub_lock:
self.publish.send_string(msg)
self.publish_socket.send_string(msg)

def stop(self):
"""Stop the publisher."""
self.publish.setsockopt(zmq.LINGER, 1)
self.publish.close()
self.publish_socket.setsockopt(zmq.LINGER, 1)
self.publish_socket.close()

def close(self):
"""Alias for stop."""
self.stop()

def heartbeat(self, min_interval=0):
"""Send a heartbeat ... but only if *min_interval* seconds has passed since last beat."""
Expand Down Expand Up @@ -210,7 +217,7 @@ def start(self):
pub_addr = _get_publish_address(self._port)
self._publisher = self._publisher_class(pub_addr, self._name,
min_port=self.min_port,
max_port=self.max_port)
max_port=self.max_port).start()
LOGGER.debug("entering publish %s", str(self._publisher.destination))
addr = _get_publish_address(self._publisher.port_number, str(get_own_ip()))
self._broadcaster = sendaddressservice(self._name, addr,
Expand All @@ -233,6 +240,10 @@ def stop(self):
self._broadcaster.stop()
self._broadcaster = None

def close(self):
"""Alias for stop."""
self.stop()


def _get_publish_address(port, ip_address="*"):
return "tcp://" + ip_address + ":" + str(port)
Expand Down
67 changes: 38 additions & 29 deletions posttroll/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(self, addresses, topics='', message_filter=None, translate=False):

self.sub_addr = {}
self.addr_sub = {}
self.poller = None

self._hooks = []
self._hooks_cb = {}
Expand All @@ -85,7 +84,7 @@ def __init__(self, addresses, topics='', message_filter=None, translate=False):

self.update(addresses)

self._loop = True
self._loop = None

def add(self, address, topics=None):
"""Add *address* to the subscribing list for *topics*.
Expand All @@ -94,48 +93,55 @@ def add(self, address, topics=None):
"""
with self._lock:
if address in self.addresses:
return False
return

topics = self._magickfy_topics(topics) or self._topics
LOGGER.info("Subscriber adding address %s with topics %s",
str(address), str(topics))
subscriber = get_context().socket(SUB)
_set_tcp_keepalive(subscriber)
for t__ in topics:
subscriber.setsockopt_string(SUBSCRIBE, str(t__))
subscriber.connect(address)
subscriber = self._add_sub_socket(address, topics)
self.sub_addr[subscriber] = address
self.addr_sub[address] = subscriber
if self.poller:
self.poller.register(subscriber, POLLIN)
return True

def _add_sub_socket(self, address, topics):
subscriber = get_context().socket(SUB)
_set_tcp_keepalive(subscriber)
for t__ in topics:
subscriber.setsockopt_string(SUBSCRIBE, str(t__))
subscriber.connect(address)

if self.poller:
self.poller.register(subscriber, POLLIN)
return subscriber

def remove(self, address):
"""Remove *address* from the subscribing list for *topics*."""
with self._lock:
try:
subscriber = self.addr_sub[address]
except KeyError:
return False
return
LOGGER.info("Subscriber removing address %s", str(address))
if self.poller:
self.poller.unregister(subscriber)
del self.addr_sub[address]
del self.sub_addr[subscriber]
subscriber.close()
return True
self._remove_sub_socket(subscriber)

def _remove_sub_socket(self, subscriber):
if self.poller:
self.poller.unregister(subscriber)
subscriber.close()

def update(self, addresses):
"""Update with a set of addresses."""
if isinstance(addresses, str):
addresses = [addresses, ]
s0_, s1_ = set(self.addresses), set(addresses)
sr_, sa_ = s0_.difference(s1_), s1_.difference(s0_)
for a__ in sr_:
self.remove(a__)
for a__ in sa_:
self.add(a__)
return bool(sr_ or sa_)
current_addresses, new_addresses = set(self.addresses), set(addresses)
addresses_to_remove = current_addresses.difference(new_addresses)
addresses_to_add = new_addresses.difference(current_addresses)
for addr in addresses_to_remove:
self.remove(addr)
for addr in addresses_to_add:
self.add(addr)
return bool(addresses_to_remove or addresses_to_add)

def add_hook_sub(self, address, topics, callback):
"""Specify a SUB *callback* in the same stream (thread) as the main receive loop.
Expand All @@ -146,12 +152,10 @@ def add_hook_sub(self, address, topics, callback):
Good for operations, which is required to be done in the same thread as
the main recieve loop (e.q operations on the underlying sockets).
"""
topics = self._magickfy_topics(topics)
LOGGER.info("Subscriber adding SUB hook %s for topics %s",
str(address), str(topics))
socket = get_context().socket(SUB)
for t__ in self._magickfy_topics(topics):
socket.setsockopt_string(SUBSCRIBE, str(t__))
socket.connect(address)
socket = self._add_sub_socket(address, topics)
self._add_hook(socket, callback)

def add_hook_pull(self, address, callback):
Expand All @@ -163,14 +167,15 @@ def add_hook_pull(self, address, callback):
LOGGER.info("Subscriber adding PULL hook %s", str(address))
socket = get_context().socket(PULL)
socket.connect(address)
if self.poller:
self.poller.register(socket, POLLIN)
self._add_hook(socket, callback)

def _add_hook(self, socket, callback):
"""Add a generic hook. The passed socket has to be "receive only"."""
self._hooks.append(socket)
self._hooks_cb[socket] = callback
if self.poller:
self.poller.register(socket, POLLIN)


@property
def addresses(self):
Expand Down Expand Up @@ -346,6 +351,10 @@ def stop(self):
self._subscriber.close()
self._subscriber = None

def close(self):
"""Alias for stop."""
return self.stop()


class Subscribe:
"""Subscriber context.
Expand Down
24 changes: 24 additions & 0 deletions posttroll/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Testing utilities."""
from contextlib import contextmanager

@contextmanager
def patched_subscriber_recv(messages):
"""Patch the Subscriber object to return given messages."""
from unittest import mock
with mock.patch("posttroll.subscriber.Subscriber.recv", mock.Mock(return_value=messages)):
yield

@contextmanager
def patched_publisher():
"""Patch the Subscriber object to return given messages."""
from unittest import mock
published = []

def fake_send(self, message):
published.append(message)

def noop(self, *args, **kwargs):
pass

with mock.patch.multiple("posttroll.publisher.Publisher", send=fake_send, start=noop, stop=noop):
yield published
Loading

0 comments on commit b210496

Please sign in to comment.