Skip to content

Commit

Permalink
Support emit custom message (#561)
Browse files Browse the repository at this point in the history
* Support emit message

* Format message

* black format python

* Fix clientId

* Format fix
  • Loading branch information
oeway authored Jul 14, 2024
1 parent 788de8f commit 015378c
Show file tree
Hide file tree
Showing 14 changed files with 135 additions and 39 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
![License](https://img.shields.io/github/license/imjoy-team/imjoy-rpc.svg)
![Build ImJoy RPC](https://github.com/imjoy-team/imjoy-rpc/workflows/Build%20ImJoy%20RPC/badge.svg)
![PyPI](https://img.shields.io/pypi/v/imjoy-rpc.svg?style=popout)

# ImJoy RPC
Expand Down
2 changes: 1 addition & 1 deletion javascript/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion javascript/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "imjoy-rpc",
"version": "0.5.56",
"version": "0.5.57",
"description": "Remote procedure calls for ImJoy.",
"module": "index.js",
"types": "index.d.ts",
Expand Down
20 changes: 19 additions & 1 deletion javascript/src/hypha/rpc.js
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,24 @@ export class RPC extends MessageEmitter {
await message_cache.process(message_id, !!session_id);
}

emit(main_message, extra_data) {
assert(
typeof main_message === "object" && main_message.type,
"Invalid message, must be an object with a type field."
);
let message_package = msgpack_packb(main_message);
if (extra_data) {
const extra = msgpack_packb(extra_data);
message_package = new Uint8Array([...message_package, ...extra]);
}
const total_size = message_package.length;
if (total_size <= CHUNK_SIZE + 1024) {
return this._emit_message(message_package);
} else {
throw new Error("Message is too large to send in one go.");
}
}

_generate_remote_method(
encoded_method,
remote_parent,
Expand Down Expand Up @@ -858,7 +876,7 @@ export class RPC extends MessageEmitter {
const extra = msgpack_packb(extra_data);
message_package = new Uint8Array([...message_package, ...extra]);
}
let total_size = message_package.length;
const total_size = message_package.length;
if (total_size <= CHUNK_SIZE + 1024) {
self._emit_message(message_package).then(function() {
if (timer) {
Expand Down
18 changes: 18 additions & 0 deletions javascript/src/hypha/websocket-client.js
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,29 @@ export async function connectToServer(config) {
await connection.disconnect();
}

wm.config["client_id"] = clientId;
wm.export = _export;
wm.getPlugin = getPlugin;
wm.listPlugins = wm.listServices;
wm.disconnect = disconnect;
wm.registerCodec = rpc.register_codec.bind(rpc);

wm.emit = async function(message) {
assert(
message && typeof message === "object",
"message must be a dictionary"
);
assert("to" in message, "message must have a 'to' field");
assert("type" in message, "message must have a 'type' field");
assert(type !== "method", "message type cannot be 'method'");
return await rpc.emit(message);
};

wm.on = function(type, handler) {
assert(type !== "method", "message type cannot be 'method'");
rpc.on(type, handler);
};

if (config.webrtc) {
await registerRTCService(wm, clientId + "-rtc", config.webrtc_config);
}
Expand Down
2 changes: 1 addition & 1 deletion python/imjoy_rpc/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.5.56"
"version": "0.5.57"
}
7 changes: 5 additions & 2 deletions python/imjoy_rpc/hypha/pyodide_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class LocalWebSocket {
}
"""


class PyodideWebsocketRPCConnection:
"""Represent a pyodide websocket RPC connection."""

Expand All @@ -117,7 +118,7 @@ def __init__(
self._websocket = None
self._handle_message = None
assert server_url and client_id

server_url = server_url + f"?client_id={client_id}"
if workspace is not None:
server_url += f"&workspace={workspace}"
Expand All @@ -139,7 +140,9 @@ async def open(self):
if self._server_url.startswith("wss://local-hypha-server:"):
js.console.log("Connecting to local websocket " + self._server_url)
LocalWebSocket = js.eval("(" + local_websocket_patch + ")")
self._websocket = LocalWebSocket.new(self._server_url, self._client_id, self._workspace)
self._websocket = LocalWebSocket.new(
self._server_url, self._client_id, self._workspace
)
else:
self._websocket = WebSocket.new(self._server_url)
self._websocket.binaryType = "arraybuffer"
Expand Down
16 changes: 15 additions & 1 deletion python/imjoy_rpc/hypha/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,18 @@ async def _send_chunks(self, package, target_id, session_id):
logger.info("All chunks sent (%d)", chunk_num)
await message_cache.process(message_id, bool(session_id))

def emit(self, main_message, extra_data=None):
"""Emit a message."""
assert isinstance(main_message, dict) and "type" in main_message
message_package = msgpack.packb(main_message)
if extra_data:
message_package = message_package + msgpack.packb(extra_data)
total_size = len(message_package)
if total_size <= CHUNK_SIZE + 1024:
return self.loop.create_task(self._emit_message(message_package))
else:
raise Exception("Message is too large to send in one go.")

def _generate_remote_method(
self,
encoded_method,
Expand Down Expand Up @@ -789,7 +801,9 @@ def pfunc(resolve, reject):
# However, if the args contains _rintf === true, we will not clear the session
clear_after_called = True
for arg in args:
if (isinstance(arg, dict) and arg.get("_rintf")) or (hasattr(arg, "_rintf") and arg._rintf == True):
if (isinstance(arg, dict) and arg.get("_rintf")) or (
hasattr(arg, "_rintf") and arg._rintf == True
):
clear_after_called = False
break
extra_data["promise"] = self._encode_promise(
Expand Down
4 changes: 3 additions & 1 deletion python/imjoy_rpc/hypha/sse_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shortuuid

from .rpc import RPC
from .websocket_client import WebsocketRPCConnection

try:
import js # noqa: F401
Expand Down Expand Up @@ -95,7 +96,8 @@ async def open(self):
self._retry_count += 1
self._opening.set_exception(
Exception(
f"Failed to connect to {server_url.split('?')[0]} (retry {self._retry_count}/{MAX_RETRY}): {exp}"
f"Failed to connect to {server_url.split('?')[0]} "
f"(retry {self._retry_count}/{MAX_RETRY}): {exp}"
)
)
finally:
Expand Down
2 changes: 2 additions & 0 deletions python/imjoy_rpc/hypha/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class SyncHyphaServer:
"""A class to interact with the Hypha server synchronously."""

def __init__(self, sync_max_workers=2):
"""Initialize the SyncHyphaServer."""
self.loop = None
self.thread = None
self.server = None
Expand Down Expand Up @@ -212,6 +213,7 @@ def get_rtc_service(server, service_id, config=None):
print("Public services: #", len(services))

def hello(name):
"""Say hello."""
print("Hello " + name)
print("Current thread id: ", threading.get_ident(), threading.current_thread())
time.sleep(2)
Expand Down
32 changes: 25 additions & 7 deletions python/imjoy_rpc/hypha/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import secrets
import string
import traceback
import collections.abc
from functools import partial
from inspect import Parameter, Signature, signature
from inspect import Parameter, Signature
from types import BuiltinFunctionType, FunctionType
from typing import Any

Expand All @@ -24,11 +25,22 @@ def generate_password(length=50):
_hash_id = generate_password()


def recursive_hash(obj):
"""Generate a hash for nested dictionaries and lists."""
if isinstance(obj, collections.abc.Hashable) and not isinstance(obj, dotdict):
return hash(obj)
elif isinstance(obj, dict) or isinstance(obj, dotdict):
return hash(tuple(sorted((k, recursive_hash(v)) for k, v in obj.items())))
elif isinstance(obj, (list, tuple)):
return hash(tuple(recursive_hash(i) for i in obj))
else:
raise TypeError(f"Unsupported type: {type(obj)}")


class dotdict(dict): # pylint: disable=invalid-name
"""Access dictionary attributes with dot.notation."""

__getattr__ = dict.get
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
__delattr__ = dict.__delitem__

def __setattr__(self, name, value):
Expand All @@ -41,16 +53,23 @@ def __setattr__(self, name, value):

def __hash__(self):
"""Return the hash."""
if self.__rid__ and type(self.__rid__) is str:
if hasattr(self, "__rid__") and isinstance(self.__rid__, str):
return hash(self.__rid__ + _hash_id)

# FIXME: This does not address the issue of inner list
return hash(tuple(sorted(self.items())))
return recursive_hash(self)

def __deepcopy__(self, memo=None):
"""Make a deep copy."""
return dotdict(copy.deepcopy(dict(self), memo=memo))

def __getattribute__(self, name):
if name in self:
return self[name]
try:
return super().__getattribute__(name)
except AttributeError:
return None


def format_traceback(traceback_string):
"""Format traceback."""
Expand Down Expand Up @@ -313,7 +332,6 @@ def make_signature(func, name=None, sig=None, doc=None):
sig can be a Signature object or a string without 'def' such as
"foo(a, b=0)"
"""

if isinstance(sig, str):
# Parse signature string
func_name, sig = _str_to_signature(sig)
Expand Down
47 changes: 32 additions & 15 deletions python/imjoy_rpc/hypha/websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import inspect
import logging
import sys
import types

import msgpack
import shortuuid

from .rpc import RPC
from .utils import dotdict

try:
import js # noqa: F401
Expand Down Expand Up @@ -253,12 +253,28 @@ async def disconnect():
await rpc.disconnect()
await connection.disconnect()

wm.config = dotdict(wm.config)
wm.config["client_id"] = client_id
wm.export = export
wm.get_plugin = get_plugin
wm.list_plugins = wm.list_services
wm.disconnect = disconnect
wm.register_codec = rpc.register_codec

def emit_msg(message):
assert isinstance(message, dict), "message must be a dictionary"
assert "to" in message, "message must have a 'to' field"
assert "type" in message, "message must have a 'type' field"
assert message["type"] != "method", "message type cannot be 'method'"
return rpc.emit(message)

def on_msg(type, handler):
assert type != "method", "message type cannot be 'method'"
rpc.on(type, handler)

wm.emit = emit_msg
wm.on = on_msg

if config.get("webrtc", False):
from .webrtc_client import AIORTC_AVAILABLE, register_rtc_service

Expand All @@ -283,7 +299,8 @@ async def get_service(query, webrtc=None, webrtc_config=None):
if ":" in svc.id and "/" in svc.id and AIORTC_AVAILABLE:
client = svc.id.split(":")[0]
try:
# Assuming that the client registered a webrtc service with the client_id + "-rtc"
# Assuming that the client registered
# a webrtc service with the client_id + "-rtc"
peer = await get_rtc_service(
wm,
client + ":" + client.split("/")[1] + "-rtc",
Expand All @@ -310,9 +327,10 @@ async def get_service(query, webrtc=None, webrtc_config=None):
wm["getService"] = get_service
return wm


def setup_local_client(enable_execution=False, on_ready=None):
"""Set up a local client."""
fut = asyncio.Future()

async def message_handler(event):
data = event.data.to_py()
type = data.get("type")
Expand All @@ -333,14 +351,16 @@ async def message_handler(event):
print("server_url should start with https://local-hypha-server:")
return

server = await connect_to_server({
"server_url": server_url,
"workspace": workspace,
"client_id": client_id,
"token": token,
"method_timeout": method_timeout,
"name": name
})
server = await connect_to_server(
{
"server_url": server_url,
"workspace": workspace,
"client_id": client_id,
"token": token,
"method_timeout": method_timeout,
"name": name,
}
)

js.globalThis.api = server
try:
Expand All @@ -349,10 +369,7 @@ async def message_handler(event):
if on_ready:
await on_ready(server, config)
except Exception as e:
await server.update_client_info({
"id": client_id,
"error": str(e)
})
await server.update_client_info({"id": client_id, "error": str(e)})
fut.set_exception(e)
return
fut.set_result(server)
Expand Down
Loading

0 comments on commit 015378c

Please sign in to comment.