From 016ed6e998de25c3a2d5caf119b4489c281b3ba5 Mon Sep 17 00:00:00 2001 From: Sims <38142618+suchmememanyskill@users.noreply.github.com> Date: Sun, 1 Sep 2024 20:15:49 +0200 Subject: [PATCH] Fix shutdown timeouts (#695) Co-authored-by: AAGaming --- backend/decky_loader/loader.py | 7 ++- .../decky_loader/localplatform/localsocket.py | 23 +++++--- backend/decky_loader/main.py | 8 ++- backend/decky_loader/plugin/plugin.py | 56 +++++++++++-------- .../decky_loader/plugin/sandboxed_plugin.py | 12 +--- dist/install_prerelease.sh | 3 + dist/install_release.sh | 3 + dist/plugin_loader-prerelease.service | 1 + dist/plugin_loader-release.service | 1 + 9 files changed, 70 insertions(+), 44 deletions(-) diff --git a/backend/decky_loader/loader.py b/backend/decky_loader/loader.py index fcd363464..6a324e238 100644 --- a/backend/decky_loader/loader.py +++ b/backend/decky_loader/loader.py @@ -104,10 +104,15 @@ async def shutdown_plugins(self): async def enable_reload_wait(self): if self.live_reload: await sleep(10) - if self.watcher: + if self.watcher and self.live_reload: self.logger.info("Hot reload enabled") self.watcher.disabled = False + async def disable_reload(self): + if self.watcher: + self.watcher.disabled = True + self.live_reload = False + async def handle_frontend_assets(self, request: web.Request): file = Path(__file__).parent.joinpath("static").joinpath(request.match_info["path"]) return web.FileResponse(file, headers={"Cache-Control": "no-cache"}) diff --git a/backend/decky_loader/localplatform/localsocket.py b/backend/decky_loader/localplatform/localsocket.py index b25b275a5..74e654069 100644 --- a/backend/decky_loader/localplatform/localsocket.py +++ b/backend/decky_loader/localplatform/localsocket.py @@ -7,22 +7,24 @@ BUFFER_LIMIT = 2 ** 20 # 1 MiB class UnixSocket: - def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]): + def __init__(self): ''' on_new_message takes 1 string argument. It's return value gets used, if not None, to write data to the socket. Method should be async ''' self.socket_addr = f"/tmp/plugin_socket_{time.time()}" - self.on_new_message = on_new_message + self.on_new_message = None self.socket = None self.reader = None self.writer = None self.server_writer = None self.open_lock = asyncio.Lock() + self.active = True - async def setup_server(self): + async def setup_server(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]): try: + self.on_new_message = on_new_message self.socket = await asyncio.start_unix_server(self._listen_for_method_call, path=self.socket_addr, limit=BUFFER_LIMIT) except asyncio.CancelledError: await self.close_socket_connection() @@ -58,6 +60,8 @@ async def close_socket_connection(self): if self.socket: self.socket.close() await self.socket.wait_closed() + + self.active = False async def read_single_line(self) -> str|None: reader, _ = await self.get_socket_connection() @@ -81,7 +85,7 @@ async def write_single_line(self, message : str): async def _read_single_line(self, reader: asyncio.StreamReader) -> str: line = bytearray() - while True: + while self.active: try: line.extend(await reader.readuntil()) except asyncio.LimitOverrunError: @@ -91,7 +95,7 @@ async def _read_single_line(self, reader: asyncio.StreamReader) -> str: line.extend(err.partial) break except asyncio.CancelledError: - break + raise else: break @@ -111,7 +115,7 @@ async def write_single_line_server(self, message: str): async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): self.server_writer = writer - while True: + while self.active and self.on_new_message: def _(task: asyncio.Task[str|None]): res = task.result() @@ -122,18 +126,19 @@ def _(task: asyncio.Task[str|None]): asyncio.create_task(self.on_new_message(line)).add_done_callback(_) class PortSocket (UnixSocket): - def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]): + def __init__(self): ''' on_new_message takes 1 string argument. It's return value gets used, if not None, to write data to the socket. Method should be async ''' - super().__init__(on_new_message) + super().__init__() self.host = "127.0.0.1" self.port = random.sample(range(40000, 60000), 1)[0] - async def setup_server(self): + async def setup_server(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]): try: + self.on_new_message = on_new_message self.socket = await asyncio.start_server(self._listen_for_method_call, host=self.host, port=self.port, limit=BUFFER_LIMIT) except asyncio.CancelledError: await self.close_socket_connection() diff --git a/backend/decky_loader/main.py b/backend/decky_loader/main.py index c268b387e..983d3dca2 100644 --- a/backend/decky_loader/main.py +++ b/backend/decky_loader/main.py @@ -101,6 +101,8 @@ async def startup(_: Application): self.web_app.add_routes([static("/static", path.join(path.dirname(__file__), 'static'))]) async def handle_crash(self): + if not self.reinject: + return new_time = time() if (new_time - self.last_webhelper_exit < 60): self.webhelper_crash_count += 1 @@ -118,9 +120,13 @@ async def handle_crash(self): async def shutdown(self, _: Application): try: logger.info(f"Shutting down...") + logger.info("Disabling reload...") + await self.plugin_loader.disable_reload() + logger.info("Killing plugins...") await self.plugin_loader.shutdown_plugins() - await self.ws.disconnect() + logger.info("Disconnecting from WS...") self.reinject = False + await self.ws.disconnect() if self.js_ctx_tab: await self.js_ctx_tab.close_websocket() self.js_ctx_tab = None diff --git a/backend/decky_loader/plugin/plugin.py b/backend/decky_loader/plugin/plugin.py index a9a9ce292..e22eca3f5 100644 --- a/backend/decky_loader/plugin/plugin.py +++ b/backend/decky_loader/plugin/plugin.py @@ -1,8 +1,10 @@ -from asyncio import CancelledError, Task, create_task, sleep +from asyncio import CancelledError, Task, create_task, sleep, get_event_loop, wait from json import dumps, load, loads from logging import getLogger from os import path from multiprocessing import Process +from time import time +from traceback import format_exc from .sandboxed_plugin import SandboxedPlugin from .messages import MethodCallRequest, SocketMessageType @@ -42,8 +44,7 @@ def __init__(self, file: str, plugin_directory: str, plugin_path: str, emit_call self.sandboxed_plugin = SandboxedPlugin(self.name, self.passive, self.flags, self.file, self.plugin_directory, self.plugin_path, self.version, self.author, self.api_version) self.proc: Process | None = None - # TODO: Maybe make LocalSocket not require on_new_message to make this cleaner - self._socket = LocalSocket(self.sandboxed_plugin.on_new_message) + self._socket = LocalSocket() self._listener_task: Task[Any] self._method_call_requests: Dict[str, MethodCallRequest] = {} @@ -65,7 +66,7 @@ def __str__(self) -> str: return self.name async def _response_listener(self): - while True: + while self._socket.active: try: line = await self._socket.read_single_line() if line != None: @@ -115,29 +116,40 @@ def start(self): return self async def stop(self, uninstall: bool = False): - self.log.info(f"Stopping plugin {self.name}") - if self.passive: - return - if hasattr(self, "_socket"): - await self._socket.write_single_line(dumps({ "stop": True, "uninstall": uninstall }, ensure_ascii=False)) - await self._socket.close_socket_connection() - if self.proc: - self.proc.join() - await self.kill_if_still_running() - if hasattr(self, "_listener_task"): - self._listener_task.cancel() + try: + start_time = time() + if self.passive: + return + + _, pending = await wait([ + create_task(self._socket.write_single_line(dumps({ "stop": True, "uninstall": uninstall }, ensure_ascii=False))) + ], timeout=1) + + if hasattr(self, "_listener_task"): + self._listener_task.cancel() + + await self.kill_if_still_running() + + for pending_task in pending: + pending_task.cancel() + + self.log.info(f"Plugin {self.name} has been stopped in {time() - start_time:.1f}s") + except Exception as e: + self.log.error(f"Error during shutdown for plugin {self.name}: {str(e)}\n{format_exc()}") async def kill_if_still_running(self): - time = 0 + start_time = time() + sigtermed = False while self.proc and self.proc.is_alive(): - await sleep(0.1) - time += 1 - if time == 100: - self.log.warn(f"Plugin {self.name} still alive 10 seconds after stop request! Sending SIGTERM!") + elapsed_time = time() - start_time + if elapsed_time >= 5 and not sigtermed: + sigtermed = True + self.log.warn(f"Plugin {self.name} still alive 5 seconds after stop request! Sending SIGTERM!") self.terminate() - elif time == 200: - self.log.warn(f"Plugin {self.name} still alive 20 seconds after stop request! Sending SIGKILL!") + elif elapsed_time >= 10: + self.log.warn(f"Plugin {self.name} still alive 10 seconds after stop request! Sending SIGKILL!") self.terminate(True) + await sleep(0.1) def terminate(self, kill: bool = False): if self.proc and self.proc.is_alive(): diff --git a/backend/decky_loader/plugin/sandboxed_plugin.py b/backend/decky_loader/plugin/sandboxed_plugin.py index 93691a446..23575900f 100644 --- a/backend/decky_loader/plugin/sandboxed_plugin.py +++ b/backend/decky_loader/plugin/sandboxed_plugin.py @@ -1,6 +1,5 @@ import sys from os import path, environ -from signal import SIG_IGN, SIGINT, SIGTERM, getsignal, signal from importlib.util import module_from_spec, spec_from_file_location from json import dumps, loads from logging import getLogger @@ -19,8 +18,6 @@ DataType = TypeVar("DataType") -original_term_handler = getsignal(SIGTERM) - class SandboxedPlugin: def __init__(self, name: str, @@ -48,11 +45,6 @@ def initialize(self, socket: LocalSocket): self._socket = socket try: - # Ignore signals meant for parent Process - # TODO SURELY there's a better way to do this. - signal(SIGINT, SIG_IGN) - signal(SIGTERM, SIG_IGN) - setproctitle(f"{self.name} ({self.file})") setthreadtitle(self.name) @@ -120,7 +112,7 @@ async def emit(event: str, *args: Any) -> None: get_event_loop().create_task(self.Plugin._main()) else: get_event_loop().create_task(self.Plugin._main(self.Plugin)) - get_event_loop().create_task(socket.setup_server()) + get_event_loop().create_task(socket.setup_server(self.on_new_message)) except: self.log.error("Failed to start " + self.name + "!\n" + format_exc()) sys.exit(0) @@ -167,8 +159,6 @@ async def on_new_message(self, message : str) -> str|None: data = loads(message) if "stop" in data: - # Incase the loader needs to terminate our process soon - signal(SIGTERM, original_term_handler) self.log.info(f"Calling Loader unload function for {self.name}.") await self._unload() diff --git a/dist/install_prerelease.sh b/dist/install_prerelease.sh index 950c25aae..9e5ce9cc3 100644 --- a/dist/install_prerelease.sh +++ b/dist/install_prerelease.sh @@ -34,10 +34,13 @@ curl -L https://raw.githubusercontent.com/SteamDeckHomebrew/decky-loader/main/di cat > "${HOMEBREW_FOLDER}/services/plugin_loader-backup.service" <<- EOM [Unit] Description=SteamDeck Plugin Loader +After=network.target [Service] Type=simple User=root Restart=always +KillMode=process +TimeoutStopSec=45 ExecStart=${HOMEBREW_FOLDER}/services/PluginLoader WorkingDirectory=${HOMEBREW_FOLDER}/services Environment=UNPRIVILEGED_PATH=${HOMEBREW_FOLDER} diff --git a/dist/install_release.sh b/dist/install_release.sh index 46a478675..61f85488f 100644 --- a/dist/install_release.sh +++ b/dist/install_release.sh @@ -34,10 +34,13 @@ curl -L https://raw.githubusercontent.com/SteamDeckHomebrew/decky-loader/main/di cat > "${HOMEBREW_FOLDER}/services/plugin_loader-backup.service" <<- EOM [Unit] Description=SteamDeck Plugin Loader +After=network.target [Service] Type=simple User=root Restart=always +KillMode=process +TimeoutStopSec=45 ExecStart=${HOMEBREW_FOLDER}/services/PluginLoader WorkingDirectory=${HOMEBREW_FOLDER}/services Environment=UNPRIVILEGED_PATH=${HOMEBREW_FOLDER} diff --git a/dist/plugin_loader-prerelease.service b/dist/plugin_loader-prerelease.service index 594925dd4..78970909d 100644 --- a/dist/plugin_loader-prerelease.service +++ b/dist/plugin_loader-prerelease.service @@ -5,6 +5,7 @@ After=network.target Type=simple User=root Restart=always +KillMode=process TimeoutStopSec=45 ExecStart=${HOMEBREW_FOLDER}/services/PluginLoader WorkingDirectory=${HOMEBREW_FOLDER}/services diff --git a/dist/plugin_loader-release.service b/dist/plugin_loader-release.service index 6f94d4e10..d8f69dcaa 100644 --- a/dist/plugin_loader-release.service +++ b/dist/plugin_loader-release.service @@ -5,6 +5,7 @@ After=network.target Type=simple User=root Restart=always +KillMode=process TimeoutStopSec=45 ExecStart=${HOMEBREW_FOLDER}/services/PluginLoader WorkingDirectory=${HOMEBREW_FOLDER}/services