Skip to content

Commit

Permalink
Cleanup LocalSocket on_new_message
Browse files Browse the repository at this point in the history
  • Loading branch information
suchmememanyskill committed Sep 1, 2024
1 parent 850253f commit 3f51a93
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
16 changes: 9 additions & 7 deletions backend/decky_loader/localplatform/localsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,24 @@
BUFFER_LIMIT = 2 ** 20 # 1 MiB

class UnixSocket:
def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
def __init__(self):
'''
on_new_message takes 1 string argument.
It's return value gets used, if not None, to write data to the socket.
Method should be async
'''
self.socket_addr = f"/tmp/plugin_socket_{time.time()}"
self.on_new_message = on_new_message
self.on_new_message = None
self.socket = None
self.reader = None
self.writer = None
self.server_writer = None
self.open_lock = asyncio.Lock()
self.active = True

async def setup_server(self):
async def setup_server(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
try:
self.on_new_message = on_new_message
self.socket = await asyncio.start_unix_server(self._listen_for_method_call, path=self.socket_addr, limit=BUFFER_LIMIT)
except asyncio.CancelledError:
await self.close_socket_connection()
Expand Down Expand Up @@ -114,7 +115,7 @@ async def write_single_line_server(self, message: str):

async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
self.server_writer = writer
while True:
while self.active and self.on_new_message:

def _(task: asyncio.Task[str|None]):
res = task.result()
Expand All @@ -125,18 +126,19 @@ def _(task: asyncio.Task[str|None]):
asyncio.create_task(self.on_new_message(line)).add_done_callback(_)

class PortSocket (UnixSocket):
def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
def __init__(self):
'''
on_new_message takes 1 string argument.
It's return value gets used, if not None, to write data to the socket.
Method should be async
'''
super().__init__(on_new_message)
super().__init__()
self.host = "127.0.0.1"
self.port = random.sample(range(40000, 60000), 1)[0]

async def setup_server(self):
async def setup_server(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
try:
self.on_new_message = on_new_message
self.socket = await asyncio.start_server(self._listen_for_method_call, host=self.host, port=self.port, limit=BUFFER_LIMIT)
except asyncio.CancelledError:
await self.close_socket_connection()
Expand Down
5 changes: 2 additions & 3 deletions backend/decky_loader/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def __init__(self, file: str, plugin_directory: str, plugin_path: str, emit_call

self.sandboxed_plugin = SandboxedPlugin(self.name, self.passive, self.flags, self.file, self.plugin_directory, self.plugin_path, self.version, self.author, self.api_version)
self.proc: Process | None = None
# TODO: Maybe make LocalSocket not require on_new_message to make this cleaner
self._socket = LocalSocket(self.sandboxed_plugin.on_new_message)
self._socket = LocalSocket()
self._listener_task: Task[Any]
self._method_call_requests: Dict[str, MethodCallRequest] = {}

Expand Down Expand Up @@ -147,7 +146,6 @@ async def kill_if_still_running(self):
start_time = time()
sigtermed = False
while self.proc and self.proc.is_alive():
await sleep(0.1)
elapsed_time = time() - start_time
if elapsed_time >= 2 and not sigtermed:
sigtermed = True
Expand All @@ -156,6 +154,7 @@ async def kill_if_still_running(self):
elif elapsed_time >= 5:
self.log.warn(f"Plugin {self.name} still alive 5 seconds after stop request! Sending SIGKILL!")
self.terminate(True)
await sleep(0.1)

def terminate(self, kill: bool = False):
if self.proc and self.proc.is_alive():
Expand Down

0 comments on commit 3f51a93

Please sign in to comment.