Skip to content

Commit

Permalink
Replace thread add_task with start_soon
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 15, 2024
1 parent 1fe492a commit 9203727
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 37 deletions.
17 changes: 9 additions & 8 deletions ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import uuid
import warnings
from datetime import datetime
from functools import partial
from signal import SIGINT, SIGTERM, Signals

from .thread import CONTROL_THREAD_NAME
Expand Down Expand Up @@ -536,7 +537,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
self.control_stop = threading.Event()
if not self._is_test and self.control_socket is not None:
if self.control_thread:
self.control_thread.add_task(self.control_main)
self.control_thread.start_soon(self.control_main)
self.control_thread.start()
else:
tg.start_soon(self.control_main)
Expand All @@ -551,11 +552,11 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:

# Assign tasks to and start shell channel thread.
manager = self.shell_channel_thread.manager
self.shell_channel_thread.add_task(self.shell_channel_thread_main)
self.shell_channel_thread.add_task(
manager.listen_from_control, self.shell_main, self.shell_channel_thread
self.shell_channel_thread.start_soon(self.shell_channel_thread_main)
self.shell_channel_thread.start_soon(
partial(manager.listen_from_control, self.shell_main, self.shell_channel_thread)
)
self.shell_channel_thread.add_task(manager.listen_from_subshells)
self.shell_channel_thread.start_soon(manager.listen_from_subshells)
self.shell_channel_thread.start()
else:
if not self._is_test and self.shell_socket is not None:
Expand Down Expand Up @@ -1085,7 +1086,7 @@ async def create_subshell_request(self, socket, ident, parent) -> None:
# This should only be called in the control thread if it exists.
# Request is passed to shell channel thread to process.
other_socket = await self.shell_channel_thread.manager.get_control_other_socket(
self.control_thread.get_task_group()
self.control_thread
)
await other_socket.asend_json({"type": "create"})
reply = await other_socket.arecv_json()
Expand All @@ -1109,7 +1110,7 @@ async def delete_subshell_request(self, socket, ident, parent) -> None:
# This should only be called in the control thread if it exists.
# Request is passed to shell channel thread to process.
other_socket = await self.shell_channel_thread.manager.get_control_other_socket(
self.control_thread.get_task_group()
self.control_thread
)
await other_socket.asend_json({"type": "delete", "subshell_id": subshell_id})
reply = await other_socket.arecv_json()
Expand All @@ -1126,7 +1127,7 @@ async def list_subshell_request(self, socket, ident, parent) -> None:
# This should only be called in the control thread if it exists.
# Request is passed to shell channel thread to process.
other_socket = await self.shell_channel_thread.manager.get_control_other_socket(
self.control_thread.get_task_group()
self.control_thread
)
await other_socket.asend_json({"type": "list"})
reply = await other_socket.arecv_json()
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/shellchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
def manager(self) -> SubshellManager:
# Lazy initialisation.
if self._manager is None:
self._manager = SubshellManager(self._context, self._shell_socket, self.get_task_group)
self._manager = SubshellManager(self._context, self._shell_socket)
return self._manager

def run(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions ipykernel/subshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ async def create_pair_socket(
) -> None:
"""Create inproc PAIR socket, for communication with shell channel thread.
Should be called from this thread, so usually via add_task before the
Should be called from this thread, so usually via start_soon before the
thread is started.
"""
assert current_thread() == self
self._pair_socket = zmq_anyio.Socket(context, zmq.PAIR)
self._pair_socket.connect(address)
self.add_task(self._pair_socket.start)
self.start_soon(self._pair_socket.start)

def run(self) -> None:
try:
Expand Down
14 changes: 6 additions & 8 deletions ipykernel/subshell_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import typing as t
import uuid
from dataclasses import dataclass
from functools import partial
from threading import Lock, current_thread, main_thread
from typing import Callable

import zmq
import zmq_anyio
Expand Down Expand Up @@ -44,13 +44,11 @@ def __init__(
self,
context: zmq.Context, # type: ignore[type-arg]
shell_socket: zmq_anyio.Socket,
get_task_group: Callable[[], TaskGroup],
):
assert current_thread() == main_thread()

self._context: zmq.Context = context # type: ignore[type-arg]
self._shell_socket = shell_socket
self._get_task_group = get_task_group
self._cache: dict[str, Subshell] = {}
self._lock_cache = Lock()
self._lock_shell_socket = Lock()
Expand Down Expand Up @@ -91,9 +89,9 @@ def close(self) -> None:
break
self._stop_subshell(subshell)

async def get_control_other_socket(self, task_group: TaskGroup) -> zmq_anyio.Socket:
async def get_control_other_socket(self, thread: BaseThread) -> zmq_anyio.Socket:
if not self._control_other_socket.started.is_set():
task_group.start_soon(self._control_other_socket.start)
thread.start_soon(self._control_other_socket.start)
await self._control_other_socket.started.wait()
return self._control_other_socket

Expand Down Expand Up @@ -134,7 +132,7 @@ async def listen_from_control(self, subshell_task: t.Any, thread: BaseThread) ->
assert current_thread().name == SHELL_CHANNEL_THREAD_NAME

if not self._control_shell_channel_socket.started.is_set():
thread.get_task_group().start_soon(self._control_shell_channel_socket.start)
thread.start_soon(self._control_shell_channel_socket.start)
await self._control_shell_channel_socket.started.wait()
socket = self._control_shell_channel_socket
while True:
Expand Down Expand Up @@ -200,8 +198,8 @@ async def _create_subshell(self, subshell_task: t.Any) -> str:
await self._send_stream.send(subshell_id)

address = self._get_inproc_socket_address(subshell_id)
thread.add_task(thread.create_pair_socket, self._context, address)
thread.add_task(subshell_task, subshell_id)
thread.start_soon(partial(thread.create_pair_socket, self._context, address))
thread.start_soon(partial(subshell_task, subshell_id))
thread.start()

return subshell_id
Expand Down
35 changes: 17 additions & 18 deletions ipykernel/thread.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Base class for threads."""
import typing as t
from threading import Event, Thread
from __future__ import annotations

import queue
from collections.abc import Awaitable
from threading import Thread
from typing import Callable

from anyio import create_task_group, run, to_thread
from anyio.abc import TaskGroup

CONTROL_THREAD_NAME = "Control"
SHELL_CHANNEL_THREAD_NAME = "Shell channel"
Expand All @@ -17,31 +20,27 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.pydev_do_not_trace = True
self.is_pydev_daemon_thread = True
self.__stop = Event()
self._tasks_and_args: list[tuple[t.Any, t.Any]] = []

def get_task_group(self) -> TaskGroup:
return self._task_group
self._tasks = queue.Queue()

def add_task(self, task: t.Any, *args: t.Any) -> None:
# May only add tasks before the thread is started.
self._tasks_and_args.append((task, args))
def start_soon(self, task: Callable[[], Awaitable[None]] | None) -> None:
self._tasks.put(task)

def run(self) -> t.Any:
def run(self) -> None:
"""Run the thread."""
return run(self._main)
run(self._main)

async def _main(self) -> None:
async with create_task_group() as tg:
self._task_group = tg
for task, args in self._tasks_and_args:
tg.start_soon(task, *args)
await to_thread.run_sync(self.__stop.wait)
while True:
task = await to_thread.run_sync(self._tasks.get)
if task is None:
break
tg.start_soon(task)
tg.cancel_scope.cancel()

def stop(self) -> None:
"""Stop the thread.
This method is threadsafe.
"""
self.__stop.set()
self._tasks.put(None)

0 comments on commit 9203727

Please sign in to comment.