Skip to content

Commit

Permalink
[UX]: support sub thread status attaching
Browse files Browse the repository at this point in the history
Signed-off-by: Aylei <[email protected]>
  • Loading branch information
aylei committed Dec 20, 2024
1 parent a73a9cb commit 8ece27c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 12 deletions.
12 changes: 9 additions & 3 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION
from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL
from sky.utils import resources_utils
from sky.utils import rich_utils
from sky.utils import subprocess_utils
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky.clouds import cloud
Expand Down Expand Up @@ -71,9 +74,12 @@ def list_accelerators(
Returns: A dictionary of canonical accelerator names mapped to a list
of instance type offerings. See usage in cli.py.
"""
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter, region_filter, quantity_filter,
case_sensitive, all_regions, require_price)
with rich_utils.safe_status(
ux_utils.spinner_message('Listing accelerators')):
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter, region_filter,
quantity_filter, case_sensitive,
all_regions, require_price)
if not isinstance(results, list):
results = [results]
ret: Dict[str,
Expand Down
79 changes: 70 additions & 9 deletions sky/utils/rich_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""Rich status spinner utils."""
import contextlib
import threading
from typing import Union
from typing import Dict, Optional, Union

import rich.console as rich_console

console = rich_console.Console(soft_wrap=True)
_status = None
_status_nesting_level = 0
_main_message = None

_logging_lock = threading.RLock()

# Track sub thread progress statuses
_thread_statuses: Dict[int, Optional[str]] = {}
_status_lock = threading.RLock()


class _NoOpConsoleStatus:
"""An empty class for multi-threaded console.status."""
Expand All @@ -35,15 +40,17 @@ class _RevertibleStatus:
"""A wrapper for status that can revert to previous message after exit."""

def __init__(self, message: str):
if _status is not None:
self.previous_message = _status.status
if _main_message is not None:
self.previous_message = _main_message
else:
self.previous_message = None
self.message = message

def __enter__(self):
global _status_nesting_level
_status.update(self.message)
global _main_message
_main_message = self.message
refresh()
_status_nesting_level += 1
_status.__enter__()
return _status
Expand All @@ -57,10 +64,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
_status.__exit__(exc_type, exc_val, exc_tb)
_status = None
else:
_status.update(self.previous_message)
global _main_message
_main_message = self.previous_message
refresh()

def update(self, *args, **kwargs):
_status.update(*args, **kwargs)
global _main_message
_main_message = _status.status
refresh()

def stop(self):
_status.stop()
Expand All @@ -69,16 +81,65 @@ def start(self):
_status.start()


def safe_status(msg: str) -> Union['rich_console.Status', _NoOpConsoleStatus]:
class _ThreadStatus:
"""A wrapper of sub thread status"""

def __init__(self, message: str):
self.thread_id = threading.get_ident()
self.message = message
self.previous_message = _thread_statuses.get(self.thread_id)

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.previous_message is not None:
_thread_statuses[self.thread_id] = self.previous_message
else:
# No previous message, remove the thread status
if self.thread_id in _thread_statuses:
del _thread_statuses[self.thread_id]
refresh()

def update(self, new_message: str):
self.message = new_message
_thread_statuses[self.thread_id] = new_message
refresh()

def stop(self):
_thread_statuses[self.thread_id] = None
refresh()

def start(self):
_thread_statuses[self.thread_id] = self.message
refresh()


def refresh():
"""Refresh status to include all thread statuses."""
if _status is None or _main_message is None:
return
with _status_lock:
msg = _main_message
for v in _thread_statuses.values():
if v is not None:
msg = msg + f'\n └─ {v}'
_status.update(msg)


def safe_status(msg: str) -> Union['rich_console.Status', '_NoOpConsoleStatus']:
"""A wrapper for multi-threaded console.status."""
from sky import sky_logging # pylint: disable=import-outside-toplevel
global _status
if (threading.current_thread() is threading.main_thread() and
not sky_logging.is_silent()):
if sky_logging.is_silent():
return _NoOpConsoleStatus()
if threading.current_thread() is threading.main_thread():
if _status is None:
_status = console.status(msg, refresh_per_second=8)
return _RevertibleStatus(msg)
return _NoOpConsoleStatus()
else:
return _ThreadStatus(msg)


def stop_safe_status():
Expand Down

0 comments on commit 8ece27c

Please sign in to comment.