Skip to content

Commit

Permalink
be more lenient with the destination for events/commands (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
circuitsacul authored Sep 8, 2022
1 parent fa7ad46 commit 037238a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
2 changes: 1 addition & 1 deletion hikari_clusters/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async def _main_loop(self) -> None:
server_uid, shards = to_launch

await self.ipc.send_command(
[server_uid],
server_uid,
"launch_cluster",
{"shard_ids": shards, "shard_count": self.total_shards},
)
Expand Down
3 changes: 3 additions & 0 deletions hikari_clusters/info_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def fromdict(data: dict[str, Any]) -> BaseInfo:
cls = BaseInfo._info_classes[data.pop("_info_class_id")]
return cls(**data)

def __int__(self) -> int:
return self.uid


@dataclass
class ServerInfo(BaseInfo):
Expand Down
45 changes: 26 additions & 19 deletions hikari_clusters/ipc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import logging
import pathlib
import ssl
from typing import Any, Iterable, TypeVar, cast
from typing import Any, Iterable, TypeVar, Union, cast

from websockets.exceptions import ConnectionClosed, ConnectionClosedOK
from websockets.legacy import client
Expand All @@ -46,6 +46,12 @@

__all__ = ("IpcClient",)

_TO = Union[Iterable[Union[BaseInfo, int]], BaseInfo, int]


def _parse_to(to: _TO) -> Iterable[int]:
return map(int, to) if isinstance(to, Iterable) else [int(to)]


class IpcClient(IpcBase):
"""A connection to a :class:`~ipc_server.IpcServer`.
Expand Down Expand Up @@ -186,59 +192,57 @@ def _stop(*args: Any, **kwargs: Any) -> None:

self.tasks.create_task(self._start()).add_done_callback(_stop)

async def send_not_found_response(
self, to: Iterable[int], callback: int
) -> None:
async def send_not_found_response(self, to: _TO, callback: int) -> None:
"""Respond to a command saying that the command was not found.
Parameters
----------
to : Iterable[int]
to : Iterable[int | BaseInfo]
The clients to send the response to.
callback : int
The command callback (:attr:`~payload.Command.callback`)
"""

await self._send(to, payload.ResponseNotFound(callback))
await self._send(_parse_to(to), payload.ResponseNotFound(callback))

async def send_ok_response(
self, to: Iterable[int], callback: int, data: payload.DATA = None
self, to: _TO, callback: int, data: payload.DATA = None
) -> None:
"""Respond that the command *function* finished without any problems.
Does not necessarily mean that the command itself finished correctly.
Parameters
----------
to : Iterable[int]
to : Iterable[int | BaseInfo]
The clients to send the response to.
callback : int
The command callback (:attr:`~payload.Command.callback`)
data : payload.DATA, optional
The data to send with the response, by default None
"""

await self._send(to, payload.ResponseOk(callback, data))
await self._send(_parse_to(to), payload.ResponseOk(callback, data))

async def send_tb_response(
self, to: Iterable[int], callback: int, tb: str
) -> None:
async def send_tb_response(self, to: _TO, callback: int, tb: str) -> None:
"""Respond that the command function raised an exception.
Parameters
----------
to : Iterable[int]
to : Iterable[int | BaseInfo]
The clients to send the response to.
callback : int
The command callback (:attr:`~payload.Command.callback`)
tb : str
The exception traceback.
"""

await self._send(to, payload.ResponseTraceback(callback, tb))
await self._send(
_parse_to(to), payload.ResponseTraceback(callback, tb)
)

async def send_event(
self, to: Iterable[int], name: str, data: payload.DATA = None
self, to: _TO, name: str, data: payload.DATA = None
) -> None:
"""Dispatch an event.
Expand All @@ -254,11 +258,11 @@ async def send_event(
The data to send with the event, by default None
"""

await self._send(to, payload.Event(name, data))
await self._send(_parse_to(to), payload.Event(name, data))

async def send_command(
self,
to: Iterable[int],
to: _TO,
name: str,
data: payload.DATA = None,
timeout: float = 3.0,
Expand All @@ -267,7 +271,7 @@ async def send_command(
Parameters
----------
to : Iterable[int]
to : Iterable[int | BaseInfo]
The clients to send the command to.
name : str
The name of the command.
Expand All @@ -283,6 +287,7 @@ async def send_command(
if any.
"""

to = _parse_to(to)
with self.callbacks.callback(to) as cb:
await self._send(to, payload.Command(name, cb.key, data))
await cb.wait(timeout)
Expand Down Expand Up @@ -363,7 +368,9 @@ async def _send(
self, to: Iterable[int], pl_data: payload.PAYLOAD_DATA
) -> None:
assert self.uid is not None
pl = payload.Payload(pl_data.opcode, self.uid, list(to), pl_data)
pl = payload.Payload(
pl_data.opcode, self.uid, list(map(int, to)), pl_data
)
await self._raw_send(json.dumps(pl.serialize()))

async def _raw_send(self, msg: str) -> None:
Expand Down

0 comments on commit 037238a

Please sign in to comment.