Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
weih1121 committed Jan 2, 2025
1 parent 3c3c42c commit c7ba4f8
Showing 1 changed file with 13 additions and 28 deletions.
41 changes: 13 additions & 28 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,33 +686,6 @@ def get_expirable_clouds(
return expirable_clouds


@functools.lru_cache(maxsize=1)
def get_remote_identity(cloud: Optional[clouds.Cloud],
cluster_name: str) -> str:
"""Retrieves the remote identity for a given cloud and cluster name.
Args:
cloud: The cloud object or None.
cluster_name: The name of the cluster.
Returns:
The remote identity as a string.
"""
remote_identity_config = skypilot_config.get_nested(
(str(cloud).lower(), 'remote_identity'), None)
remote_identity = schemas.get_default_remote_identity(str(cloud).lower())
if isinstance(remote_identity_config, str):
remote_identity = remote_identity_config
if isinstance(remote_identity_config, list):
# Some clouds (e.g., AWS) support specifying multiple service accounts
# chosen based on the cluster name. Do the matching here to pick the
# correct one.
for profile in remote_identity_config:
if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]):
return list(profile.values())[0]
return remote_identity


# TODO: too many things happening here - leaky abstraction. Refactor.
@timeline.event
def write_cluster_config(
Expand Down Expand Up @@ -792,7 +765,19 @@ def write_cluster_config(
assert cluster_name is not None
excluded_clouds = set()

remote_identity = get_remote_identity(cloud, cluster_name)
remote_identity_config = skypilot_config.get_nested(
(str(cloud).lower(), 'remote_identity'), None)
remote_identity = schemas.get_default_remote_identity(str(cloud).lower())
if isinstance(remote_identity_config, str):
remote_identity = remote_identity_config
if isinstance(remote_identity_config, list):
# Some clouds (e.g., AWS) support specifying multiple service accounts
# chosen based on the cluster name. Do the matching here to pick the
# correct one.
for profile in remote_identity_config:
if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]):
remote_identity = list(profile.values())[0]
break
if remote_identity != schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value:
# If LOCAL_CREDENTIALS is not specified, we add the cloud to the
# excluded_clouds set, but we must also check if the cloud supports
Expand Down

0 comments on commit c7ba4f8

Please sign in to comment.