JAX is a rapidly growing Python library for high-performance numerical computing and machine learning (ML) research. With applications in large language models, drug discovery, physics ML, reinforcement learning, and neural graphics, JAX has seen incredible adoption in the past few years. JAX offers numerous benefits for developers and researchers, including an easy-to-use NumPy API, auto differentiation and optimization. JAX also includes support for distributed processing across multi-node and multi-GPU systems in a few lines of code, with accelerated performance through XLA-optimized kernels on NVIDIA GPUs.
We show how to run JAX multi-GPU-multi-node applications on GKE (Google Kubernetes Engine) using the A2 ultra machine series, powered by NVIDIA A100 80GB Tensor Core GPUs. It runs a simple Hello World application on 4 nodes with 8 processes and 8 GPUs each.
Prerequisites
Install gcloud and setup your environment by running gcloud init and following the prompts
Install docker and login into the Google Container Registry using gcloud credentials helper
Install kubectl and the kubectl authentication plugin for GCP
Setup a GKE cluster
Clone the repository
code_block[StructValue([(u’code’, u’$ git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples/rn$ cd ai-ml/gke-a100-jax’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f1bf3ded0>)])]
code_block[StructValue([(u’code’, u’$ gcloud services enable container.googleapis.comrn$ gcloud services enable containerregistry.googleapis.com’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f3a5f0ad0>)])]
Create a default VPC (if it doesn’t already exist)
code_block[StructValue([(u’code’, u’$ gcloud compute networks create default –subnet-mode=auto –bgp-routing-mode=regional –mtu=1460′), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7550>)])]
Create a cluster (the control nodes). Replace us-central1-c by your preferred zone.
code_block[StructValue([(u’code’, u’$ gcloud container clusters create jax-example –zone=us-central1-c’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7a90>)])]
Create a pool (the compute nodes). –enable-fast-socket –enable-gvnic is required for multi-node performance. –preemptible removes the need for quotas but makes the node preemptible. Remove the flag if this is not desirable. Replace us-central1-c by your preferred zone. This might take a few minutes.
code_block[StructValue([(u’code’, u’$ gcloud container node-pools create gpus-node-pool –machine-type=a2-ultragpu-8g –cluster=jax-example –enable-fast-socket –enable-gvnic –zone=us-central1-c –num-nodes=4 –preemptible’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7f10>)])]
Install the NVIDIA CUDA driver on the compute nodes
code_block[StructValue([(u’code’, u’$ kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/cos/daemonset-preloaded-latest.yaml’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7650>)])]
Build and push the container to your registry. This will push a container to gcr.io/<your project>/jax/hello:latest. This might take a few minutes.
code_block[StructValue([(u’code’, u’$ bash build_push_container.sh’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7ed0>)])]
In kubernetes/job.yaml and kubernetes/kustomization.yaml, change <<PROJECT>> by your GCP project name.
Run JAX
Run the JAX application on the compute nodes. This will create 32 pods (8 per nodes), each running one JAX process on one NVIDIA GPU.
code_block[StructValue([(u’code’, u’$ cd kubernetesrn$ kubectl apply -k .’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d76d0>)])]
code_block[StructValue([(u’code’, u’$ kubectl get podsrnrnNAME READY STATUS RESTARTS AGErnjax-hello-world-0-zmcrr 0/1 ContainerCreating 0 5srnjax-hello-world-1-gw4c5 0/1 ContainerCreating 0 5srnjax-hello-world-10-7f467 0/1 ContainerCreating 0 4s’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f3a13e790>)])]
to check the status. This will change from ContainerCreating to Pending (after a few minutes), Running and finally Completed.
Once the job has completed, use kubectl logs to see the output from one pod
code_block[StructValue([(u’code’, u’$ kubectl logs jax-hello-world-28-56r9prnrnI0301 22:16:00.772733 140398129055552 distributed.py:79] Connecting to JAX distributed service on 10.68.4.3:1234rnu2026rnCoordinator host name: jax-hello-world-0.headless-svcrnCoordinator IP address: 10.68.4.3rnJAX process 28/32 initialized on jax-hello-world-28rnJAX global devices:[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=2, process_index=2, slice_index=2), StreamExecutorGpuDevice(id=3, process_index=3, slice_index=3), u2026 StreamExecutorGpuDevice(id=31, process_index=31, slice_index=3)]rnJAX local devices:[StreamExecutorGpuDevice(id=28, process_index=28, slice_index=1)]rn[32.]’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f3a13ebd0>)])]
The application creates an array of length 1 equal to [1.0] on each process and then reduces them all. The output, on 32 processes, should be [32.0] on each process.
Congratulations! You just ran JAX on 32 NVIDIA A100 GPUs in GKE. Next, learn how to run inference at scale with TensorRT on NVIDIA T4 GPUs.
Special thanks to Jarek Kazmierczak, Google Machine Learning Solution Architect and Iris Liu, NVIDIA System Software Engineer for their expertise and guidance on this blog post.
Cloud Blog