Skip to content

Commit

Permalink
JS -> Python WS now functional
Browse files Browse the repository at this point in the history
  • Loading branch information
AAGaming00 authored and marios8543 committed Nov 13, 2023
1 parent 05b41b3 commit 1921e7e
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 74 deletions.
16 changes: 11 additions & 5 deletions backend/decky_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def on_modified(self, event: DirModifiedEvent | FileModifiedEvent):
self.maybe_reload(src_path)

class Loader:
def __init__(self, server_instance: PluginManager, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None:
def __init__(self, server_instance: PluginManager, ws: WSRouter, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None:
self.loop = loop
self.logger = getLogger("Loader")
self.plugin_path = plugin_path
Expand All @@ -88,10 +88,7 @@ def __init__(self, server_instance: PluginManager, plugin_path: str, loop: Abstr
self.observer.start()
self.loop.create_task(self.enable_reload_wait())

self.ws = WSRouter()

server_instance.web_app.add_routes([
web.get("/ws", self.ws.handle),
server_instance.add_routes([
web.get("/frontend/{path:.*}", self.handle_frontend_assets),
web.get("/locales/{path:.*}", self.handle_frontend_locales),
web.get("/plugins", self.get_plugins),
Expand All @@ -101,6 +98,15 @@ def __init__(self, server_instance: PluginManager, plugin_path: str, loop: Abstr
web.post("/plugins/{plugin_name}/reload", self.handle_backend_reload_request)
])

ws.add_route("test", self.test_method)

async def test_method():
await sleep(2)

return {
"test data": True
}

async def enable_reload_wait(self):
if self.live_reload:
await sleep(10)
Expand Down
4 changes: 3 additions & 1 deletion backend/decky_loader/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Change PyInstaller files permissions
import sys
from typing import Dict
from wsrouter import WSRouter
from .localplatform.localplatform import (chmod, chown, service_stop, service_start,
ON_WINDOWS, get_log_level, get_live_reload,
get_server_port, get_server_host, get_chown_plugin_path,
Expand Down Expand Up @@ -63,7 +64,8 @@ def __init__(self, loop: AbstractEventLoop) -> None:
allow_credentials=True
)
})
self.plugin_loader = Loader(self, plugin_path, self.loop, get_live_reload())
self.ws = WSRouter(self.loop, self.web_app)
self.plugin_loader = Loader(self, self.ws, plugin_path, self.loop, get_live_reload())
self.settings = SettingsManager("loader", path.join(get_privileged_path(), "settings"))
self.plugin_browser = PluginBrowser(plugin_path, self.plugin_loader.plugins, self.plugin_loader, self.settings)
self.utilities = Utilities(self)
Expand Down
6 changes: 5 additions & 1 deletion backend/decky_loader/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def __init__(self, context: PluginManager) -> None:
web.post("/methods/{method_name}", self._handle_server_method_call)
])

async def _handle_server_method_call(self, request: web.Request):
context.ws.add_route("utilities/ping", self.ping)
context.ws.add_route("utilities/settings/get", self.get_setting)
context.ws.add_route("utilities/settings/set", self.set_setting)

async def _handle_server_method_call(self, request):
method_name = request.match_info["method_name"]
try:
args = await request.json()
Expand Down
85 changes: 57 additions & 28 deletions backend/decky_loader/wsrouter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from logging import getLogger

from asyncio import Future
from asyncio import AbstractEventLoop, Future

from aiohttp import web, WSMsgType
from aiohttp import WSMsgType
from aiohttp.web import Application, WebSocketResponse, Request, Response, get

from enum import Enum

from typing import Dict, Any, Callable
from typing import Dict

from traceback import format_exc

from helpers import get_csrf_token

class MessageType(Enum):
# Call-reply
CALL = 0
Expand All @@ -23,20 +26,34 @@ class MessageType(Enum):
# see wsrouter.ts for typings

class WSRouter:
def __init__(self) -> None:
def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None:
self.loop = loop
self.ws = None
self.req_id = 0
self.routes = {}
self.running_calls: Dict[int, Future] = {}
# self.subscriptions: Dict[str, Callable[[Any]]] = {}
self.logger = getLogger("WSRouter")

async def add_route(self, name, route):
server_instance.add_routes([
get("/ws", self.handle)
])

async def write(self, dta: Dict[str, any]):
await self.ws.send_json(dta)

def add_route(self, name: str, route):
self.routes[name] = route

async def handle(self, request):
def remove_route(self, name: str):
del self.routes[name]

async def handle(self, request: Request):
# Auth is a query param as JS WebSocket doesn't support headers
if request.rel_url.query["auth"] != get_csrf_token():
return Response(text='Forbidden', status='403')
self.logger.debug('Websocket connection starting')
ws = web.WebSocketResponse()
ws = WebSocketResponse()
await ws.prepare(request)
self.logger.debug('Websocket connection ready')

Expand All @@ -58,29 +75,29 @@ async def handle(self, request):
# TODO DO NOT RELY ON THIS!
break
else:
match data.type:
case MessageType.CALL:
data = msg.json()
match data["type"]:
case MessageType.CALL.value:
# do stuff with the message
data = msg.json()
if self.routes[data.route]:
if self.routes[data["route"]]:
try:
res = await self.routes[data.route](*data.args)
await self.write({"type": MessageType.REPLY, "id": data.id, "result": res})
self.logger.debug(f"Started PY call {data.route} ID {data.id}")
res = await self.routes[data["route"]](*data["args"])
await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res})
self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}')
except:
await self.write({"type": MessageType.ERROR, "id": data.id, "error": format_exc()})
await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})
else:
await self.write({"type": MessageType.ERROR, "id": data.id, "error": "Route does not exist."})
case MessageType.REPLY:
if self.running_calls[data.id]:
self.running_calls[data.id].set_result(data.result)
del self.running_calls[data.id]
self.logger.debug(f"Resolved JS call {data.id} with value {str(data.result)}")
case MessageType.ERROR:
if self.running_calls[data.id]:
self.running_calls[data.id].set_exception(data.error)
del self.running_calls[data.id]
self.logger.debug(f"Errored JS call {data.id} with error {data.error}")
await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route does not exist."})
case MessageType.REPLY.value:
if self.running_calls[data["id"]]:
self.running_calls[data["id"]].set_result(data["result"])
del self.running_calls[data["id"]]
self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}')
case MessageType.ERROR.value:
if self.running_calls[data["id"]]:
self.running_calls[data["id"]].set_exception(data["error"])
del self.running_calls[data["id"]]
self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}')

