Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shutdown timeouts #695

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion backend/decky_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,15 @@ async def shutdown_plugins(self):
async def enable_reload_wait(self):
if self.live_reload:
await sleep(10)
if self.watcher:
if self.watcher and self.live_reload:
self.logger.info("Hot reload enabled")
self.watcher.disabled = False

async def disable_reload(self):
if self.watcher:
self.watcher.disabled = True
self.live_reload = False

async def handle_frontend_assets(self, request: web.Request):
file = Path(__file__).parent.joinpath("static").joinpath(request.match_info["path"])
return web.FileResponse(file, headers={"Cache-Control": "no-cache"})
Expand Down
23 changes: 14 additions & 9 deletions backend/decky_loader/localplatform/localsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@
BUFFER_LIMIT = 2 ** 20 # 1 MiB

class UnixSocket:
def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
def __init__(self):
'''
on_new_message takes 1 string argument.
It's return value gets used, if not None, to write data to the socket.
Method should be async
'''
self.socket_addr = f"/tmp/plugin_socket_{time.time()}"
self.on_new_message = on_new_message
self.on_new_message = None
self.socket = None
self.reader = None
self.writer = None
self.server_writer = None
self.open_lock = asyncio.Lock()
self.active = True

async def setup_server(self):
async def setup_server(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
try:
self.on_new_message = on_new_message
self.socket = await asyncio.start_unix_server(self._listen_for_method_call, path=self.socket_addr, limit=BUFFER_LIMIT)
except asyncio.CancelledError:
await self.close_socket_connection()
Expand Down Expand Up @@ -58,6 +60,8 @@ async def close_socket_connection(self):
if self.socket:
self.socket.close()
await self.socket.wait_closed()

self.active = False

async def read_single_line(self) -> str|None:
reader, _ = await self.get_socket_connection()
Expand All @@ -81,7 +85,7 @@ async def write_single_line(self, message : str):

async def _read_single_line(self, reader: asyncio.StreamReader) -> str:
line = bytearray()
while True:
while self.active:
try:
line.extend(await reader.readuntil())
except asyncio.LimitOverrunError:
Expand All @@ -91,7 +95,7 @@ async def _read_single_line(self, reader: asyncio.StreamReader) -> str:
line.extend(err.partial)
break
except asyncio.CancelledError:
break
raise
else:
break

Expand All @@ -111,7 +115,7 @@ async def write_single_line_server(self, message: str):

async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
self.server_writer = writer
while True:
while self.active and self.on_new_message:

def _(task: asyncio.Task[str|None]):
res = task.result()
Expand All @@ -122,18 +126,19 @@ def _(task: asyncio.Task[str|None]):
asyncio.create_task(self.on_new_message(line)).add_done_callback(_)

class PortSocket (UnixSocket):
def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
def __init__(self):
'''
on_new_message takes 1 string argument.
It's return value gets used, if not None, to write data to the socket.
Method should be async
'''
super().__init__(on_new_message)
super().__init__()
self.host = "127.0.0.1"
self.port = random.sample(range(40000, 60000), 1)[0]

async def setup_server(self):
async def setup_server(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
try:
self.on_new_message = on_new_message
self.socket = await asyncio.start_server(self._listen_for_method_call, host=self.host, port=self.port, limit=BUFFER_LIMIT)
except asyncio.CancelledError:
await self.close_socket_connection()
Expand Down
8 changes: 7 additions & 1 deletion backend/decky_loader/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ async def startup(_: Application):
self.web_app.add_routes([static("/static", path.join(path.dirname(__file__), 'static'))])

async def handle_crash(self):
if not self.reinject:
return
new_time = time()
if (new_time - self.last_webhelper_exit < 60):
self.webhelper_crash_count += 1
Expand All @@ -118,9 +120,13 @@ async def handle_crash(self):
async def shutdown(self, _: Application):
try:
logger.info(f"Shutting down...")
logger.info("Disabling reload...")
await self.plugin_loader.disable_reload()
logger.info("Killing plugins...")
await self.plugin_loader.shutdown_plugins()
await self.ws.disconnect()
logger.info("Disconnecting from WS...")
self.reinject = False
await self.ws.disconnect()
if self.js_ctx_tab:
await self.js_ctx_tab.close_websocket()
self.js_ctx_tab = None
Expand Down
56 changes: 34 additions & 22 deletions backend/decky_loader/plugin/plugin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from asyncio import CancelledError, Task, create_task, sleep
from asyncio import CancelledError, Task, create_task, sleep, get_event_loop, wait
from json import dumps, load, loads
from logging import getLogger
from os import path
from multiprocessing import Process
from time import time
from traceback import format_exc

from .sandboxed_plugin import SandboxedPlugin
from .messages import MethodCallRequest, SocketMessageType
Expand Down Expand Up @@ -42,8 +44,7 @@ def __init__(self, file: str, plugin_directory: str, plugin_path: str, emit_call

self.sandboxed_plugin = SandboxedPlugin(self.name, self.passive, self.flags, self.file, self.plugin_directory, self.plugin_path, self.version, self.author, self.api_version)
self.proc: Process | None = None
# TODO: Maybe make LocalSocket not require on_new_message to make this cleaner
self._socket = LocalSocket(self.sandboxed_plugin.on_new_message)
self._socket = LocalSocket()
self._listener_task: Task[Any]
self._method_call_requests: Dict[str, MethodCallRequest] = {}

Expand All @@ -65,7 +66,7 @@ def __str__(self) -> str:
return self.name

async def _response_listener(self):
while True:
while self._socket.active:
try:
line = await self._socket.read_single_line()
if line != None:
Expand Down Expand Up @@ -115,29 +116,40 @@ def start(self):
return self

async def stop(self, uninstall: bool = False):
self.log.info(f"Stopping plugin {self.name}")
if self.passive:
return
if hasattr(self, "_socket"):
await self._socket.write_single_line(dumps({ "stop": True, "uninstall": uninstall }, ensure_ascii=False))
await self._socket.close_socket_connection()
if self.proc:
self.proc.join()
await self.kill_if_still_running()
if hasattr(self, "_listener_task"):
self._listener_task.cancel()
try:
start_time = time()
if self.passive:
return

_, pending = await wait([
create_task(self._socket.write_single_line(dumps({ "stop": True, "uninstall": uninstall }, ensure_ascii=False)))
], timeout=1)

if hasattr(self, "_listener_task"):
self._listener_task.cancel()

await self.kill_if_still_running()

for pending_task in pending:
pending_task.cancel()

self.log.info(f"Plugin {self.name} has been stopped in {time() - start_time:.1f}s")
except Exception as e:
self.log.error(f"Error during shutdown for plugin {self.name}: {str(e)}\n{format_exc()}")

async def kill_if_still_running(self):
time = 0
start_time = time()
sigtermed = False
while self.proc and self.proc.is_alive():
await sleep(0.1)
time += 1
if time == 100:
self.log.warn(f"Plugin {self.name} still alive 10 seconds after stop request! Sending SIGTERM!")
elapsed_time = time() - start_time
if elapsed_time >= 5 and not sigtermed:
sigtermed = True
self.log.warn(f"Plugin {self.name} still alive 5 seconds after stop request! Sending SIGTERM!")
self.terminate()
elif time == 200:
self.log.warn(f"Plugin {self.name} still alive 20 seconds after stop request! Sending SIGKILL!")
elif elapsed_time >= 10:
self.log.warn(f"Plugin {self.name} still alive 10 seconds after stop request! Sending SIGKILL!")
self.terminate(True)
await sleep(0.1)

def terminate(self, kill: bool = False):
if self.proc and self.proc.is_alive():
Expand Down
12 changes: 1 addition & 11 deletions backend/decky_loader/plugin/sandboxed_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
from os import path, environ
from signal import SIG_IGN, SIGINT, SIGTERM, getsignal, signal
from importlib.util import module_from_spec, spec_from_file_location
from json import dumps, loads
from logging import getLogger
Expand All @@ -19,8 +18,6 @@

DataType = TypeVar("DataType")

original_term_handler = getsignal(SIGTERM)

class SandboxedPlugin:
def __init__(self,
name: str,
Expand Down Expand Up @@ -48,11 +45,6 @@ def initialize(self, socket: LocalSocket):
self._socket = socket

try:
# Ignore signals meant for parent Process
# TODO SURELY there's a better way to do this.
signal(SIGINT, SIG_IGN)
signal(SIGTERM, SIG_IGN)

setproctitle(f"{self.name} ({self.file})")
setthreadtitle(self.name)

Expand Down Expand Up @@ -120,7 +112,7 @@ async def emit(event: str, *args: Any) -> None:
get_event_loop().create_task(self.Plugin._main())
else:
get_event_loop().create_task(self.Plugin._main(self.Plugin))
get_event_loop().create_task(socket.setup_server())
get_event_loop().create_task(socket.setup_server(self.on_new_message))
except:
self.log.error("Failed to start " + self.name + "!\n" + format_exc())
sys.exit(0)
Expand Down Expand Up @@ -167,8 +159,6 @@ async def on_new_message(self, message : str) -> str|None:
data = loads(message)

if "stop" in data:
# Incase the loader needs to terminate our process soon
signal(SIGTERM, original_term_handler)
self.log.info(f"Calling Loader unload function for {self.name}.")
await self._unload()

Expand Down
3 changes: 3 additions & 0 deletions dist/install_prerelease.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ curl -L https://raw.githubusercontent.com/SteamDeckHomebrew/decky-loader/main/di
cat > "${HOMEBREW_FOLDER}/services/plugin_loader-backup.service" <<- EOM
[Unit]
Description=SteamDeck Plugin Loader
After=network.target
[Service]
Type=simple
User=root
Restart=always
KillMode=process
TimeoutStopSec=45
ExecStart=${HOMEBREW_FOLDER}/services/PluginLoader
WorkingDirectory=${HOMEBREW_FOLDER}/services
Environment=UNPRIVILEGED_PATH=${HOMEBREW_FOLDER}
Expand Down
3 changes: 3 additions & 0 deletions dist/install_release.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ curl -L https://raw.githubusercontent.com/SteamDeckHomebrew/decky-loader/main/di
cat > "${HOMEBREW_FOLDER}/services/plugin_loader-backup.service" <<- EOM
[Unit]
Description=SteamDeck Plugin Loader
After=network.target
[Service]
Type=simple
User=root
Restart=always
KillMode=process
TimeoutStopSec=45
ExecStart=${HOMEBREW_FOLDER}/services/PluginLoader
WorkingDirectory=${HOMEBREW_FOLDER}/services
Environment=UNPRIVILEGED_PATH=${HOMEBREW_FOLDER}
Expand Down
1 change: 1 addition & 0 deletions dist/plugin_loader-prerelease.service
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ After=network.target
Type=simple
User=root
Restart=always
KillMode=process
TimeoutStopSec=45
ExecStart=${HOMEBREW_FOLDER}/services/PluginLoader
WorkingDirectory=${HOMEBREW_FOLDER}/services
Expand Down
1 change: 1 addition & 0 deletions dist/plugin_loader-release.service
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ After=network.target
Type=simple
User=root
Restart=always
KillMode=process
TimeoutStopSec=45
ExecStart=${HOMEBREW_FOLDER}/services/PluginLoader
WorkingDirectory=${HOMEBREW_FOLDER}/services
Expand Down
Loading