diff --git a/src/lvmopstools/clu.py b/src/lvmopstools/clu.py index cacab94..5452b60 100644 --- a/src/lvmopstools/clu.py +++ b/src/lvmopstools/clu.py @@ -46,42 +46,64 @@ class CluClient: during the life of the worker. The singleton can be cleared by calling `.clear`. + The host and port for the connection can be passed on initialisation. Otherwise + it will use the values in the environment variables ``RABBITMQ_HOST`` and + ``RABBITMQ_PORT`` or the default values in the configuration file. + """ __initialised: bool = False __instance: CluClient | None = None - def __new__(cls): - if cls.__instance is None: + def __new__(cls, host: str | None = None, port: int | None = None): + if ( + cls.__instance is None + or (host is not None and cls.__instance.host != host) + or (port is not None and cls.__instance.port != port) + ): + cls.clear() + cls.__instance = super(CluClient, cls).__new__(cls) cls.__instance.__initialised = False return cls.__instance - def __init__(self): + def __init__(self, host: str | None = None, port: int | None = None): if self.__initialised is True: + # Bail out if we are returning a singleton instance + # which is already initialised. return - host: str = os.environ.get("RABBITMQ_HOST", config["rabbitmq.host"]) - port: int = int(os.environ.get("RABBITMQ_PORT", config["rabbitmq.port"])) + host_default = os.environ.get("RABBITMQ_HOST", config["rabbitmq.host"]) + port_default = int(os.environ.get("RABBITMQ_PORT", config["rabbitmq.port"])) + + self.host: str = host or host_default + self.port: int = port or port_default - self.client = AMQPClient(host=host, port=port) + self.client = AMQPClient(host=self.host, port=self.port) self.__initialised = True self._lock = asyncio.Lock() + def is_connected(self): + """Is the client connected?""" + + connection = self.client.connection + connected = connection.connection and not connection.connection.is_closed + channel_closed = hasattr(connection, "channel") and connection.channel.is_closed + + if not connected or channel_closed: + return False + + return True + async def __aenter__(self): # Small delay to allow the event loop to update the # connection status if needed. await asyncio.sleep(0.05) async with self._lock: - connection = self.client.connection - connected = connection.connection and not connection.connection.is_closed - closed = hasattr(connection, "channel") and connection.channel.is_closed - - if not connected or closed: - print("reconnecting") + if not self.is_connected(): await self.client.start() return self.client @@ -90,7 +112,7 @@ async def __aexit__(self, exc_type, exc, tb): pass async def __anext__(self): - if not self.client.is_connected(): + if not self.is_connected(): await self.client.start() return self.client @@ -99,6 +121,9 @@ async def __anext__(self): def clear(cls): """Clears the current instance.""" + if cls.__instance and cls.__instance.is_connected(): + asyncio.create_task(cls.__instance.client.stop()) + cls.__instance = None cls.__initialised = False