From bf104470d2efe50437523697b736ac66b2c2fc16 Mon Sep 17 00:00:00 2001 From: Min RK Date: Sat, 26 Oct 2024 19:31:46 +0200 Subject: [PATCH] fix mixture of sync/async sockets in IOPubThread (#1275) --- ipykernel/inprocess/ipkernel.py | 4 +- ipykernel/inprocess/socket.py | 3 ++ ipykernel/iostream.py | 75 +++++++++++++++++++-------------- tests/test_io.py | 4 +- tests/test_kernel.py | 24 +++++------ 5 files changed, 64 insertions(+), 46 deletions(-) diff --git a/ipykernel/inprocess/ipkernel.py b/ipykernel/inprocess/ipkernel.py index 114e231d..c6f8c612 100644 --- a/ipykernel/inprocess/ipkernel.py +++ b/ipykernel/inprocess/ipkernel.py @@ -6,6 +6,7 @@ import logging import sys from contextlib import contextmanager +from typing import cast from anyio import TASK_STATUS_IGNORED from anyio.abc import TaskStatus @@ -146,7 +147,8 @@ def callback(msg): assert frontend is not None frontend.iopub_channel.call_handlers(msg) - self.iopub_thread.socket.on_recv = callback + iopub_socket = cast(DummySocket, self.iopub_thread.socket) + iopub_socket.on_recv = callback # ------ Trait initializers ----------------------------------------------- diff --git a/ipykernel/inprocess/socket.py b/ipykernel/inprocess/socket.py index edc77c28..5a2e0008 100644 --- a/ipykernel/inprocess/socket.py +++ b/ipykernel/inprocess/socket.py @@ -63,3 +63,6 @@ async def poll(self, timeout=0): assert timeout == 0 statistics = self.in_receive_stream.statistics() return statistics.current_buffer_used != 0 + + def close(self): + pass diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 81170b97..d8171017 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -3,6 +3,8 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import atexit import contextvars import io @@ -15,7 +17,7 @@ from collections import defaultdict, deque from io import StringIO, TextIOBase from threading import Event, Thread, local -from typing import Any, Callable, Optional +from typing import Any, Callable import zmq from anyio import create_task_group, run, sleep, to_thread @@ -25,8 +27,8 @@ # Globals # ----------------------------------------------------------------------------- -MASTER = 0 -CHILD = 1 +_PARENT = 0 +_CHILD = 1 PIPE_BUFFER_SIZE = 1000 @@ -87,9 +89,16 @@ def __init__(self, socket, pipe=False): Whether this process should listen for IOPub messages piped from subprocesses. """ - self.socket = socket + # ensure all of our sockets as sync zmq.Sockets + # don't create async wrappers until we are within the appropriate coroutines + self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket) + if self.socket.context is None: + # bug in pyzmq, shadow socket doesn't always inherit context attribute + self.socket.context = socket.context # type:ignore[unreachable] + self._context = socket.context + self.background_socket = BackgroundSocket(self) - self._master_pid = os.getpid() + self._main_pid = os.getpid() self._pipe_flag = pipe if pipe: self._setup_pipe_in() @@ -106,8 +115,7 @@ def __init__(self, socket, pipe=False): def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" - ctx = self.socket.context - self._pipe_in0 = ctx.socket(zmq.PULL) + self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket) self._pipe_in0.linger = 0 _uuid = b2a_hex(os.urandom(16)).decode("ascii") @@ -141,8 +149,8 @@ def _event_pipe(self): event_pipe = self._local.event_pipe except AttributeError: # new thread, new event pipe - ctx = zmq.Context(self.socket.context) - event_pipe = ctx.socket(zmq.PUSH) + # create sync base socket + event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket) event_pipe.linger = 0 event_pipe.connect(self._event_interface) self._local.event_pipe = event_pipe @@ -161,9 +169,11 @@ async def _handle_event(self): Whenever *an* event arrives on the event stream, *all* waiting events are processed in order. """ + # create async wrapper within coroutine + pipe_in = zmq.asyncio.Socket(self._pipe_in0) try: while True: - await self._pipe_in0.recv() + await pipe_in.recv() # freeze event count so new writes don't extend the queue # while we are processing n_events = len(self._events) @@ -177,12 +187,12 @@ async def _handle_event(self): def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" - ctx = self.socket.context + ctx = self._context # use UUID to authenticate pipe messages self._pipe_uuid = os.urandom(16) - self._pipe_in1 = ctx.socket(zmq.PULL) + self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket) self._pipe_in1.linger = 0 try: @@ -199,6 +209,8 @@ def _setup_pipe_in(self): async def _handle_pipe_msgs(self): """handle pipe messages from a subprocess""" + # create async wrapper within coroutine + self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1) try: while True: await self._handle_pipe_msg() @@ -209,8 +221,8 @@ async def _handle_pipe_msgs(self): async def _handle_pipe_msg(self, msg=None): """handle a pipe message from a subprocess""" - msg = msg or await self._pipe_in1.recv_multipart() - if not self._pipe_flag or not self._is_master_process(): + msg = msg or await self._async_pipe_in1.recv_multipart() + if not self._pipe_flag or not self._is_main_process(): return if msg[0] != self._pipe_uuid: print("Bad pipe message: %s", msg, file=sys.__stderr__) @@ -225,14 +237,14 @@ def _setup_pipe_out(self): pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port) return ctx, pipe_out - def _is_master_process(self): - return os.getpid() == self._master_pid + def _is_main_process(self): + return os.getpid() == self._main_pid def _check_mp_mode(self): """check for forks, and switch to zmq pipeline if necessary""" - if not self._pipe_flag or self._is_master_process(): - return MASTER - return CHILD + if not self._pipe_flag or self._is_main_process(): + return _PARENT + return _CHILD def start(self): """Start the IOPub thread""" @@ -265,7 +277,8 @@ def close(self): self._pipe_in0.close() if self._pipe_flag: self._pipe_in1.close() - self.socket.close() + if self.socket is not None: + self.socket.close() self.socket = None @property @@ -301,12 +314,12 @@ def _really_send(self, msg, *args, **kwargs): return mp_mode = self._check_mp_mode() - - if mp_mode != CHILD: - # we are master, do a regular send + if mp_mode != _CHILD: + # we are the main parent process, do a regular send + assert self.socket is not None self.socket.send_multipart(msg, *args, **kwargs) else: - # we are a child, pipe to master + # we are a child, pipe to parent process # new context/socket for every pipe-out # since forks don't teardown politely, use ctx.term to ensure send has completed ctx, pipe_out = self._setup_pipe_out() @@ -379,7 +392,7 @@ class OutStream(TextIOBase): flush_interval = 0.2 topic = None encoding = "UTF-8" - _exc: Optional[Any] = None + _exc: Any = None def fileno(self): """ @@ -477,7 +490,7 @@ def __init__( self._thread_to_parent = {} self._thread_to_parent_header = {} self._parent_header_global = {} - self._master_pid = os.getpid() + self._main_pid = os.getpid() self._flush_pending = False self._subprocess_flush_pending = False self._buffer_lock = threading.RLock() @@ -569,8 +582,8 @@ def _setup_stream_redirects(self, name): self.watch_fd_thread.daemon = True self.watch_fd_thread.start() - def _is_master_process(self): - return os.getpid() == self._master_pid + def _is_main_process(self): + return os.getpid() == self._main_pid def set_parent(self, parent): """Set the parent header.""" @@ -674,7 +687,7 @@ def _flush(self): ident=self.topic, ) - def write(self, string: str) -> Optional[int]: # type:ignore[override] + def write(self, string: str) -> int: """Write to current stream after encoding if necessary Returns @@ -700,7 +713,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override] msg = "I/O operation on closed file" raise ValueError(msg) - is_child = not self._is_master_process() + is_child = not self._is_main_process() # only touch the buffer in the IO thread to avoid races with self._buffer_lock: self._buffers[frozenset(parent.items())].write(string) @@ -708,7 +721,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override] # mp.Pool cannot be trusted to flush promptly (or ever), # and this helps. if self._subprocess_flush_pending: - return None + return 0 self._subprocess_flush_pending = True # We can not rely on self._io_loop.call_later from a subprocess self.pub_thread.schedule(self._flush) diff --git a/tests/test_io.py b/tests/test_io.py index e49bc276..e3ff2815 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -15,7 +15,7 @@ import zmq.asyncio from jupyter_client.session import Session -from ipykernel.iostream import MASTER, BackgroundSocket, IOPubThread, OutStream +from ipykernel.iostream import _PARENT, BackgroundSocket, IOPubThread, OutStream @pytest.fixture() @@ -73,7 +73,7 @@ async def test_io_thread(anyio_backend, iopub_thread): ctx1, pipe = thread._setup_pipe_out() pipe.close() thread._pipe_in1.close() - thread._check_mp_mode = lambda: MASTER + thread._check_mp_mode = lambda: _PARENT thread._really_send([b"hi"]) ctx1.destroy() thread.stop() diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 89d5e390..8efc3dcc 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -32,10 +32,10 @@ ) -def _check_master(kc, expected=True, stream="stdout"): +def _check_main(kc, expected=True, stream="stdout"): execute(kc=kc, code="import sys") flush_channels(kc) - msg_id, content = execute(kc=kc, code="print(sys.%s._is_master_process())" % stream) + msg_id, content = execute(kc=kc, code="print(sys.%s._is_main_process())" % stream) stdout, stderr = assemble_output(kc.get_iopub_msg) assert stdout.strip() == repr(expected) @@ -56,7 +56,7 @@ def test_simple_print(): stdout, stderr = assemble_output(kc.get_iopub_msg) assert stdout == "hi\n" assert stderr == "" - _check_master(kc, expected=True) + _check_main(kc, expected=True) def test_print_to_correct_cell_from_thread(): @@ -168,7 +168,7 @@ def test_capture_fd(): stdout, stderr = assemble_output(iopub) assert stdout == "capsys\n" assert stderr == "" - _check_master(kc, expected=True) + _check_main(kc, expected=True) @pytest.mark.skip(reason="Currently don't capture during test as pytest does its own capturing") @@ -182,7 +182,7 @@ def test_subprocess_peek_at_stream_fileno(): stdout, stderr = assemble_output(iopub) assert stdout == "CAP1\nCAP2\n" assert stderr == "" - _check_master(kc, expected=True) + _check_main(kc, expected=True) def test_sys_path(): @@ -218,7 +218,7 @@ def test_sys_path_profile_dir(): def test_subprocess_print(): """printing from forked mp.Process""" with new_kernel() as kc: - _check_master(kc, expected=True) + _check_main(kc, expected=True) flush_channels(kc) np = 5 code = "\n".join( @@ -238,8 +238,8 @@ def test_subprocess_print(): for n in range(np): assert stdout.count(str(n)) == 1, stdout assert stderr == "" - _check_master(kc, expected=True) - _check_master(kc, expected=True, stream="stderr") + _check_main(kc, expected=True) + _check_main(kc, expected=True, stream="stderr") @flaky(max_runs=3) @@ -261,8 +261,8 @@ def test_subprocess_noprint(): assert stdout == "" assert stderr == "" - _check_master(kc, expected=True) - _check_master(kc, expected=True, stream="stderr") + _check_main(kc, expected=True) + _check_main(kc, expected=True, stream="stderr") @flaky(max_runs=3) @@ -287,8 +287,8 @@ def test_subprocess_error(): assert stdout == "" assert "ValueError" in stderr - _check_master(kc, expected=True) - _check_master(kc, expected=True, stream="stderr") + _check_main(kc, expected=True) + _check_main(kc, expected=True, stream="stderr") # raw_input tests