diff --git a/dashboard/src/components/AzureProvisionerSettings.tsx b/dashboard/src/components/AzureProvisionerSettings.tsx index 06fbf87ba2..6c81a2d0b9 100644 --- a/dashboard/src/components/AzureProvisionerSettings.tsx +++ b/dashboard/src/components/AzureProvisionerSettings.tsx @@ -33,7 +33,9 @@ import InputRow from "./form-components/InputRow"; import Button from "./porter/Button"; import Error from "./porter/Error"; import Icon from "./porter/Icon"; +import InputSlider from "./porter/InputSlider"; import Link from "./porter/Link"; +import Select from "./porter/Select"; import Spacer from "./porter/Spacer"; import Step from "./porter/Step"; import Text from "./porter/Text"; @@ -53,6 +55,7 @@ type Props = RouteComponentProps & { provisionerError?: string; credentialId: string; clusterId?: number; + gpuModal?: boolean; }; const VALID_CIDR_RANGE_PATTERN = @@ -71,6 +74,11 @@ const AzureProvisionerSettings: React.FC = (props) => { const [clusterName, setClusterName] = useState(""); const [azureLocation, setAzureLocation] = useState("eastus"); const [machineType, setMachineType] = useState("Standard_B2als_v2"); + const [gpuMinInstances, setGpuMinInstances] = useState(1); + const [gpuMaxInstances, setGpuMaxInstances] = useState(5); + const [gpuInstanceType, setGpuInstanceType] = useState( + "Standard_NC4as_T4_v3" + ); const [isExpanded, setIsExpanded] = useState(false); const [minInstances, setMinInstances] = useState(1); const [maxInstances, setMaxInstances] = useState(10); @@ -85,6 +93,12 @@ const AzureProvisionerSettings: React.FC = (props) => { regionFilteredMachineTypeOptions, setRegionFilteredMachineTypeOptions, ] = useState(azureSupportedMachineTypes(azureLocation)); + const [ + regionFilteredGPUMachineTypeOptions, + setRegionFilteredGPUMachineTypeOptions, + ] = useState( + azureSupportedMachineTypes(azureLocation, true) + ); const { showIntercomWithMessage } = useIntercom(); @@ -92,6 +106,9 @@ const AzureProvisionerSettings: React.FC = (props) => { setRegionFilteredMachineTypeOptions( azureSupportedMachineTypes(azureLocation) ); + setRegionFilteredGPUMachineTypeOptions( + azureSupportedMachineTypes(azureLocation, true) + ); }, [azureLocation]); const markStepStarted = async ( @@ -188,6 +205,42 @@ const AzureProvisionerSettings: React.FC = (props) => { console.log(err); } + const nodePools = [ + new AKSNodePool({ + instanceType: "Standard_B2als_v2", + minInstances: 1, + maxInstances: 3, + nodePoolType: NodePoolType.SYSTEM, + mode: "User", + }), + new AKSNodePool({ + instanceType: "Standard_B2as_v2", + minInstances: 1, + maxInstances: 3, + nodePoolType: NodePoolType.MONITORING, + mode: "User", + }), + new AKSNodePool({ + instanceType: machineType, + minInstances: minInstances || 1, + maxInstances: maxInstances || 10, + nodePoolType: NodePoolType.APPLICATION, + mode: "User", + }), + ]; + + // Conditionally add the last EKSNodeGroup if gpuModal is enabled + if (props.gpuModal) { + nodePools.push( + new AKSNodePool({ + instanceType: gpuInstanceType, + minInstances: gpuMinInstances || 0, + maxInstances: gpuMaxInstances || 5, + nodePoolType: NodePoolType.CUSTOM, + }) + ); + } + const data = new Contract({ cluster: new Cluster({ projectId: currentProject.id, @@ -201,29 +254,7 @@ const AzureProvisionerSettings: React.FC = (props) => { clusterVersion: clusterVersion || "v1.27.3", cidrRange: cidrRange || "10.78.0.0/16", location: azureLocation, - nodePools: [ - new AKSNodePool({ - instanceType: "Standard_B2als_v2", - minInstances: 1, - maxInstances: 3, - nodePoolType: NodePoolType.SYSTEM, - mode: "User", - }), - new AKSNodePool({ - instanceType: "Standard_B2as_v2", - minInstances: 1, - maxInstances: 3, - nodePoolType: NodePoolType.MONITORING, - mode: "User", - }), - new AKSNodePool({ - instanceType: machineType, - minInstances: minInstances || 1, - maxInstances: maxInstances || 10, - nodePoolType: NodePoolType.APPLICATION, - mode: "User", - }), - ], + nodePools, skuTier, }), }, @@ -317,7 +348,10 @@ const AzureProvisionerSettings: React.FC = (props) => { // TODO: pass in contract as the already parsed object, rather than JSON (requires changes to AWS/GCP provisioning) const contract = Contract.fromJsonString( - JSON.stringify(props.selectedClusterVersion) + JSON.stringify(props.selectedClusterVersion), + { + ignoreUnknownFields: true, + } ); if ( @@ -471,6 +505,46 @@ const AzureProvisionerSettings: React.FC = (props) => { ); }; + if (props.gpuModal) { + return ( + <> +