Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to ubuntu for GCP and avoid key pair checking #1641

Closed
wants to merge 16 commits into from
Closed
187 changes: 109 additions & 78 deletions sky/authentication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module to enable a single SkyPilot key for all VMs in each cloud."""
import copy
import functools
import json
import os
import re
import socket
Expand Down Expand Up @@ -132,6 +133,104 @@ def _wait_for_compute_global_operation(project_name: str, operation_name: str,
return result


def _maybe_gcp_add_ssh_key_to_account(compute, project, config: Dict[str, Any],
os_login_enabled: bool):
"""Add ssh key to GCP account if using Debian image without cloud-init.

This function is for backward compatibility. It is only used when the user
is using the old Debian image without cloud-init. In this case, we need to
add the ssh key to the GCP account so that we can ssh into the instance.
"""
private_key_path, public_key_path = get_or_generate_keys()
user = config['auth']['ssh_user']

node_config = config.get('available_node_types',
{}).get('ray_head_default',
cblmemo marked this conversation as resolved.
Show resolved Hide resolved
{}).get('node_config', {})
image_id = node_config.get('disks', [{}])[0].get('initializeParams',
{}).get('sourceImage')
# image_id is None when TPU VM is used, as TPU VM does not use image.
if image_id is not None and 'debian' not in image_id.lower():
image_infos = clouds.GCP.get_image_infos(image_id)
if 'debian' not in json.dumps(image_infos).lower():
# The non-Debian images have the ssh key setup by cloud-init.
return
logger.info('Adding ssh key to GCP account.')
if os_login_enabled:
# Add ssh key to GCP with oslogin
subprocess.run(
'gcloud compute os-login ssh-keys add '
f'--key-file={public_key_path}',
check=True,
shell=True,
stdout=subprocess.DEVNULL)
# Enable ssh port for all the instances
enable_ssh_cmd = ('gcloud compute firewall-rules create '
'allow-ssh-ingress-from-iap '
'--direction=INGRESS '
'--action=allow '
'--rules=tcp:22 '
'--source-ranges=0.0.0.0/0')
proc = subprocess.run(enable_ssh_cmd,
check=False,
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE)
if proc.returncode != 0 and 'already exists' not in proc.stderr.decode(
'utf-8'):
subprocess_utils.handle_returncode(proc.returncode, enable_ssh_cmd,
'Failed to enable ssh port.',
proc.stderr.decode('utf-8'))
return config

# OS Login is not enabled for the project. Add the ssh key directly to the
# metadata.
project_keys: str = next( # type: ignore
(item for item in project['commonInstanceMetadata'].get('items', [])
if item['key'] == 'ssh-keys'), {}).get('value', '')
ssh_keys = project_keys.split('\n') if project_keys else []

# Get public key from file.
with open(public_key_path, 'r') as f:
public_key = f.read()

# Check if ssh key in Google Project's metadata
public_key_token = public_key.split(' ')[1]

key_found = False
for key in ssh_keys:
key_list = key.split(' ')
if len(key_list) != 3:
continue
if user == key_list[-1] and os.path.exists(
private_key_path) and key_list[1] == public_key.split(' ')[1]:
key_found = True

if not key_found:
new_ssh_key = '{user}:ssh-rsa {public_key_token} {user}'.format(
user=user, public_key_token=public_key_token)
metadata = project['commonInstanceMetadata'].get('items', [])

ssh_key_index = [
k for k, v in enumerate(metadata) if v['key'] == 'ssh-keys'
]
assert len(ssh_key_index) <= 1

if len(ssh_key_index) == 0:
metadata.append({'key': 'ssh-keys', 'value': new_ssh_key})
else:
first_ssh_key_index = ssh_key_index[0]
metadata[first_ssh_key_index]['value'] += '\n' + new_ssh_key

project['commonInstanceMetadata']['items'] = metadata

operation = compute.projects().setCommonInstanceMetadata(
project=project['name'],
body=project['commonInstanceMetadata']).execute()
_wait_for_compute_global_operation(project['name'], operation['name'],
compute)


# Snippets of code inspired from
# https://github.com/ray-project/ray/blob/master/python/ray/autoscaler/_private/gcp/config.py
# Takes in config, a yaml dict and outputs a postprocessed dict
Expand All @@ -140,15 +239,16 @@ def _wait_for_compute_global_operation(project_name: str, operation_name: str,
# Retry for the GCP as sometimes there will be connection reset by peer error.
@common_utils.retry
def setup_gcp_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
private_key_path, public_key_path = get_or_generate_keys()
_, public_key_path = get_or_generate_keys()
with open(public_key_path, 'r') as f:
public_key = f.read()
config = copy.deepcopy(config)

project_id = config['provider']['project_id']
compute = gcp.build('compute',
'v1',
credentials=None,
cache_discovery=False)
user = config['auth']['ssh_user']

try:
project = compute.projects().get(project=project_id).execute()
Expand Down Expand Up @@ -191,7 +291,8 @@ def setup_gcp_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
(item for item in project['commonInstanceMetadata'].get('items', [])
if item['key'] == 'enable-oslogin'), {}).get('value', 'False')

if project_oslogin.lower() == 'true':
oslogin_enabled = project_oslogin.lower() == 'true'
if oslogin_enabled:
# project.
logger.info(
f'OS Login is enabled for GCP project {project_id}. Running '
Expand All @@ -218,81 +319,11 @@ def setup_gcp_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
'account information.')
config['auth']['ssh_user'] = account.replace('@', '_').replace('.', '_')

# Add ssh key to GCP with oslogin
subprocess.run(
'gcloud compute os-login ssh-keys add '
f'--key-file={public_key_path}',
check=True,
shell=True,
stdout=subprocess.DEVNULL)
# Enable ssh port for all the instances
enable_ssh_cmd = ('gcloud compute firewall-rules create '
'allow-ssh-ingress-from-iap '
'--direction=INGRESS '
'--action=allow '
'--rules=tcp:22 '
'--source-ranges=0.0.0.0/0')
proc = subprocess.run(enable_ssh_cmd,
check=False,
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE)
if proc.returncode != 0 and 'already exists' not in proc.stderr.decode(
'utf-8'):
subprocess_utils.handle_returncode(proc.returncode, enable_ssh_cmd,
'Failed to enable ssh port.',
proc.stderr.decode('utf-8'))
return config

# OS Login is not enabled for the project. Add the ssh key directly to the
# metadata.
# TODO(zhwu): Use cloud init to add ssh public key, to avoid the permission
# issue. A blocker is that the cloud init is not installed in the debian
# image by default.
project_keys: str = next( # type: ignore
(item for item in project['commonInstanceMetadata'].get('items', [])
if item['key'] == 'ssh-keys'), {}).get('value', '')
ssh_keys = project_keys.split('\n') if project_keys else []

# Get public key from file.
with open(public_key_path, 'r') as f:
public_key = f.read()

# Check if ssh key in Google Project's metadata
public_key_token = public_key.split(' ')[1]

key_found = False
for key in ssh_keys:
key_list = key.split(' ')
if len(key_list) != 3:
continue
if user == key_list[-1] and os.path.exists(
private_key_path) and key_list[1] == public_key.split(' ')[1]:
key_found = True

if not key_found:
new_ssh_key = '{user}:ssh-rsa {public_key_token} {user}'.format(
user=user, public_key_token=public_key_token)
metadata = project['commonInstanceMetadata'].get('items', [])

ssh_key_index = [
k for k, v in enumerate(metadata) if v['key'] == 'ssh-keys'
]
assert len(ssh_key_index) <= 1

if len(ssh_key_index) == 0:
metadata.append({'key': 'ssh-keys', 'value': new_ssh_key})
else:
first_ssh_key_index = ssh_key_index[0]
metadata[first_ssh_key_index]['value'] += '\n' + new_ssh_key

project['commonInstanceMetadata']['items'] = metadata

operation = compute.projects().setCommonInstanceMetadata(
project=project['name'],
body=project['commonInstanceMetadata']).execute()
_wait_for_compute_global_operation(project['name'], operation['name'],
compute)
config = _replace_cloud_init_ssh_info_in_config(config, public_key)
# This function is for backward compatibility, as the user using the old
# Debian-based image may not have the cloud-init enabled, and we need to
# add the ssh key to the account.
_maybe_gcp_add_ssh_key_to_account(compute, project, config, oslogin_enabled)
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
return config


Expand Down
41 changes: 21 additions & 20 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import time
import typing
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from sky import clouds
from sky import exceptions
Expand Down Expand Up @@ -235,17 +235,15 @@ def get_egress_cost(self, num_gigabytes):
def is_same_cloud(self, other):
return isinstance(other, GCP)

def get_image_size(self, image_id: str, region: Optional[str]) -> float:
del region # unused
if image_id.startswith('skypilot:'):
return DEFAULT_GCP_IMAGE_GB
@classmethod
def get_image_infos(cls, image_id) -> Dict[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: get_image_info() and s/infos/info elsewhere?

try:
compute = gcp.build('compute',
'v1',
credentials=None,
cache_discovery=False)
except gcp.credential_error_exception() as e:
return DEFAULT_GCP_IMAGE_GB
return {}
try:
image_attrs = image_id.split('/')
if len(image_attrs) == 1:
Expand All @@ -254,7 +252,7 @@ def get_image_size(self, image_id: str, region: Optional[str]) -> float:
image_name = image_attrs[-1]
image_infos = compute.images().get(project=project,
image=image_name).execute()
return float(image_infos['diskSizeGb'])
return image_infos
except gcp.http_error_exception() as e:
if e.resp.status == 403:
with ux_utils.print_exception_no_traceback():
Expand All @@ -266,6 +264,15 @@ def get_image_size(self, image_id: str, region: Optional[str]) -> float:
'GCP.') from None
raise

def get_image_size(self, image_id: str, region: Optional[str]) -> float:
del region # unused
if image_id.startswith('skypilot:'):
return DEFAULT_GCP_IMAGE_GB
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: how do we guarantee that the ubuntu & debian tags have the same size, DEFAULT_GCP_IMAGE_GB?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The image size can be got using gcloud compute images describe projects/deeplearning-platform-release/global/images/common-cu113-v20230501-ubuntu-2004-py37, and it seems both of them have the same size of 50GB. Added a comment for the hack.

image_infos = self.get_image_infos(image_id)
if 'diskSizeGb' not in image_infos:
return DEFAULT_GCP_IMAGE_GB
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reading get_image_size() for the first time and this seems a bit unclear. If user passes in a custom image that does not have diskSizeGb, returning a default 50GB seems like a guess (maybe ok), rather than we the func name suggests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the images should have that field, but this is just a safeguard to avoid the function raising the error, as the image size check is not critical. Added a comment.

return float(image_infos['diskSizeGb'])

@classmethod
def get_default_instance_type(
cls,
Expand All @@ -287,10 +294,10 @@ def make_deploy_resources_variables(

# gcloud compute images list \
# --project deeplearning-platform-release \
# --no-standard-images
# --no-standard-images | grep ubuntu-2004
# We use the debian image, as the ubuntu image has some connectivity
# issue when first booted.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remnant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using ubuntu image now, we use grep ubuntu-2004 to find the correct images.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, meant that

        # We use the debian image, as the ubuntu image has some connectivity
        # issue when first booted.

could be removed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, good catch! Removed. Thanks!

image_id = 'skypilot:cpu-debian-10'
image_id = 'skypilot:cpu-ubuntu-2004'
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved

r = resources
# Find GPU spec, if any.
Expand Down Expand Up @@ -330,17 +337,11 @@ def make_deploy_resources_variables(
resources_vars['gpu'] = 'nvidia-tesla-{}'.format(
acc.lower())
resources_vars['gpu_count'] = acc_count
if acc == 'K80':
# Though the image is called cu113, it actually has later
# versions of CUDA as noted below.
# CUDA driver version 470.57.02, CUDA Library 11.4
image_id = 'skypilot:k80-debian-10'
else:
# Though the image is called cu113, it actually has later
# versions of CUDA as noted below.
# CUDA driver version 510.47.03, CUDA Library 11.6
# Does not support torch==1.13.0 with cu117
image_id = 'skypilot:gpu-debian-10'
# Though the image is called cu113, it actually has later
# versions of CUDA as noted below.
# CUDA driver version 510.47.03, CUDA Library 11.6
# K80: CUDA driver version 470.103.01, CUDA Library 11.4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L340-343 is out of touch w/ L344. Remove/move?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rephrased the comments and moved it to near the definition of the image tag (the module-level variable). PTAL.

image_id = 'skypilot:gpu-ubuntu-2004'

if resources.image_id is not None:
if None in resources.image_id:
Expand Down
50 changes: 50 additions & 0 deletions sky/skylet/providers/gcp/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,53 @@ def _get_cached_node(self, node_id: str) -> GCPNode:
@staticmethod
def bootstrap_config(cluster_config):
return bootstrap_gcp(cluster_config)

def get_command_runner(
self,
log_prefix,
node_id,
auth_config,
cluster_name,
process_runner,
use_internal_ip,
docker_config,
):
from ray.autoscaler._private.command_runner import (
DockerCommandRunner,
SSHCommandRunner,
)

class SSHCommandRunnerWithRetry(SSHCommandRunner):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider moving this to sky/skylet/providers/command_runner.py after #1910 is merged?

def _run_helper(
self, final_cmd, with_output=False, exit_on_fail=False, silent=False
):
"""Wrapper around _run_helper to retry on failure."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add something like the following to document intention?

