Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] make per-cloud catalog lookup parallel #4483

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
20 changes: 14 additions & 6 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION
from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL
from sky.utils import resources_utils
from sky.utils import rich_utils
from sky.utils import subprocess_utils
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky.clouds import cloud
Expand All @@ -31,8 +34,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
if single:
clouds = [clouds] # type: ignore

results = []
for cloud in clouds:
def _execute_catalog_method(cloud: str):
try:
cloud_module = importlib.import_module(
f'sky.clouds.service_catalog.{cloud.lower()}_catalog')
Expand All @@ -46,7 +48,10 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
raise AttributeError(
f'Module "{cloud}_catalog" does not '
f'implement the "{method_name}" method') from None
results.append(method(*args, **kwargs))
return method(*args, **kwargs)

results = subprocess_utils.run_in_parallel(_execute_catalog_method, clouds,
len(clouds)) # type: ignore
if single:
return results[0]
return results
Expand All @@ -71,9 +76,12 @@ def list_accelerators(
Returns: A dictionary of canonical accelerator names mapped to a list
of instance type offerings. See usage in cli.py.
"""
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter, region_filter, quantity_filter,
case_sensitive, all_regions, require_price)
with rich_utils.safe_status(
ux_utils.spinner_message('Listing accelerators')):
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
name_filter, region_filter,
quantity_filter, case_sensitive,
all_regions, require_price)
if not isinstance(results, list):
results = [results]
ret: Dict[str,
Expand Down
1 change: 1 addition & 0 deletions sky/clouds/service_catalog/aws_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def get_region_zones_for_instance_type(instance_type: str,
return us_region_list + other_region_list


@timeline.event
def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
Expand Down
9 changes: 6 additions & 3 deletions sky/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,9 +1293,12 @@ def _fill_in_launchable_resources(
if resources.cloud is not None else enabled_clouds)
# If clouds provide hints, store them for later printing.
hints: Dict[clouds.Cloud, str] = {}
for cloud in clouds_list:
feasible_resources = cloud.get_feasible_launchable_resources(
resources, num_nodes=task.num_nodes)

feasible_list = subprocess_utils.run_in_parallel(
lambda cloud, r=resources, n=task.num_nodes:
(cloud, cloud.get_feasible_launchable_resources(r, n)),
clouds_list)
for cloud, feasible_resources in feasible_list:
Comment on lines +1297 to +1301
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very minor nit: the run_in_parallel keeps the original order of the arguments, so no need to have cloud in the output, but zip(cloud_list, feasible_list) should be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, with this new knowledge, I personally still tend to loosen the behavior dependency on the run_in_parallel function for maintainability, wdyt?

if feasible_resources.hint is not None:
hints[cloud] = feasible_resources.hint
if len(feasible_resources.resources_list) > 0:
Expand Down
79 changes: 70 additions & 9 deletions sky/utils/rich_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""Rich status spinner utils."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! However, since this can affect many other part of the code, let's move this to another PR, and keep this PR clean for only having the parallelism changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but since status of sub thread is silent on master, there would be a UX degradation in between, is that okay?

import contextlib
import threading
from typing import Union
from typing import Dict, Optional, Union

import rich.console as rich_console

console = rich_console.Console(soft_wrap=True)
_status = None
_status_nesting_level = 0
_main_message = None

_logging_lock = threading.RLock()

# Track sub thread progress statuses
_thread_statuses: Dict[int, Optional[str]] = {}
_status_lock = threading.RLock()


class _NoOpConsoleStatus:
"""An empty class for multi-threaded console.status."""
Expand All @@ -35,15 +40,17 @@ class _RevertibleStatus:
"""A wrapper for status that can revert to previous message after exit."""

def __init__(self, message: str):
if _status is not None:
self.previous_message = _status.status
if _main_message is not None:
self.previous_message = _main_message
else:
self.previous_message = None
self.message = message

def __enter__(self):
global _status_nesting_level
_status.update(self.message)
global _main_message
_main_message = self.message
refresh()
_status_nesting_level += 1
_status.__enter__()
return _status
Expand All @@ -57,10 +64,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
_status.__exit__(exc_type, exc_val, exc_tb)
_status = None
else:
_status.update(self.previous_message)
global _main_message
_main_message = self.previous_message
refresh()

def update(self, *args, **kwargs):
_status.update(*args, **kwargs)
global _main_message
_main_message = _status.status
refresh()

def stop(self):
_status.stop()
Expand All @@ -69,16 +81,65 @@ def start(self):
_status.start()


def safe_status(msg: str) -> Union['rich_console.Status', _NoOpConsoleStatus]:
class _ThreadStatus:
"""A wrapper of sub thread status"""

def __init__(self, message: str):
self.thread_id = threading.get_ident()
self.message = message
self.previous_message = _thread_statuses.get(self.thread_id)

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.previous_message is not None:
_thread_statuses[self.thread_id] = self.previous_message
else:
# No previous message, remove the thread status
if self.thread_id in _thread_statuses:
del _thread_statuses[self.thread_id]
refresh()

def update(self, new_message: str):
self.message = new_message
_thread_statuses[self.thread_id] = new_message
refresh()

def stop(self):
_thread_statuses[self.thread_id] = None
refresh()

def start(self):
_thread_statuses[self.thread_id] = self.message
refresh()


def refresh():
"""Refresh status to include all thread statuses."""
if _status is None or _main_message is None:
return
with _status_lock:
msg = _main_message
for v in _thread_statuses.values():
if v is not None:
msg = msg + f'\n └─ {v}'
_status.update(msg)


def safe_status(msg: str) -> Union['rich_console.Status', '_NoOpConsoleStatus']:
"""A wrapper for multi-threaded console.status."""
from sky import sky_logging # pylint: disable=import-outside-toplevel
global _status
if (threading.current_thread() is threading.main_thread() and
not sky_logging.is_silent()):
if sky_logging.is_silent():
return _NoOpConsoleStatus()
if threading.current_thread() is threading.main_thread():
if _status is None:
_status = console.status(msg, refresh_per_second=8)
return _RevertibleStatus(msg)
return _NoOpConsoleStatus()
else:
return _ThreadStatus(msg)


def stop_safe_status():
Expand Down
7 changes: 7 additions & 0 deletions sky/utils/subprocess_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility functions for subprocesses."""
import collections
from multiprocessing import pool
import os
import random
Expand Down Expand Up @@ -113,6 +114,12 @@ def run_in_parallel(func: Callable,
A list of the return values of the function func, in the same order as the
arguments.
"""
if isinstance(args, collections.abc.Sized):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about we change the input args to be List[Any] and change the invocations to use lists to avoid this collections.abc.Sized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I will try this~

if len(args) == 0:
return []
# Short-circuit for single element
if len(args) == 1:
return [func(next(iter(args)))]
# Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long
processes = num_threads if num_threads is not None else get_parallel_threads(
)
Expand Down
Loading