Skip to content

Commit

Permalink
V2 (#49)
Browse files Browse the repository at this point in the history
* refactorings

* fix logging

* remove uneeded kwargs from ipc_client

* handle disconnects gracefully

* update docstrings

* missed some docstrings

* don't close loop when connection fails
  • Loading branch information
circuitsacul authored Sep 8, 2022
1 parent 3c43fc8 commit fa7ad46
Show file tree
Hide file tree
Showing 12 changed files with 334 additions and 372 deletions.
2 changes: 1 addition & 1 deletion examples/broadcast_exec/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def on_message(self, event: GuildMessageCreateEvent) -> None:
return
if event.content.startswith("!exec"):
await self.cluster.ipc.send_command(
self.cluster.ipc.cluster_uids,
self.cluster.ipc.clusters,
"exec_code",
{"code": event.content[6:]},
)
Expand Down
10 changes: 9 additions & 1 deletion hikari_clusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,19 @@

from importlib.metadata import version

from hikari.internal.ux import init_logging

from . import close_codes, commands, events, exceptions, payload
from .base_client import BaseClient
from .brain import Brain
from .cluster import Cluster, ClusterLauncher
from .info_classes import ClusterInfo, ServerInfo
from .info_classes import BaseInfo, BrainInfo, ClusterInfo, ServerInfo
from .ipc_client import IpcClient
from .ipc_server import IpcServer
from .server import Server

init_logging("INFO", True, False)

__version__ = version(__name__)

__all__ = (
Expand All @@ -45,8 +50,11 @@
"Cluster",
"ClusterLauncher",
"Server",
"BaseClient",
"ClusterInfo",
"ServerInfo",
"BrainInfo",
"BaseInfo",
"payload",
"events",
"commands",
Expand Down
113 changes: 113 additions & 0 deletions hikari_clusters/base_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import asyncio
import logging
import pathlib

from websockets.exceptions import ConnectionClosed

from .info_classes import BaseInfo
from .ipc_client import IpcClient
from .task_manager import TaskManager

_LOG = logging.getLogger(__name__)


class BaseClient:
"""The base client, which contains an IpcClient.
Parameters
----------
ipc_uri : str
The URI of the brain.
token : str
The token for the IPC server.
reconnect : bool
Whether to automatically reconnect if the connection
is lost. Defaults to True.
certificate_path : pathlib.Path | str | None
The path to your certificate, which allos for secure
connection over the IPC. Defaults to None.
"""

def __init__(
self,
ipc_uri: str,
token: str,
reconnect: bool = True,
certificate_path: pathlib.Path | str | None = None,
):
if isinstance(certificate_path, str):
certificate_path = pathlib.Path(certificate_path)

self.tasks = TaskManager()
self.ipc = IpcClient(
uri=ipc_uri,
token=token,
reconnect=reconnect,
certificate_path=certificate_path,
)

self.stop_future: asyncio.Future[None] | None = None

def get_info(self) -> BaseInfo:
"""Get the info class for this client.
Returns:
BaseInfo: The info class.
"""

raise NotImplementedError

async def start(self) -> None:
"""Start the client.
Connects to the IPC server and begins sending out this clients
info.
"""

if self.stop_future is None:
self.stop_future = asyncio.Future()

await self.ipc.start()

self.tasks.create_task(self._broadcast_info_loop())

async def join(self) -> None:
"""Wait until the client begins exiting."""

assert self.stop_future and self.ipc.stop_future

await asyncio.wait(
[self.stop_future, self.ipc.stop_future],
return_when=asyncio.FIRST_COMPLETED,
)

async def close(self) -> None:
"""Shut down the client."""

self.ipc.stop()
await self.ipc.close()

self.tasks.cancel_all()
await self.tasks.wait_for_all()

def stop(self) -> None:
"""Tell the client to stop."""

assert self.stop_future
self.stop_future.set_result(None)

async def _broadcast_info_loop(self) -> None:
while True:
await self.ipc.wait_until_ready()
assert self.ipc.uid
try:
await self.ipc.send_event(
self.ipc.client_uids,
"set_info_class",
self.get_info().asdict(),
)
except ConnectionClosed:
_LOG.error("Failed to send client info.", exc_info=True)
await asyncio.sleep(1)
96 changes: 43 additions & 53 deletions hikari_clusters/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@
import signal
from typing import Any

from . import log, payload
from hikari_clusters.base_client import BaseClient
from hikari_clusters.info_classes import BrainInfo

from . import payload
from .events import EventGroup
from .ipc_client import IpcClient
from .ipc_server import IpcServer
from .task_manager import TaskManager

__all__ = ("Brain",)

LOG = log.Logger("Brain")


class Brain:
class Brain(BaseClient):
"""The brain of the bot.
Allows for comunication between clusters and servers,
Expand Down Expand Up @@ -73,29 +73,30 @@ def __init__(
shards_per_cluster: int,
certificate_path: pathlib.Path | str | None = None,
) -> None:
self.tasks = TaskManager(LOG)
certificate_path = (
pathlib.Path(certificate_path)
if isinstance(certificate_path, str)
else certificate_path
)

super().__init__(
IpcClient.get_uri(host, port, certificate_path is not None),
token,
True,
certificate_path,
)

self.total_servers = total_servers
self.cluster_per_server = clusters_per_server
self.shards_per_cluster = shards_per_cluster

if isinstance(certificate_path, str):
certificate_path = pathlib.Path(certificate_path)

self.server = IpcServer(
host, port, token, certificate_path=certificate_path
)
self.ipc = IpcClient(
IpcClient.get_uri(host, port, certificate_path is not None),
token,
LOG,
certificate_path=certificate_path,
cmd_kwargs={"brain": self},
event_kwargs={"brain": self},
)
self.ipc.events.include(_E)

self.stop_future: asyncio.Future[None] | None = None
self.ipc.commands.cmd_kwargs["brain"] = self
self.ipc.events.event_kwargs["brain"] = self
self.ipc.events.include(_E)

self._waiting_for: tuple[int, int] | None = None

Expand Down Expand Up @@ -128,7 +129,7 @@ def waiting_for(self) -> tuple[int, int] | None:
if self._waiting_for is not None:
server_uid, smallest_shard = self._waiting_for
if (
server_uid not in self.ipc.server_uids
server_uid not in self.ipc.servers
or smallest_shard in self.ipc.all_shards()
):
# `server_uid not in self.ipc.server_uids`
Expand All @@ -148,6 +149,11 @@ def waiting_for(self) -> tuple[int, int] | None:
def waiting_for(self, value: tuple[int, int] | None) -> None:
self._waiting_for = value

def get_info(self) -> BrainInfo:
# <<<docstring from superclass>>>
assert self.ipc.uid
return BrainInfo(uid=self.ipc.uid)

def run(self) -> None:
"""Run the brain, wait for the brain to stop, then cleanup."""

Expand All @@ -161,42 +167,26 @@ def sigstop(*args: Any, **kwargs: Any) -> None:
loop.run_until_complete(self.close())

async def start(self) -> None:
"""Start the brain.
Returns as soon as all tasks have started.
"""
# <<<docstring from superclass>>>
self.stop_future = asyncio.Future()
self.tasks.create_task(self._send_brain_uid_loop())
self.tasks.create_task(self._main_loop())
await self.server.start()
await self.ipc.start()

async def join(self) -> None:
"""Wait for the brain to stop."""

assert self.stop_future
await self.stop_future
await self.server.start()
await super().start()
self.tasks.create_task(self._main_loop())

async def close(self) -> None:
"""Shut the brain down."""
# <<<docstring from superclass>>>
self.ipc.stop()
await self.ipc.close()

self.server.stop()
await self.server.close()

self.ipc.stop()
await self.ipc.close()

self.tasks.cancel_all()
await self.tasks.wait_for_all()

def stop(self) -> None:
"""Tell the brain to stop."""

assert self.stop_future
self.stop_future.set_result(None)

def _get_next_cluster_to_launch(self) -> tuple[int, list[int]] | None:
if len(self.ipc.server_uids) == 0:
if len(self.ipc.servers) == 0:
return None

if not all(c.ready for c in self.ipc.clusters.values()):
Expand All @@ -219,14 +209,6 @@ def _get_next_cluster_to_launch(self) -> tuple[int, list[int]] | None:

return s.uid, list(shards_to_launch)[: self.shards_per_cluster]

async def _send_brain_uid_loop(self) -> None:
while True:
await self.ipc.wait_until_ready()
await self.ipc.send_event(
self.ipc.client_uids, "set_brain_uid", {"uid": self.ipc.uid}
)
await asyncio.sleep(1)

async def _main_loop(self) -> None:
await self.ipc.wait_until_ready()
while True:
Expand Down Expand Up @@ -257,5 +239,13 @@ async def brain_stop(pl: payload.EVENT, brain: Brain) -> None:

@_E.add("shutdown")
async def shutdown(pl: payload.EVENT, brain: Brain) -> None:
await brain.ipc.send_event(brain.ipc.server_uids, "server_stop")
await brain.ipc.send_event(brain.ipc.servers.keys(), "server_stop")
brain.stop()


@_E.add("cluster_died")
async def cluster_died(pl: payload.EVENT, brain: Brain) -> None:
assert pl.data.data is not None
shard_id = pl.data.data["smallest_shard_id"]
if brain._waiting_for is not None and brain._waiting_for[1] == shard_id:
brain.waiting_for = None
Loading

0 comments on commit fa7ad46

Please sign in to comment.