Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU Provisioner: Add support for v6e & cross project reservations #851

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tpu-provisioner/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Build the manager binary
FROM golang:1.22 as builder
FROM golang:1.23 as builder
ARG TARGETOS
ARG TARGETARCH

Expand Down
30 changes: 16 additions & 14 deletions tpu-provisioner/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ func main() {
GCPCluster string `envconfig:"GCP_CLUSTER"`
GCPNodeServiceAccount string `envconfig:"GCP_NODE_SERVICE_ACCOUNT"`

GCPNodeTags []string `envconfig:"GCP_NODE_TAGS"`
GCPPodToNodeLabels []string `envconfig:"GCP_POD_TO_NODE_LABELS"`
GCPNodeSecondaryDisk string `envconfig:"GCP_NODE_SECONDARY_DISK" default:""`
GCPNodeSecureBoot bool `envconfig:"GCP_NODE_SECURE_BOOT" default:"true"`
GCPNodeTags []string `envconfig:"GCP_NODE_TAGS"`
GCPPodToNodeLabels []string `envconfig:"GCP_POD_TO_NODE_LABELS"`
GCPNodeSecondaryDisk string `envconfig:"GCP_NODE_SECONDARY_DISK" default:""`
GCPNodeSecureBoot bool `envconfig:"GCP_NODE_SECURE_BOOT" default:"true"`
GCPNodeAdditionalNetworks string `envconfig:"GCP_NODE_ADDITIONAL_NETWORKS" default:""`
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Format: network-1:subnet-2,network2-subnet-2


// GCPForceOnDemand forces the controller to create nodes on demand, even if
// the Pod requests a reservation or spot.
Expand Down Expand Up @@ -201,16 +202,17 @@ func main() {
provider = &cloud.GKE{
Service: containers,
ClusterContext: cloud.GKEContext{
ProjectID: cfg.GCPProjectID,
ClusterLocation: cfg.GCPClusterLocation,
Cluster: cfg.GCPCluster,
NodeZone: cfg.GCPZone,
NodeServiceAccount: cfg.GCPNodeServiceAccount,
NodeSecondaryDisk: cfg.GCPNodeSecondaryDisk,
NodeTags: cfg.GCPNodeTags,
PodToNodeLabels: cfg.GCPPodToNodeLabels,
NodeSecureBoot: cfg.GCPNodeSecureBoot,
ForceOnDemand: cfg.GCPForceOnDemand,
ProjectID: cfg.GCPProjectID,
ClusterLocation: cfg.GCPClusterLocation,
Cluster: cfg.GCPCluster,
NodeZone: cfg.GCPZone,
NodeServiceAccount: cfg.GCPNodeServiceAccount,
NodeAdditionalNetworks: cfg.GCPNodeAdditionalNetworks,
NodeSecondaryDisk: cfg.GCPNodeSecondaryDisk,
NodeTags: cfg.GCPNodeTags,
PodToNodeLabels: cfg.GCPPodToNodeLabels,
NodeSecureBoot: cfg.GCPNodeSecureBoot,
ForceOnDemand: cfg.GCPForceOnDemand,
},
Recorder: mgr.GetEventRecorderFor("tpu-provisioner"),
}
Expand Down
2 changes: 1 addition & 1 deletion tpu-provisioner/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/GoogleCloudPlatform/ai-on-gke/tpu-provisioner

go 1.22.0
go 1.23.0

require (
cloud.google.com/go/compute/metadata v0.3.0
Expand Down
5 changes: 5 additions & 0 deletions tpu-provisioner/internal/cloud/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ const (

// AnnotationCopyLabels is a comma-separated list of labels to copy from the Pod to the node pool config (Nodes).
AnnotationCopyLabels = "tpu-provisioner.cloud.google.com/copy-labels"
// AnnotationAdditionalNodeNetworks is a comma-separated list of additional networks and subnets to attach to the node pool.
// Format: "<network-name>:<subnet-name>, ..."
AnnotationAdditionalNodeNetworks = "tpu-provisioner.cloud.google.com/additional-node-networks"
// AnnotatationServiceAccount is the GCP service account to use for the node pool.
AnnotationNodeServiceAccount = "tpu-provisioner.cloud.google.com/node-service-account"

EventNodePoolCreationStarted = "NodePoolCreationStarted"
EventNodePoolCreationSucceeded = "NodePoolCreationSucceeded"
Expand Down
51 changes: 48 additions & 3 deletions tpu-provisioner/internal/cloud/gke.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const (
V4PodSliceAccelerator = "tpu-v4-podslice"
V5ePodSliceAccelerator = "tpu-v5-lite-podslice"
V5pPodSliceAccelerator = "tpu-v5p-slice"
V6eSliceAccelerator = "tpu-v6e-slice"

// Resource type labels
GoogleTPUResource = "google.com/tpu"
Expand Down Expand Up @@ -323,11 +324,18 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node

if !g.ClusterContext.ForceOnDemand {
if resName, ok := p.Spec.NodeSelector["cloud.google.com/reservation-name"]; ok {
var resVal string
resProj, ok := p.Spec.NodeSelector["cloud.google.com/reservation-project"]
if ok {
resVal = fmt.Sprintf("projects/%s/reservations/%s", resProj, resName)
} else {
resVal = resName
}
reservation = &containerv1beta1.ReservationAffinity{
ConsumeReservationType: "SPECIFIC_RESERVATION",
Key: "compute.googleapis.com/reservation-name",
Values: []string{
resName,
resVal,
},
}
}
Expand Down Expand Up @@ -355,10 +363,44 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node
}
}

var networkConfig *containerv1beta1.NodeNetworkConfig
var additionalNodeNetworks []*containerv1beta1.AdditionalNodeNetworkConfig
// additional-node-networks: "vpc1:subnet1, vpc2:subnet2"
additionalNodeNetworksCSV := g.ClusterContext.NodeAdditionalNetworks
if getAnnotation(p, AnnotationAdditionalNodeNetworks) != "" {
additionalNodeNetworksCSV = getAnnotation(p, AnnotationAdditionalNodeNetworks)
}
for _, pair := range strings.Split(additionalNodeNetworksCSV, ",") {
pair = strings.TrimSpace(pair)
if pair == "" {
continue
}

netAndSubnet := strings.SplitN(pair, ":", 2)
if len(netAndSubnet) != 2 {
return nil, fmt.Errorf("invalid additional network annotation: %v", pair)
}

additionalNodeNetworks = append(additionalNodeNetworks, &containerv1beta1.AdditionalNodeNetworkConfig{
Network: strings.TrimSpace(netAndSubnet[0]),
Subnetwork: strings.TrimSpace(netAndSubnet[1]),
})
}
if len(additionalNodeNetworks) > 0 {
networkConfig = &containerv1beta1.NodeNetworkConfig{
AdditionalNodeNetworkConfigs: additionalNodeNetworks,
}
}

nodeServiceAccount := g.ClusterContext.NodeServiceAccount
if sa, ok := p.Annotations[AnnotationNodeServiceAccount]; ok {
nodeServiceAccount = sa
}

return &containerv1beta1.NodePool{
Name: name,
Config: &containerv1beta1.NodeConfig{
ServiceAccount: g.ClusterContext.NodeServiceAccount,
ServiceAccount: nodeServiceAccount,
ShieldedInstanceConfig: &containerv1beta1.ShieldedInstanceConfig{
EnableIntegrityMonitoring: true,
EnableSecureBoot: g.ClusterContext.NodeSecureBoot,
Expand Down Expand Up @@ -387,6 +429,7 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node
MaxSurge: 1,
},
MaxPodsConstraint: &containerv1beta1.MaxPodsConstraint{MaxPodsPerNode: maxPodsPerNode},
NetworkConfig: networkConfig,
}, nil
}

Expand Down Expand Up @@ -438,7 +481,7 @@ func tpuTopologyToNodeCount(accelerator, topo string) (int, error) {
switch accelerator {
case V4PodSliceAccelerator, V5pPodSliceAccelerator:
expectedDims = 3
case V5ePodSliceAccelerator:
case V5ePodSliceAccelerator, V6eSliceAccelerator:
expectedDims = 2
default:
return 0, fmt.Errorf("invalid accelerator: %v", accelerator)
Expand Down Expand Up @@ -475,6 +518,8 @@ func tpuMachineType(accel string, tpuRequest int) (string, error) {
return fmt.Sprintf("ct5lp-hightpu-%vt", tpuRequest), nil
case V5pPodSliceAccelerator: // v5p
return fmt.Sprintf("ct5p-hightpu-%vt", tpuRequest), nil
case V6eSliceAccelerator: // v6e
return fmt.Sprintf("ct6e-standard-%vt", tpuRequest), nil
}

return "", fmt.Errorf("invalid accelerator: %v", accel)
Expand Down
15 changes: 8 additions & 7 deletions tpu-provisioner/internal/cloud/gke_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package cloud
import "fmt"

type GKEContext struct {
ProjectID string
ClusterLocation string
Cluster string
NodeZone string
NodeServiceAccount string
NodeSecondaryDisk string
NodeTags []string
ProjectID string
ClusterLocation string
Cluster string
NodeZone string
NodeServiceAccount string
NodeAdditionalNetworks string
NodeSecondaryDisk string
NodeTags []string
// PodToNodeLabels is a list of key=value pairs that will be copied from the Pod
// to the Node.
PodToNodeLabels []string
Expand Down
124 changes: 124 additions & 0 deletions tpu-provisioner/internal/cloud/gke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ func Test_tpuTopologyToNodeCount(t *testing.T) {
topo: "not-a-topo",
err: true,
},
{
accel: "tpu-v6e-slice",
topo: "16x16",
count: 64,
},
{
accel: "tpu-v6e-slice",
topo: "1x1x1",
err: true,
},
}

for _, c := range cases {
Expand Down Expand Up @@ -341,6 +351,39 @@ func TestNodePoolForPod(t *testing.T) {
UpgradeSettings: &container.UpgradeSettings{MaxSurge: 1},
},
},
{
desc: "pod with cross-project reservation selector",
selector: map[string]string{
"cloud.google.com/reservation-name": "tpu-rsv",
"cloud.google.com/reservation-project": "tpu-rsv-project",
},
want: &containerv1beta1.NodePool{
Config: &container.NodeConfig{
Labels: map[string]string{
"google.com/nodepool-manager": "tpu-provisioner",
"google.com/tpu-provisioner-jobset-name": "jobset-test",
"google.com/tpu-provisioner-jobset-namespace": "default",
"google.com/tpu-provisioner-parent-kind": "job",
"google.com/tpu-provisioner-parent-name": "jobset-test-job-1-0",
"google.com/tpu-provisioner-parent-namespace": "default",
},
MachineType: "ct5p-hightpu-4t",
ReservationAffinity: &container.ReservationAffinity{
ConsumeReservationType: "SPECIFIC_RESERVATION",
Key: "compute.googleapis.com/reservation-name",
Values: []string{"projects/tpu-rsv-project/reservations/tpu-rsv"},
},
ShieldedInstanceConfig: &container.ShieldedInstanceConfig{EnableIntegrityMonitoring: true},
},
InitialNodeCount: 512,
Locations: []string{""},
Management: &container.NodeManagement{AutoRepair: true, AutoUpgrade: false},
MaxPodsConstraint: &container.MaxPodsConstraint{MaxPodsPerNode: 15},
Name: "test-pool",
PlacementPolicy: &container.PlacementPolicy{TpuTopology: "8x16x16", Type: "COMPACT"},
UpgradeSettings: &container.UpgradeSettings{MaxSurge: 1},
},
},
{
desc: "pod with reservation selector but on demand is forced",
selector: map[string]string{"cloud.google.com/reservation-name": "tpu-rsv"},
Expand Down Expand Up @@ -515,6 +558,87 @@ func TestNodePoolForPod(t *testing.T) {
UpgradeSettings: &container.UpgradeSettings{MaxSurge: 1},
},
},
{
desc: "additional node networks configured in cluster context",
gkeContext: GKEContext{
NodeAdditionalNetworks: "network-1:subnet-1, network-2:subnet-2",
},
want: &containerv1beta1.NodePool{
Config: &container.NodeConfig{
Labels: map[string]string{
"google.com/nodepool-manager": "tpu-provisioner",
"google.com/tpu-provisioner-jobset-name": "jobset-test",
"google.com/tpu-provisioner-jobset-namespace": "default",
"google.com/tpu-provisioner-parent-kind": "job",
"google.com/tpu-provisioner-parent-name": "jobset-test-job-1-0",
"google.com/tpu-provisioner-parent-namespace": "default",
},
MachineType: "ct5p-hightpu-4t",
ShieldedInstanceConfig: &container.ShieldedInstanceConfig{EnableIntegrityMonitoring: true},
},
InitialNodeCount: 512,
Locations: []string{""},
Management: &container.NodeManagement{AutoRepair: true, AutoUpgrade: false},
MaxPodsConstraint: &container.MaxPodsConstraint{MaxPodsPerNode: 15},
Name: "test-pool",
PlacementPolicy: &container.PlacementPolicy{TpuTopology: "8x16x16", Type: "COMPACT"},
UpgradeSettings: &container.UpgradeSettings{MaxSurge: 1},
NetworkConfig: &container.NodeNetworkConfig{
AdditionalNodeNetworkConfigs: []*container.AdditionalNodeNetworkConfig{
{
Network: "network-1",
Subnetwork: "subnet-1",
},
{
Network: "network-2",
Subnetwork: "subnet-2",
},
},
},
},
},
{
desc: "pod requesting additional node networks",
gkeContext: GKEContext{
NodeAdditionalNetworks: "should-be-overriden-1:should-be-overriden-2",
},
additionalAnnotations: map[string]string{
"tpu-provisioner.cloud.google.com/additional-node-networks": "network-1:subnet-1, network-2:subnet-2",
},
want: &containerv1beta1.NodePool{
Config: &container.NodeConfig{
Labels: map[string]string{
"google.com/nodepool-manager": "tpu-provisioner",
"google.com/tpu-provisioner-jobset-name": "jobset-test",
"google.com/tpu-provisioner-jobset-namespace": "default",
"google.com/tpu-provisioner-parent-kind": "job",
"google.com/tpu-provisioner-parent-name": "jobset-test-job-1-0",
"google.com/tpu-provisioner-parent-namespace": "default",
},
MachineType: "ct5p-hightpu-4t",
ShieldedInstanceConfig: &container.ShieldedInstanceConfig{EnableIntegrityMonitoring: true},
},
InitialNodeCount: 512,
Locations: []string{""},
Management: &container.NodeManagement{AutoRepair: true, AutoUpgrade: false},
MaxPodsConstraint: &container.MaxPodsConstraint{MaxPodsPerNode: 15},
Name: "test-pool",
PlacementPolicy: &container.PlacementPolicy{TpuTopology: "8x16x16", Type: "COMPACT"},
UpgradeSettings: &container.UpgradeSettings{MaxSurge: 1},
NetworkConfig: &container.NodeNetworkConfig{
AdditionalNodeNetworkConfigs: []*container.AdditionalNodeNetworkConfig{
{
Network: "network-1",
Subnetwork: "subnet-1",
},
{
Network: "network-2",
Subnetwork: "subnet-2",
},
},
},
},
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
Expand Down
Loading