Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mixture of sync/async sockets in IOPubThread #1275

Merged
merged 5 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ipykernel/inprocess/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import sys
from contextlib import contextmanager
from typing import cast

from anyio import TASK_STATUS_IGNORED
from anyio.abc import TaskStatus
Expand Down Expand Up @@ -146,7 +147,8 @@ def callback(msg):
assert frontend is not None
frontend.iopub_channel.call_handlers(msg)

self.iopub_thread.socket.on_recv = callback
iopub_socket = cast(DummySocket, self.iopub_thread.socket)
iopub_socket.on_recv = callback

# ------ Trait initializers -----------------------------------------------

Expand Down
3 changes: 3 additions & 0 deletions ipykernel/inprocess/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ async def poll(self, timeout=0):
assert timeout == 0
statistics = self.in_receive_stream.statistics()
return statistics.current_buffer_used != 0

def close(self):
pass
74 changes: 42 additions & 32 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.

from __future__ import annotations

import atexit
import contextvars
import io
Expand All @@ -15,7 +17,7 @@
from collections import defaultdict, deque
from io import StringIO, TextIOBase
from threading import Event, Thread, local
from typing import Any, Callable, Deque, Dict, Optional
from typing import Any, Callable

import zmq
from anyio import create_task_group, run, sleep, to_thread
Expand All @@ -25,8 +27,8 @@
# Globals
# -----------------------------------------------------------------------------

MASTER = 0
CHILD = 1
_PARENT = 0
_CHILD = 1

PIPE_BUFFER_SIZE = 1000

Expand Down Expand Up @@ -87,15 +89,19 @@ def __init__(self, socket, pipe=False):
Whether this process should listen for IOPub messages
piped from subprocesses.
"""
self.socket = socket
# ensure all of our sockets as sync zmq.Sockets
# don't create async wrappers until we are within the appropriate coroutines
self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket)
self._sync_context: zmq.Context[zmq.Socket[bytes]] = zmq.Context(socket.context)

self.background_socket = BackgroundSocket(self)
self._master_pid = os.getpid()
self._main_pid = os.getpid()
self._pipe_flag = pipe
if pipe:
self._setup_pipe_in()
self._local = threading.local()
self._events: Deque[Callable[..., Any]] = deque()
self._event_pipes: Dict[threading.Thread, Any] = {}
self._events: deque[Callable[..., Any]] = deque()
self._event_pipes: dict[threading.Thread, Any] = {}
self._event_pipe_gc_lock: threading.Lock = threading.Lock()
self._event_pipe_gc_seconds: float = 10
self._setup_event_pipe()
Expand All @@ -106,7 +112,7 @@ def __init__(self, socket, pipe=False):

def _setup_event_pipe(self):
"""Create the PULL socket listening for events that should fire in this thread."""
ctx = self.socket.context
ctx = self._sync_context
self._pipe_in0 = ctx.socket(zmq.PULL)
self._pipe_in0.linger = 0

Expand Down Expand Up @@ -141,8 +147,7 @@ def _event_pipe(self):
event_pipe = self._local.event_pipe
except AttributeError:
# new thread, new event pipe
ctx = zmq.Context(self.socket.context)
event_pipe = ctx.socket(zmq.PUSH)
event_pipe = self._sync_context.socket(zmq.PUSH)
event_pipe.linger = 0
event_pipe.connect(self._event_interface)
self._local.event_pipe = event_pipe
Expand All @@ -161,9 +166,11 @@ async def _handle_event(self):
Whenever *an* event arrives on the event stream,
*all* waiting events are processed in order.
"""
# create async wrapper within coroutine
pipe_in = zmq.asyncio.Socket.shadow(self._pipe_in0)
try:
while True:
await self._pipe_in0.recv()
await pipe_in.recv()
# freeze event count so new writes don't extend the queue
# while we are processing
n_events = len(self._events)
Expand All @@ -177,7 +184,7 @@ async def _handle_event(self):

def _setup_pipe_in(self):
"""setup listening pipe for IOPub from forked subprocesses"""
ctx = self.socket.context
ctx = self._sync_context

# use UUID to authenticate pipe messages
self._pipe_uuid = os.urandom(16)
Expand All @@ -199,6 +206,8 @@ def _setup_pipe_in(self):

