Skip to content

Commit

Permalink
Allow passing host and port to CluClient
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Sep 12, 2024
1 parent 64c13af commit bb47477
Showing 1 changed file with 38 additions and 13 deletions.
51 changes: 38 additions & 13 deletions src/lvmopstools/clu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,42 +46,64 @@ class CluClient:
during the life of the worker. The singleton can be cleared by calling
`.clear`.
The host and port for the connection can be passed on initialisation. Otherwise
it will use the values in the environment variables ``RABBITMQ_HOST`` and
``RABBITMQ_PORT`` or the default values in the configuration file.
"""

__initialised: bool = False
__instance: CluClient | None = None

def __new__(cls):
if cls.__instance is None:
def __new__(cls, host: str | None = None, port: int | None = None):
if (
cls.__instance is None
or (host is not None and cls.__instance.host != host)
or (port is not None and cls.__instance.port != port)
):
cls.clear()

cls.__instance = super(CluClient, cls).__new__(cls)
cls.__instance.__initialised = False

return cls.__instance

def __init__(self):
def __init__(self, host: str | None = None, port: int | None = None):
if self.__initialised is True:
# Bail out if we are returning a singleton instance
# which is already initialised.
return

host: str = os.environ.get("RABBITMQ_HOST", config["rabbitmq.host"])
port: int = int(os.environ.get("RABBITMQ_PORT", config["rabbitmq.port"]))
host_default = os.environ.get("RABBITMQ_HOST", config["rabbitmq.host"])
port_default = int(os.environ.get("RABBITMQ_PORT", config["rabbitmq.port"]))

self.host: str = host or host_default
self.port: int = port or port_default

self.client = AMQPClient(host=host, port=port)
self.client = AMQPClient(host=self.host, port=self.port)
self.__initialised = True

self._lock = asyncio.Lock()

def is_connected(self):
"""Is the client connected?"""

connection = self.client.connection
connected = connection.connection and not connection.connection.is_closed
channel_closed = hasattr(connection, "channel") and connection.channel.is_closed

if not connected or channel_closed:
return False

return True

async def __aenter__(self):
# Small delay to allow the event loop to update the
# connection status if needed.
await asyncio.sleep(0.05)

async with self._lock:
connection = self.client.connection
connected = connection.connection and not connection.connection.is_closed
closed = hasattr(connection, "channel") and connection.channel.is_closed

if not connected or closed:
print("reconnecting")
if not self.is_connected():
await self.client.start()

return self.client
Expand All @@ -90,7 +112,7 @@ async def __aexit__(self, exc_type, exc, tb):
pass

async def __anext__(self):
if not self.client.is_connected():
if not self.is_connected():
await self.client.start()

return self.client
Expand All @@ -99,6 +121,9 @@ async def __anext__(self):
def clear(cls):
"""Clears the current instance."""

if cls.__instance and cls.__instance.is_connected():
asyncio.create_task(cls.__instance.client.stop())

cls.__instance = None
cls.__initialised = False

Expand Down

0 comments on commit bb47477

Please sign in to comment.