diff --git a/backend/decky_loader/localplatform/localsocket.py b/backend/decky_loader/localplatform/localsocket.py index 7fb4fef5..74e65406 100644 --- a/backend/decky_loader/localplatform/localsocket.py +++ b/backend/decky_loader/localplatform/localsocket.py @@ -7,14 +7,14 @@ BUFFER_LIMIT = 2 ** 20 # 1 MiB class UnixSocket: - def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]): + def __init__(self): ''' on_new_message takes 1 string argument. It's return value gets used, if not None, to write data to the socket. Method should be async ''' self.socket_addr = f"/tmp/plugin_socket_{time.time()}" - self.on_new_message = on_new_message + self.on_new_message = None self.socket = None self.reader = None self.writer = None @@ -22,8 +22,9 @@ def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]): self.open_lock = asyncio.Lock() self.active = True - async def setup_server(self): + async def setup_server(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]): try: + self.on_new_message = on_new_message self.socket = await asyncio.start_unix_server(self._listen_for_method_call, path=self.socket_addr, limit=BUFFER_LIMIT) except asyncio.CancelledError: await self.close_socket_connection() @@ -114,7 +115,7 @@ async def write_single_line_server(self, message: str): async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): self.server_writer = writer - while True: + while self.active and self.on_new_message: def _(task: asyncio.Task[str|None]): res = task.result() @@ -125,18 +126,19 @@ def _(task: asyncio.Task[str|None]): asyncio.create_task(self.on_new_message(line)).add_done_callback(_) class PortSocket (UnixSocket): - def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]): + def __init__(self): ''' on_new_message takes 1 string argument. It's return value gets used, if not None, to write data to the socket. Method should be async ''' - super().__init__(on_new_message) + super().__init__() self.host = "127.0.0.1" self.port = random.sample(range(40000, 60000), 1)[0] - async def setup_server(self): + async def setup_server(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]): try: + self.on_new_message = on_new_message self.socket = await asyncio.start_server(self._listen_for_method_call, host=self.host, port=self.port, limit=BUFFER_LIMIT) except asyncio.CancelledError: await self.close_socket_connection() diff --git a/backend/decky_loader/plugin/plugin.py b/backend/decky_loader/plugin/plugin.py index ed6d7da2..050a16a9 100644 --- a/backend/decky_loader/plugin/plugin.py +++ b/backend/decky_loader/plugin/plugin.py @@ -44,8 +44,7 @@ def __init__(self, file: str, plugin_directory: str, plugin_path: str, emit_call self.sandboxed_plugin = SandboxedPlugin(self.name, self.passive, self.flags, self.file, self.plugin_directory, self.plugin_path, self.version, self.author, self.api_version) self.proc: Process | None = None - # TODO: Maybe make LocalSocket not require on_new_message to make this cleaner - self._socket = LocalSocket(self.sandboxed_plugin.on_new_message) + self._socket = LocalSocket() self._listener_task: Task[Any] self._method_call_requests: Dict[str, MethodCallRequest] = {} @@ -147,7 +146,6 @@ async def kill_if_still_running(self): start_time = time() sigtermed = False while self.proc and self.proc.is_alive(): - await sleep(0.1) elapsed_time = time() - start_time if elapsed_time >= 2 and not sigtermed: sigtermed = True @@ -156,6 +154,7 @@ async def kill_if_still_running(self): elif elapsed_time >= 5: self.log.warn(f"Plugin {self.name} still alive 5 seconds after stop request! Sending SIGKILL!") self.terminate(True) + await sleep(0.1) def terminate(self, kill: bool = False): if self.proc and self.proc.is_alive():