diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 8974a0129bd..99726719f1d 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3309,7 +3309,7 @@ def error_message() -> str: # even if some of them raise exceptions. We should replace it with # multi-process. rich_utils.stop_safe_status() - subprocess_utils.run_in_parallel(_setup_node, range(num_nodes)) + subprocess_utils.run_in_parallel(_setup_node, list(range(num_nodes))) if detach_setup: # Only set this when setup needs to be run outside the self._setup() diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index d28b530ff06..3aad5a0b7fd 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -10,6 +10,7 @@ from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL from sky.utils import resources_utils +from sky.utils import subprocess_utils if typing.TYPE_CHECKING: from sky.clouds import cloud @@ -31,8 +32,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): if single: clouds = [clouds] # type: ignore - results = [] - for cloud in clouds: + def _execute_catalog_method(cloud: str): try: cloud_module = importlib.import_module( f'sky.clouds.service_catalog.{cloud.lower()}_catalog') @@ -46,7 +46,11 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): raise AttributeError( f'Module "{cloud}_catalog" does not ' f'implement the "{method_name}" method') from None - results.append(method(*args, **kwargs)) + return method(*args, **kwargs) + + results = subprocess_utils.run_in_parallel(_execute_catalog_method, + args=list(clouds), + num_threads=len(clouds)) if single: return results[0] return results diff --git a/sky/optimizer.py b/sky/optimizer.py index 2f70dd39429..d22029f1dc9 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -1293,9 +1293,12 @@ def _fill_in_launchable_resources( if resources.cloud is not None else enabled_clouds) # If clouds provide hints, store them for later printing. hints: Dict[clouds.Cloud, str] = {} - for cloud in clouds_list: - feasible_resources = cloud.get_feasible_launchable_resources( - resources, num_nodes=task.num_nodes) + + feasible_list = subprocess_utils.run_in_parallel( + lambda cloud, r=resources, n=task.num_nodes: + (cloud, cloud.get_feasible_launchable_resources(r, n)), + clouds_list) + for cloud, feasible_resources in feasible_list: if feasible_resources.hint is not None: hints[cloud] = feasible_resources.hint if len(feasible_resources.resources_list) > 0: diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 229d7361e22..4e461375a14 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -343,7 +343,7 @@ def create_single_instance(vm_i): _create_vm(compute_client, vm_name, node_tags, provider_config, node_config, network_interface.id) - subprocess_utils.run_in_parallel(create_single_instance, range(count)) + subprocess_utils.run_in_parallel(create_single_instance, list(range(count))) # Update disk performance tier performance_tier = node_config.get('disk_performance_tier', None) diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index c431b023ab9..a849dfc3044 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -804,7 +804,8 @@ def _create_pod_thread(i: int): # Create pods in parallel pods = subprocess_utils.run_in_parallel(_create_pod_thread, - range(to_start_count), _NUM_THREADS) + list(range(to_start_count)), + _NUM_THREADS) # Process created pods for pod in pods: diff --git a/sky/utils/accelerator_registry.py b/sky/utils/accelerator_registry.py index 78a708efb91..11dd5280ac4 100644 --- a/sky/utils/accelerator_registry.py +++ b/sky/utils/accelerator_registry.py @@ -3,6 +3,7 @@ from typing import Optional from sky.clouds import service_catalog +from sky.utils import rich_utils from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -88,14 +89,17 @@ def canonicalize_accelerator_name(accelerator: str, if accelerator.lower() in mapping: return mapping[accelerator.lower()] - # _ACCELERATORS may not be comprehensive. - # Users may manually add new accelerators to the catalogs, or download new - # catalogs (that have new accelerators) without upgrading SkyPilot. - # To cover such cases, we should search the accelerator name - # in the service catalog. - searched = service_catalog.list_accelerators(name_filter=accelerator, - case_sensitive=False, - clouds=cloud_str) + # Listing accelerators can be time-consuming since canonicalizing usually + # involves catalog reading with cache not warmed up. + with rich_utils.safe_status('Listing accelerators...'): + # _ACCELERATORS may not be comprehensive. + # Users may manually add new accelerators to the catalogs, or download + # new catalogs (that have new accelerators) without upgrading SkyPilot. + # To cover such cases, we should search the accelerator name + # in the service catalog. + searched = service_catalog.list_accelerators(name_filter=accelerator, + case_sensitive=False, + clouds=cloud_str) names = list(searched.keys()) # Exact match. diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 992c6bbe3ff..88d351632a3 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -5,7 +5,7 @@ import resource import subprocess import time -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import colorama import psutil @@ -97,7 +97,7 @@ def get_parallel_threads(cloud_str: Optional[str] = None) -> int: def run_in_parallel(func: Callable, - args: Iterable[Any], + args: List[Any], num_threads: Optional[int] = None) -> List[Any]: """Run a function in parallel on a list of arguments. @@ -113,6 +113,11 @@ def run_in_parallel(func: Callable, A list of the return values of the function func, in the same order as the arguments. """ + if len(args) == 0: + return [] + # Short-circuit for single element + if len(args) == 1: + return [func(args[0])] # Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long processes = num_threads if num_threads is not None else get_parallel_threads( )