Fix the ssh connection issue caused by control master for GCP with ubuntu image
Before the fix, the ssh connection will be disconnected when ray trying to setup the runtime dependencies, which is probably because the ssh connection is unstable when the cluster is just provisioned ray-project/ray#16539 (comment). We added retry for the ssh commands executed by ray up, which is ok since our setup commands are idempotent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Added. Thanks!

retry_cnt = 0
import click

while True:
try:
return super()._run_helper(
final_cmd, with_output, exit_on_fail, silent
)
except click.ClickException as e:
retry_cnt += 1
if retry_cnt > 3:
raise e
logger.info(f"Retrying SSH command in 5 seconds: {e}")
time.sleep(5)

# Adopted from super().get_command_runner()
common_args = {
"log_prefix": log_prefix,
"node_id": node_id,
"provider": self,
"auth_config": auth_config,
"cluster_name": cluster_name,
"process_runner": process_runner,
"use_internal_ip": use_internal_ip,
}
if docker_config and docker_config["container_name"] != "":
return DockerCommandRunner(docker_config, **common_args)
else:
return SSHCommandRunnerWithRetry(**common_args)
2 changes: 1 addition & 1 deletion sky/templates/aws-ray.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ setup_commands:
(type -a pip | grep -q pip3) || echo 'alias pip=pip3' >> ~/.bashrc;
(which conda > /dev/null 2>&1 && conda init > /dev/null) || (wget -nc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && bash Miniconda3-latest-Linux-x86_64.sh -b && eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true);
source ~/.bashrc;
(pip3 list | grep ray | grep {{ray_version}} 2>&1 > /dev/null || pip3 install -U ray[default]=={{ray_version}}) && mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app;
(pip3 list | grep ray | grep {{ray_version}} 2>&1 > /dev/null || pip3 uninstall -y ray ray-cpp && pip3 install -U ray[default]=={{ray_version}}) && mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: since we do a grep ray ... || ..., this means ray is not found. Why is pip3 uninstall necessary? Why add ray-cpp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the grep {{ray_version}} to check the correct ray version is installed on the remote VM. It is possible that ray with another version is pre-installed on the VM. For example, the ubuntu image on GCP has the ray==2.4.0 installed, which will cause pip problem if we directly pip install -U ray[default]==2.0.1 (causing the ray package corrupted.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Can we copy this comment into the j2 files? Something like "Ensure only one Ray version (which is our ray_version) is installed, regardless of if the image comes pre-installed with another Ray version."

(pip3 list | grep skypilot && [ "$(cat {{sky_remote_path}}/current_sky_wheel_hash)" == "{{sky_wheel_hash}}" ]) || (pip3 uninstall skypilot -y; pip3 install "$(echo {{sky_remote_path}}/{{sky_wheel_hash}}/skypilot-{{sky_version}}*.whl)[aws]" && echo "{{sky_wheel_hash}}" > {{sky_remote_path}}/current_sky_wheel_hash || exit 1);
sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf';
sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload;
Expand Down
2 changes: 1 addition & 1 deletion sky/templates/azure-ray.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ setup_commands:
(type -a pip | grep -q pip3) || echo 'alias pip=pip3' >> ~/.bashrc;
which conda > /dev/null 2>&1 || (wget -nc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && bash Miniconda3-latest-Linux-x86_64.sh -b && eval "$(/home/azureuser/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true);
source ~/.bashrc;
(pip3 list | grep ray | grep {{ray_version}} 2>&1 > /dev/null || pip3 install -U ray[default]=={{ray_version}}) && mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app && touch ~/.sudo_as_admin_successful;
(pip3 list | grep ray | grep {{ray_version}} 2>&1 > /dev/null || pip3 uninstall -y ray ray-cpp && pip3 install -U ray[default]=={{ray_version}}) && mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app && touch ~/.sudo_as_admin_successful;
(pip3 list | grep skypilot && [ "$(cat {{sky_remote_path}}/current_sky_wheel_hash)" == "{{sky_wheel_hash}}" ]) || (pip3 uninstall skypilot -y; pip3 install "$(echo {{sky_remote_path}}/{{sky_wheel_hash}}/skypilot-{{sky_version}}*.whl)[azure]" && echo "{{sky_wheel_hash}}" > {{sky_remote_path}}/current_sky_wheel_hash || exit 1);
sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf';
sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload;
Expand Down
Loading