Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] make per-cloud catalog lookup parallel #4483

Merged
merged 12 commits into from
Jan 13, 2025
2 changes: 1 addition & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3309,7 +3309,7 @@ def error_message() -> str:
# even if some of them raise exceptions. We should replace it with
# multi-process.
rich_utils.stop_safe_status()
subprocess_utils.run_in_parallel(_setup_node, range(num_nodes))
subprocess_utils.run_in_parallel(_setup_node, list(range(num_nodes)))

if detach_setup:
# Only set this when setup needs to be run outside the self._setup()
Expand Down
10 changes: 7 additions & 3 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION
from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL
from sky.utils import resources_utils
from sky.utils import subprocess_utils

if typing.TYPE_CHECKING:
from sky.clouds import cloud
Expand All @@ -31,8 +32,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
if single:
clouds = [clouds] # type: ignore

results = []
for cloud in clouds:
def _execute_catalog_method(cloud: str):
try:
cloud_module = importlib.import_module(
f'sky.clouds.service_catalog.{cloud.lower()}_catalog')
Expand All @@ -46,7 +46,11 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
raise AttributeError(
f'Module "{cloud}_catalog" does not '
f'implement the "{method_name}" method') from None
results.append(method(*args, **kwargs))
return method(*args, **kwargs)

results = subprocess_utils.run_in_parallel(_execute_catalog_method,
args=list(clouds),
num_threads=len(clouds))
if single:
return results[0]
return results
Expand Down
9 changes: 6 additions & 3 deletions sky/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,9 +1293,12 @@ def _fill_in_launchable_resources(
if resources.cloud is not None else enabled_clouds)
# If clouds provide hints, store them for later printing.
hints: Dict[clouds.Cloud, str] = {}
for cloud in clouds_list:
feasible_resources = cloud.get_feasible_launchable_resources(
resources, num_nodes=task.num_nodes)

feasible_list = subprocess_utils.run_in_parallel(
lambda cloud, r=resources, n=task.num_nodes:
(cloud, cloud.get_feasible_launchable_resources(r, n)),
clouds_list)
for cloud, feasible_resources in feasible_list:
Comment on lines +1297 to +1301
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very minor nit: the run_in_parallel keeps the original order of the arguments, so no need to have cloud in the output, but zip(cloud_list, feasible_list) should be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, with this new knowledge, I personally still tend to loosen the behavior dependency on the run_in_parallel function for maintainability, wdyt?

if feasible_resources.hint is not None:
hints[cloud] = feasible_resources.hint
if len(feasible_resources.resources_list) > 0:
Expand Down
2 changes: 1 addition & 1 deletion sky/provision/azure/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def create_single_instance(vm_i):
_create_vm(compute_client, vm_name, node_tags, provider_config,
node_config, network_interface.id)

subprocess_utils.run_in_parallel(create_single_instance, range(count))
subprocess_utils.run_in_parallel(create_single_instance, list(range(count)))

# Update disk performance tier
performance_tier = node_config.get('disk_performance_tier', None)
Expand Down
3 changes: 2 additions & 1 deletion sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,8 @@ def _create_pod_thread(i: int):

# Create pods in parallel
pods = subprocess_utils.run_in_parallel(_create_pod_thread,
range(to_start_count), _NUM_THREADS)
list(range(to_start_count)),
_NUM_THREADS)

# Process created pods
for pod in pods:
Expand Down
20 changes: 12 additions & 8 deletions sky/utils/accelerator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from sky.clouds import service_catalog
from sky.utils import rich_utils
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -88,14 +89,17 @@ def canonicalize_accelerator_name(accelerator: str,
if accelerator.lower() in mapping:
return mapping[accelerator.lower()]

# _ACCELERATORS may not be comprehensive.
# Users may manually add new accelerators to the catalogs, or download new
# catalogs (that have new accelerators) without upgrading SkyPilot.
# To cover such cases, we should search the accelerator name
# in the service catalog.
searched = service_catalog.list_accelerators(name_filter=accelerator,
case_sensitive=False,
clouds=cloud_str)
# Listing accelerators can be time-consuming since canonicalizing usually
# involves catalog reading with cache not warmed up.
with rich_utils.safe_status('Listing accelerators...'):
# _ACCELERATORS may not be comprehensive.
# Users may manually add new accelerators to the catalogs, or download
# new catalogs (that have new accelerators) without upgrading SkyPilot.
# To cover such cases, we should search the accelerator name
# in the service catalog.
searched = service_catalog.list_accelerators(name_filter=accelerator,
case_sensitive=False,
clouds=cloud_str)
names = list(searched.keys())

# Exact match.
Expand Down
9 changes: 7 additions & 2 deletions sky/utils/subprocess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import resource
import subprocess
import time
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import colorama
import psutil
Expand Down Expand Up @@ -97,7 +97,7 @@ def get_parallel_threads(cloud_str: Optional[str] = None) -> int:


def run_in_parallel(func: Callable,
args: Iterable[Any],
args: List[Any],
num_threads: Optional[int] = None) -> List[Any]:
"""Run a function in parallel on a list of arguments.

Expand All @@ -113,6 +113,11 @@ def run_in_parallel(func: Callable,
A list of the return values of the function func, in the same order as the
arguments.
"""
if len(args) == 0:
return []
# Short-circuit for single element
if len(args) == 1:
return [func(args[0])]
# Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long
processes = num_threads if num_threads is not None else get_parallel_threads(
)
Expand Down
Loading