Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
weih1121 committed Jan 2, 2025
1 parent 613b355 commit 3c3c42c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 15 deletions.
37 changes: 36 additions & 1 deletion sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,42 @@ def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]):
return common_utils.dump_yaml_str(new_config)


def get_expirable_clouds(
enabled_clouds: Sequence[clouds.Cloud]) -> List[clouds.Cloud]:
"""Returns a list of clouds that use local credentials and whose credentials can expire.
This function checks each cloud in the provided sequence to determine if it uses local credentials
and if its credentials can expire. If both conditions are met, the cloud is added to the list of
expirable clouds.
Args:
enabled_clouds (Sequence[clouds.Cloud]): A sequence of cloud objects to check.
Returns:
list[clouds.Cloud]: A list of cloud objects that use local credentials and whose credentials can expire.
"""
expirable_clouds = []
local_credentials_value = schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value
for cloud in enabled_clouds:
remote_identities = skypilot_config.get_nested(
(str(cloud).lower(), 'remote_identity'), None)
if remote_identities is None:
remote_identities = schemas.get_default_remote_identity(
str(cloud).lower())

local_credential_expiring = cloud.can_credential_expire()
if isinstance(remote_identities, str):
if remote_identities == local_credentials_value and local_credential_expiring:
expirable_clouds.append(cloud)
elif isinstance(remote_identities, list):
for profile in remote_identities:
if list(profile.values(
))[0] == local_credentials_value and local_credential_expiring:
expirable_clouds.append(cloud)
break
return expirable_clouds


@functools.lru_cache(maxsize=1)
def get_remote_identity(cloud: Optional[clouds.Cloud],
cluster_name: str) -> str:
Expand Down Expand Up @@ -787,7 +823,6 @@ def write_cluster_config(
if (remote_identity_config ==
schemas.RemoteIdentityOptions.NO_UPLOAD.value):
excluded_clouds.add(cloud_obj)

credentials = sky_check.get_cloud_credential_file_mounts(excluded_clouds)

auth_config = {'ssh_private_key': auth.PRIVATE_SSH_KEY_PATH}
Expand Down
29 changes: 15 additions & 14 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import sky
from sky import backends
from sky import check as sky_check
from sky import cloud_stores
from sky import clouds
from sky import exceptions
Expand Down Expand Up @@ -62,7 +63,6 @@
from sky.utils import log_utils
from sky.utils import resources_utils
from sky.utils import rich_utils
from sky.utils import schemas
from sky.utils import subprocess_utils
from sky.utils import timeline
from sky.utils import ux_utils
Expand Down Expand Up @@ -1998,20 +1998,21 @@ def provision_with_retries(
skip_unnecessary_provisioning else None)

failover_history: List[Exception] = list()

cloud = to_provision.cloud
# When jobs controller/server using the local credentials which are
# expiring it may cause the cluster to be leaked. So, checking the
# enabled clouds and expiring credentials and warning the user to use
# the credentials that never expire or a service account.
if task.is_controller_task():
remote_identity = backend_utils.get_remote_identity(
cloud, cluster_name)
local_credentials_value = schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value # pylint: disable=line-too-long
use_local_cred = remote_identity == local_credentials_value
expirable_cred = to_provision.cloud.can_credential_expire()
if use_local_cred and expirable_cred:
warnings = (
f'\nWarning: Expiring credentials detected for {cloud}.'
'Clusters may be leaked if the credentials expire while '
'jobs are running. It is recommended to use credentials '
'that never expire or a service account.')
enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh()
expirable_clouds = backend_utils.get_expirable_clouds(
enabled_clouds)

if len(expirable_clouds) > 0:
warnings = (f'\nWarning: Expiring credentials detected for '
f'{expirable_clouds}. Clusters may be leaked if '
f'the credentials expire while jobs are running. '
f'It is recommended to use credentials that never'
f' expire or a service account.')
click.secho(warnings, fg='yellow')

# Retrying launchable resources.
Expand Down

0 comments on commit 3c3c42c

Please sign in to comment.