diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 13b9bb2dbce..19fdef0bfad 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -686,33 +686,6 @@ def get_expirable_clouds( return expirable_clouds -@functools.lru_cache(maxsize=1) -def get_remote_identity(cloud: Optional[clouds.Cloud], - cluster_name: str) -> str: - """Retrieves the remote identity for a given cloud and cluster name. - - Args: - cloud: The cloud object or None. - cluster_name: The name of the cluster. - - Returns: - The remote identity as a string. - """ - remote_identity_config = skypilot_config.get_nested( - (str(cloud).lower(), 'remote_identity'), None) - remote_identity = schemas.get_default_remote_identity(str(cloud).lower()) - if isinstance(remote_identity_config, str): - remote_identity = remote_identity_config - if isinstance(remote_identity_config, list): - # Some clouds (e.g., AWS) support specifying multiple service accounts - # chosen based on the cluster name. Do the matching here to pick the - # correct one. - for profile in remote_identity_config: - if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]): - return list(profile.values())[0] - return remote_identity - - # TODO: too many things happening here - leaky abstraction. Refactor. @timeline.event def write_cluster_config( @@ -792,7 +765,19 @@ def write_cluster_config( assert cluster_name is not None excluded_clouds = set() - remote_identity = get_remote_identity(cloud, cluster_name) + remote_identity_config = skypilot_config.get_nested( + (str(cloud).lower(), 'remote_identity'), None) + remote_identity = schemas.get_default_remote_identity(str(cloud).lower()) + if isinstance(remote_identity_config, str): + remote_identity = remote_identity_config + if isinstance(remote_identity_config, list): + # Some clouds (e.g., AWS) support specifying multiple service accounts + # chosen based on the cluster name. Do the matching here to pick the + # correct one. + for profile in remote_identity_config: + if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]): + remote_identity = list(profile.values())[0] + break if remote_identity != schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value: # If LOCAL_CREDENTIALS is not specified, we add the cloud to the # excluded_clouds set, but we must also check if the cloud supports