Sunday, December 3, 2023
No menu items!
HomeCloud ComputingGetting started with JAX multi-node applications with NVIDIA GPUs on Google Kubernetes...

Getting started with JAX multi-node applications with NVIDIA GPUs on Google Kubernetes Engine

JAX is a rapidly growing Python library for high-performance numerical computing and machine learning (ML) research. With applications in large language models, drug discovery, physics ML, reinforcement learning, and neural graphics, JAX has seen incredible adoption in the past few years. JAX offers numerous benefits for developers and researchers, including an easy-to-use NumPy API, auto differentiation and optimization. JAX also includes support for distributed processing across multi-node and multi-GPU systems in a few lines of code, with accelerated performance through XLA-optimized kernels on NVIDIA GPUs.

We show how to run JAX multi-GPU-multi-node applications on GKE (Google Kubernetes Engine) using the A2 ultra machine series, powered by NVIDIA A100 80GB Tensor Core  GPUs. It runs a simple Hello World application on 4 nodes with 8 processes and 8 GPUs each.


Install gcloud and setup your environment by running gcloud init and following the prompts

Install docker and login into the Google Container Registry using gcloud credentials helper

Install kubectl and the kubectl authentication plugin for GCP

Setup a GKE cluster

Clone the repository

code_block[StructValue([(u’code’, u’$ git clone$ cd ai-ml/gke-a100-jax’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f1bf3ded0>)])]

Enable the required APIs

code_block[StructValue([(u’code’, u’$ gcloud services enable container.googleapis.comrn$ gcloud services enable’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f3a5f0ad0>)])]

Create a default VPC (if it doesn’t already exist)

code_block[StructValue([(u’code’, u’$ gcloud compute networks create default –subnet-mode=auto –bgp-routing-mode=regional –mtu=1460′), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7550>)])]

Create a cluster (the control nodes). Replace us-central1-c by your preferred zone.

code_block[StructValue([(u’code’, u’$ gcloud container clusters create jax-example –zone=us-central1-c’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7a90>)])]

Create a pool (the compute nodes). –enable-fast-socket –enable-gvnic is required for multi-node performance. –preemptible removes the need for quotas but makes the node preemptible. Remove the flag if this is not desirable. Replace us-central1-c by your preferred zone. This might take a few minutes.

code_block[StructValue([(u’code’, u’$ gcloud container node-pools create gpus-node-pool –machine-type=a2-ultragpu-8g –cluster=jax-example –enable-fast-socket –enable-gvnic –zone=us-central1-c –num-nodes=4 –preemptible’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7f10>)])]

Install the NVIDIA CUDA driver on the compute nodes

code_block[StructValue([(u’code’, u’$ kubectl apply -f’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7650>)])]

Build and push the container to your registry. This will push a container to<your project>/jax/hello:latest. This might take a few minutes.

code_block[StructValue([(u’code’, u’$ bash’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d7ed0>)])]

In kubernetes/job.yaml and kubernetes/kustomization.yaml, change <<PROJECT>> by your GCP project name.


Run the JAX application on the compute nodes. This will create 32 pods (8 per nodes), each running one JAX process on one NVIDIA GPU.

code_block[StructValue([(u’code’, u’$ cd kubernetesrn$ kubectl apply -k .’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f226d76d0>)])]


code_block[StructValue([(u’code’, u’$ kubectl get podsrnrnNAME READY STATUS RESTARTS AGErnjax-hello-world-0-zmcrr 0/1 ContainerCreating 0 5srnjax-hello-world-1-gw4c5 0/1 ContainerCreating 0 5srnjax-hello-world-10-7f467 0/1 ContainerCreating 0 4s’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f3a13e790>)])]

to check the status. This will change from ContainerCreating to Pending (after a few minutes), Running and finally Completed.

Once the job has completed, use kubectl logs to see the output from one pod

code_block[StructValue([(u’code’, u’$ kubectl logs jax-hello-world-28-56r9prnrnI0301 22:16:00.772733 140398129055552] Connecting to JAX distributed service on host name: jax-hello-world-0.headless-svcrnCoordinator IP address: process 28/32 initialized on jax-hello-world-28rnJAX global devices:[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=1, slice_index=1), StreamExecutorGpuDevice(id=2, process_index=2, slice_index=2), StreamExecutorGpuDevice(id=3, process_index=3, slice_index=3), u2026 StreamExecutorGpuDevice(id=31, process_index=31, slice_index=3)]rnJAX local devices:[StreamExecutorGpuDevice(id=28, process_index=28, slice_index=1)]rn[32.]’), (u’language’, u”), (u’caption’, <wagtail.wagtailcore.rich_text.RichText object at 0x3e5f3a13ebd0>)])]

The application creates an array of length 1 equal to [1.0] on each process and then reduces them all. The output, on 32 processes, should be [32.0] on each process.

Congratulations! You just ran JAX on 32 NVIDIA A100 GPUs in GKE. Next, learn how to run inference at scale with TensorRT on NVIDIA T4 GPUs.

Special thanks to Jarek Kazmierczak, Google Machine Learning Solution Architect and Iris Liu, NVIDIA System Software Engineer for their expertise and guidance on this blog post.

Cloud BlogRead More



Please enter your comment!
Please enter your name here

Most Popular

Recent Comments