async def _handle_pipe_msgs(self):
"""handle pipe messages from a subprocess"""
# create async wrapper within coroutine
self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1)
minrk marked this conversation as resolved.
Show resolved Hide resolved
try:
while True:
await self._handle_pipe_msg()
Expand All @@ -209,8 +218,8 @@ async def _handle_pipe_msgs(self):

async def _handle_pipe_msg(self, msg=None):
"""handle a pipe message from a subprocess"""
msg = msg or await self._pipe_in1.recv_multipart()
if not self._pipe_flag or not self._is_master_process():
msg = msg or await self._async_pipe_in1.recv_multipart()
if not self._pipe_flag or not self._is_main_process():
return
if msg[0] != self._pipe_uuid:
print("Bad pipe message: %s", msg, file=sys.__stderr__)
Expand All @@ -225,14 +234,14 @@ def _setup_pipe_out(self):
pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port)
return ctx, pipe_out

def _is_master_process(self):
return os.getpid() == self._master_pid
def _is_main_process(self):
return os.getpid() == self._main_pid

def _check_mp_mode(self):
"""check for forks, and switch to zmq pipeline if necessary"""
if not self._pipe_flag or self._is_master_process():
return MASTER
return CHILD
if not self._pipe_flag or self._is_main_process():
return _PARENT
return _CHILD

def start(self):
"""Start the IOPub thread"""
Expand Down Expand Up @@ -265,7 +274,8 @@ def close(self):
self._pipe_in0.close()
if self._pipe_flag:
self._pipe_in1.close()
self.socket.close()
if self.socket is not None:
self.socket.close()
self.socket = None

@property
Expand Down Expand Up @@ -301,12 +311,12 @@ def _really_send(self, msg, *args, **kwargs):
return

mp_mode = self._check_mp_mode()

if mp_mode != CHILD:
# we are master, do a regular send
if mp_mode != _CHILD:
# we are the main parent process, do a regular send
assert self.socket is not None
self.socket.send_multipart(msg, *args, **kwargs)
else:
# we are a child, pipe to master
# we are a child, pipe to parent process
# new context/socket for every pipe-out
# since forks don't teardown politely, use ctx.term to ensure send has completed
ctx, pipe_out = self._setup_pipe_out()
Expand Down Expand Up @@ -379,7 +389,7 @@ class OutStream(TextIOBase):
flush_interval = 0.2
topic = None
encoding = "UTF-8"
_exc: Optional[Any] = None
_exc: Any = None

def fileno(self):
"""
Expand Down Expand Up @@ -470,14 +480,14 @@ def __init__(
self.pub_thread = pub_thread
self.name = name
self.topic = b"stream." + name.encode()
self._parent_header: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar(
self._parent_header: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
"parent_header"
)
self._parent_header.set({})
self._thread_to_parent = {}
self._thread_to_parent_header = {}
self._parent_header_global = {}
self._master_pid = os.getpid()
self._main_pid = os.getpid()
self._flush_pending = False
self._subprocess_flush_pending = False
self._buffer_lock = threading.RLock()
Expand Down Expand Up @@ -569,8 +579,8 @@ def _setup_stream_redirects(self, name):
self.watch_fd_thread.daemon = True
self.watch_fd_thread.start()

def _is_master_process(self):
return os.getpid() == self._master_pid
def _is_main_process(self):
return os.getpid() == self._main_pid

def set_parent(self, parent):
"""Set the parent header."""
Expand Down Expand Up @@ -674,7 +684,7 @@ def _flush(self):
ident=self.topic,
)

def write(self, string: str) -> Optional[int]: # type:ignore[override]
def write(self, string: str) -> int:
"""Write to current stream after encoding if necessary

Returns
Expand All @@ -700,15 +710,15 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
msg = "I/O operation on closed file"
raise ValueError(msg)

is_child = not self._is_master_process()
is_child = not self._is_main_process()
# only touch the buffer in the IO thread to avoid races
with self._buffer_lock:
self._buffers[frozenset(parent.items())].write(string)
if is_child:
# mp.Pool cannot be trusted to flush promptly (or ever),
# and this helps.
if self._subprocess_flush_pending:
return None
return 0
self._subprocess_flush_pending = True
# We can not rely on self._io_loop.call_later from a subprocess
self.pub_thread.schedule(self._flush)
Expand Down
Loading