From 3e7cb9ba39fbdd4c217ec2295f44badf6d851ce6 Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Wed, 16 Oct 2024 14:56:16 -0400 Subject: [PATCH] Add support for v6e --- tpu-provisioner/internal/cloud/gke.go | 5 ++++- tpu-provisioner/internal/cloud/gke_test.go | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tpu-provisioner/internal/cloud/gke.go b/tpu-provisioner/internal/cloud/gke.go index cba11e6b7..3d4ef5da3 100644 --- a/tpu-provisioner/internal/cloud/gke.go +++ b/tpu-provisioner/internal/cloud/gke.go @@ -42,6 +42,7 @@ const ( V4PodSliceAccelerator = "tpu-v4-podslice" V5ePodSliceAccelerator = "tpu-v5-lite-podslice" V5pPodSliceAccelerator = "tpu-v5p-slice" + V6eSliceAccelerator = "tpu-v6e-slice" // Resource type labels GoogleTPUResource = "google.com/tpu" @@ -438,7 +439,7 @@ func tpuTopologyToNodeCount(accelerator, topo string) (int, error) { switch accelerator { case V4PodSliceAccelerator, V5pPodSliceAccelerator: expectedDims = 3 - case V5ePodSliceAccelerator: + case V5ePodSliceAccelerator, V6eSliceAccelerator: expectedDims = 2 default: return 0, fmt.Errorf("invalid accelerator: %v", accelerator) @@ -475,6 +476,8 @@ func tpuMachineType(accel string, tpuRequest int) (string, error) { return fmt.Sprintf("ct5lp-hightpu-%vt", tpuRequest), nil case V5pPodSliceAccelerator: // v5p return fmt.Sprintf("ct5p-hightpu-%vt", tpuRequest), nil + case V6eSliceAccelerator: // v6e + return fmt.Sprintf("ct6e-standard-%vt", tpuRequest), nil } return "", fmt.Errorf("invalid accelerator: %v", accel) diff --git a/tpu-provisioner/internal/cloud/gke_test.go b/tpu-provisioner/internal/cloud/gke_test.go index 0b5adc5e0..fa68da354 100644 --- a/tpu-provisioner/internal/cloud/gke_test.go +++ b/tpu-provisioner/internal/cloud/gke_test.go @@ -67,6 +67,16 @@ func Test_tpuTopologyToNodeCount(t *testing.T) { topo: "not-a-topo", err: true, }, + { + accel: "tpu-v6e-slice", + topo: "16x16", + count: 64, + }, + { + accel: "tpu-v6e-slice", + topo: "1x1x1", + err: true, + }, } for _, c := range cases {