case _:
self.logger.error("Unknown message type", data)
Expand All @@ -94,5 +111,17 @@ async def handle(self, request):
self.logger.debug('Websocket connection closed')
return ws

async def write(self, dta: Dict[str, any]):
await self.ws.send_json(dta)
async def call(self, route: str, *args):
future = Future()

self.req_id += 1

id = self.req_id

self.running_calls[id] = future

self.logger.debug(f'Calling JS method {route} with args {str(args)}')

self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id })

return await future
9 changes: 7 additions & 2 deletions frontend/src/plugin-loader.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import Toaster from './toaster';
import { VerInfo, callUpdaterMethod } from './updater';
import { getSetting, setSetting } from './utils/settings';
import TranslationHelper, { TranslationClass } from './utils/TranslationHelper';
import { WSRouter } from './wsrouter';

const StorePage = lazy(() => import('./components/store/Store'));
const SettingsPage = lazy(() => import('./components/settings'));
Expand All @@ -48,6 +49,8 @@ class PluginLoader extends Logger {
public toaster: Toaster = new Toaster();
private deckyState: DeckyState = new DeckyState();

public ws: WSRouter = new WSRouter();

public hiddenPluginsService = new HiddenPluginsService(this.deckyState);
public notificationService = new NotificationService(this.deckyState);

Expand Down Expand Up @@ -102,9 +105,11 @@ class PluginLoader extends Logger {

initFilepickerPatches();

this.getUserInfo();
this.ws.connect().then(() => {
this.getUserInfo();

this.updateVersion();
this.updateVersion();
});
}

public async getUserInfo() {
Expand Down
22 changes: 3 additions & 19 deletions frontend/src/utils/settings.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,8 @@
interface GetSettingArgs<T> {
key: string;
default: T;
}

interface SetSettingArgs<T> {
key: string;
value: T;
}

export async function getSetting<T>(key: string, def: T): Promise<T> {
const res = (await window.DeckyPluginLoader.callServerMethod('get_setting', {
key,
default: def,
} as GetSettingArgs<T>)) as { result: T };
return res.result;
const res = await window.DeckyPluginLoader.ws.call<[string, T], T>('utilities/settings/get', key, def);
return res;
}

export async function setSetting<T>(key: string, value: T): Promise<void> {
await window.DeckyPluginLoader.callServerMethod('set_setting', {
key,
value,
} as SetSettingArgs<T>);
await window.DeckyPluginLoader.ws.call<[string, T], void>('utilities/settings/set', key, value);
}
45 changes: 27 additions & 18 deletions frontend/src/wsrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,41 +41,47 @@ interface PromiseResolver<T> {
promise: Promise<T>;
}

class WSRouter extends Logger {
export class WSRouter extends Logger {
routes: Map<string, (...args: any) => any> = new Map();
runningCalls: Map<number, PromiseResolver<any>> = new Map();
ws?: WebSocket;
connectPromise?: Promise<void>;
// Used to map results and errors to calls
reqId: number = 0;
constructor() {
super('WSRouter');
}

connect() {
this.ws = new WebSocket('ws://127.0.0.1:1337/ws');

this.ws.addEventListener('message', this.onMessage.bind(this));
this.ws.addEventListener('close', this.onError.bind(this));
this.ws.addEventListener('message', this.onError.bind(this));
return (this.connectPromise = new Promise<void>((resolve) => {
// Auth is a query param as JS WebSocket doesn't support headers
this.ws = new WebSocket(`ws://127.0.0.1:1337/ws?auth=${window.deckyAuthToken}`);

this.ws.addEventListener('open', () => {
this.debug('WS Connected');
resolve();
delete this.connectPromise;
});
this.ws.addEventListener('message', this.onMessage.bind(this));
this.ws.addEventListener('close', this.onError.bind(this));
// this.ws.addEventListener('error', this.onError.bind(this));
}));
}

createPromiseResolver<T>(): PromiseResolver<T> {
let resolver: PromiseResolver<T>;
let resolver: Partial<PromiseResolver<T>> = {};
const promise = new Promise<T>((resolve, reject) => {
resolver = {
promise,
resolve,
reject,
};
this.debug('Created new PromiseResolver');
resolver.resolve = resolve;
resolver.reject = reject;
});
this.debug('Returning new PromiseResolver');
resolver.promise = promise;
// The promise will always run first
// @ts-expect-error 2454
return resolver;
}

write(data: Message) {
async write(data: Message) {
if (this.connectPromise) await this.connectPromise;
this.ws?.send(JSON.stringify(data));
}

Expand Down Expand Up @@ -129,22 +135,25 @@ class WSRouter extends Logger {
} catch (e) {
this.error('Error parsing WebSocket message', e);
}
this.call<[number, number], string>('methodName', 1, 2);
}

// this.call<[number, number], string>('methodName', 1, 2);
call<Args extends any[] = any[], Return = void>(route: string, ...args: Args): Promise<Return> {
const resolver = this.createPromiseResolver<Return>();

const id = ++this.reqId;

this.runningCalls.set(id, resolver);

this.debug(`Calling PY method ${route} with args`, args);

this.write({ type: MessageType.CALL, route, args, id });

return resolver.promise;
}

onError(error: any) {
this.error('WS ERROR', error);
async onError(error: any) {
this.error('WS DISCONNECTED', error);
await this.connect();
}
}

0 comments on commit 1921e7e

Please sign in to comment.