diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 286788625bd..ba0dd633d54 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -456,6 +456,24 @@ Available fields and semantics: # Reference: https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview storage_account: user-storage-account-name + # Specify subnet_id to use for instances (optional). + # SkyPilot created new vnet and subnet by default but it will reuse exisiting subnet if specified. + subnet_id: /subscriptions/subscription-id/resourceGroups/resource-group-name/providers/Microsoft.Network/virtualNetworks/vnet-name/subnets/subnet-name + + # Set existing managed identity for instances (optional). + msi_id: /subscriptions/subscription-id/resourceGroups/resource-group-name/providers/Microsoft.ManagedIdentity/userAssignedIdentities/msi-name + + # Should instances be assigned private IPs only? (optional) + # + # Set to true to use private IPs to communicate between the local client and + # any SkyPilot nodes. This requires the networking stack be properly set up. + # + # When set to true, SkyPilot will only use private subnets to launch nodes and won't expose + # instances on public IP addresses. + # Reference: https://learn.microsoft.com/en-us/azure/virtual-network/virtual-network-manage-subnet?tabs=azure-portal + # Default: false. + use_internal_ips: true + # Advanced Kubernetes configurations (optional). kubernetes: # The networking mode for accessing SSH jump pod (optional). diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index eb76d2b5e48..c3815d0fa31 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -366,6 +366,17 @@ def make_deploy_resources_variables( if resource_group_name is None: resource_group_name = f'{cluster_name.name_on_cloud}-{region_name}' + # Determine subnet_id if configured + subnet_id = skypilot_config.get_nested(('azure', 'subnet_id'), None) + + # Detemine if msi_id is configured + msi_id = skypilot_config.get_nested(('azure', 'msi_id'), None) + + # Determine if internal IPs should be used + use_internal_ips = skypilot_config.get_nested( + ('azure', 'use_internal_ips'), False) + + # Setup commands to eliminate the banner and restart sshd. # This script will modify /etc/ssh/sshd_config and add a bash script # into .bashrc. The bash script will restart sshd if it has not been @@ -423,6 +434,9 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]: 'azure_subscription_id': self.get_project_id(dryrun), 'resource_group': resource_group_name, 'use_external_resource_group': use_external_resource_group, + 'subnet_id': subnet_id, + 'use_internal_ips': use_internal_ips, + 'msi_id': msi_id, } # Setting disk performance tier for high disk tier. diff --git a/sky/provision/azure/azure-config-template.json b/sky/provision/azure/azure-config-template.json index 0c70c4d3999..2a476ead9bd 100644 --- a/sky/provision/azure/azure-config-template.json +++ b/sky/provision/azure/azure-config-template.json @@ -25,7 +25,19 @@ "metadata": { "description": "Name of the Network Security Group associated with the SkyPilot cluster." } - } + }, + "existingSubnet": { + "type": "string", + "metadata": { + "description": "Existing subnet id to use." + } + }, + "existingMSI": { + "type": "string", + "metadata": { + "description": "Existing MSI id to use." + } + } }, "variables": { "contributor": "[subscriptionResourceId('Microsoft.Authorization/roleDefinitions', 'b24988ac-6180-42a0-ab88-20f7382dd24c')]", @@ -40,12 +52,14 @@ "resources": [ { "type": "Microsoft.ManagedIdentity/userAssignedIdentities", + "condition": "[equals(parameters('existingMSI'), '')]", "apiVersion": "2018-11-30", "location": "[variables('location')]", "name": "[variables('msiName')]" }, { "type": "Microsoft.Authorization/roleAssignments", + "condition": "[equals(parameters('existingMSI'), '')]", "apiVersion": "2020-08-01-preview", "name": "[guid(variables('roleAssignmentName'))]", "properties": { @@ -86,6 +100,7 @@ "apiVersion": "2019-11-01", "name": "[variables('vnetName')]", "location": "[variables('location')]", + "condition": "[equals(parameters('existingSubnet'), '')]", "properties": { "addressSpace": { "addressPrefixes": [ diff --git a/sky/provision/azure/config.py b/sky/provision/azure/config.py index e7ab59daa33..a88818f80c3 100644 --- a/sky/provision/azure/config.py +++ b/sky/provision/azure/config.py @@ -86,6 +86,9 @@ def bootstrap_instances( 'use_external_resource_group field') use_external_resource_group = provider_config['use_external_resource_group'] + subnet_id = provider_config.get('subnet_id', '') + msi_id = provider_config.get('msi_id', '') + if 'tags' in provider_config: params['tags'] = provider_config['tags'] @@ -142,12 +145,15 @@ def bootstrap_instances( cluster_id, nsg_name = get_cluster_id_and_nsg_name( resource_group=provider_config['resource_group'], cluster_name_on_cloud=cluster_name_on_cloud) + + # subnet_mask is generated only for new subnets subnet_mask = provider_config.get('subnet_mask') - if subnet_mask is None: - # choose a random subnet, skipping most common value of 0 - random.seed(cluster_id) - subnet_mask = f'10.{random.randint(1, 254)}.0.0/16' - logger.info(f'Using subnet mask: {subnet_mask}') + # choose a random subnet, skipping most common value of 0 + random.seed(cluster_id) + subnet_mask = f'10.{random.randint(1, 254)}.0.0/16' + if subnet_id == '': + # subnet_mask is not used if subnet_id is provided + logger.info(f'Using subnet mask: {subnet_mask}') parameters = { 'properties': { @@ -165,6 +171,12 @@ def bootstrap_instances( }, 'location': { 'value': params['location'] + }, + 'existingSubnet': { + 'value': subnet_id + }, + 'existingMSI': { + 'value': msi_id } }, } @@ -213,8 +225,9 @@ def bootstrap_instances( ).result().properties.outputs # append output resource ids to be used with vm creation - provider_config['msi'] = outputs['msi']['value'] + provider_config['msi'] = outputs['msi']['value'] if msi_id == '' else msi_id provider_config['nsg'] = outputs['nsg']['value'] - provider_config['subnet'] = outputs['subnet']['value'] + provider_config[ + 'subnet'] = outputs['subnet']['value'] if subnet_id == '' else subnet_id return config diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 1140704a708..6fc2658dc84 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -47,6 +47,13 @@ provider: # leakage. disable_launch_config_check: true + {%- if subnet_id is not none %} + subnet_id: {{subnet_id}} + {%- endif %} + use_internal_ips: {{use_internal_ips}} + {%- if msi_id is not none %} + msi_id: {{msi_id}} + {%- endif %} auth: ssh_user: azureuser diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 851e77a57fc..a947934a198 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -797,6 +797,15 @@ def get_config_schema(): 'resource_group_vm': { 'type': 'string', }, + 'subnet_id': { + 'type': 'string', + }, + 'use_internal_ips': { + 'type': 'boolean', + }, + 'msi_id': { + 'type': 'string', + }, } }, 'kubernetes': {