diff --git a/backend/decky_loader/loader.py b/backend/decky_loader/loader.py index fa4949f2a..7f81777f8 100644 --- a/backend/decky_loader/loader.py +++ b/backend/decky_loader/loader.py @@ -88,7 +88,7 @@ def __init__(self, server_instance: PluginManager, ws: WSRouter, plugin_path: st self.observer.start() self.loop.create_task(self.enable_reload_wait()) - server_instance.add_routes([ + server_instance.web_app.add_routes([ web.get("/frontend/{path:.*}", self.handle_frontend_assets), web.get("/locales/{path:.*}", self.handle_frontend_locales), web.get("/plugins", self.get_plugins), diff --git a/backend/decky_loader/utilities.py b/backend/decky_loader/utilities.py index 2eea63ea7..0e3e9fb0a 100644 --- a/backend/decky_loader/utilities.py +++ b/backend/decky_loader/utilities.py @@ -166,7 +166,7 @@ async def inject_css_into_tab(self, tab: str, style: str) -> str: style.textContent = `{style}`; }})() """, False) - + assert result is not None # TODO remove this once it has proper typings if "exceptionDetails" in result["result"]: raise result["result"]["exceptionDetails"] @@ -233,7 +233,7 @@ async def filepicker_ls(self, folders.append({"file": file, "filest": filest, "is_dir": True}) elif include_files: # Handle requested extensions if present - if len(include_ext) == 0 or 'all_files' in include_ext \ + if include_ext == None or len(include_ext) == 0 or 'all_files' in include_ext \ or splitext(file.name)[1].lstrip('.') in include_ext: if (is_hidden and include_hidden) or not is_hidden: files.append({"file": file, "filest": filest, "is_dir": False}) diff --git a/backend/decky_loader/wsrouter.py b/backend/decky_loader/wsrouter.py index 7a7b59c92..28e3e9258 100644 --- a/backend/decky_loader/wsrouter.py +++ b/backend/decky_loader/wsrouter.py @@ -1,37 +1,51 @@ from logging import getLogger -from asyncio import AbstractEventLoop, Future +from asyncio import AbstractEventLoop, Future, create_task -from aiohttp import WSMsgType +from aiohttp import WSMsgType, WSMessage from aiohttp.web import Application, WebSocketResponse, Request, Response, get -from enum import Enum +from enum import IntEnum -from typing import Dict +from typing import Callable, Dict, Any, cast, TypeVar, Type +from dataclasses import dataclass from traceback import format_exc from helpers import get_csrf_token -class MessageType(Enum): - # Call-reply +class MessageType(IntEnum): + ERROR = -1 + # Call-reply, Frontend -> Backend CALL = 0 REPLY = 1 - ERROR = 2 - # # Pub/sub - # SUBSCRIBE = 3 - # UNSUBSCRIBE = 4 - # PUBLISH = 5 + # Pub/Sub, Backend -> Frontend + EVENT = 3 + +# WSMessage with slightly better typings +class WSMessageExtra(WSMessage): + data: Any + type: WSMsgType +@dataclass +class Message: + data: Any + type: MessageType + +# @dataclass +# class CallMessage # see wsrouter.ts for typings +DataType = TypeVar("DataType") + +Route = Callable[..., Future[Any]] + class WSRouter: def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None: self.loop = loop - self.ws = None - self.req_id = 0 - self.routes = {} - self.running_calls: Dict[int, Future] = {} + self.ws: WebSocketResponse | None + self.instance_id = 0 + self.routes: Dict[str, Route] = {} # self.subscriptions: Dict[str, Callable[[Any]]] = {} self.logger = getLogger("WSRouter") @@ -39,22 +53,38 @@ def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> Non get("/ws", self.handle) ]) - async def write(self, dta: Dict[str, any]): - await self.ws.send_json(dta) + async def write(self, data: Dict[str, Any]): + if self.ws != None: + await self.ws.send_json(data) + else: + self.logger.warn("Dropping message as there is no connected socket: %s", data) - def add_route(self, name: str, route): + def add_route(self, name: str, route: Route): self.routes[name] = route def remove_route(self, name: str): del self.routes[name] + async def _call_route(self, route: str, args: ..., call_id: int): + instance_id = self.instance_id + res = await self.routes[route](*args) + if instance_id != self.instance_id: + try: + self.logger.warn("Ignoring %s reply from stale instance %d with args %s and response %s", route, instance_id, args, res) + except: + self.logger.warn("Ignoring %s reply from stale instance %d (failed to log event data)", route, instance_id) + finally: + return + await self.write({"type": MessageType.REPLY.value, "id": call_id, "result": res}) + async def handle(self, request: Request): # Auth is a query param as JS WebSocket doesn't support headers if request.rel_url.query["auth"] != get_csrf_token(): - return Response(text='Forbidden', status='403') + return Response(text='Forbidden', status=403) self.logger.debug('Websocket connection starting') ws = WebSocketResponse() await ws.prepare(request) + self.instance_id += 1 self.logger.debug('Websocket connection ready') if self.ws != None: @@ -68,6 +98,8 @@ async def handle(self, request: Request): try: async for msg in ws: + msg = cast(WSMessageExtra, msg) + self.logger.debug(msg) if msg.type == WSMsgType.TEXT: self.logger.debug(msg.data) @@ -81,25 +113,13 @@ async def handle(self, request: Request): # do stuff with the message if data["route"] in self.routes: try: - res = await self.routes[data["route"]](*data["args"]) - await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res}) self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}') + create_task(self._call_route(data["route"], data["args"], data["id"])) except: - await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()}) + create_task(self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})) else: # Dunno why but fstring doesnt work here - await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route " + data["route"] + " does not exist."}) - case MessageType.REPLY.value: - if self.running_calls[data["id"]]: - self.running_calls[data["id"]].set_result(data["result"]) - del self.running_calls[data["id"]] - self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}') - case MessageType.ERROR.value: - if self.running_calls[data["id"]]: - self.running_calls[data["id"]].set_exception(data["error"]) - del self.running_calls[data["id"]] - self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}') - + create_task(self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route " + data["route"] + " does not exist."})) case _: self.logger.error("Unknown message type", data) finally: @@ -112,17 +132,7 @@ async def handle(self, request: Request): self.logger.debug('Websocket connection closed') return ws - async def call(self, route: str, *args): - future = Future() - - self.req_id += 1 - - id = self.req_id - - self.running_calls[id] = future - - self.logger.debug(f'Calling JS method {route} with args {str(args)}') - - self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id }) + async def emit(self, event: str, data: DataType | None = None, data_type: Type[DataType] = Any): + self.logger.debug('Firing frontend event %s with args %s', data) - return await future + await self.write({ "type": MessageType.EVENT.value, "event": event, "data": data }) \ No newline at end of file diff --git a/frontend/src/wsrouter.ts b/frontend/src/wsrouter.ts index e50b06a70..e1d766c3b 100644 --- a/frontend/src/wsrouter.ts +++ b/frontend/src/wsrouter.ts @@ -7,14 +7,12 @@ declare global { } enum MessageType { - // Call-reply - CALL, - REPLY, - ERROR, - // Pub/sub - // SUBSCRIBE, - // UNSUBSCRIBE, - // PUBLISH + ERROR = -1, + // Call-reply, Frontend -> Backend + CALL = 0, + REPLY = 1, + // Pub/Sub, Backend -> Frontend + EVENT = 3, } interface CallMessage {