From 1921e7ec56b39d4a6af111a75cf74c495800abc6 Mon Sep 17 00:00:00 2001 From: AAGaming Date: Mon, 10 Jul 2023 18:41:56 -0400 Subject: [PATCH] JS -> Python WS now functional --- backend/decky_loader/loader.py | 16 ++++-- backend/decky_loader/main.py | 4 +- backend/decky_loader/utilities.py | 6 ++- backend/decky_loader/wsrouter.py | 85 +++++++++++++++++++++---------- frontend/src/plugin-loader.tsx | 9 +++- frontend/src/utils/settings.ts | 22 ++------ frontend/src/wsrouter.ts | 45 +++++++++------- 7 files changed, 113 insertions(+), 74 deletions(-) diff --git a/backend/decky_loader/loader.py b/backend/decky_loader/loader.py index 85b5de582..8721ea05d 100644 --- a/backend/decky_loader/loader.py +++ b/backend/decky_loader/loader.py @@ -70,7 +70,7 @@ def on_modified(self, event: DirModifiedEvent | FileModifiedEvent): self.maybe_reload(src_path) class Loader: - def __init__(self, server_instance: PluginManager, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None: + def __init__(self, server_instance: PluginManager, ws: WSRouter, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None: self.loop = loop self.logger = getLogger("Loader") self.plugin_path = plugin_path @@ -88,10 +88,7 @@ def __init__(self, server_instance: PluginManager, plugin_path: str, loop: Abstr self.observer.start() self.loop.create_task(self.enable_reload_wait()) - self.ws = WSRouter() - - server_instance.web_app.add_routes([ - web.get("/ws", self.ws.handle), + server_instance.add_routes([ web.get("/frontend/{path:.*}", self.handle_frontend_assets), web.get("/locales/{path:.*}", self.handle_frontend_locales), web.get("/plugins", self.get_plugins), @@ -101,6 +98,15 @@ def __init__(self, server_instance: PluginManager, plugin_path: str, loop: Abstr web.post("/plugins/{plugin_name}/reload", self.handle_backend_reload_request) ]) + ws.add_route("test", self.test_method) + + async def test_method(): + await sleep(2) + + return { + "test data": True + } + async def enable_reload_wait(self): if self.live_reload: await sleep(10) diff --git a/backend/decky_loader/main.py b/backend/decky_loader/main.py index fae305747..e33f0a9b7 100644 --- a/backend/decky_loader/main.py +++ b/backend/decky_loader/main.py @@ -1,6 +1,7 @@ # Change PyInstaller files permissions import sys from typing import Dict +from wsrouter import WSRouter from .localplatform.localplatform import (chmod, chown, service_stop, service_start, ON_WINDOWS, get_log_level, get_live_reload, get_server_port, get_server_host, get_chown_plugin_path, @@ -63,7 +64,8 @@ def __init__(self, loop: AbstractEventLoop) -> None: allow_credentials=True ) }) - self.plugin_loader = Loader(self, plugin_path, self.loop, get_live_reload()) + self.ws = WSRouter(self.loop, self.web_app) + self.plugin_loader = Loader(self, self.ws, plugin_path, self.loop, get_live_reload()) self.settings = SettingsManager("loader", path.join(get_privileged_path(), "settings")) self.plugin_browser = PluginBrowser(plugin_path, self.plugin_loader.plugins, self.plugin_loader, self.settings) self.utilities = Utilities(self) diff --git a/backend/decky_loader/utilities.py b/backend/decky_loader/utilities.py index f04ed3718..20280c248 100644 --- a/backend/decky_loader/utilities.py +++ b/backend/decky_loader/utilities.py @@ -63,7 +63,11 @@ def __init__(self, context: PluginManager) -> None: web.post("/methods/{method_name}", self._handle_server_method_call) ]) - async def _handle_server_method_call(self, request: web.Request): + context.ws.add_route("utilities/ping", self.ping) + context.ws.add_route("utilities/settings/get", self.get_setting) + context.ws.add_route("utilities/settings/set", self.set_setting) + + async def _handle_server_method_call(self, request): method_name = request.match_info["method_name"] try: args = await request.json() diff --git a/backend/decky_loader/wsrouter.py b/backend/decky_loader/wsrouter.py index 9c8fe4246..2b4c3a3bf 100644 --- a/backend/decky_loader/wsrouter.py +++ b/backend/decky_loader/wsrouter.py @@ -1,15 +1,18 @@ from logging import getLogger -from asyncio import Future +from asyncio import AbstractEventLoop, Future -from aiohttp import web, WSMsgType +from aiohttp import WSMsgType +from aiohttp.web import Application, WebSocketResponse, Request, Response, get from enum import Enum -from typing import Dict, Any, Callable +from typing import Dict from traceback import format_exc +from helpers import get_csrf_token + class MessageType(Enum): # Call-reply CALL = 0 @@ -23,7 +26,8 @@ class MessageType(Enum): # see wsrouter.ts for typings class WSRouter: - def __init__(self) -> None: + def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None: + self.loop = loop self.ws = None self.req_id = 0 self.routes = {} @@ -31,12 +35,25 @@ def __init__(self) -> None: # self.subscriptions: Dict[str, Callable[[Any]]] = {} self.logger = getLogger("WSRouter") - async def add_route(self, name, route): + server_instance.add_routes([ + get("/ws", self.handle) + ]) + + async def write(self, dta: Dict[str, any]): + await self.ws.send_json(dta) + + def add_route(self, name: str, route): self.routes[name] = route - async def handle(self, request): + def remove_route(self, name: str): + del self.routes[name] + + async def handle(self, request: Request): + # Auth is a query param as JS WebSocket doesn't support headers + if request.rel_url.query["auth"] != get_csrf_token(): + return Response(text='Forbidden', status='403') self.logger.debug('Websocket connection starting') - ws = web.WebSocketResponse() + ws = WebSocketResponse() await ws.prepare(request) self.logger.debug('Websocket connection ready') @@ -58,29 +75,29 @@ async def handle(self, request): # TODO DO NOT RELY ON THIS! break else: - match data.type: - case MessageType.CALL: + data = msg.json() + match data["type"]: + case MessageType.CALL.value: # do stuff with the message - data = msg.json() - if self.routes[data.route]: + if self.routes[data["route"]]: try: - res = await self.routes[data.route](*data.args) - await self.write({"type": MessageType.REPLY, "id": data.id, "result": res}) - self.logger.debug(f"Started PY call {data.route} ID {data.id}") + res = await self.routes[data["route"]](*data["args"]) + await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res}) + self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}') except: - await self.write({"type": MessageType.ERROR, "id": data.id, "error": format_exc()}) + await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()}) else: - await self.write({"type": MessageType.ERROR, "id": data.id, "error": "Route does not exist."}) - case MessageType.REPLY: - if self.running_calls[data.id]: - self.running_calls[data.id].set_result(data.result) - del self.running_calls[data.id] - self.logger.debug(f"Resolved JS call {data.id} with value {str(data.result)}") - case MessageType.ERROR: - if self.running_calls[data.id]: - self.running_calls[data.id].set_exception(data.error) - del self.running_calls[data.id] - self.logger.debug(f"Errored JS call {data.id} with error {data.error}") + await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route does not exist."}) + case MessageType.REPLY.value: + if self.running_calls[data["id"]]: + self.running_calls[data["id"]].set_result(data["result"]) + del self.running_calls[data["id"]] + self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}') + case MessageType.ERROR.value: + if self.running_calls[data["id"]]: + self.running_calls[data["id"]].set_exception(data["error"]) + del self.running_calls[data["id"]] + self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}') case _: self.logger.error("Unknown message type", data) @@ -94,5 +111,17 @@ async def handle(self, request): self.logger.debug('Websocket connection closed') return ws - async def write(self, dta: Dict[str, any]): - await self.ws.send_json(dta) \ No newline at end of file + async def call(self, route: str, *args): + future = Future() + + self.req_id += 1 + + id = self.req_id + + self.running_calls[id] = future + + self.logger.debug(f'Calling JS method {route} with args {str(args)}') + + self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id }) + + return await future diff --git a/frontend/src/plugin-loader.tsx b/frontend/src/plugin-loader.tsx index e5f69f1f2..865920166 100644 --- a/frontend/src/plugin-loader.tsx +++ b/frontend/src/plugin-loader.tsx @@ -34,6 +34,7 @@ import Toaster from './toaster'; import { VerInfo, callUpdaterMethod } from './updater'; import { getSetting, setSetting } from './utils/settings'; import TranslationHelper, { TranslationClass } from './utils/TranslationHelper'; +import { WSRouter } from './wsrouter'; const StorePage = lazy(() => import('./components/store/Store')); const SettingsPage = lazy(() => import('./components/settings')); @@ -48,6 +49,8 @@ class PluginLoader extends Logger { public toaster: Toaster = new Toaster(); private deckyState: DeckyState = new DeckyState(); + public ws: WSRouter = new WSRouter(); + public hiddenPluginsService = new HiddenPluginsService(this.deckyState); public notificationService = new NotificationService(this.deckyState); @@ -102,9 +105,11 @@ class PluginLoader extends Logger { initFilepickerPatches(); - this.getUserInfo(); + this.ws.connect().then(() => { + this.getUserInfo(); - this.updateVersion(); + this.updateVersion(); + }); } public async getUserInfo() { diff --git a/frontend/src/utils/settings.ts b/frontend/src/utils/settings.ts index cadfe9359..d390d7ba9 100644 --- a/frontend/src/utils/settings.ts +++ b/frontend/src/utils/settings.ts @@ -1,24 +1,8 @@ -interface GetSettingArgs { - key: string; - default: T; -} - -interface SetSettingArgs { - key: string; - value: T; -} - export async function getSetting(key: string, def: T): Promise { - const res = (await window.DeckyPluginLoader.callServerMethod('get_setting', { - key, - default: def, - } as GetSettingArgs)) as { result: T }; - return res.result; + const res = await window.DeckyPluginLoader.ws.call<[string, T], T>('utilities/settings/get', key, def); + return res; } export async function setSetting(key: string, value: T): Promise { - await window.DeckyPluginLoader.callServerMethod('set_setting', { - key, - value, - } as SetSettingArgs); + await window.DeckyPluginLoader.ws.call<[string, T], void>('utilities/settings/set', key, value); } diff --git a/frontend/src/wsrouter.ts b/frontend/src/wsrouter.ts index b64375689..3a36b5b08 100644 --- a/frontend/src/wsrouter.ts +++ b/frontend/src/wsrouter.ts @@ -41,10 +41,11 @@ interface PromiseResolver { promise: Promise; } -class WSRouter extends Logger { +export class WSRouter extends Logger { routes: Map any> = new Map(); runningCalls: Map> = new Map(); ws?: WebSocket; + connectPromise?: Promise; // Used to map results and errors to calls reqId: number = 0; constructor() { @@ -52,30 +53,35 @@ class WSRouter extends Logger { } connect() { - this.ws = new WebSocket('ws://127.0.0.1:1337/ws'); - - this.ws.addEventListener('message', this.onMessage.bind(this)); - this.ws.addEventListener('close', this.onError.bind(this)); - this.ws.addEventListener('message', this.onError.bind(this)); + return (this.connectPromise = new Promise((resolve) => { + // Auth is a query param as JS WebSocket doesn't support headers + this.ws = new WebSocket(`ws://127.0.0.1:1337/ws?auth=${window.deckyAuthToken}`); + + this.ws.addEventListener('open', () => { + this.debug('WS Connected'); + resolve(); + delete this.connectPromise; + }); + this.ws.addEventListener('message', this.onMessage.bind(this)); + this.ws.addEventListener('close', this.onError.bind(this)); + // this.ws.addEventListener('error', this.onError.bind(this)); + })); } createPromiseResolver(): PromiseResolver { - let resolver: PromiseResolver; + let resolver: Partial> = {}; const promise = new Promise((resolve, reject) => { - resolver = { - promise, - resolve, - reject, - }; - this.debug('Created new PromiseResolver'); + resolver.resolve = resolve; + resolver.reject = reject; }); - this.debug('Returning new PromiseResolver'); + resolver.promise = promise; // The promise will always run first // @ts-expect-error 2454 return resolver; } - write(data: Message) { + async write(data: Message) { + if (this.connectPromise) await this.connectPromise; this.ws?.send(JSON.stringify(data)); } @@ -129,9 +135,9 @@ class WSRouter extends Logger { } catch (e) { this.error('Error parsing WebSocket message', e); } - this.call<[number, number], string>('methodName', 1, 2); } + // this.call<[number, number], string>('methodName', 1, 2); call(route: string, ...args: Args): Promise { const resolver = this.createPromiseResolver(); @@ -139,12 +145,15 @@ class WSRouter extends Logger { this.runningCalls.set(id, resolver); + this.debug(`Calling PY method ${route} with args`, args); + this.write({ type: MessageType.CALL, route, args, id }); return resolver.promise; } - onError(error: any) { - this.error('WS ERROR', error); + async onError(error: any) { + this.error('WS DISCONNECTED', error); + await this.connect(); } }