Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle resizing in pty-shell #2803

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions modal/_utils/shell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
import asyncio
import contextlib
import errno
import fcntl
import os
import select
import signal
import struct
import sys
import termios
import threading
from collections.abc import Coroutine
from queue import Empty, Queue
from types import FrameType
from typing import Callable, Optional

from modal._pty import raw_terminal, set_nonblocking
Expand Down Expand Up @@ -77,3 +84,54 @@ async def _write():
yield
os.write(quit_pipe_write, b"\n")
write_task.cancel()


class WindowSizeHandler:
"""Handles terminal window resize events."""

def __init__(self):
"""Initialize window size handler. Must be called from the main thread to set signals properly.
In case this is invoked from a thread that is not the main thread, e.g. in tests, the context manager
becomes a no-op."""
self._is_main_thread = threading.current_thread() is threading.main_thread()
self._event_queue: Queue[tuple[int, int]] = Queue()

if self._is_main_thread and hasattr(signal, "SIGWINCH"):
signal.signal(signal.SIGWINCH, self._queue_resize_event)

def _queue_resize_event(self, signum: Optional[int] = None, frame: Optional[FrameType] = None) -> None:
"""Signal handler for SIGWINCH that queues events."""
try:
hw = struct.unpack("hh", fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"1234"))
rows, cols = hw
self._event_queue.put((rows, cols))
except Exception:
# ignore failed window size reads
pass

@contextlib.asynccontextmanager
async def watch_window_size(self, handler: Callable[[int, int], Coroutine]):
"""Context manager that processes window resize events from the queue.
Can be run from any thread. If the window manager was initialized from a thread that is not the main thread,
e.g. in tests, this context manager is a no-op.

Args:
handler: Callback function to handle window resize events
"""
if not self._is_main_thread:
yield
return

async def process_events():
while True:
try:
rows, cols = self._event_queue.get_nowait()
await handler(rows, cols)
except Empty:
await asyncio.sleep(0.1)

event_task = asyncio.create_task(process_events())
try:
yield
finally:
event_task.cancel()
4 changes: 3 additions & 1 deletion modal/cli/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from modal._pty import get_pty_info
from modal._utils.async_utils import synchronizer
from modal._utils.grpc_utils import retry_transient_errors
from modal._utils.shell_utils import WindowSizeHandler
from modal.cli.utils import ENV_OPTION, display_table, is_tty, stream_app_logs, timestamp_to_local
from modal.client import _Client
from modal.config import config
Expand Down Expand Up @@ -79,7 +80,8 @@ async def exec(
res: api_pb2.ContainerExecResponse = await client.stub.ContainerExec(req)

if pty:
await _ContainerProcess(res.exec_id, client).attach()
window_size_handler = WindowSizeHandler()
await _ContainerProcess(res.exec_id, client).attach(window_size_handler=window_size_handler)
else:
# TODO: redirect stderr to its own stream?
await _ContainerProcess(res.exec_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT).wait()
Expand Down
5 changes: 5 additions & 0 deletions modal/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import typer
from typing_extensions import TypedDict

from .._utils.shell_utils import WindowSizeHandler
from ..app import App, LocalEntrypoint
from ..config import config
from ..environments import ensure_env
Expand Down Expand Up @@ -461,6 +462,8 @@ def shell(
if pty is None:
pty = is_tty()

window_size_handler = WindowSizeHandler()

if platform.system() == "Windows":
raise InvalidError("`modal shell` is currently not supported on Windows")

Expand Down Expand Up @@ -503,6 +506,7 @@ def shell(
volumes=function_spec.volumes,
region=function_spec.scheduler_placement.proto.regions if function_spec.scheduler_placement else None,
pty=pty,
window_size_handler=window_size_handler,
proxy=function_spec.proxy,
)
else:
Expand All @@ -518,6 +522,7 @@ def shell(
volumes=volumes,
region=region.split(",") if region else [],
pty=pty,
window_size_handler=window_size_handler,
)

# NB: invoking under bash makes --cmd a lot more flexible.
Expand Down
23 changes: 18 additions & 5 deletions modal/container_process.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Copyright Modal Labs 2024
import asyncio
import platform
import struct
from typing import Generic, Optional, TypeVar

from modal_proto import api_pb2

from ._utils.async_utils import TaskContext, synchronize_api
from ._utils.deprecation import deprecation_error
from ._utils.grpc_utils import retry_transient_errors
from ._utils.shell_utils import stream_from_stdin, write_to_fd
from ._utils.shell_utils import WindowSizeHandler, stream_from_stdin, write_to_fd
from .client import _Client
from .exception import InteractiveTimeoutError, InvalidError
from .io_streams import _StreamReader, _StreamWriter
Expand Down Expand Up @@ -115,7 +116,7 @@ async def wait(self) -> int:
self._returncode = resp.exit_code
return self._returncode

async def attach(self, *, pty: Optional[bool] = None):
async def attach(self, *, window_size_handler: WindowSizeHandler, pty: Optional[bool] = None):
if platform.system() == "Windows":
print("interactive exec is not currently supported on Windows.")
return
Expand Down Expand Up @@ -151,6 +152,17 @@ async def _handle_input(data: bytes, message_index: int):
self.stdin.write(data)
await self.stdin.drain()

async def _send_window_resize(rows: int, cols: int):
# create resize sequence:
# - magic byte 0xC1 to identify the resize sequence
# - 2 bytes for the number of rows (big-endian)
# - 2 bytes for the number of columns (big-endian)
magic = bytes([0xC1])
dims = struct.pack(">HH", rows, cols)
resize_data = magic + dims
self.stdin.write(resize_data)
await self.stdin.drain()

async with TaskContext() as tc:
stdout_task = tc.create_task(_write_to_fd_loop(self.stdout))
stderr_task = tc.create_task(_write_to_fd_loop(self.stderr))
Expand All @@ -159,9 +171,10 @@ async def _handle_input(data: bytes, message_index: int):
# time out if we can't connect to the server fast enough
await asyncio.wait_for(on_connect.wait(), timeout=60)

async with stream_from_stdin(_handle_input, use_raw_terminal=True):
await stdout_task
await stderr_task
async with window_size_handler.watch_window_size(_send_window_resize):
async with stream_from_stdin(_handle_input, use_raw_terminal=True):
await stdout_task
await stderr_task

# TODO: this doesn't work right now.
# if exit_status != 0:
Expand Down
11 changes: 9 additions & 2 deletions modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ._utils.deprecation import deprecation_error
from ._utils.grpc_utils import retry_transient_errors
from ._utils.name_utils import check_object_name, is_valid_tag
from ._utils.shell_utils import WindowSizeHandler
from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
from .cls import _Cls
from .config import config, logger
Expand Down Expand Up @@ -563,7 +564,12 @@ def heartbeat():


async def _interactive_shell(
_app: _App, cmds: list[str], environment_name: str = "", pty: bool = True, **kwargs: Any
_app: _App,
cmds: list[str],
environment_name: str = "",
pty: bool = True,
window_size_handler: Optional[WindowSizeHandler] = None,
**kwargs: Any,
) -> None:
"""Run an interactive shell (like `bash`) within the image for this app.

Expand Down Expand Up @@ -611,7 +617,8 @@ async def _interactive_shell(
container_process = await sandbox.exec(
*sandbox_cmds, pty_info=get_pty_info(shell=True) if pty else None
)
await container_process.attach()
assert window_size_handler is not None, "window_size_handler must be provided when pty is True"
await container_process.attach(window_size_handler=window_size_handler)
else:
container_process = await sandbox.exec(
*sandbox_cmds, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT
Expand Down
Loading