Skip to content

Commit

Permalink
more progress on websockets
Browse files Browse the repository at this point in the history
  • Loading branch information
AAGaming00 committed Nov 10, 2023
1 parent 6bb56cb commit be9f8bc
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 58 deletions.
2 changes: 1 addition & 1 deletion backend/src/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, server_instance: PluginManager, ws: WSRouter, plugin_path: st
self.observer.start()
self.loop.create_task(self.enable_reload_wait())

server_instance.add_routes([
server_instance.web_app.add_routes([
web.get("/frontend/{path:.*}", self.handle_frontend_assets),
web.get("/locales/{path:.*}", self.handle_frontend_locales),
web.get("/plugins", self.get_plugins),
Expand Down
4 changes: 2 additions & 2 deletions backend/src/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def inject_css_into_tab(self, tab: str, style: str) -> str:
style.textContent = `{style}`;
}})()
""", False)

assert result is not None # TODO remove this once it has proper typings
if "exceptionDetails" in result["result"]:
raise result["result"]["exceptionDetails"]

Expand Down Expand Up @@ -233,7 +233,7 @@ async def filepicker_ls(self,
folders.append({"file": file, "filest": filest, "is_dir": True})
elif include_files:
# Handle requested extensions if present
if len(include_ext) == 0 or 'all_files' in include_ext \
if include_ext == None or len(include_ext) == 0 or 'all_files' in include_ext \
or splitext(file.name)[1].lstrip('.') in include_ext:
if (is_hidden and include_hidden) or not is_hidden:
files.append({"file": file, "filest": filest, "is_dir": False})
Expand Down
104 changes: 57 additions & 47 deletions backend/src/wsrouter.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,90 @@
from logging import getLogger

from asyncio import AbstractEventLoop, Future
from asyncio import AbstractEventLoop, Future, create_task

from aiohttp import WSMsgType
from aiohttp import WSMsgType, WSMessage
from aiohttp.web import Application, WebSocketResponse, Request, Response, get

from enum import Enum
from enum import IntEnum

from typing import Dict
from typing import Callable, Dict, Any, cast, TypeVar, Type
from dataclasses import dataclass

from traceback import format_exc

from helpers import get_csrf_token

class MessageType(Enum):
# Call-reply
class MessageType(IntEnum):
ERROR = -1
# Call-reply, Frontend -> Backend
CALL = 0
REPLY = 1
ERROR = 2
# # Pub/sub
# SUBSCRIBE = 3
# UNSUBSCRIBE = 4
# PUBLISH = 5
# Pub/Sub, Backend -> Frontend
EVENT = 3

# WSMessage with slightly better typings
class WSMessageExtra(WSMessage):
data: Any
type: WSMsgType
@dataclass
class Message:
data: Any
type: MessageType

# @dataclass
# class CallMessage

# see wsrouter.ts for typings

DataType = TypeVar("DataType")

Route = Callable[..., Future[Any]]

class WSRouter:
def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None:
self.loop = loop
self.ws = None
self.req_id = 0
self.routes = {}
self.running_calls: Dict[int, Future] = {}
self.ws: WebSocketResponse | None
self.instance_id = 0
self.routes: Dict[str, Route] = {}
# self.subscriptions: Dict[str, Callable[[Any]]] = {}
self.logger = getLogger("WSRouter")

server_instance.add_routes([
get("/ws", self.handle)
])

async def write(self, dta: Dict[str, any]):
await self.ws.send_json(dta)
async def write(self, data: Dict[str, Any]):
if self.ws != None:
await self.ws.send_json(data)
else:
self.logger.warn("Dropping message as there is no connected socket: %s", data)

def add_route(self, name: str, route):
def add_route(self, name: str, route: Route):
self.routes[name] = route

def remove_route(self, name: str):
del self.routes[name]

async def _call_route(self, route: str, args: ..., call_id: int):
instance_id = self.instance_id
res = await self.routes[route](*args)
if instance_id != self.instance_id:
try:
self.logger.warn("Ignoring %s reply from stale instance %d with args %s and response %s", route, instance_id, args, res)
except:
self.logger.warn("Ignoring %s reply from stale instance %d (failed to log event data)", route, instance_id)
finally:
return
await self.write({"type": MessageType.REPLY.value, "id": call_id, "result": res})

async def handle(self, request: Request):
# Auth is a query param as JS WebSocket doesn't support headers
if request.rel_url.query["auth"] != get_csrf_token():
return Response(text='Forbidden', status='403')
return Response(text='Forbidden', status=403)
self.logger.debug('Websocket connection starting')
ws = WebSocketResponse()
await ws.prepare(request)
self.instance_id += 1
self.logger.debug('Websocket connection ready')

if self.ws != None:
Expand All @@ -68,6 +98,8 @@ async def handle(self, request: Request):

try:
async for msg in ws:
msg = cast(WSMessageExtra, msg)

self.logger.debug(msg)
if msg.type == WSMsgType.TEXT:
self.logger.debug(msg.data)
Expand All @@ -81,25 +113,13 @@ async def handle(self, request: Request):
# do stuff with the message
if data["route"] in self.routes:
try:
res = await self.routes[data["route"]](*data["args"])
await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res})
self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}')
create_task(self._call_route(data["route"], data["args"], data["id"]))
except:
await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})
create_task(self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()}))
else:
# Dunno why but fstring doesnt work here
await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route " + data["route"] + " does not exist."})
case MessageType.REPLY.value:
if self.running_calls[data["id"]]:
self.running_calls[data["id"]].set_result(data["result"])
del self.running_calls[data["id"]]
self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}')
case MessageType.ERROR.value:
if self.running_calls[data["id"]]:
self.running_calls[data["id"]].set_exception(data["error"])
del self.running_calls[data["id"]]
self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}')

create_task(self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route " + data["route"] + " does not exist."}))
case _:
self.logger.error("Unknown message type", data)
finally:
Expand All @@ -112,17 +132,7 @@ async def handle(self, request: Request):
self.logger.debug('Websocket connection closed')
return ws

async def call(self, route: str, *args):
future = Future()

self.req_id += 1

id = self.req_id

self.running_calls[id] = future

self.logger.debug(f'Calling JS method {route} with args {str(args)}')

self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id })
async def emit(self, event: str, data: DataType | None = None, data_type: Type[DataType] = Any):
self.logger.debug('Firing frontend event %s with args %s', data)

return await future
await self.write({ "type": MessageType.EVENT.value, "event": event, "data": data })
14 changes: 6 additions & 8 deletions frontend/src/wsrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ declare global {
}

enum MessageType {
// Call-reply
CALL,
REPLY,
ERROR,
// Pub/sub
// SUBSCRIBE,
// UNSUBSCRIBE,
// PUBLISH
ERROR = -1,
// Call-reply, Frontend -> Backend
CALL = 0,
REPLY = 1,
// Pub/Sub, Backend -> Frontend
EVENT = 3,
}

interface CallMessage {
Expand Down

0 comments on commit be9f8bc

Please sign in to comment.