diff --git a/examples/embedding/inprocess_terminal.py b/examples/embedding/inprocess_terminal.py index b644c94a..c951859e 100644 --- a/examples/embedding/inprocess_terminal.py +++ b/examples/embedding/inprocess_terminal.py @@ -1,8 +1,7 @@ """An in-process terminal example.""" import os -import sys -import tornado +from anyio import run from jupyter_console.ptshell import ZMQTerminalInteractiveShell from ipykernel.inprocess.manager import InProcessKernelManager @@ -13,46 +12,15 @@ def print_process_id(): print("Process ID is:", os.getpid()) -def init_asyncio_patch(): - """set default asyncio policy to be compatible with tornado - Tornado 6 (at least) is not compatible with the default - asyncio implementation on Windows - Pick the older SelectorEventLoopPolicy on Windows - if the known-incompatible default policy is in use. - do this as early as possible to make it a low priority and overridable - ref: https://github.com/tornadoweb/tornado/issues/2608 - FIXME: if/when tornado supports the defaults in asyncio, - remove and bump tornado requirement for py38 - """ - if ( - sys.platform.startswith("win") - and sys.version_info >= (3, 8) - and tornado.version_info < (6, 1) - ): - import asyncio - - try: - from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy - except ImportError: - pass - # not affected - else: - if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy: - # WindowsProactorEventLoopPolicy is not compatible with tornado 6 - # fallback to the pre-3.8 default of Selector - asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy()) - - -def main(): +async def main(): """The main function.""" print_process_id() # Create an in-process kernel # >>> print_process_id() # will print the same process ID as the main process - init_asyncio_patch() kernel_manager = InProcessKernelManager() - kernel_manager.start_kernel() + await kernel_manager.start_kernel() kernel = kernel_manager.kernel kernel.gui = "qt4" kernel.shell.push({"foo": 43, "print_process_id": print_process_id}) @@ -64,4 +32,4 @@ def main(): if __name__ == "__main__": - main() + run(main) diff --git a/ipykernel/control.py b/ipykernel/control.py index 0ee0fad0..a70377c0 100644 --- a/ipykernel/control.py +++ b/ipykernel/control.py @@ -1,7 +1,7 @@ """A thread for a control channel.""" -from threading import Thread +from threading import Event, Thread -from tornado.ioloop import IOLoop +from anyio import create_task_group, run, to_thread CONTROL_THREAD_NAME = "Control" @@ -12,21 +12,29 @@ class ControlThread(Thread): def __init__(self, **kwargs): """Initialize the thread.""" Thread.__init__(self, name=CONTROL_THREAD_NAME, **kwargs) - self.io_loop = IOLoop(make_current=False) self.pydev_do_not_trace = True self.is_pydev_daemon_thread = True + self.__stop = Event() + self._task = None + + def set_task(self, task): + self._task = task def run(self): """Run the thread.""" self.name = CONTROL_THREAD_NAME - try: - self.io_loop.start() - finally: - self.io_loop.close() + run(self._main) + + async def _main(self): + async with create_task_group() as tg: + if self._task is not None: + tg.start_soon(self._task) + await to_thread.run_sync(self.__stop.wait) + tg.cancel_scope.cancel() def stop(self): """Stop the thread. This method is threadsafe. """ - self.io_loop.add_callback(self.io_loop.stop) + self.__stop.set() diff --git a/ipykernel/debugger.py b/ipykernel/debugger.py index 90e4f888..509f0f26 100644 --- a/ipykernel/debugger.py +++ b/ipykernel/debugger.py @@ -3,12 +3,12 @@ import re import sys import typing as t +from math import inf import zmq +from anyio import Event, create_memory_object_stream from IPython.core.getipython import get_ipython from IPython.core.inputtransformer2 import leading_empty_lines -from tornado.locks import Event -from tornado.queues import Queue from zmq.utils import jsonapi try: @@ -116,7 +116,9 @@ def __init__(self, event_callback, log): self.tcp_buffer = "" self._reset_tcp_pos() self.event_callback = event_callback - self.message_queue: Queue[t.Any] = Queue() + self.message_send_stream, self.message_receive_stream = create_memory_object_stream( + max_buffer_size=inf + ) self.log = log def _reset_tcp_pos(self): @@ -135,7 +137,7 @@ def _put_message(self, raw_msg): else: self.log.debug("QUEUE - put message:") self.log.debug(msg) - self.message_queue.put_nowait(msg) + self.message_send_stream.send_nowait(msg) def put_tcp_frame(self, frame): """Put a tcp frame in the queue.""" @@ -186,25 +188,31 @@ def put_tcp_frame(self, frame): async def get_message(self): """Get a message from the queue.""" - return await self.message_queue.get() + return await self.message_receive_stream.receive() class DebugpyClient: """A client for debugpy.""" - def __init__(self, log, debugpy_stream, event_callback): + def __init__(self, log, debugpy_socket, event_callback): """Initialize the client.""" self.log = log - self.debugpy_stream = debugpy_stream + self.debugpy_socket = debugpy_socket self.event_callback = event_callback self.message_queue = DebugpyMessageQueue(self._forward_event, self.log) self.debugpy_host = "127.0.0.1" self.debugpy_port = -1 self.routing_id = None self.wait_for_attach = True - self.init_event = Event() + self._init_event = None self.init_event_seq = -1 + @property + def init_event(self): + if self._init_event is None: + self._init_event = Event() + return self._init_event + def _get_endpoint(self): host, port = self.get_host_port() return "tcp://" + host + ":" + str(port) @@ -215,9 +223,9 @@ def _forward_event(self, msg): self.init_event_seq = msg["seq"] self.event_callback(msg) - def _send_request(self, msg): + async def _send_request(self, msg): if self.routing_id is None: - self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID) + self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID) content = jsonapi.dumps( msg, default=json_default, @@ -232,7 +240,7 @@ def _send_request(self, msg): self.log.debug("DEBUGPYCLIENT:") self.log.debug(self.routing_id) self.log.debug(buf) - self.debugpy_stream.send_multipart((self.routing_id, buf)) + await self.debugpy_socket.send_multipart((self.routing_id, buf)) async def _wait_for_response(self): # Since events are never pushed to the message_queue @@ -250,7 +258,7 @@ async def _handle_init_sequence(self): "seq": int(self.init_event_seq) + 1, "command": "configurationDone", } - self._send_request(configurationDone) + await self._send_request(configurationDone) # 3] Waits for configurationDone response await self._wait_for_response() @@ -262,7 +270,7 @@ async def _handle_init_sequence(self): def get_host_port(self): """Get the host debugpy port.""" if self.debugpy_port == -1: - socket = self.debugpy_stream.socket + socket = self.debugpy_socket socket.bind_to_random_port("tcp://" + self.debugpy_host) self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode("utf-8") socket.unbind(self.endpoint) @@ -272,12 +280,12 @@ def get_host_port(self): def connect_tcp_socket(self): """Connect to the tcp socket.""" - self.debugpy_stream.socket.connect(self._get_endpoint()) - self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID) + self.debugpy_socket.connect(self._get_endpoint()) + self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID) def disconnect_tcp_socket(self): """Disconnect from the tcp socket.""" - self.debugpy_stream.socket.disconnect(self._get_endpoint()) + self.debugpy_socket.disconnect(self._get_endpoint()) self.routing_id = None self.init_event = Event() self.init_event_seq = -1 @@ -289,7 +297,7 @@ def receive_dap_frame(self, frame): async def send_dap_request(self, msg): """Send a dap request.""" - self._send_request(msg) + await self._send_request(msg) if self.wait_for_attach and msg["command"] == "attach": rep = await self._handle_init_sequence() self.wait_for_attach = False @@ -325,17 +333,19 @@ class Debugger: ] def __init__( - self, log, debugpy_stream, event_callback, shell_socket, session, just_my_code=True + self, log, debugpy_socket, event_callback, shell_socket, session, just_my_code=True ): """Initialize the debugger.""" self.log = log - self.debugpy_client = DebugpyClient(log, debugpy_stream, self._handle_event) + self.debugpy_client = DebugpyClient(log, debugpy_socket, self._handle_event) self.shell_socket = shell_socket self.session = session self.is_started = False self.event_callback = event_callback self.just_my_code = just_my_code - self.stopped_queue: Queue[t.Any] = Queue() + self.stopped_send_stream, self.stopped_receive_stream = create_memory_object_stream( + max_buffer_size=inf + ) self.started_debug_handlers = {} for msg_type in Debugger.started_debug_msg_types: @@ -360,7 +370,7 @@ def __init__( def _handle_event(self, msg): if msg["event"] == "stopped": if msg["body"]["allThreadsStopped"]: - self.stopped_queue.put_nowait(msg) + self.stopped_send_stream.send_nowait(msg) # Do not forward the event now, will be done in the handle_stopped_event return else: @@ -400,7 +410,7 @@ async def handle_stopped_event(self): """Handle a stopped event.""" # Wait for a stopped event message in the stopped queue # This message is used for triggering the 'threads' request - event = await self.stopped_queue.get() + event = await self.stopped_receive_stream.receive() req = {"seq": event["seq"] + 1, "type": "request", "command": "threads"} rep = await self._forward_message(req) for thread in rep["body"]["threads"]: @@ -412,7 +422,7 @@ async def handle_stopped_event(self): def tcp_client(self): return self.debugpy_client - def start(self): + async def start(self): """Start the debugger.""" if not self.debugpy_initialized: tmp_dir = get_tmp_directory() @@ -430,7 +440,12 @@ def start(self): (self.shell_socket.getsockopt(ROUTING_ID)), ) - ident, msg = self.session.recv(self.shell_socket, mode=0) + msg = await self.shell_socket.recv_multipart() + ident, msg = self.session.feed_identities(msg, copy=True) + try: + msg = self.session.deserialize(msg, content=True, copy=True) + except Exception: + self.log.error("Invalid message", exc_info=True) self.debugpy_initialized = msg["content"]["status"] == "ok" # Don't remove leading empty lines when debugging so the breakpoints are correctly positioned @@ -719,7 +734,7 @@ async def process_request(self, message): if self.is_started: self.log.info("The debugger has already started") else: - self.is_started = self.start() + self.is_started = await self.start() if self.is_started: self.log.info("The debugger has started") else: diff --git a/ipykernel/eventloops.py b/ipykernel/eventloops.py index ef54f410..08fd6730 100644 --- a/ipykernel/eventloops.py +++ b/ipykernel/eventloops.py @@ -388,13 +388,12 @@ def loop_asyncio(kernel): loop._should_close = False # type:ignore[attr-defined] # pause eventloop when there's an event on a zmq socket - def process_stream_events(stream): + def process_stream_events(socket): """fall back to main loop when there's a socket event""" - if stream.flush(limit=1): - loop.stop() + loop.stop() - notifier = partial(process_stream_events, kernel.shell_stream) - loop.add_reader(kernel.shell_stream.getsockopt(zmq.FD), notifier) + notifier = partial(process_stream_events, kernel.shell_socket) + loop.add_reader(kernel.shell_socket.getsockopt(zmq.FD), notifier) loop.call_soon(notifier) while True: diff --git a/ipykernel/inprocess/blocking.py b/ipykernel/inprocess/blocking.py index c598a44b..b5c421a7 100644 --- a/ipykernel/inprocess/blocking.py +++ b/ipykernel/inprocess/blocking.py @@ -80,10 +80,10 @@ class BlockingInProcessKernelClient(InProcessKernelClient): iopub_channel_class = Type(BlockingInProcessChannel) # type:ignore[arg-type] stdin_channel_class = Type(BlockingInProcessStdInChannel) # type:ignore[arg-type] - def wait_for_ready(self): + async def wait_for_ready(self): """Wait for kernel info reply on shell channel.""" while True: - self.kernel_info() + await self.kernel_info() try: msg = self.shell_channel.get_msg(block=True, timeout=1) except Empty: @@ -103,6 +103,5 @@ def wait_for_ready(self): while True: try: msg = self.iopub_channel.get_msg(block=True, timeout=0.2) - print(msg["msg_type"]) except Empty: break diff --git a/ipykernel/inprocess/client.py b/ipykernel/inprocess/client.py index d0ebfd22..c4072582 100644 --- a/ipykernel/inprocess/client.py +++ b/ipykernel/inprocess/client.py @@ -11,11 +11,9 @@ # Imports # ----------------------------------------------------------------------------- -import asyncio from jupyter_client.client import KernelClient from jupyter_client.clientabc import KernelClientABC -from jupyter_core.utils import run_sync # IPython imports from traitlets import Instance, Type, default @@ -104,7 +102,7 @@ def hb_channel(self): # Methods for sending specific messages # ------------------------------------- - def execute( + async def execute( self, code, silent=False, store_history=True, user_expressions=None, allow_stdin=None ): """Execute code on the client.""" @@ -118,19 +116,19 @@ def execute( allow_stdin=allow_stdin, ) msg = self.session.msg("execute_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def complete(self, code, cursor_pos=None): + async def complete(self, code, cursor_pos=None): """Get code completion.""" if cursor_pos is None: cursor_pos = len(code) content = dict(code=code, cursor_pos=cursor_pos) msg = self.session.msg("complete_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def inspect(self, code, cursor_pos=None, detail_level=0): + async def inspect(self, code, cursor_pos=None, detail_level=0): """Get code inspection.""" if cursor_pos is None: cursor_pos = len(code) @@ -140,14 +138,14 @@ def inspect(self, code, cursor_pos=None, detail_level=0): detail_level=detail_level, ) msg = self.session.msg("inspect_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def history(self, raw=True, output=False, hist_access_type="range", **kwds): + async def history(self, raw=True, output=False, hist_access_type="range", **kwds): """Get code history.""" content = dict(raw=raw, output=output, hist_access_type=hist_access_type, **kwds) msg = self.session.msg("history_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] def shutdown(self, restart=False): @@ -156,17 +154,17 @@ def shutdown(self, restart=False): msg = "Cannot shutdown in-process kernel" raise NotImplementedError(msg) - def kernel_info(self): + async def kernel_info(self): """Request kernel info.""" msg = self.session.msg("kernel_info_request") - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def comm_info(self, target_name=None): + async def comm_info(self, target_name=None): """Request a dictionary of valid comms and their targets.""" content = {} if target_name is None else dict(target_name=target_name) msg = self.session.msg("comm_info_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] def input(self, string): @@ -176,29 +174,21 @@ def input(self, string): raise RuntimeError(msg) self.kernel.raw_input_str = string - def is_complete(self, code): + async def is_complete(self, code): """Handle an is_complete request.""" msg = self.session.msg("is_complete_request", {"code": code}) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def _dispatch_to_kernel(self, msg): + async def _dispatch_to_kernel(self, msg): """Send a message to the kernel and handle a reply.""" kernel = self.kernel if kernel is None: - msg = "Cannot send request. No kernel exists." - raise RuntimeError(msg) + error_message = "Cannot send request. No kernel exists." + raise RuntimeError(error_message) - stream = kernel.shell_stream - self.session.send(stream, msg) - msg_parts = stream.recv_multipart() - if run_sync is not None: - dispatch_shell = run_sync(kernel.dispatch_shell) - dispatch_shell(msg_parts) - else: - loop = asyncio.get_event_loop() # type:ignore[unreachable] - loop.run_until_complete(kernel.dispatch_shell(msg_parts)) - idents, reply_msg = self.session.recv(stream, copy=False) + kernel.shell_socket.put(msg) + reply_msg = await kernel.shell_socket.get() self.shell_channel.call_handlers_later(reply_msg) def get_shell_msg(self, block=True, timeout=None): diff --git a/ipykernel/inprocess/ipkernel.py b/ipykernel/inprocess/ipkernel.py index 13b17217..087b323a 100644 --- a/ipykernel/inprocess/ipkernel.py +++ b/ipykernel/inprocess/ipkernel.py @@ -7,6 +7,8 @@ import sys from contextlib import contextmanager +from anyio import TASK_STATUS_IGNORED +from anyio.abc import TaskStatus from IPython.core.interactiveshell import InteractiveShellABC from traitlets import Any, Enum, Instance, List, Type, default @@ -47,11 +49,11 @@ class InProcessKernel(IPythonKernel): # Kernel interface # ------------------------------------------------------------------------- - shell_class = Type(allow_none=True) # type:ignore[assignment] - _underlying_iopub_socket = Instance(DummySocket, ()) + shell_class = Type(allow_none=True) + _underlying_iopub_socket = Instance(DummySocket, (False,)) iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment] - shell_stream = Instance(DummySocket, ()) # type:ignore[arg-type] + shell_socket = Instance(DummySocket, (True,)) @default("iopub_thread") def _default_iopub_thread(self): @@ -65,13 +67,13 @@ def _default_iopub_thread(self): def _default_iopub_socket(self): return self.iopub_thread.background_socket - stdin_socket = Instance(DummySocket, ()) # type:ignore[assignment] + stdin_socket = Instance(DummySocket, (False,)) # type:ignore[assignment] def __init__(self, **traits): """Initialize the kernel.""" super().__init__(**traits) - self._underlying_iopub_socket.observe(self._io_dispatch, names=["message_sent"]) + self._io_dispatch() if self.shell: self.shell.kernel = self @@ -80,10 +82,14 @@ async def execute_request(self, stream, ident, parent): with self._redirected_io(): await super().execute_request(stream, ident, parent) - def start(self): + async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: """Override registration of dispatchers for streams.""" if self.shell: self.shell.exit_now = False + await super().start(task_status=task_status) + + def stop(self): + super().stop() def _abort_queues(self): """The in-process kernel doesn't abort requests.""" @@ -132,13 +138,16 @@ def _redirected_io(self): # ------ Trait change handlers -------------------------------------------- - def _io_dispatch(self, change): + def _io_dispatch(self): """Called when a message is sent to the IO socket.""" assert self.iopub_socket.io_thread is not None assert self.session is not None - ident, msg = self.session.recv(self.iopub_socket.io_thread.socket, copy=False) - for frontend in self.frontends: - frontend.iopub_channel.call_handlers(msg) + + def callback(msg): + for frontend in self.frontends: + frontend.iopub_channel.call_handlers(msg) + + self.iopub_thread.socket.on_recv = callback # ------ Trait initializers ----------------------------------------------- @@ -148,7 +157,7 @@ def _default_log(self): @default("session") def _default_session(self): - from jupyter_client.session import Session + from .session import Session return Session(parent=self, key=INPROCESS_KEY) diff --git a/ipykernel/inprocess/manager.py b/ipykernel/inprocess/manager.py index 3a3f92c3..04c718c6 100644 --- a/ipykernel/inprocess/manager.py +++ b/ipykernel/inprocess/manager.py @@ -3,12 +3,14 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from anyio import TASK_STATUS_IGNORED +from anyio.abc import TaskStatus from jupyter_client.manager import KernelManager from jupyter_client.managerabc import KernelManagerABC -from jupyter_client.session import Session from traitlets import DottedObjectName, Instance, default from .constants import INPROCESS_KEY +from .session import Session class InProcessKernelManager(KernelManager): @@ -41,11 +43,12 @@ def _default_session(self): # Kernel management methods # -------------------------------------------------------------------------- - def start_kernel(self, **kwds): + async def start_kernel(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED, **kwds) -> None: """Start the kernel.""" from ipykernel.inprocess.ipkernel import InProcessKernel self.kernel = InProcessKernel(parent=self, session=self.session) + await self.kernel.start(task_status=task_status) def shutdown_kernel(self): """Shutdown the kernel.""" @@ -53,16 +56,19 @@ def shutdown_kernel(self): self.kernel.iopub_thread.stop() self._kill_kernel() - def restart_kernel(self, now=False, **kwds): + async def restart_kernel( + self, now=False, *, task_status: TaskStatus = TASK_STATUS_IGNORED, **kwds + ) -> None: """Restart the kernel.""" self.shutdown_kernel() - self.start_kernel(**kwds) + await self.start_kernel(task_status=task_status, **kwds) @property def has_kernel(self): return self.kernel is not None def _kill_kernel(self): + self.kernel.stop() self.kernel = None def interrupt_kernel(self): diff --git a/ipykernel/inprocess/session.py b/ipykernel/inprocess/session.py new file mode 100644 index 00000000..0eaed2c6 --- /dev/null +++ b/ipykernel/inprocess/session.py @@ -0,0 +1,41 @@ +from jupyter_client.session import Session as _Session + + +class Session(_Session): + async def recv(self, socket, copy=True): + return await socket.recv_multipart() + + def send( + self, + socket, + msg_or_type, + content=None, + parent=None, + ident=None, + buffers=None, + track=False, + header=None, + metadata=None, + ): + if isinstance(msg_or_type, str): + msg = self.msg( + msg_or_type, + content=content, + parent=parent, + header=header, + metadata=metadata, + ) + else: + # We got a Message or message dict, not a msg_type so don't + # build a new Message. + msg = msg_or_type + buffers = buffers or msg.get("buffers", []) + + socket.send_multipart(msg) + return msg + + def feed_identities(self, msg, copy=True): + return "", msg + + def deserialize(self, msg, content=True, copy=True): + return msg diff --git a/ipykernel/inprocess/socket.py b/ipykernel/inprocess/socket.py index 7e48789e..3e79297c 100644 --- a/ipykernel/inprocess/socket.py +++ b/ipykernel/inprocess/socket.py @@ -3,10 +3,11 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. -from queue import Queue +from math import inf import zmq -from traitlets import HasTraits, Instance, Int +from anyio import create_memory_object_stream +from traitlets import HasTraits, Instance # ----------------------------------------------------------------------------- # Dummy socket class @@ -14,29 +15,53 @@ class DummySocket(HasTraits): - """A dummy socket implementing (part of) the zmq.Socket interface.""" + """A dummy socket implementing (part of) the zmq.asyncio.Socket interface.""" - queue = Instance(Queue, ()) - message_sent = Int(0) # Should be an Event - context = Instance(zmq.Context) + context = Instance(zmq.asyncio.Context) def _context_default(self): - return zmq.Context() + return zmq.asyncio.Context() # ------------------------------------------------------------------------- # Socket interface # ------------------------------------------------------------------------- - def recv_multipart(self, flags=0, copy=True, track=False): + def __init__(self, is_shell, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_shell = is_shell + self.on_recv = None + if is_shell: + self.in_send_stream, self.in_receive_stream = create_memory_object_stream( + max_buffer_size=inf + ) + self.out_send_stream, self.out_receive_stream = create_memory_object_stream( + max_buffer_size=inf + ) + + def put(self, msg): + self.in_send_stream.send_nowait(msg) + + async def get(self): + msg = await self.out_receive_stream.receive() + return msg + + async def recv_multipart(self, flags=0, copy=True, track=False): """Recv a multipart message.""" - return self.queue.get_nowait() + msg = await self.in_receive_stream.receive() + return msg def send_multipart(self, msg_parts, flags=0, copy=True, track=False): """Send a multipart message.""" - msg_parts = list(map(zmq.Message, msg_parts)) - self.queue.put_nowait(msg_parts) - self.message_sent += 1 + if self.is_shell: + self.out_send_stream.send_nowait(msg_parts) + if self.on_recv is not None: + self.on_recv(msg_parts) def flush(self, timeout=1.0): """no-op to comply with stream API""" pass + + async def poll(self, timeout=0): + assert timeout == 0 + statistics = self.in_receive_stream.statistics() + return statistics.current_buffer_used != 0 diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 0c8a2fa9..0cde676a 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -14,13 +14,12 @@ from binascii import b2a_hex from collections import deque from io import StringIO, TextIOBase -from threading import local +from threading import Event, Thread, local from typing import Any, Callable, Deque, Dict, Optional import zmq +from anyio import create_task_group, run, to_thread from jupyter_client.session import extract_header -from tornado.ioloop import IOLoop -from zmq.eventloop.zmqstream import ZMQStream # ----------------------------------------------------------------------------- # Globals @@ -36,6 +35,38 @@ # ----------------------------------------------------------------------------- +class _IOPubThread(Thread): + """A thread for a IOPub.""" + + def __init__(self, tasks, **kwargs): + """Initialize the thread.""" + Thread.__init__(self, name="IOPub", **kwargs) + self._tasks = tasks + self.pydev_do_not_trace = True + self.is_pydev_daemon_thread = True + self.daemon = True + self.__stop = Event() + + def run(self): + """Run the thread.""" + self.name = "Control" + run(self._main) + + async def _main(self): + async with create_task_group() as tg: + for task in self._tasks: + tg.start_soon(task) + await to_thread.run_sync(self.__stop.wait) + tg.cancel_scope.cancel() + + def stop(self): + """Stop the thread. + + This method is threadsafe. + """ + self.__stop.set() + + class IOPubThread: """An object for sending IOPub messages in a background thread @@ -57,11 +88,9 @@ def __init__(self, socket, pipe=False): piped from subprocesses. """ self.socket = socket - self._stopped = False self.background_socket = BackgroundSocket(self) self._master_pid = os.getpid() self._pipe_flag = pipe - self.io_loop = IOLoop(make_current=False) if pipe: self._setup_pipe_in() self._local = threading.local() @@ -71,48 +100,20 @@ def __init__(self, socket, pipe=False): self._event_pipe_gc_seconds: float = 10 self._event_pipe_gc_task: Optional[asyncio.Task] = None self._setup_event_pipe() - self.thread = threading.Thread(target=self._thread_main, name="IOPub") - self.thread.daemon = True - self.thread.pydev_do_not_trace = True # type:ignore[attr-defined] - self.thread.is_pydev_daemon_thread = True # type:ignore[attr-defined] - self.thread.name = "IOPub" - - def _thread_main(self): - """The inner loop that's actually run in a thread""" - - def _start_event_gc(): - self._event_pipe_gc_task = asyncio.ensure_future(self._run_event_pipe_gc()) - - self.io_loop.run_sync(_start_event_gc) - - if not self._stopped: - # avoid race if stop called before start thread gets here - # probably only comes up in tests - self.io_loop.start() - - if self._event_pipe_gc_task is not None: - # cancel gc task to avoid pending task warnings - async def _cancel(): - self._event_pipe_gc_task.cancel() # type:ignore[union-attr] - - if not self._stopped: - self.io_loop.run_sync(_cancel) - else: - self._event_pipe_gc_task.cancel() - - self.io_loop.close(all_fds=True) + tasks = [self._handle_event] + if pipe: + tasks.append(self._handle_pipe_msgs) + self.thread = _IOPubThread(tasks) def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" ctx = self.socket.context - pipe_in = ctx.socket(zmq.PULL) - pipe_in.linger = 0 + self._pipe_in0 = ctx.socket(zmq.PULL) + self._pipe_in0.linger = 0 _uuid = b2a_hex(os.urandom(16)).decode("ascii") iface = self._event_interface = "inproc://%s" % _uuid - pipe_in.bind(iface) - self._event_puller = ZMQStream(pipe_in, self.io_loop) - self._event_puller.on_recv(self._handle_event) + self._pipe_in0.bind(iface) async def _run_event_pipe_gc(self): """Task to run event pipe gc continuously""" @@ -141,7 +142,7 @@ def _event_pipe(self): event_pipe = self._local.event_pipe except AttributeError: # new thread, new event pipe - ctx = self.socket.context + ctx = zmq.Context(self.socket.context) event_pipe = ctx.socket(zmq.PUSH) event_pipe.linger = 0 event_pipe.connect(self._event_interface) @@ -153,7 +154,7 @@ def _event_pipe(self): self._event_pipes[threading.current_thread()] = event_pipe return event_pipe - def _handle_event(self, msg): + async def _handle_event(self): """Handle an event on the event pipe Content of the message is ignored. @@ -161,12 +162,19 @@ def _handle_event(self, msg): Whenever *an* event arrives on the event stream, *all* waiting events are processed in order. """ - # freeze event count so new writes don't extend the queue - # while we are processing - n_events = len(self._events) - for _ in range(n_events): - event_f = self._events.popleft() - event_f() + try: + while True: + await self._pipe_in0.recv() + # freeze event count so new writes don't extend the queue + # while we are processing + n_events = len(self._events) + for _ in range(n_events): + event_f = self._events.popleft() + event_f() + except Exception as e: + if self.thread.__stop.is_set(): + return + raise e def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" @@ -175,11 +183,11 @@ def _setup_pipe_in(self): # use UUID to authenticate pipe messages self._pipe_uuid = os.urandom(16) - pipe_in = ctx.socket(zmq.PULL) - pipe_in.linger = 0 + self._pipe_in1 = ctx.socket(zmq.PULL) + self._pipe_in1.linger = 0 try: - self._pipe_port = pipe_in.bind_to_random_port("tcp://127.0.0.1") + self._pipe_port = self._pipe_in1.bind_to_random_port("tcp://127.0.0.1") except zmq.ZMQError as e: warnings.warn( "Couldn't bind IOPub Pipe to 127.0.0.1: %s" % e @@ -187,13 +195,22 @@ def _setup_pipe_in(self): stacklevel=2, ) self._pipe_flag = False - pipe_in.close() + self._pipe_in1.close() return - self._pipe_in = ZMQStream(pipe_in, self.io_loop) - self._pipe_in.on_recv(self._handle_pipe_msg) - def _handle_pipe_msg(self, msg): + async def _handle_pipe_msgs(self): + """handle pipe messages from a subprocess""" + try: + while True: + await self._handle_pipe_msg() + except Exception as e: + if self.thread.__stop.is_set(): + return + raise e + + async def _handle_pipe_msg(self, msg=None): """handle a pipe message from a subprocess""" + msg = msg or await self._pipe_in1.recv_multipart() if not self._pipe_flag or not self._is_master_process(): return if msg[0] != self._pipe_uuid: @@ -221,7 +238,6 @@ def _check_mp_mode(self): def start(self): """Start the IOPub thread""" - self.thread.name = "IOPub" self.thread.start() # make sure we don't prevent process exit # I'm not sure why setting daemon=True above isn't enough, but it doesn't appear to be. @@ -229,16 +245,9 @@ def start(self): def stop(self): """Stop the IOPub thread""" - self._stopped = True if not self.thread.is_alive(): return - self.io_loop.add_callback(self.io_loop.stop) - - self.thread.join(timeout=30) - if self.thread.is_alive(): - # avoid infinite hang if stop fails - msg = "IOPub thread did not terminate in 30 seconds" - raise TimeoutError(msg) + self.thread.stop() # close *all* event pipes, created in any thread # event pipes can only be used from other threads while self.thread.is_alive() # so after thread.join, this should be safe @@ -249,6 +258,9 @@ def close(self): """Close the IOPub thread.""" if self.closed: return + self._pipe_in0.close() + if self._pipe_flag: + self._pipe_in1.close() self.socket.close() self.socket = None @@ -435,6 +447,8 @@ def __init__( ) # This is necessary for compatibility with Python built-in streams self.session = session + self._has_thread = False + self.watch_fd_thread = None if not isinstance(pub_thread, IOPubThread): # Backward-compat: given socket, not thread. Wrap in a thread. warnings.warn( @@ -445,6 +459,7 @@ def __init__( ) pub_thread = IOPubThread(pub_thread) pub_thread.start() + self._has_thread = True self.pub_thread = pub_thread self.name = name self.topic = b"stream." + name.encode() @@ -452,7 +467,6 @@ def __init__( self._master_pid = os.getpid() self._flush_pending = False self._subprocess_flush_pending = False - self._io_loop = pub_thread.io_loop self._buffer_lock = threading.RLock() self._buffer = StringIO() self.echo = None @@ -532,13 +546,16 @@ def close(self): # thread won't wake unless there's something to read # writing something after _should_watch will not be echoed os.write(self._original_stdstream_fd, b'\0') - self.watch_fd_thread.join() + if self.watch_fd_thread is not None: + self.watch_fd_thread.join() # restore original FDs os.dup2(self._original_stdstream_copy, self._original_stdstream_fd) os.close(self._original_stdstream_copy) if self._exc: etype, value, tb = self._exc traceback.print_exception(etype, value, tb) + if self._has_thread: + self.pub_thread.stop() self.pub_thread = None @property @@ -555,10 +572,7 @@ def _schedule_flush(self): self._flush_pending = True # add_timeout has to be handed to the io thread via event pipe - def _schedule_in_thread(): - self._io_loop.call_later(self.flush_interval, self._flush) - - self.pub_thread.schedule(_schedule_in_thread) + self.pub_thread.schedule(self._flush) def flush(self): """trigger actual zmq send diff --git a/ipykernel/ipkernel.py b/ipykernel/ipkernel.py index 58821850..b05d350d 100644 --- a/ipykernel/ipkernel.py +++ b/ipykernel/ipkernel.py @@ -1,21 +1,20 @@ """The IPython kernel implementation""" -import asyncio import builtins import getpass import os -import signal import sys import threading import typing as t -from contextlib import contextmanager -from functools import partial +from dataclasses import dataclass import comm +import zmq.asyncio +from anyio import TASK_STATUS_IGNORED, create_task_group, to_thread +from anyio.abc import TaskStatus from IPython.core import release from IPython.utils.tokenutil import line_at_cursor, token_at_cursor from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat -from zmq.eventloop.zmqstream import ZMQStream from .comm.comm import BaseComm from .comm.manager import CommManager @@ -26,11 +25,6 @@ from .kernelbase import _accepts_cell_id from .zmqshell import ZMQInteractiveShell -try: - from IPython.core.interactiveshell import _asyncio_runner # type:ignore[attr-defined] -except ImportError: - _asyncio_runner = None # type:ignore[assignment] - try: from IPython.core.completer import provisionalcompleter as _provisionalcompleter from IPython.core.completer import rectify_completions as _rectify_completions @@ -78,7 +72,9 @@ class IPythonKernel(KernelBase): help="Set this flag to False to deactivate the use of experimental IPython completion APIs.", ).tag(config=True) - debugpy_stream = Instance(ZMQStream, allow_none=True) if _is_debugpy_available else None + debugpy_socket = ( + Instance(zmq.asyncio.Socket, allow_none=True) if _is_debugpy_available else None + ) user_module = Any() @@ -106,11 +102,13 @@ def __init__(self, **kwargs): """Initialize the kernel.""" super().__init__(**kwargs) + self.executing_blocking_code_in_main_shell = False + # Initialize the Debugger if _is_debugpy_available: self.debugger = Debugger( self.log, - self.debugpy_stream, + self.debugpy_socket, self._publish_debug_event, self.debug_shell_socket, self.session, @@ -197,12 +195,31 @@ def __init__(self, **kwargs): "file_extension": ".py", } - def dispatch_debugpy(self, msg): - if _is_debugpy_available: - # The first frame is the socket id, we can drop it - frame = msg[1].bytes.decode("utf-8") - self.log.debug("Debugpy received: %s", frame) - self.debugger.tcp_client.receive_dap_frame(frame) + async def process_debugpy(self): + async with create_task_group() as tg: + tg.start_soon(self.receive_debugpy_messages) + tg.start_soon(self.poll_stopped_queue) + await to_thread.run_sync(self.debugpy_stop.wait) + tg.cancel_scope.cancel() + + async def receive_debugpy_messages(self): + if not _is_debugpy_available: + return + + while True: + await self.receive_debugpy_message() + + async def receive_debugpy_message(self, msg=None): + if not _is_debugpy_available: + return + + if msg is None: + assert self.debugpy_socket is not None + msg = await self.debugpy_socket.recv_multipart() + # The first frame is the socket id, we can drop it + frame = msg[1].decode("utf-8") + self.log.debug("Debugpy received: %s", frame) + self.debugger.tcp_client.receive_dap_frame(frame) @property def banner(self): @@ -214,19 +231,21 @@ async def poll_stopped_queue(self): while True: await self.debugger.handle_stopped_event() - def start(self): + async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: """Start the kernel.""" if self.shell: self.shell.exit_now = False - if self.debugpy_stream is None: - self.log.warning("debugpy_stream undefined, debugging will not be enabled") + if self.debugpy_socket is None: + self.log.warning("debugpy_socket undefined, debugging will not be enabled") else: - self.debugpy_stream.on_recv(self.dispatch_debugpy, copy=False) - super().start() - if self.debugpy_stream: - asyncio.run_coroutine_threadsafe( - self.poll_stopped_queue(), self.control_thread.io_loop.asyncio_loop - ) + self.debugpy_stop = threading.Event() + self.control_tasks.append(self.process_debugpy) + await super().start(task_status=task_status) + + def stop(self): + super().stop() + if self.debugpy_socket is not None: + self.debugpy_stop.set() def set_parent(self, ident, parent, channel="shell"): """Overridden from parent to tell the display hook and output streams @@ -295,50 +314,6 @@ def execution_count(self, value): # execution counter. pass - @contextmanager - def _cancel_on_sigint(self, future): - """ContextManager for capturing SIGINT and cancelling a future - - SIGINT raises in the event loop when running async code, - but we want it to halt a coroutine. - - Ideally, it would raise KeyboardInterrupt, - but this turns it into a CancelledError. - At least it gets a decent traceback to the user. - """ - sigint_future: asyncio.Future[int] = asyncio.Future() - - # whichever future finishes first, - # cancel the other one - def cancel_unless_done(f, _ignored): - if f.cancelled() or f.done(): - return - f.cancel() - - # when sigint finishes, - # abort the coroutine with CancelledError - sigint_future.add_done_callback(partial(cancel_unless_done, future)) - # when the main future finishes, - # stop watching for SIGINT events - future.add_done_callback(partial(cancel_unless_done, sigint_future)) - - def handle_sigint(*args): - def set_sigint_result(): - if sigint_future.cancelled() or sigint_future.done(): - return - sigint_future.set_result(1) - - # use add_callback for thread safety - self.io_loop.add_callback(set_sigint_result) - - # set the custom sigint handler during this context - save_sigint = signal.signal(signal.SIGINT, handle_sigint) - try: - yield - finally: - # restore the previous sigint handler - signal.signal(signal.SIGINT, save_sigint) - async def do_execute( self, code, @@ -380,63 +355,69 @@ async def run_cell(*args, **kwargs): transformed_cell = code preprocessing_exc_tuple = sys.exc_info() - if ( - _asyncio_runner # type:ignore[truthy-bool] - and shell.loop_runner is _asyncio_runner - and asyncio.get_event_loop().is_running() - and should_run_async( - code, + kwargs = dict( + store_history=store_history, + silent=silent, + ) + if with_cell_id: + kwargs.update(cell_id=cell_id) + + if should_run_async( + code, + transformed_cell=transformed_cell, + preprocessing_exc_tuple=preprocessing_exc_tuple, + ): + kwargs.update( transformed_cell=transformed_cell, preprocessing_exc_tuple=preprocessing_exc_tuple, ) - ): - if with_cell_id: - coro = run_cell( - code, - store_history=store_history, - silent=silent, - transformed_cell=transformed_cell, - preprocessing_exc_tuple=preprocessing_exc_tuple, - cell_id=cell_id, - ) - else: - coro = run_cell( - code, - store_history=store_history, - silent=silent, - transformed_cell=transformed_cell, - preprocessing_exc_tuple=preprocessing_exc_tuple, - ) + coro = run_cell(code, **kwargs) + + @dataclass + class Execution: + interrupt: bool = False + result: t.Any = None + + async def run(execution: Execution) -> None: + execution.result = await coro + if not execution.interrupt: + self.shell_interrupt.put(False) + + res = None + try: + async with create_task_group() as tg: + execution = Execution() + self.shell_is_awaiting = True + tg.start_soon(run, execution) + execution.interrupt = await to_thread.run_sync(self.shell_interrupt.get) + self.shell_is_awaiting = False + if execution.interrupt: + tg.cancel_scope.cancel() + + res = execution.result + finally: + shell.events.trigger("post_execute") + if not silent: + shell.events.trigger("post_run_cell", res) - coro_future = asyncio.ensure_future(coro) - - with self._cancel_on_sigint(coro_future): - res = None - try: - res = await coro_future - finally: - shell.events.trigger("post_execute") - if not silent: - shell.events.trigger("post_run_cell", res) else: # runner isn't already running, # make synchronous call, # letting shell dispatch to loop runners - if with_cell_id: - res = shell.run_cell( - code, - store_history=store_history, - silent=silent, - cell_id=cell_id, - ) - else: - res = shell.run_cell(code, store_history=store_history, silent=silent) + self.shell_is_blocking = True + try: + res = shell.run_cell(code, **kwargs) + finally: + self.shell_is_blocking = False finally: self._restore_input() - err = res.error_before_exec if res.error_before_exec is not None else res.error_in_exec + if res is not None: + err = res.error_before_exec if res.error_before_exec is not None else res.error_in_exec + else: + err = KeyboardInterrupt() - if res.success: + if res is not None and res.success: reply_content["status"] = "ok" else: reply_content["status"] = "error" diff --git a/ipykernel/kernelapp.py b/ipykernel/kernelapp.py index de4682f8..adf03ac3 100644 --- a/ipykernel/kernelapp.py +++ b/ipykernel/kernelapp.py @@ -16,6 +16,8 @@ from typing import Optional import zmq +import zmq.asyncio +from anyio import create_task_group, run from IPython.core.application import ( # type:ignore[attr-defined] BaseIPythonApplication, base_aliases, @@ -27,7 +29,6 @@ from jupyter_client.connect import ConnectionFileMixin from jupyter_client.session import Session, session_aliases, session_flags from jupyter_core.paths import jupyter_runtime_dir -from tornado import ioloop from traitlets.traitlets import ( Any, Bool, @@ -41,7 +42,6 @@ ) from traitlets.utils import filefind from traitlets.utils.importstring import import_item -from zmq.eventloop.zmqstream import ZMQStream from .connect import get_connection_info, write_connection_file @@ -321,7 +321,7 @@ def init_sockets(self): """Create a context, a session, and the kernel sockets.""" self.log.info("Starting the kernel at pid: %i", os.getpid()) assert self.context is None, "init_sockets cannot be called twice!" - self.context = context = zmq.Context() + self.context = context = zmq.asyncio.Context() atexit.register(self.close) self.shell_socket = context.socket(zmq.ROUTER) @@ -329,7 +329,7 @@ def init_sockets(self): self.shell_port = self._bind_socket(self.shell_socket, self.shell_port) self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port) - self.stdin_socket = context.socket(zmq.ROUTER) + self.stdin_socket = zmq.Context(context).socket(zmq.ROUTER) self.stdin_socket.linger = 1000 self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port) self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port) @@ -538,25 +538,27 @@ def register(signum, file=sys.__stderr__, all_threads=True, chain=False, **kwarg faulthandler.register = register + def sigint_handler(self, *args): + if self.kernel.shell_is_awaiting: + self.kernel.shell_interrupt.put(True) + elif self.kernel.shell_is_blocking: + raise KeyboardInterrupt + def init_signal(self): """Initialize the signal handler.""" - signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGINT, self.sigint_handler) def init_kernel(self): """Create the Kernel object itself""" - shell_stream = ZMQStream(self.shell_socket) - control_stream = ZMQStream(self.control_socket, self.control_thread.io_loop) - debugpy_stream = ZMQStream(self.debugpy_socket, self.control_thread.io_loop) - self.control_thread.start() - kernel_factory = self.kernel_class.instance # type:ignore[attr-defined] + kernel_factory = self.kernel_class.instance kernel = kernel_factory( parent=self, session=self.session, - control_stream=control_stream, - debugpy_stream=debugpy_stream, + control_socket=self.control_socket, + debugpy_socket=self.debugpy_socket, debug_shell_socket=self.debug_shell_socket, - shell_stream=shell_stream, + shell_socket=self.shell_socket, control_thread=self.control_thread, iopub_thread=self.iopub_thread, iopub_socket=self.iopub_socket, @@ -721,22 +723,18 @@ def start(self): return self.subapp.start() if self.poller is not None: self.poller.start() - self.kernel.start() - self.io_loop = ioloop.IOLoop.current() - if self.trio_loop: - from ipykernel.trio_runner import TrioRunner - - tr = TrioRunner() - tr.initialize(self.kernel, self.io_loop) - try: - tr.run() - except KeyboardInterrupt: - pass - else: - try: - self.io_loop.start() - except KeyboardInterrupt: - pass + backend = "trio" if self.trio_loop else "asyncio" + run(self.main, backend=backend) + + async def main(self): + async with create_task_group() as tg: + if self.kernel.eventloop: + tg.start_soon(self.kernel.enter_eventloop) + tg.start_soon(self.kernel.start) + + def stop(self): + """Stop the kernel, thread-safe.""" + self.kernel.stop() launch_new_instance = IPKernelApp.launch_instance diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index 6d06d4ab..d1ea79fc 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -4,12 +4,11 @@ # Distributed under the terms of the Modified BSD License. import asyncio -import concurrent.futures import inspect import itertools import logging import os -import socket +import queue import sys import threading import time @@ -17,8 +16,7 @@ import uuid import warnings from datetime import datetime -from functools import partial -from signal import SIGINT, SIGTERM, Signals, default_int_handler, signal +from signal import SIGINT, SIGTERM, Signals from .control import CONTROL_THREAD_NAME @@ -37,10 +35,10 @@ import psutil import zmq +from anyio import TASK_STATUS_IGNORED, create_task_group, sleep, to_thread +from anyio.abc import TaskStatus from IPython.core.error import StdinNotImplementedError from jupyter_client.session import Session -from tornado import ioloop -from tornado.queues import Queue, QueueEmpty from traitlets.config.configurable import SingletonConfigurable from traitlets.traitlets import ( Any, @@ -53,9 +51,7 @@ Set, Unicode, default, - observe, ) -from zmq.eventloop.zmqstream import ZMQStream from ipykernel.jsonutil import json_clean @@ -73,6 +69,8 @@ def _accepts_cell_id(meth): class Kernel(SingletonConfigurable): """The base kernel class.""" + _aborted_time: float + # --------------------------------------------------------------------------- # Kernel interface # --------------------------------------------------------------------------- @@ -82,58 +80,18 @@ class Kernel(SingletonConfigurable): processes: t.Dict[str, psutil.Process] = {} - @observe("eventloop") - def _update_eventloop(self, change): - """schedule call to eventloop from IOLoop""" - loop = ioloop.IOLoop.current() - if change.new is not None: - loop.add_callback(self.enter_eventloop) - session = Instance(Session, allow_none=True) profile_dir = Instance("IPython.core.profiledir.ProfileDir", allow_none=True) - shell_stream = Instance(ZMQStream, allow_none=True) - - shell_streams = List( - help="""Deprecated shell_streams alias. Use shell_stream - - .. versionchanged:: 6.0 - shell_streams is deprecated. Use shell_stream. - """ - ) + shell_socket = Instance(zmq.asyncio.Socket, allow_none=True) implementation: str implementation_version: str banner: str - @default("shell_streams") - def _shell_streams_default(self): # pragma: no cover - warnings.warn( - "Kernel.shell_streams is deprecated in ipykernel 6.0. Use Kernel.shell_stream", - DeprecationWarning, - stacklevel=2, - ) - if self.shell_stream is not None: - return [self.shell_stream] - else: - return [] + _is_test = Bool(False) - @observe("shell_streams") - def _shell_streams_changed(self, change): # pragma: no cover - warnings.warn( - "Kernel.shell_streams is deprecated in ipykernel 6.0. Use Kernel.shell_stream", - DeprecationWarning, - stacklevel=2, - ) - if len(change.new) > 1: - warnings.warn( - "Kernel only supports one shell stream. Additional streams will be ignored.", - RuntimeWarning, - stacklevel=2, - ) - if change.new: - self.shell_stream = change.new[0] - - control_stream = Instance(ZMQStream, allow_none=True) + control_socket = Instance(zmq.asyncio.Socket, allow_none=True) + control_tasks = List() debug_shell_socket = Any() @@ -275,83 +233,7 @@ def __init__(self, **kwargs): for msg_type in self.control_msg_types: self.control_handlers[msg_type] = getattr(self, msg_type) - self.control_queue: Queue[t.Any] = Queue() - - def dispatch_control(self, msg): - self.control_queue.put_nowait(msg) - - async def poll_control_queue(self): - while True: - msg = await self.control_queue.get() - # handle tracers from _flush_control_queue - if isinstance(msg, (concurrent.futures.Future, asyncio.Future)): - msg.set_result(None) - continue - await self.process_control(msg) - - async def _flush_control_queue(self): - """Flush the control queue, wait for processing of any pending messages""" - tracer_future: t.Union[concurrent.futures.Future[object], asyncio.Future[object]] - if self.control_thread: - control_loop = self.control_thread.io_loop - # concurrent.futures.Futures are threadsafe - # and can be used to await across threads - tracer_future = concurrent.futures.Future() - awaitable_future = asyncio.wrap_future(tracer_future) - else: - control_loop = self.io_loop - tracer_future = awaitable_future = asyncio.Future() - - def _flush(): - # control_stream.flush puts messages on the queue - if self.control_stream: - self.control_stream.flush() - # put Future on the queue after all of those, - # so we can wait for all queued messages to be processed - self.control_queue.put(tracer_future) - - control_loop.add_callback(_flush) - return awaitable_future - - async def process_control(self, msg): - """dispatch control requests""" - if not self.session: - return - idents, msg = self.session.feed_identities(msg, copy=False) - try: - msg = self.session.deserialize(msg, content=True, copy=False) - except Exception: - self.log.error("Invalid Control Message", exc_info=True) - return - - self.log.debug("Control received: %s", msg) - - # Set the parent message for side effects. - self.set_parent(idents, msg, channel="control") - self._publish_status("busy", "control") - - header = msg["header"] - msg_type = header["msg_type"] - - handler = self.control_handlers.get(msg_type, None) - if handler is None: - self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type) - else: - try: - result = handler(self.control_stream, idents, msg) - if inspect.isawaitable(result): - await result - except Exception: - self.log.error("Exception in control handler:", exc_info=True) - - sys.stdout.flush() - sys.stderr.flush() - self._publish_status("idle", "control") - # flush to ensure reply is sent - if self.control_stream: - self.control_stream.flush(zmq.POLLOUT) - - def should_handle(self, stream, msg, idents): + async def should_handle(self, socket, msg, idents): """Check whether a shell-channel message should be handled Allows subclasses to prevent handling of certain messages (e.g. aborted requests). @@ -360,21 +242,86 @@ def should_handle(self, stream, msg, idents): if msg_id in self.aborted: # is it safe to assume a msg_id will not be resubmitted? self.aborted.remove(msg_id) - self._send_abort_reply(stream, msg, idents) + await self._send_abort_reply(socket, msg, idents) return False return True - async def dispatch_shell(self, msg): - """dispatch shell requests""" - if not self.session: + def pre_handler_hook(self): + """Hook to execute before calling message handler""" + # ensure default_int_handler during handler call + + def post_handler_hook(self): + """Hook to execute after calling message handler""" + + async def enter_eventloop(self): + """enter eventloop""" + self.log.info("Entering eventloop %s", self.eventloop) + # record handle, so we can check when this changes + eventloop = self.eventloop + if eventloop is None: + self.log.info("Exiting as there is no eventloop") return - # flush control queue before handling shell requests - await self._flush_control_queue() - idents, msg = self.session.feed_identities(msg, copy=False) + async def advance_eventloop(): + # check if eventloop changed: + if self.eventloop is not eventloop: + self.log.info("exiting eventloop %s", eventloop) + return + self.log.debug("Advancing eventloop %s", eventloop) + try: + eventloop(self) + except KeyboardInterrupt: + # Ctrl-C shouldn't crash the kernel + self.log.error("KeyboardInterrupt caught in kernel") + pass + if self.eventloop is eventloop: + # schedule advance again + await schedule_next() + + async def schedule_next(): + """Schedule the next advance of the eventloop""" + # flush the eventloop every so often, + # giving us a chance to handle messages in the meantime + self.log.debug("Scheduling eventloop advance") + await sleep(0.001) + await advance_eventloop() + + # begin polling the eventloop + await schedule_next() + + _message_counter = Any( + help="""Monotonic counter of messages + """, + ) + + @default("_message_counter") + def _message_counter_default(self): + return itertools.count() + + async def shell_main(self): + async with create_task_group() as tg: + tg.start_soon(self.process_shell) + await to_thread.run_sync(self.shell_stop.wait) + tg.cancel_scope.cancel() + + async def process_shell(self): try: - msg = self.session.deserialize(msg, content=True, copy=False) - except Exception: + while True: + await self.process_shell_message() + except BaseException as e: + if self.shell_stop.is_set(): + return + raise e + + async def process_shell_message(self, msg=None): + no_msg = msg is None if self._is_test else not await self.shell_socket.poll(0) + + msg = msg or await self.shell_socket.recv_multipart() + received_time = time.monotonic() + idents, msg = self.session.feed_identities(msg, copy=True) + try: + msg = self.session.deserialize(msg, content=True, copy=True) + except BaseException: self.log.error("Invalid Message", exc_info=True) return @@ -386,13 +333,15 @@ async def dispatch_shell(self, msg): # Only abort execute requests if self._aborting and msg_type == "execute_request": - self._send_abort_reply(self.shell_stream, msg, idents) - self._publish_status("idle", "shell") - # flush to ensure reply is sent before - # handling the next request - if self.shell_stream: - self.shell_stream.flush(zmq.POLLOUT) - return + if not self.stop_on_error_timeout: + if no_msg: + self._aborting = False + elif received_time - self._aborted_time > self.stop_on_error_timeout: + self._aborting = False + if self._aborting: + await self._send_abort_reply(self.shell_socket, msg, idents) + self._publish_status("idle", "shell") + return # Print some info about this message and leave a '--->' marker, so it's # easier to trace visually the message chain when debugging. Each @@ -400,10 +349,10 @@ async def dispatch_shell(self, msg): self.log.debug("\n*** MESSAGE TYPE:%s***", msg_type) self.log.debug(" Content: %s\n --->\n ", msg["content"]) - if not self.should_handle(self.shell_stream, msg, idents): + if not await self.should_handle(self.shell_socket, msg, idents): return - handler = self.shell_handlers.get(msg_type, None) + handler = self.shell_handlers.get(msg_type) if handler is None: self.log.warning("Unknown message type: %r", msg_type) else: @@ -413,7 +362,7 @@ async def dispatch_shell(self, msg): except Exception: self.log.debug("Unable to signal in pre_handler_hook:", exc_info=True) try: - result = handler(self.shell_stream, idents, msg) + result = handler(self.shell_socket, idents, msg) if inspect.isawaitable(result): await result except Exception: @@ -430,147 +379,83 @@ async def dispatch_shell(self, msg): sys.stdout.flush() sys.stderr.flush() self._publish_status("idle", "shell") - # flush to ensure reply is sent before - # handling the next request - if self.shell_stream: - self.shell_stream.flush(zmq.POLLOUT) - - def pre_handler_hook(self): - """Hook to execute before calling message handler""" - # ensure default_int_handler during handler call - self.saved_sigint_handler = signal(SIGINT, default_int_handler) - def post_handler_hook(self): - """Hook to execute after calling message handler""" - signal(SIGINT, self.saved_sigint_handler) - - def enter_eventloop(self): - """enter eventloop""" - self.log.info("Entering eventloop %s", self.eventloop) - # record handle, so we can check when this changes - eventloop = self.eventloop - if eventloop is None: - self.log.info("Exiting as there is no eventloop") - return + async def control_main(self): + async with create_task_group() as tg: + for task in self.control_tasks: + tg.start_soon(task) + tg.start_soon(self.process_control) + await to_thread.run_sync(self.control_stop.wait) + tg.cancel_scope.cancel() - def advance_eventloop(): - # check if eventloop changed: - if self.eventloop is not eventloop: - self.log.info("exiting eventloop %s", eventloop) - return - if self.msg_queue.qsize(): - self.log.debug("Delaying eventloop due to waiting messages") - # still messages to process, make the eventloop wait - schedule_next() + async def process_control(self): + try: + while True: + await self.process_control_message() + except BaseException as e: + if self.control_stop.is_set(): return - self.log.debug("Advancing eventloop %s", eventloop) - try: - eventloop(self) - except KeyboardInterrupt: - # Ctrl-C shouldn't crash the kernel - self.log.error("KeyboardInterrupt caught in kernel") - pass - if self.eventloop is eventloop: - # schedule advance again - schedule_next() - - def schedule_next(): - """Schedule the next advance of the eventloop""" - # flush the eventloop every so often, - # giving us a chance to handle messages in the meantime - self.log.debug("Scheduling eventloop advance") - self.io_loop.call_later(0.001, advance_eventloop) - - # begin polling the eventloop - schedule_next() + raise e - async def do_one_iteration(self): - """Process a single shell message + async def process_control_message(self, msg=None): + msg = msg or await self.control_socket.recv_multipart() + idents, msg = self.session.feed_identities(msg, copy=True) + try: + msg = self.session.deserialize(msg, content=True, copy=True) + except Exception: + self.log.error("Invalid Control Message", exc_info=True) + return - Any pending control messages will be flushed as well + self.log.debug("Control received: %s", msg) - .. versionchanged:: 5 - This is now a coroutine - """ - # flush messages off of shell stream into the message queue - if self.shell_stream: - self.shell_stream.flush() - # process at most one shell message per iteration - await self.process_one(wait=False) + # Set the parent message for side effects. + self.set_parent(idents, msg, channel="control") + self._publish_status("busy", "control") - async def process_one(self, wait=True): - """Process one request + header = msg["header"] + msg_type = header["msg_type"] - Returns None if no message was handled. - """ - if wait: - t, dispatch, args = await self.msg_queue.get() + handler = self.control_handlers.get(msg_type, None) + if handler is None: + self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type) else: try: - t, dispatch, args = self.msg_queue.get_nowait() - except (asyncio.QueueEmpty, QueueEmpty): - return None - await dispatch(*args) - - async def dispatch_queue(self): - """Coroutine to preserve order of message handling - - Ensures that only one message is processing at a time, - even when the handler is async - """ - - while True: - try: - await self.process_one() + result = handler(self.control_socket, idents, msg) + if inspect.isawaitable(result): + await result except Exception: - self.log.exception("Error in message handler") + self.log.error("Exception in control handler:", exc_info=True) - _message_counter = Any( - help="""Monotonic counter of messages - """, - ) + sys.stdout.flush() + sys.stderr.flush() + self._publish_status("idle", "control") - @default("_message_counter") - def _message_counter_default(self): - return itertools.count() + async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: + """Process messages on shell and control channels""" + async with create_task_group() as tg: + self.control_stop = threading.Event() + if not self._is_test and self.control_socket is not None: + if self.control_thread: + self.control_thread.set_task(self.control_main) + self.control_thread.start() + else: + tg.start_soon(self.control_main) - def schedule_dispatch(self, dispatch, *args): - """schedule a message for dispatch""" - idx = next(self._message_counter) + self.shell_interrupt: queue.Queue[bool] = queue.Queue() + self.shell_is_awaiting = False + self.shell_is_blocking = False + self.shell_stop = threading.Event() + if not self._is_test and self.shell_socket is not None: + tg.start_soon(self.shell_main) - self.msg_queue.put_nowait( - ( - idx, - dispatch, - args, - ) - ) - # ensure the eventloop wakes up - self.io_loop.add_callback(lambda: None) - - def start(self): - """register dispatchers for streams""" - self.io_loop = ioloop.IOLoop.current() - self.msg_queue: Queue[t.Any] = Queue() - self.io_loop.add_callback(self.dispatch_queue) - - if self.control_stream: - self.control_stream.on_recv(self.dispatch_control, copy=False) - - control_loop = self.control_thread.io_loop if self.control_thread else self.io_loop - - asyncio.run_coroutine_threadsafe(self.poll_control_queue(), control_loop.asyncio_loop) - if self.shell_stream: - self.shell_stream.on_recv( - partial( - self.schedule_dispatch, - self.dispatch_shell, - ), - copy=False, - ) + # publish idle status + self._publish_status("starting", "shell") - # publish idle status - self._publish_status("starting", "shell") + task_status.started() + + def stop(self): + self.shell_stop.set() + self.control_stop.set() def record_ports(self, ports): """Record the ports that this kernel is using. @@ -658,7 +543,7 @@ def get_parent(self, channel=None): def send_response( self, - stream, + socket, msg_or_type, content=None, ident=None, @@ -679,7 +564,7 @@ def send_response( if not self.session: return return self.session.send( - stream, + socket, msg_or_type, content, self.get_parent(channel), @@ -708,7 +593,7 @@ def finish_metadata(self, parent, metadata, reply_content): """ return metadata - async def execute_request(self, stream, ident, parent): + async def execute_request(self, socket, ident, parent): """handle an execute_request""" if not self.session: return @@ -770,8 +655,8 @@ async def execute_request(self, stream, ident, parent): reply_content = json_clean(reply_content) metadata = self.finish_metadata(parent, metadata, reply_content) - reply_msg: dict[str, t.Any] = self.session.send( # type:ignore[assignment] - stream, + reply_msg = self.session.send( + socket, "execute_reply", reply_content, parent, @@ -782,7 +667,11 @@ async def execute_request(self, stream, ident, parent): self.log.debug("%s", reply_msg) if not silent and reply_msg["content"]["status"] == "error" and stop_on_error: - self._abort_queues() + # while this flag is true, + # execute requests will be aborted + self._aborting = True + self._aborted_time = time.monotonic() + self.log.info("Aborting queue") def do_execute( self, @@ -797,7 +686,7 @@ def do_execute( """Execute user code. Must be overridden by subclasses.""" raise NotImplementedError - async def complete_request(self, stream, ident, parent): + async def complete_request(self, socket, ident, parent): """Handle a completion request.""" if not self.session: return @@ -810,7 +699,7 @@ async def complete_request(self, stream, ident, parent): matches = await matches matches = json_clean(matches) - self.session.send(stream, "complete_reply", matches, parent, ident) + self.session.send(socket, "complete_reply", matches, parent, ident) def do_complete(self, code, cursor_pos): """Override in subclasses to find completions.""" @@ -822,7 +711,7 @@ def do_complete(self, code, cursor_pos): "status": "ok", } - async def inspect_request(self, stream, ident, parent): + async def inspect_request(self, socket, ident, parent): """Handle an inspect request.""" if not self.session: return @@ -839,14 +728,14 @@ async def inspect_request(self, stream, ident, parent): # Before we send this object over, we scrub it for JSON usage reply_content = json_clean(reply_content) - msg = self.session.send(stream, "inspect_reply", reply_content, parent, ident) + msg = self.session.send(socket, "inspect_reply", reply_content, parent, ident) self.log.debug("%s", msg) def do_inspect(self, code, cursor_pos, detail_level=0, omit_sections=()): """Override in subclasses to allow introspection.""" return {"status": "ok", "data": {}, "metadata": {}, "found": False} - async def history_request(self, stream, ident, parent): + async def history_request(self, socket, ident, parent): """Handle a history request.""" if not self.session: return @@ -857,7 +746,7 @@ async def history_request(self, stream, ident, parent): reply_content = await reply_content reply_content = json_clean(reply_content) - msg = self.session.send(stream, "history_reply", reply_content, parent, ident) + msg = self.session.send(socket, "history_reply", reply_content, parent, ident) self.log.debug("%s", msg) def do_history( @@ -875,13 +764,13 @@ def do_history( """Override in subclasses to access history.""" return {"status": "ok", "history": []} - async def connect_request(self, stream, ident, parent): + async def connect_request(self, socket, ident, parent): """Handle a connect request.""" if not self.session: return - content = self._recorded_ports.copy() if self._recorded_ports else {} + content = self._recorded_ports.copy() if self._recorded_ports is not None else {} content["status"] = "ok" - msg = self.session.send(stream, "connect_reply", content, parent, ident) + msg = self.session.send(socket, "connect_reply", content, parent, ident) self.log.debug("%s", msg) @property @@ -895,16 +784,16 @@ def kernel_info(self): "help_links": self.help_links, } - async def kernel_info_request(self, stream, ident, parent): + async def kernel_info_request(self, socket, ident, parent): """Handle a kernel info request.""" if not self.session: return content = {"status": "ok"} content.update(self.kernel_info) - msg = self.session.send(stream, "kernel_info_reply", content, parent, ident) + msg = self.session.send(socket, "kernel_info_reply", content, parent, ident) self.log.debug("%s", msg) - async def comm_info_request(self, stream, ident, parent): + async def comm_info_request(self, socket, ident, parent): """Handle a comm info request.""" if not self.session: return @@ -921,7 +810,7 @@ async def comm_info_request(self, stream, ident, parent): else: comms = {} reply_content = dict(comms=comms, status="ok") - msg = self.session.send(stream, "comm_info_reply", reply_content, parent, ident) + msg = self.session.send(socket, "comm_info_reply", reply_content, parent, ident) self.log.debug("%s", msg) def _send_interrupt_children(self): @@ -941,7 +830,7 @@ def _send_interrupt_children(self): else: os.kill(pid, SIGINT) - async def interrupt_request(self, stream, ident, parent): + async def interrupt_request(self, socket, ident, parent): """Handle an interrupt request.""" if not self.session: return @@ -958,31 +847,23 @@ async def interrupt_request(self, stream, ident, parent): "evalue": str(err), } - self.session.send(stream, "interrupt_reply", content, parent, ident=ident) + self.session.send(socket, "interrupt_reply", content, parent, ident=ident) return - async def shutdown_request(self, stream, ident, parent): + async def shutdown_request(self, socket, ident, parent): """Handle a shutdown request.""" if not self.session: return content = self.do_shutdown(parent["content"]["restart"]) if inspect.isawaitable(content): content = await content - self.session.send(stream, "shutdown_reply", content, parent, ident=ident) + self.session.send(socket, "shutdown_reply", content, parent, ident=ident) # same content, but different msg_id for broadcasting on IOPub self._shutdown_message = self.session.msg("shutdown_reply", content, parent) await self._at_shutdown() - self.log.debug("Stopping control ioloop") - if self.control_stream: - control_io_loop = self.control_stream.io_loop - control_io_loop.add_callback(control_io_loop.stop) - - self.log.debug("Stopping shell ioloop") - if self.shell_stream: - shell_io_loop = self.shell_stream.io_loop - shell_io_loop.add_callback(shell_io_loop.stop) + self.stop() def do_shutdown(self, restart): """Override in subclasses to do things when the frontend shuts down the @@ -990,7 +871,7 @@ def do_shutdown(self, restart): """ return {"status": "ok", "restart": restart} - async def is_complete_request(self, stream, ident, parent): + async def is_complete_request(self, socket, ident, parent): """Handle an is_complete request.""" if not self.session: return @@ -1001,14 +882,14 @@ async def is_complete_request(self, stream, ident, parent): if inspect.isawaitable(reply_content): reply_content = await reply_content reply_content = json_clean(reply_content) - reply_msg = self.session.send(stream, "is_complete_reply", reply_content, parent, ident) + reply_msg = self.session.send(socket, "is_complete_reply", reply_content, parent, ident) self.log.debug("%s", reply_msg) def do_is_complete(self, code): """Override in subclasses to find completions.""" return {"status": "unknown"} - async def debug_request(self, stream, ident, parent): + async def debug_request(self, socket, ident, parent): """Handle a debug request.""" if not self.session: return @@ -1017,7 +898,7 @@ async def debug_request(self, stream, ident, parent): if inspect.isawaitable(reply_content): reply_content = await reply_content reply_content = json_clean(reply_content) - reply_msg = self.session.send(stream, "debug_reply", reply_content, parent, ident) + reply_msg = self.session.send(socket, "debug_reply", reply_content, parent, ident) self.log.debug("%s", reply_msg) def get_process_metric_value(self, process, name, attribute=None): @@ -1033,7 +914,7 @@ def get_process_metric_value(self, process, name, attribute=None): except BaseException: return None - async def usage_request(self, stream, ident, parent): + async def usage_request(self, socket, ident, parent): """Handle a usage request.""" if not self.session: return @@ -1065,7 +946,7 @@ async def usage_request(self, stream, ident, parent): reply_content["host_cpu_percent"] = cpu_percent reply_content["cpu_count"] = psutil.cpu_count(logical=True) reply_content["host_virtual_memory"] = dict(psutil.virtual_memory()._asdict()) - reply_msg = self.session.send(stream, "usage_reply", reply_content, parent, ident) + reply_msg = self.session.send(socket, "usage_reply", reply_content, parent, ident) self.log.debug("%s", reply_msg) async def do_debug_request(self, msg): @@ -1075,7 +956,7 @@ async def do_debug_request(self, msg): # Engine methods (DEPRECATED) # --------------------------------------------------------------------------- - async def apply_request(self, stream, ident, parent): # pragma: no cover + async def apply_request(self, socket, ident, parent): # pragma: no cover """Handle an apply request.""" self.log.warning("apply_request is deprecated in kernel_base, moving to ipyparallel.") try: @@ -1098,7 +979,7 @@ async def apply_request(self, stream, ident, parent): # pragma: no cover if not self.session: return self.session.send( - stream, + socket, "apply_reply", reply_content, parent=parent, @@ -1115,7 +996,7 @@ def do_apply(self, content, bufs, msg_id, reply_metadata): # Control messages (DEPRECATED) # --------------------------------------------------------------------------- - async def abort_request(self, stream, ident, parent): # pragma: no cover + async def abort_request(self, socket, ident, parent): # pragma: no cover """abort a specific msg by id""" self.log.warning( "abort_request is deprecated in kernel_base. It is only part of IPython parallel" @@ -1123,8 +1004,6 @@ async def abort_request(self, stream, ident, parent): # pragma: no cover msg_ids = parent["content"].get("msg_ids", None) if isinstance(msg_ids, str): msg_ids = [msg_ids] - if not msg_ids: - self._abort_queues() for mid in msg_ids: self.aborted.add(str(mid)) @@ -1132,18 +1011,18 @@ async def abort_request(self, stream, ident, parent): # pragma: no cover if not self.session: return reply_msg = self.session.send( - stream, "abort_reply", content=content, parent=parent, ident=ident + socket, "abort_reply", content=content, parent=parent, ident=ident ) self.log.debug("%s", reply_msg) - async def clear_request(self, stream, idents, parent): # pragma: no cover + async def clear_request(self, socket, idents, parent): # pragma: no cover """Clear our namespace.""" self.log.warning( "clear_request is deprecated in kernel_base. It is only part of IPython parallel" ) content = self.do_clear() if self.session: - self.session.send(stream, "clear_reply", ident=idents, parent=parent, content=content) + self.session.send(socket, "clear_reply", ident=idents, parent=parent, content=content) def do_clear(self): """DEPRECATED since 4.0.3""" @@ -1161,36 +1040,8 @@ def _topic(self, topic): _aborting = Bool(False) - def _abort_queues(self): - # while this flag is true, - # execute requests will be aborted - self._aborting = True - self.log.info("Aborting queue") - - # flush streams, so all currently waiting messages - # are added to the queue - if self.shell_stream: - self.shell_stream.flush() - - # Callback to signal that we are done aborting - # dispatch functions _must_ be async - async def stop_aborting(): - self.log.info("Finishing abort") - self._aborting = False - - # put the stop-aborting event on the message queue - # so that all messages already waiting in the queue are aborted - # before we reset the flag - schedule_stop_aborting = partial(self.schedule_dispatch, stop_aborting) - - # if we have a delay, give messages this long to arrive on the queue - # before we stop aborting requests - asyncio.get_event_loop().call_later(self.stop_on_error_timeout, schedule_stop_aborting) - - def _send_abort_reply(self, stream, msg, idents): + async def _send_abort_reply(self, socket, msg, idents): """Send a reply to an aborted request""" - if not self.session: - return self.log.info(f"Aborting {msg['header']['msg_id']}: {msg['header']['msg_type']}") reply_type = msg["header"]["msg_type"].rsplit("_", 1)[0] + "_reply" status = {"status": "aborted"} @@ -1199,7 +1050,7 @@ def _send_abort_reply(self, stream, msg, idents): md.update(status) self.session.send( - stream, + socket, reply_type, metadata=md, content=status, @@ -1382,5 +1233,3 @@ async def _at_shutdown(self): ident=self._topic("shutdown"), ) self.log.debug("%s", self._shutdown_message) - if self.control_stream: - self.control_stream.flush(zmq.POLLOUT) diff --git a/pyproject.toml b/pyproject.toml index 8f9ac894..e1872446 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,16 +29,16 @@ dependencies = [ "ipython>=7.23.1", "comm>=0.1.1", "traitlets>=5.4.0", - "jupyter_client>=6.1.12", + "jupyter_client>=8.0.0", "jupyter_core>=4.12,!=5.0.*", # For tk event loop support only. "nest_asyncio", - "tornado>=6.1", "matplotlib-inline>=0.1", 'appnope;platform_system=="Darwin"', - "pyzmq>=20", + "pyzmq>=25.0", "psutil", "packaging", + "anyio>=4.0.0", ] [project.optional-dependencies] @@ -57,8 +57,8 @@ test = [ "flaky", "ipyparallel", "pre-commit", - "pytest-asyncio", - "pytest-timeout" + "pytest-timeout", + "trio", ] cov = [ "coverage[toml]", @@ -178,7 +178,6 @@ testpaths = [ "tests", "tests/inprocess" ] -asyncio_mode = "auto" timeout = 300 # Restore this setting to debug failures #timeout_method = "thread" @@ -197,7 +196,10 @@ filterwarnings= [ "ignore:unclosed TIMEOUT: + raise TimeoutError() KM.interrupt_kernel() reply = KC.get_shell_msg()["content"] diff --git a/tests/test_embed_kernel.py b/tests/test_embed_kernel.py index ff97edfa..68582407 100644 --- a/tests/test_embed_kernel.py +++ b/tests/test_embed_kernel.py @@ -206,7 +206,7 @@ def test_embed_kernel_func(): def trigger_stop(): time.sleep(1) app = IPKernelApp.instance() - app.io_loop.add_callback(app.io_loop.stop) + app.stop() IPKernelApp.clear_instance() thread = threading.Thread(target=trigger_stop) diff --git a/tests/test_io.py b/tests/test_io.py index 98f04789..638b0451 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -12,6 +12,7 @@ import pytest import zmq +import zmq.asyncio from jupyter_client.session import Session from ipykernel.iostream import MASTER, BackgroundSocket, IOPubThread, OutStream @@ -19,7 +20,7 @@ @pytest.fixture def ctx(): - ctx = zmq.Context() + ctx = zmq.asyncio.Context() yield ctx ctx.destroy() @@ -64,23 +65,23 @@ def test_io_isatty(iopub_thread): assert stream.isatty() -def test_io_thread(iopub_thread): +async def test_io_thread(anyio_backend, iopub_thread): thread = iopub_thread thread._setup_pipe_in() msg = [thread._pipe_uuid, b"a"] - thread._handle_pipe_msg(msg) + await thread._handle_pipe_msg(msg) ctx1, pipe = thread._setup_pipe_out() pipe.close() - thread._pipe_in.close() + thread._pipe_in1.close() thread._check_mp_mode = lambda: MASTER thread._really_send([b"hi"]) ctx1.destroy() - thread.close() + thread.stop() thread.close() thread._really_send(None) -def test_background_socket(iopub_thread): +async def test_background_socket(anyio_backend, iopub_thread): sock = BackgroundSocket(iopub_thread) assert sock.__class__ == BackgroundSocket with warnings.catch_warnings(): @@ -91,9 +92,10 @@ def test_background_socket(iopub_thread): sock.send(b"hi") -def test_outstream(iopub_thread): +async def test_outstream(anyio_backend, iopub_thread): session = Session() pub = iopub_thread.socket + with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) stream = OutStream(session, pub, "stdout") @@ -116,6 +118,7 @@ def test_outstream(iopub_thread): assert stream.writable() +@pytest.mark.anyio async def test_event_pipe_gc(iopub_thread): session = Session(key=b'abc') stream = OutStream( diff --git a/tests/test_ipkernel_direct.py b/tests/test_ipkernel_direct.py index c9201348..c0e1b1d1 100644 --- a/tests/test_ipkernel_direct.py +++ b/tests/test_ipkernel_direct.py @@ -4,7 +4,6 @@ import os import pytest -import zmq from IPython.core.history import DummyDB from ipykernel.comm.comm import BaseComm @@ -149,19 +148,21 @@ async def test_direct_clear(ipkernel): ipkernel.do_clear() +@pytest.mark.skip("ipykernel._cancel_on_sigint doesn't exist anymore") async def test_cancel_on_sigint(ipkernel: IPythonKernel) -> None: future: asyncio.Future = asyncio.Future() - with ipkernel._cancel_on_sigint(future): - pass + # with ipkernel._cancel_on_sigint(future): + # pass future.set_result(None) -def test_dispatch_debugpy(ipkernel: IPythonKernel) -> None: +async def test_dispatch_debugpy(ipkernel: IPythonKernel) -> None: msg = ipkernel.session.msg("debug_request", {}) msg_list = ipkernel.session.serialize(msg) - ipkernel.dispatch_debugpy([zmq.Message(m) for m in msg_list]) + await ipkernel.receive_debugpy_message(msg_list) +@pytest.mark.skip("Queues don't exist anymore") async def test_start(ipkernel: IPythonKernel) -> None: shell_future: asyncio.Future = asyncio.Future() control_future: asyncio.Future = asyncio.Future() @@ -174,14 +175,15 @@ async def fake_poll_control_queue(): ipkernel.dispatch_queue = fake_dispatch_queue # type:ignore ipkernel.poll_control_queue = fake_poll_control_queue # type:ignore - ipkernel.start() - ipkernel.debugpy_stream = None - ipkernel.start() - await ipkernel.process_one(False) + await ipkernel.start() + ipkernel.debugpy_socket = None + await ipkernel.start() + # await ipkernel.process_one(False) await shell_future await control_future +@pytest.mark.skip("Queues don't exist anymore") async def test_start_no_debugpy(ipkernel: IPythonKernel) -> None: shell_future: asyncio.Future = asyncio.Future() control_future: asyncio.Future = asyncio.Future() @@ -194,8 +196,8 @@ async def fake_poll_control_queue(): ipkernel.dispatch_queue = fake_dispatch_queue # type:ignore ipkernel.poll_control_queue = fake_poll_control_queue # type:ignore - ipkernel.debugpy_stream = None - ipkernel.start() + ipkernel.debugpy_socket = None + await ipkernel.start() await shell_future await control_future diff --git a/tests/test_kernel_direct.py b/tests/test_kernel_direct.py index dfb8a70f..ea3c6fe7 100644 --- a/tests/test_kernel_direct.py +++ b/tests/test_kernel_direct.py @@ -104,6 +104,7 @@ async def test_direct_debug_request(kernel): assert reply["header"]["msg_type"] == "debug_reply" +@pytest.mark.skip("Shell streams don't exist anymore") async def test_deprecated_features(kernel): with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) @@ -119,33 +120,26 @@ async def test_deprecated_features(kernel): async def test_process_control(kernel): from jupyter_client.session import DELIM - class FakeMsg: - def __init__(self, bytes): - self.bytes = bytes - - await kernel.process_control([FakeMsg(DELIM), 1]) + await kernel.process_control_message([DELIM, 1]) msg = kernel._prep_msg("does_not_exist") - await kernel.process_control(msg) + await kernel.process_control_message(msg) -def test_should_handle(kernel): +async def test_should_handle(kernel): msg = kernel.session.msg("debug_request", {}) kernel.aborted.add(msg["header"]["msg_id"]) - assert not kernel.should_handle(kernel.control_stream, msg, []) + assert not await kernel.should_handle(kernel.control_socket, msg, []) async def test_dispatch_shell(kernel): from jupyter_client.session import DELIM - class FakeMsg: - def __init__(self, bytes): - self.bytes = bytes - - await kernel.dispatch_shell([FakeMsg(DELIM), 1]) + await kernel.process_shell_message([DELIM, 1]) msg = kernel._prep_msg("does_not_exist") - await kernel.dispatch_shell(msg) + await kernel.process_shell_message(msg) +@pytest.mark.skip("kernelbase.do_one_iteration doesn't exist anymore") async def test_do_one_iteration(kernel): kernel.msg_queue = asyncio.Queue() await kernel.do_one_iteration() @@ -156,7 +150,7 @@ async def test_publish_debug_event(kernel): async def test_connect_request(kernel): - await kernel.connect_request(kernel.shell_stream, "foo", {}) + await kernel.connect_request(kernel.shell_socket, b"foo", {}) async def test_send_interrupt_children(kernel): diff --git a/tests/test_kernelapp.py b/tests/test_kernelapp.py index da38777d..6b9f451b 100644 --- a/tests/test_kernelapp.py +++ b/tests/test_kernelapp.py @@ -2,7 +2,6 @@ import os import threading import time -from unittest.mock import patch import pytest from jupyter_core.paths import secure_write @@ -40,7 +39,7 @@ def test_start_app(): def trigger_stop(): time.sleep(1) - app.io_loop.add_callback(app.io_loop.stop) + app.stop() thread = threading.Thread(target=trigger_stop) thread.start() @@ -121,11 +120,17 @@ def test_merge_connection_file(): @pytest.mark.skipif(trio is None, reason="requires trio") def test_trio_loop(): app = IPKernelApp(trio_loop=True) + + def trigger_stop(): + time.sleep(1) + app.stop() + + thread = threading.Thread(target=trigger_stop) + thread.start() + app.kernel = MockKernel() app.init_sockets() - with patch("ipykernel.trio_runner.TrioRunner.run", lambda _: None): - app.start() + app.start() app.cleanup_connection_file() - app.io_loop.add_callback(app.io_loop.stop) app.kernel.destroy() app.close() diff --git a/tests/test_message_spec.py b/tests/test_message_spec.py index 0c9e777c..58485d2a 100644 --- a/tests/test_message_spec.py +++ b/tests/test_message_spec.py @@ -5,6 +5,7 @@ import re import sys +import time from queue import Empty import pytest @@ -364,7 +365,6 @@ def test_execute_stop_on_error(): KC.execute(code='print("Hello")') KC.execute(code='print("world")') reply = KC.get_shell_msg(timeout=TIMEOUT) - print(reply) reply = KC.get_shell_msg(timeout=TIMEOUT) assert reply["content"]["status"] == "aborted" # second message, too @@ -595,10 +595,17 @@ def test_stream(): msg_id, reply = execute("print('hi')") - stdout = KC.get_iopub_msg(timeout=TIMEOUT) - validate_message(stdout, "stream", msg_id) - content = stdout["content"] - assert content["text"] == "hi\n" + stream = "" + t0 = time.monotonic() + while True: + msg = KC.get_iopub_msg(timeout=TIMEOUT) + validate_message(msg, "stream", msg_id) + stream += msg["content"]["text"] + assert "hi\n".startswith(stream) + if stream == "hi\n": + break + if time.monotonic() - t0 > TIMEOUT: + raise TimeoutError() def test_display_data():