diff --git a/examples/broadcast_exec/server.py b/examples/broadcast_exec/server.py index 0ee77a5..5a76ede 100644 --- a/examples/broadcast_exec/server.py +++ b/examples/broadcast_exec/server.py @@ -44,7 +44,7 @@ async def on_message(self, event: GuildMessageCreateEvent) -> None: return if event.content.startswith("!exec"): await self.cluster.ipc.send_command( - self.cluster.ipc.cluster_uids, + self.cluster.ipc.clusters, "exec_code", {"code": event.content[6:]}, ) diff --git a/hikari_clusters/__init__.py b/hikari_clusters/__init__.py index 43fe510..6ec6e65 100644 --- a/hikari_clusters/__init__.py +++ b/hikari_clusters/__init__.py @@ -28,14 +28,19 @@ from importlib.metadata import version +from hikari.internal.ux import init_logging + from . import close_codes, commands, events, exceptions, payload +from .base_client import BaseClient from .brain import Brain from .cluster import Cluster, ClusterLauncher -from .info_classes import ClusterInfo, ServerInfo +from .info_classes import BaseInfo, BrainInfo, ClusterInfo, ServerInfo from .ipc_client import IpcClient from .ipc_server import IpcServer from .server import Server +init_logging("INFO", True, False) + __version__ = version(__name__) __all__ = ( @@ -45,8 +50,11 @@ "Cluster", "ClusterLauncher", "Server", + "BaseClient", "ClusterInfo", "ServerInfo", + "BrainInfo", + "BaseInfo", "payload", "events", "commands", diff --git a/hikari_clusters/base_client.py b/hikari_clusters/base_client.py new file mode 100644 index 0000000..a144bf4 --- /dev/null +++ b/hikari_clusters/base_client.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import asyncio +import logging +import pathlib + +from websockets.exceptions import ConnectionClosed + +from .info_classes import BaseInfo +from .ipc_client import IpcClient +from .task_manager import TaskManager + +_LOG = logging.getLogger(__name__) + + +class BaseClient: + """The base client, which contains an IpcClient. + + Parameters + ---------- + ipc_uri : str + The URI of the brain. + token : str + The token for the IPC server. + reconnect : bool + Whether to automatically reconnect if the connection + is lost. Defaults to True. + certificate_path : pathlib.Path | str | None + The path to your certificate, which allos for secure + connection over the IPC. Defaults to None. + """ + + def __init__( + self, + ipc_uri: str, + token: str, + reconnect: bool = True, + certificate_path: pathlib.Path | str | None = None, + ): + if isinstance(certificate_path, str): + certificate_path = pathlib.Path(certificate_path) + + self.tasks = TaskManager() + self.ipc = IpcClient( + uri=ipc_uri, + token=token, + reconnect=reconnect, + certificate_path=certificate_path, + ) + + self.stop_future: asyncio.Future[None] | None = None + + def get_info(self) -> BaseInfo: + """Get the info class for this client. + + Returns: + BaseInfo: The info class. + """ + + raise NotImplementedError + + async def start(self) -> None: + """Start the client. + + Connects to the IPC server and begins sending out this clients + info. + """ + + if self.stop_future is None: + self.stop_future = asyncio.Future() + + await self.ipc.start() + + self.tasks.create_task(self._broadcast_info_loop()) + + async def join(self) -> None: + """Wait until the client begins exiting.""" + + assert self.stop_future and self.ipc.stop_future + + await asyncio.wait( + [self.stop_future, self.ipc.stop_future], + return_when=asyncio.FIRST_COMPLETED, + ) + + async def close(self) -> None: + """Shut down the client.""" + + self.ipc.stop() + await self.ipc.close() + + self.tasks.cancel_all() + await self.tasks.wait_for_all() + + def stop(self) -> None: + """Tell the client to stop.""" + + assert self.stop_future + self.stop_future.set_result(None) + + async def _broadcast_info_loop(self) -> None: + while True: + await self.ipc.wait_until_ready() + assert self.ipc.uid + try: + await self.ipc.send_event( + self.ipc.client_uids, + "set_info_class", + self.get_info().asdict(), + ) + except ConnectionClosed: + _LOG.error("Failed to send client info.", exc_info=True) + await asyncio.sleep(1) diff --git a/hikari_clusters/brain.py b/hikari_clusters/brain.py index 0693775..84a4d13 100644 --- a/hikari_clusters/brain.py +++ b/hikari_clusters/brain.py @@ -27,18 +27,18 @@ import signal from typing import Any -from . import log, payload +from hikari_clusters.base_client import BaseClient +from hikari_clusters.info_classes import BrainInfo + +from . import payload from .events import EventGroup from .ipc_client import IpcClient from .ipc_server import IpcServer -from .task_manager import TaskManager __all__ = ("Brain",) -LOG = log.Logger("Brain") - -class Brain: +class Brain(BaseClient): """The brain of the bot. Allows for comunication between clusters and servers, @@ -73,29 +73,30 @@ def __init__( shards_per_cluster: int, certificate_path: pathlib.Path | str | None = None, ) -> None: - self.tasks = TaskManager(LOG) + certificate_path = ( + pathlib.Path(certificate_path) + if isinstance(certificate_path, str) + else certificate_path + ) + + super().__init__( + IpcClient.get_uri(host, port, certificate_path is not None), + token, + True, + certificate_path, + ) self.total_servers = total_servers self.cluster_per_server = clusters_per_server self.shards_per_cluster = shards_per_cluster - if isinstance(certificate_path, str): - certificate_path = pathlib.Path(certificate_path) - self.server = IpcServer( host, port, token, certificate_path=certificate_path ) - self.ipc = IpcClient( - IpcClient.get_uri(host, port, certificate_path is not None), - token, - LOG, - certificate_path=certificate_path, - cmd_kwargs={"brain": self}, - event_kwargs={"brain": self}, - ) - self.ipc.events.include(_E) - self.stop_future: asyncio.Future[None] | None = None + self.ipc.commands.cmd_kwargs["brain"] = self + self.ipc.events.event_kwargs["brain"] = self + self.ipc.events.include(_E) self._waiting_for: tuple[int, int] | None = None @@ -128,7 +129,7 @@ def waiting_for(self) -> tuple[int, int] | None: if self._waiting_for is not None: server_uid, smallest_shard = self._waiting_for if ( - server_uid not in self.ipc.server_uids + server_uid not in self.ipc.servers or smallest_shard in self.ipc.all_shards() ): # `server_uid not in self.ipc.server_uids` @@ -148,6 +149,11 @@ def waiting_for(self) -> tuple[int, int] | None: def waiting_for(self, value: tuple[int, int] | None) -> None: self._waiting_for = value + def get_info(self) -> BrainInfo: + # <<>> + assert self.ipc.uid + return BrainInfo(uid=self.ipc.uid) + def run(self) -> None: """Run the brain, wait for the brain to stop, then cleanup.""" @@ -161,42 +167,26 @@ def sigstop(*args: Any, **kwargs: Any) -> None: loop.run_until_complete(self.close()) async def start(self) -> None: - """Start the brain. - - Returns as soon as all tasks have started. - """ + # <<>> self.stop_future = asyncio.Future() - self.tasks.create_task(self._send_brain_uid_loop()) - self.tasks.create_task(self._main_loop()) - await self.server.start() - await self.ipc.start() - - async def join(self) -> None: - """Wait for the brain to stop.""" - assert self.stop_future - await self.stop_future + await self.server.start() + await super().start() + self.tasks.create_task(self._main_loop()) async def close(self) -> None: - """Shut the brain down.""" + # <<>> + self.ipc.stop() + await self.ipc.close() self.server.stop() await self.server.close() - self.ipc.stop() - await self.ipc.close() - self.tasks.cancel_all() await self.tasks.wait_for_all() - def stop(self) -> None: - """Tell the brain to stop.""" - - assert self.stop_future - self.stop_future.set_result(None) - def _get_next_cluster_to_launch(self) -> tuple[int, list[int]] | None: - if len(self.ipc.server_uids) == 0: + if len(self.ipc.servers) == 0: return None if not all(c.ready for c in self.ipc.clusters.values()): @@ -219,14 +209,6 @@ def _get_next_cluster_to_launch(self) -> tuple[int, list[int]] | None: return s.uid, list(shards_to_launch)[: self.shards_per_cluster] - async def _send_brain_uid_loop(self) -> None: - while True: - await self.ipc.wait_until_ready() - await self.ipc.send_event( - self.ipc.client_uids, "set_brain_uid", {"uid": self.ipc.uid} - ) - await asyncio.sleep(1) - async def _main_loop(self) -> None: await self.ipc.wait_until_ready() while True: @@ -257,5 +239,13 @@ async def brain_stop(pl: payload.EVENT, brain: Brain) -> None: @_E.add("shutdown") async def shutdown(pl: payload.EVENT, brain: Brain) -> None: - await brain.ipc.send_event(brain.ipc.server_uids, "server_stop") + await brain.ipc.send_event(brain.ipc.servers.keys(), "server_stop") brain.stop() + + +@_E.add("cluster_died") +async def cluster_died(pl: payload.EVENT, brain: Brain) -> None: + assert pl.data.data is not None + shard_id = pl.data.data["smallest_shard_id"] + if brain._waiting_for is not None and brain._waiting_for[1] == shard_id: + brain.waiting_for = None diff --git a/hikari_clusters/cluster.py b/hikari_clusters/cluster.py index b1d6a7e..a8a12ba 100644 --- a/hikari_clusters/cluster.py +++ b/hikari_clusters/cluster.py @@ -25,22 +25,19 @@ import asyncio import pathlib import signal -from dataclasses import asdict from typing import Any, Type from hikari import GatewayBot -from websockets.exceptions import ConnectionClosed -from . import log, payload +from . import payload +from .base_client import BaseClient from .events import EventGroup from .info_classes import ClusterInfo -from .ipc_client import IpcClient -from .task_manager import TaskManager __all__ = ("Cluster", "ClusterLauncher") -class Cluster: +class Cluster(BaseClient): """A subclass of :class:`~hikari.GatewayBot` designed for use with hikari-clusters. @@ -79,23 +76,23 @@ def __init__( self._shard_count = shard_count - self.logger = log.Logger(f"Cluster {self.cluster_id}") - self.ipc = IpcClient( + super().__init__( ipc_uri, ipc_token, - self.logger, reconnect=False, - cmd_kwargs={"cluster": self}, - event_kwargs={"cluster": self}, certificate_path=certificate_path, ) - self.ipc.events.include(_E) - self.__tasks = TaskManager(self.logger) - self.stop_future: asyncio.Future[None] | None = None + self.ipc.events.include(_E) self.bot.cluster = self # type: ignore + def get_info(self) -> ClusterInfo: + assert self.ipc.uid + return ClusterInfo( + self.ipc.uid, self.server_uid, self.shard_ids, self.ready + ) + @property def cluster_id(self) -> int: """The id of this cluster. @@ -121,69 +118,19 @@ def shard_count(self) -> int: return self._shard_count async def start(self, **kwargs: Any) -> None: - """Start the IPC and then the bot. - - Returns once all shards are ready.""" - - self.stop_future = asyncio.Future() - - await self.ipc.start() - - self.__tasks.create_task(self._broadcast_cluster_info_loop()) + # <<>> + await super().start() kwargs["shard_count"] = self.shard_count kwargs["shard_ids"] = self.shard_ids await self.bot.start(**kwargs) - async def join(self) -> None: - """Wait for the bot to close, and then return. - - Does not ask the bot to close. Use :meth:`~Cluster.stop` to tell - the bot to stop.""" - - assert self.stop_future and self.ipc.stop_future - - await asyncio.wait( - [self.stop_future, self.ipc.stop_future], - return_when=asyncio.FIRST_COMPLETED, - ) - async def close(self) -> None: + # <<>> await self.bot.close() - self.ipc.stop() - await self.ipc.close() - - self.__tasks.cancel_all() - await self.__tasks.wait_for_all() - - def stop(self) -> None: - """Tell the bot and IPC to close.""" - - assert self.stop_future - self.stop_future.set_result(None) - - async def _broadcast_cluster_info_loop(self) -> None: - while True: - await self.ipc.wait_until_ready() - assert self.ipc.uid - try: - await self.ipc.send_event( - self.ipc.client_uids, - "set_cluster_info", - asdict( - ClusterInfo( - self.ipc.uid, - self.server_uid, - self.shard_ids, - self.ready, - ) - ), - ) - except ConnectionClosed: - return - await asyncio.sleep(1) + await super().close() class ClusterLauncher: diff --git a/hikari_clusters/events.py b/hikari_clusters/events.py index 1613d13..504c3df 100644 --- a/hikari_clusters/events.py +++ b/hikari_clusters/events.py @@ -22,14 +22,17 @@ from __future__ import annotations +import logging import traceback from typing import Any, Awaitable, Callable -from . import log, payload +from . import payload __all__ = ("EventHandler", "EventGroup", "IPC_EVENT") IPC_EVENT = Callable[..., Awaitable[None]] +_LOG = logging.getLogger(__name__) +_LOG.setLevel(logging.INFO) class EventHandler: @@ -37,18 +40,13 @@ class EventHandler: Parameters ---------- - logger : :class:`~log.Logger` - The logger to use. event_kwargs : dict[str, Any], optional Extra kwargs to pass to event functions, default to None. """ - def __init__( - self, logger: log.Logger, event_kwargs: dict[str, Any] | None = None - ) -> None: + def __init__(self, event_kwargs: dict[str, Any] | None = None) -> None: self.events: dict[str, list[IPC_EVENT]] = {} self.event_kwargs = event_kwargs or {} - self.logger = logger async def handle_event(self, pl: payload.EVENT) -> None: """Handle an event. @@ -74,8 +72,8 @@ async def handle_event(self, pl: payload.EVENT) -> None: try: await func(pl, **kwargs) except Exception: - print("Ignoring Exception in handle_event:") - self.logger.error(traceback.format_exc()) + _LOG.error("Ignoring Exception in handle_event:") + _LOG.error(traceback.format_exc()) def include(self, group: EventGroup) -> None: """Add the events from an :class:`~EventGroup` to this diff --git a/hikari_clusters/info_classes.py b/hikari_clusters/info_classes.py index e71edb4..29b3673 100644 --- a/hikari_clusters/info_classes.py +++ b/hikari_clusters/info_classes.py @@ -22,13 +22,36 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import asdict, dataclass +from typing import Any -__all__ = ("ServerInfo", "ClusterInfo") +__all__ = ("ServerInfo", "ClusterInfo", "BaseInfo", "BrainInfo") + + +class BaseInfo: + uid: int + _next_info_class_id: int = 0 + _info_classes: dict[int, type[BaseInfo]] = {} + _info_class_id: int + + def __init_subclass__(cls) -> None: + cls._info_class_id = BaseInfo._next_info_class_id + BaseInfo._next_info_class_id += 1 + BaseInfo._info_classes[cls._info_class_id] = cls + + def asdict(self) -> dict[str, Any]: + dct = asdict(self) + dct["_info_class_id"] = self._info_class_id + return dct + + @staticmethod + def fromdict(data: dict[str, Any]) -> BaseInfo: + cls = BaseInfo._info_classes[data.pop("_info_class_id")] + return cls(**data) @dataclass -class ServerInfo: +class ServerInfo(BaseInfo): """A representation of a :class:`~server.Server`.""" uid: int @@ -38,7 +61,7 @@ class ServerInfo: @dataclass -class ClusterInfo: +class ClusterInfo(BaseInfo): """A representation of a :class:`~cluster.Cluster`.""" uid: int @@ -69,3 +92,9 @@ def get_cluster_id(shard_id: int, shards_per_cluster: int) -> int: Assumes that all the shard ids of a cluster are adjacent.""" return shard_id // shards_per_cluster + + +@dataclass +class BrainInfo(BaseInfo): + uid: int + """The ipc uid of the brain.""" diff --git a/hikari_clusters/ipc_client.py b/hikari_clusters/ipc_client.py index 5a20b1d..3fb9ed2 100644 --- a/hikari_clusters/ipc_client.py +++ b/hikari_clusters/ipc_client.py @@ -24,21 +24,26 @@ import asyncio import json +import logging import pathlib import ssl -from typing import Any, Iterable +from typing import Any, Iterable, TypeVar, cast from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from websockets.legacy import client -from . import close_codes, exceptions, log, payload +from . import close_codes, exceptions, payload from .callbacks import CallbackHandler, NoResponse from .commands import CommandHandler from .events import EventGroup, EventHandler -from .info_classes import ClusterInfo, ServerInfo +from .info_classes import BaseInfo, BrainInfo, ClusterInfo, ServerInfo from .ipc_base import IpcBase from .task_manager import TaskManager +_BI_T = TypeVar("_BI_T", bound=BaseInfo) +_LOG = logging.getLogger(__name__) +_LOG.setLevel(logging.INFO) + __all__ = ("IpcClient",) @@ -51,17 +56,9 @@ class IpcClient(IpcBase): The uri of the ipc server. token : str The token required by the ipc server. - logger : :class:`~log.Logger` - The logger used by the clients parent. reconnect : bool Whether or not to try to reconnect after disconnection, by default True. - cmd_kwargs : dict[str, Any], optional - Command arguments to pass to :class:`~commands.CommandHandler`, by - default None. - event_kwargs : dict[str, Any], optional - Event arguments to pass to :class:`~events.EventHandler`, by - default None. certificate_path : pathlib.Path, optional Required for secure (wss) connections, by default None. """ @@ -70,23 +67,14 @@ def __init__( self, uri: str, token: str, - logger: log.Logger, reconnect: bool = True, - cmd_kwargs: dict[str, Any] | None = None, - event_kwargs: dict[str, Any] | None = None, certificate_path: pathlib.Path | None = None, ) -> None: - self.logger = logger - self.tasks = TaskManager(logger) - - cmd_kwargs = cmd_kwargs or {} - event_kwargs = event_kwargs or {} - cmd_kwargs["_ipc_client"] = self - event_kwargs["_ipc_client"] = self + self.tasks = TaskManager() self.callbacks = CallbackHandler(self) - self.commands = CommandHandler(self, cmd_kwargs) - self.events = EventHandler(logger, event_kwargs) + self.commands = CommandHandler(self, {"_ipc_client": self}) + self.events = EventHandler({"_ipc_client": self}) self.events.include(_E) @@ -94,6 +82,7 @@ def __init__( self.token = token self.reconnect = reconnect + self.certificate_path = certificate_path self.ssl_context: ssl.SSLContext | None if certificate_path is not None: self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) @@ -103,14 +92,7 @@ def __init__( self.client_uids: set[int] = set() """A set of all IPC uids representing every connected client.""" - self.brain_uid: int | None = None - """The IPC uid of the brain's client.""" - self.servers: dict[int, ServerInfo] = {} - """Maps the IPC uids for each server to its ServerInfo class.""" - self.clusters: dict[int, ClusterInfo] = {} - """Maps the IPC uids for each cluster to its ClusterInfo class.""" - self.clusters_by_cluster_id: dict[int, ClusterInfo] = {} - """Maps the cluster id for each cluster to its ClusterInfo class.""" + self.clients: dict[int, dict[int, BaseInfo]] = {} self._ws: client.WebSocketClientProtocol | None = None self.uid: int | None = None @@ -119,16 +101,32 @@ def __init__( self.ready_future: asyncio.Future[None] | None = None @property - def server_uids(self) -> set[int]: - """A set of all IPC uids representing every connected server.""" + def servers(self) -> dict[int, ServerInfo]: + """Shorthand for IpcClient.get_clients(ServerInfo)""" + + return self.get_clients(ServerInfo) + + @property + def clusters(self) -> dict[int, ClusterInfo]: + """Shorthand for IpcClient.get_clients(ClusterInfo)""" - return set(self.servers.keys()) + return self.get_clients(ClusterInfo) @property - def cluster_uids(self) -> set[int]: - """A set of all IPC uids representing every connected cluster.""" + def brain(self) -> BrainInfo | None: + """The IPC UID of the brain.""" - return set(self.clusters.keys()) + brains = self.get_clients(BrainInfo) + if not brains: + return None + elif len(brains) > 1: + _LOG.warning("More than one brain connected.") + return brains[max(brains.keys())] + + def get_clients(self, client: type[_BI_T]) -> dict[int, _BI_T]: + return cast( + "dict[int, _BI_T]", self.clients.get(client._info_class_id, {}) + ) def all_shards(self) -> set[int]: """Get all shard ids. @@ -138,10 +136,11 @@ def all_shards(self) -> set[int]: all_shards: set[int] = set() for c in self.clusters.values(): + servers = self.servers if not ( c.ready - and c.server_uid in self.servers - and c.uid in self.servers[c.server_uid].cluster_uids + and c.server_uid in servers + and c.uid in servers[c.server_uid].cluster_uids ): continue all_shards.update(c.shard_ids) @@ -290,14 +289,14 @@ async def send_command( return cb.resps async def _start(self) -> None: - self.logger.debug("Attempting connection to IPC...") + _LOG.debug("Attempting connection to IPC...") assert self.ready_future async for ws in client.connect(self.uri, ssl=self.ssl_context): reconnect = self.reconnect await self._handshake(ws) self._ws = ws self.ready_future.set_result(None) - self.logger.info(f"Connected successfully as {self.uid}.") + _LOG.info(f"Connected successfully as {self.uid}.") try: await self._recv_loop(ws) except ConnectionClosedOK: @@ -310,25 +309,23 @@ async def _start(self) -> None: finally: self._ws = None self.ready_future = asyncio.Future() - self.logger.info("Disconnected.") + _LOG.info("Disconnected.") self.client_uids.clear() - self.clusters_by_cluster_id.clear() - self.clusters.clear() - self.servers.clear() + self.clients.clear() if reconnect: - self.logger.info("Attempting reconnection...") + _LOG.info("Attempting reconnection...") else: return async def _handshake(self, ws: client.WebSocketClientProtocol) -> None: - self.logger.debug("Attempting handshake...") + _LOG.debug("Attempting handshake...") await ws.send(json.dumps({"token": self.token})) data: dict[str, Any] = json.loads(await ws.recv()) self.uid = data["uid"] self.client_uids = set(data["client_uids"]) - self.logger.debug(f"Handshake successful, uid {self.uid}") + _LOG.debug(f"Handshake successful, uid {self.uid}") def _update_clients(self, client_uids: set[int]) -> None: if self.client_uids.difference(client_uids): @@ -343,11 +340,10 @@ def _update_clients(self, client_uids: set[int]) -> None: if cid in self.client_uids: continue del self.clusters[cid] - del self.clusters_by_cluster_id[c.cluster_id] async def _recv_loop(self, ws: client.WebSocketClientProtocol) -> None: async for msg in ws: - self.logger.debug(f"Received message: {msg!s}") + _LOG.debug(f"Received message: {msg!s}") data: dict[str, Any] = json.loads(msg) if data.get("internal", False): self._update_clients(set(data["client_uids"])) @@ -378,32 +374,9 @@ async def _raw_send(self, msg: str) -> None: _E = EventGroup() -@_E.add("set_brain_uid") -async def set_brain_uid(pl: payload.EVENT, _ipc_client: IpcClient) -> None: - assert pl.data.data is not None - uid = pl.data.data["uid"] - _ipc_client.logger.debug(f"Setting brain uid to {uid}.") - _ipc_client.brain_uid = pl.data.data["uid"] - - -@_E.add("set_cluster_info") -async def update_cluster_info( - pl: payload.EVENT, _ipc_client: IpcClient -) -> None: - assert pl.data.data is not None - cinfo = ClusterInfo(**pl.data.data) - _ipc_client.logger.debug( - f"Updating info for Cluster {cinfo.cluster_id} ({cinfo.uid})" - ) - _ipc_client.clusters_by_cluster_id[cinfo.cluster_id] = cinfo - _ipc_client.clusters[cinfo.uid] = cinfo - - -@_E.add("set_server_info") -async def update_server_info( - pl: payload.EVENT, _ipc_client: IpcClient -) -> None: +@_E.add("set_info_class") +async def set_info_class(pl: payload.EVENT, _ipc_client: IpcClient) -> None: assert pl.data.data is not None - sinfo = ServerInfo(**pl.data.data) - _ipc_client.logger.debug(f"Updating info for Server {sinfo.uid}") - _ipc_client.servers[sinfo.uid] = sinfo + info = BaseInfo.fromdict(pl.data.data) + _LOG.debug("Setting info class {info}.") + _ipc_client.clients.setdefault(info._info_class_id, {})[info.uid] = info diff --git a/hikari_clusters/ipc_server.py b/hikari_clusters/ipc_server.py index 2de536e..1d80fdc 100644 --- a/hikari_clusters/ipc_server.py +++ b/hikari_clusters/ipc_server.py @@ -24,6 +24,7 @@ import asyncio import json +import logging import pathlib import ssl import traceback @@ -32,13 +33,14 @@ from websockets.exceptions import ConnectionClosedOK from websockets.legacy import server -from . import close_codes, log, payload +from . import close_codes, payload from .ipc_base import IpcBase from .task_manager import TaskManager __all__ = ("IpcServer",) -LOG = log.Logger("Ipc Server") +_LOG = logging.getLogger(__name__) +_LOG.setLevel(logging.INFO) class IpcServer(IpcBase): @@ -69,7 +71,7 @@ def __init__( token: str, certificate_path: pathlib.Path | None = None, ) -> None: - self.tasks = TaskManager(LOG) + self.tasks = TaskManager() self.host = host self.port = port @@ -111,19 +113,19 @@ async def _serve( self, ws: server.WebSocketServerProtocol, path: str ) -> None: uid: int | None = None - LOG.debug("Client connected.") + _LOG.debug("Client connected.") try: uid = await self._handshake(ws) if uid is None: return self.clients[uid] = ws - LOG.info(f"Client connected as {uid}") + _LOG.info(f"Client connected as {uid}") try: while True: msg = await ws.recv() - LOG.debug(f"Received message: {msg!s}") + _LOG.debug(f"Received message: {msg!s}") pl = payload.deserialize_payload(json.loads(msg)) await self._dispatch(pl.recipients, msg) except ConnectionClosedOK: @@ -132,31 +134,31 @@ async def _serve( del self.clients[uid] except Exception: - LOG.error(f"Exception in handler for client {uid}:") - LOG.error(traceback.format_exc()) + _LOG.error(f"Exception in handler for client {uid}:") + _LOG.error(traceback.format_exc()) - LOG.info(f"Client {uid} disconnected.") + _LOG.info(f"Client {uid} disconnected.") async def _start(self) -> None: - LOG.debug("Server starting up...") + _LOG.debug("Server starting up...") assert self.ready_future assert self.stop_future async with server.serve( self._serve, self.host, self.port, ssl=self.ssl_context ): - LOG.debug("Server started.") + _LOG.debug("Server started.") self.ready_future.set_result(None) await self.stop_future - LOG.debug("Stopping...") - LOG.debug("Server exited.") + _LOG.debug("Stopping...") + _LOG.debug("Server exited.") async def _handshake( self, ws: server.WebSocketServerProtocol ) -> int | None: - LOG.debug("Attempting handshake.") + _LOG.debug("Attempting handshake.") req: dict[str, Any] = json.loads(await ws.recv()) if req.get("token") != self.token: - LOG.debug("Received invalid token.") + _LOG.debug("Received invalid token.") await ws.close(close_codes.INVALID_TOKEN, "Invalid Token") return None @@ -164,7 +166,7 @@ async def _handshake( await ws.send( json.dumps({"uid": uid, "client_uids": list(self.clients.keys())}) ) - LOG.debug(f"Handshake successful, uid {uid}") + _LOG.debug(f"Handshake successful, uid {uid}") return uid async def _send_client_uids_loop(self) -> None: @@ -183,4 +185,4 @@ async def _dispatch(self, to: Iterable[int], msg: str | bytes) -> None: try: await client.send(msg) except Exception: - LOG.error(traceback.format_exc()) + _LOG.error(traceback.format_exc()) diff --git a/hikari_clusters/log.py b/hikari_clusters/log.py deleted file mode 100644 index d92e4c6..0000000 --- a/hikari_clusters/log.py +++ /dev/null @@ -1,71 +0,0 @@ -# MIT License -# -# Copyright (c) 2021 TrigonDev -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from __future__ import annotations - -import logging -from typing import Any - -__all__ = ("FATAL", "ERROR", "WARN", "INFO", "DEBUG", "Logger") - -FATAL = logging.FATAL -ERROR = logging.ERROR -WARN = logging.WARN -INFO = logging.INFO -DEBUG = logging.DEBUG - - -class _LoggingHandler(logging.Handler): - def __init__(self, name: str) -> None: - super().__init__() - self.fmt = logging.Formatter(logging.BASIC_FORMAT) - self.limit = 50 - self.name = name - - def format(self, record: logging.LogRecord) -> str: - record.message = record.getMessage() - return f"{self.name} {record.levelname}: {record.message}" - - def emit(self, record: logging.LogRecord) -> None: - print(self.format(record)) - - -class Logger(logging.Logger): - """A logging.Logger that forces the handlers - to be the same at all times.""" - - def __init__(self, name: str) -> None: - super().__init__(name, INFO) - self.hdlr = _LoggingHandler(name) - - @property # type: ignore - def handlers(self) -> list[_LoggingHandler]: # type: ignore - return [self.hdlr] - - @handlers.setter - def handlers(self, other: Any) -> None: - # take that stupid things that try to erase my logs - pass - - -# ignore info messages from websockets -logging.getLogger("websockets.client").setLevel(logging.FATAL) diff --git a/hikari_clusters/server.py b/hikari_clusters/server.py index 110ed9d..23f2396 100644 --- a/hikari_clusters/server.py +++ b/hikari_clusters/server.py @@ -23,33 +23,30 @@ from __future__ import annotations import asyncio -import contextlib +import logging import multiprocessing import pathlib import signal -from dataclasses import asdict from typing import TYPE_CHECKING, Any -from websockets.exceptions import ConnectionClosed - from hikari_clusters import payload from hikari_clusters.info_classes import ClusterInfo, ServerInfo -from . import log +from .base_client import BaseClient from .commands import CommandGroup from .events import EventGroup from .ipc_client import IpcClient -from .task_manager import TaskManager if TYPE_CHECKING: from .cluster import ClusterLauncher __all__ = ("Server",) -LOG = log.Logger("Server") +_LOG = logging.getLogger(__name__) +_LOG.setLevel(logging.INFO) -class Server: +class Server(BaseClient): """A group of clusters. Parameters @@ -74,20 +71,18 @@ def __init__( cluster_launcher: ClusterLauncher, certificate_path: pathlib.Path | str | None = None, ) -> None: - self.tasks = TaskManager(LOG) - if isinstance(certificate_path, str): certificate_path = pathlib.Path(certificate_path) - self.ipc = IpcClient( + super().__init__( IpcClient.get_uri(host, port, certificate_path is not None), token, - LOG, - cmd_kwargs={"server": self}, - event_kwargs={"server": self}, - certificate_path=certificate_path, + True, + certificate_path, ) - self.certificate_path = certificate_path + + self.ipc.commands.cmd_kwargs["server"] = self + self.ipc.events.event_kwargs["server"] = self self.ipc.commands.include(_C) self.ipc.events.include(_E) @@ -96,8 +91,6 @@ def __init__( self.cluster_launcher = cluster_launcher - self.stop_future: asyncio.Future[None] | None = None - @property def clusters(self) -> list[ClusterInfo]: """A list of :class:`~info_classes.ClusterInfo` @@ -109,6 +102,10 @@ def clusters(self) -> list[ClusterInfo]: if c.server_uid == self.ipc.uid ] + def get_info(self) -> ServerInfo: + assert self.ipc.uid + return ServerInfo(self.ipc.uid, [c.uid for c in self.clusters]) + def run(self) -> None: """Run the server, wait for the server to stop, and then shutdown.""" @@ -122,54 +119,30 @@ def sigstop(*args: Any, **kwargs: Any) -> None: loop.run_until_complete(self.close()) async def start(self) -> None: - """Start the server. - - Returns as soon as all tasks are completed. Returning does not mean - that the server is ready.""" - - self.stop_future = asyncio.Future() - self.tasks.create_task(self._broadcast_server_info_loop()) - await self.ipc.start() - - async def join(self) -> None: - """Wait for the server to stop.""" + # <<>> + await super().start() - assert self.stop_future and self.ipc.stop_future - await asyncio.wait( - [self.stop_future, self.ipc.stop_future], - return_when=asyncio.FIRST_COMPLETED, - ) - - async def close(self) -> None: - """Shutdown the server and all clusters that belong to this server.""" - - self.ipc.stop() - await self.ipc.close() - - self.tasks.cancel_all() - await self.tasks.wait_for_all() + self.tasks.create_task(self._loop_cleanup_processes()) - def stop(self) -> None: - """Tell the server to stop.""" - - assert self.stop_future - self.stop_future.set_result(None) - - async def _broadcast_server_info_loop(self) -> None: + async def _loop_cleanup_processes(self) -> None: while True: + await asyncio.sleep(5) await self.ipc.wait_until_ready() - assert self.ipc.uid - with contextlib.suppress(ConnectionClosed): - await self.ipc.send_event( - self.ipc.client_uids, - "set_server_info", - asdict( - ServerInfo( - self.ipc.uid, [c.uid for c in self.clusters] - ) - ), - ) - await asyncio.sleep(1) + if (brain := self.ipc.brain) is None: + continue + + dead_procs: list[int] = [] + for smallest_shard_id, proc in self.cluster_processes.items(): + if not proc.is_alive(): + await self.ipc.send_event( + [brain.uid], + "cluster_died", + {"smallest_shard_id": smallest_shard_id}, + ) + dead_procs.append(smallest_shard_id) + + for shard_id in dead_procs: + del self.cluster_processes[shard_id] _C = CommandGroup() @@ -178,7 +151,7 @@ async def _broadcast_server_info_loop(self) -> None: @_C.add("launch_cluster") async def start_cluster(pl: payload.COMMAND, server: Server) -> None: assert pl.data.data is not None - LOG.info(f"Launching Cluster with shard_ids {pl.data.data['shard_ids']}") + _LOG.info(f"Launching Cluster with shard_ids {pl.data.data['shard_ids']}") p = multiprocessing.Process( target=server.cluster_launcher.launch_cluster, kwargs={ @@ -187,7 +160,7 @@ async def start_cluster(pl: payload.COMMAND, server: Server) -> None: "shard_ids": pl.data.data["shard_ids"], "shard_count": pl.data.data["shard_count"], "server_uid": server.ipc.uid, - "certificate_path": server.certificate_path, + "certificate_path": server.ipc.certificate_path, }, ) p.start() diff --git a/hikari_clusters/task_manager.py b/hikari_clusters/task_manager.py index 435380e..510dd1f 100644 --- a/hikari_clusters/task_manager.py +++ b/hikari_clusters/task_manager.py @@ -23,14 +23,15 @@ from __future__ import annotations import asyncio +import logging import traceback from typing import Any, Coroutine, Generator, Iterable, Type, TypeVar -from . import log - __all__ = ("TaskManager",) _T = TypeVar("_T") +_LOG = logging.getLogger(__name__) +_LOG.setLevel(logging.INFO) class _TaskWrapper: @@ -47,8 +48,7 @@ def __init__( class TaskManager: """Makes asyncio.Task managements slightly easier.""" - def __init__(self, logger: log.Logger) -> None: - self.logger = logger + def __init__(self) -> None: self._tasks: dict[int, _TaskWrapper] = {} self._curr_tid = 0 @@ -101,8 +101,8 @@ def callback(task: asyncio.Task[Any]) -> None: pass except Exception as e: if not isinstance(e, tuple(ignored_exceptions)): - self.logger.error("Exception in task callback:") - self.logger.error(traceback.format_exc()) + _LOG.error("Exception in task callback:") + _LOG.error(traceback.format_exc()) finally: self._tasks.pop(tid, None)