This page contains instructions for how to set up Ray on GKE with TPUs.
For general setup instructions please refer to the README file.
For more information about TPUs on GKE, see this page.
-
If needed, git clone https://github.com/GoogleCloudPlatform/ai-on-gke
-
cd ai-on-gke/gke-platform
-
Edit
variables.tf
with your GCP settings. -
Change the region or zone to one where TPUs are available (see this link for details. For v4 TPUs (the default type), the region should be set to
us-central2
orus-central2-b
. -
Set the following flags (note that TPUs are currently only supported on GKE standard):
variable "enable_autopilot" {
type = bool
description = "Set to true to enable GKE Autopilot clusters"
default = false
}
variable "enable_tpu" {
type = bool
description = "Set to true to create TPU node pool"
default = true
}
-
Run
terraform init
-
Run
terraform apply
The TPU Initialization Webhook can automatically inject the TPU_WORKER_ID
and TPU_WORKER_HOSTNAMES
environment variables necessary for multi-host TPU clusters. The webhook needs to be installed once per GKE cluster. The instructions can be found here.
-
Get the GKE cluster name and location/region from
gke-platform/variables.tf
. Rungcloud container clusters get-credentials %gke_cluster_name% --location=%location%
Configuringgcloud
instructions -
cd ../user
-
Edit
variables.tf
with your GCP settings. The<your user name>
that you specify will become a K8s namespace for your Ray services. -
Set
enable_tpu
totrue
. -
Run
terraform init
-
Run
terraform apply
This should deploy a Kuberay cluster with a single TPU worker node (v4 TPU with 2x2x1
topology).
To deploy a multi-host Ray Cluster, modify tpu_topology
here as well as the Kuberay cluster spec here.
Install Jupyterhub according to the instructions in the README.
A basic JAX program can be found here.
For a more advanced workload running Stable Diffusion on TPUs, see here.
To manually set JAX environment variables for TPU_WORKER_ID
and TPU_WORKER_HOSTNAMES
before initializing JAX without using the webhook, run the following sample code:
@ray.remote(resources={"google.com/tpu": 4})
def get_hostname():
import time
time.sleep(1)
return ray.util.get_node_ip_address()
@ray.remote(resources={"google.com/tpu": 4})
def init_tpu_env_from_ray(id_hostname_map):
import os
import time
time.sleep(1)
hostname = ray.util.get_node_ip_address()
worker_id = id_hostname_map[hostname]
os.environ["TPU_WORKER_ID"] = str(worker_id)
os.environ["TPU_WORKER_HOSTNAMES"] = ",".join(list(id_hostname_map))
return "TPU_WORKER_ID: " + os.environ["TPU_WORKER_ID"] + " TPU_WORKER_HOSTNAMES: " + os.environ["TPU_WORKER_HOSTNAMES"]
def init_jax_from_ray(num_workers: int):
results = ray.get([get_hostname.remote() for x in range(num_workers)])
id_hostname_map = {
hostname: worker_id for worker_id, hostname in enumerate(set(results))}
result = [init_tpu_env_from_ray.remote(id_hostname_map) for _ in range(num_workers)]
print(ray.get(result))
init_jax_from_ray(num_workers=2)