diff --git a/ipykernel/inprocess/socket.py b/ipykernel/inprocess/socket.py index 5a2e0008..d14d0850 100644 --- a/ipykernel/inprocess/socket.py +++ b/ipykernel/inprocess/socket.py @@ -65,4 +65,8 @@ async def poll(self, timeout=0): return statistics.current_buffer_used != 0 def close(self): - pass + if self.is_shell: + self.in_send_stream.close() + self.in_receive_stream.close() + self.out_send_stream.close() + self.out_receive_stream.close() diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 19334212..02a0e22a 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -16,14 +16,16 @@ from binascii import b2a_hex from collections import defaultdict, deque from io import StringIO, TextIOBase -from threading import Event, Thread, local +from threading import local from typing import Any, Callable import zmq import zmq_anyio -from anyio import create_task_group, run, sleep, to_thread +from anyio import sleep from jupyter_client.session import extract_header +from .thread import BaseThread + # ----------------------------------------------------------------------------- # Globals # ----------------------------------------------------------------------------- @@ -38,38 +40,6 @@ # ----------------------------------------------------------------------------- -class _IOPubThread(Thread): - """A thread for a IOPub.""" - - def __init__(self, tasks, **kwargs): - """Initialize the thread.""" - super().__init__(name="IOPub", **kwargs) - self._tasks = tasks - self.pydev_do_not_trace = True - self.is_pydev_daemon_thread = True - self.daemon = True - self.__stop = Event() - - def run(self): - """Run the thread.""" - self.name = "IOPub" - run(self._main) - - async def _main(self): - async with create_task_group() as self._task_group: - for task in self._tasks: - self._task_group.start_soon(task) - await to_thread.run_sync(self.__stop.wait) - self._task_group.cancel_scope.cancel() - - def stop(self): - """Stop the thread. - - This method is threadsafe. - """ - self.__stop.set() - - class IOPubThread: """An object for sending IOPub messages in a background thread @@ -109,7 +79,9 @@ def __init__(self, socket: zmq_anyio.Socket, pipe=False): tasks = [self._handle_event, self._run_event_pipe_gc, self.socket.start] if pipe: tasks.append(self._handle_pipe_msgs) - self.thread = _IOPubThread(tasks) + self.thread = BaseThread(name="IOPub", daemon=True) + for task in tasks: + self.thread.start_soon(task) def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" @@ -179,7 +151,7 @@ async def _handle_event(self): event_f = self._events.popleft() event_f() except Exception: - if self.thread.__stop.is_set(): + if self.thread.stopped.is_set(): return raise @@ -211,7 +183,7 @@ async def _handle_pipe_msgs(self): while True: await self._handle_pipe_msg() except Exception: - if self.thread.__stop.is_set(): + if self.thread.stopped.is_set(): return raise diff --git a/ipykernel/thread.py b/ipykernel/thread.py index df8fa412..4c9edf86 100644 --- a/ipykernel/thread.py +++ b/ipykernel/thread.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable from queue import Queue -from threading import Thread +from threading import Event, Thread from typing import Callable from anyio import create_task_group, run, to_thread @@ -18,6 +18,8 @@ class BaseThread(Thread): def __init__(self, **kwargs): """Initialize the thread.""" super().__init__(**kwargs) + self.started = Event() + self.stopped = Event() self.pydev_do_not_trace = True self.is_pydev_daemon_thread = True self._tasks: Queue[Callable[[], Awaitable[None]] | None] = Queue() @@ -31,6 +33,7 @@ def run(self) -> None: async def _main(self) -> None: async with create_task_group() as tg: + self.started.set() while True: task = await to_thread.run_sync(self._tasks.get) if task is None: @@ -44,3 +47,4 @@ def stop(self) -> None: This method is threadsafe. """ self._tasks.put(None) + self.stopped.set()