Reinforcement Learning on Multi-Host TPUs#
This tutorial provides step-by-step instructions for setting up the environment
and training the Llama3.1 70B-IT model on the GSM8K math reasoning dataset using
Pathways for orchestration
on multi-host TPU-VMs, such as v5p-128.
We utilize two RL algorithms, implemented via the Tunix library, to enhance the model’s reasoning capabilities:
Group Relative Policy Optimization (GRPO): GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group’s performance to update the policy.
Group Sequence Policy Optimization (GSPO): GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.
For efficient model inference and response generation during this process, we rely on the vLLM library.
Table of Contents#
Prerequisites#
Before starting, ensure you have:
Access to a Google Cloud Project with TPU quotas.
A Hugging Face account with an access token for downloading models.
Permissions for Google Artifact Registry (Artifact Registry Writer role).
XPK installed (follow official documentation).
A Pathways-ready GKE cluster (see create GKE cluster).
Setup Environment Variables#
Set up the following environment variables. Replace placeholders with your actual values.
# -- Model configuration --
export HF_MODEL=<Hugging Face Model> # e.g. 'llama3.1-70b-Instruct'
export MODEL=<MaxText Model> # e.g. 'llama3.1-70b'
export TOKENIZER=<Tokenizer> # e.g. 'meta-llama/Llama-3.1-70B-Instruct'
export HF_TOKEN=<Hugging Face access token>
# -- MaxText configuration --
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
export WORKLOAD=<Name for this run> # e.g., llama-3-70b-grpo
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${WORKLOAD}/0/items
# -- Workload configuration --
export TPU_TYPE=<TPU Type> # e.g., 'v5p-128'
export TPU_CLUSTER=<cluster name>
export PROJECT_ID=<GCP project ID>
export CLOUD_IMAGE_NAME=<your artifact registry image> # Name for the Docker image to be built
Get Your Model Checkpoint#
Option 1: Using an existing MaxText checkpoint#
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
Option 2: Converting from a Hugging Face checkpoint#
Refer the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
Build and upload MaxText Docker image with post-training dependencies#
Before building the Docker image, authenticate to Google Artifact Registry for permission to push your images and other access.
# Authenticate your user account for gcloud CLI access
gcloud auth login
# Configure application default credentials for Docker and other tools
gcloud auth application-default login
# Configure Docker credentials and test your access
gcloud auth configure-docker
docker run hello-world
Option 1: Install stable releases of post-training dependencies#
Caution: RL in MaxText is currently broken with stable releases of post-training dependencies. We are working on fixing this and recommend following Option 2: Install from Git repositories of post-training dependencies in the meantime.
Run the following script to create a Docker image with stable releases of
MaxText, Tunix,
vLLM, and
tpu-inference dependencies.
This installs vllm-tpu which provides TPU inference for vLLM with unified JAX
and PyTorch support. The build process takes approximately 10-15 minutes.
bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training
For experimental features (such as improved pathwaysutils resharding API), use:
bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training-experimental
Option 2: Install from Git repositories of post-training dependencies#
You can also locally clone the tunix,
tpu-inference, and
vllm repositories and then build the
docker image with these local sources. To get a set of compatible commit IDs for
maxtext, tunix, tpu-inference, and vllm, follow these steps:
Navigate to the MaxText Package Tests GitHub Actions workflow.
Select the latest successful run.
Within the workflow run, find and click on the
maxtext_jupyter_notebooks (py312)job, then expand therunjob.Locate the
Record Commit IDsstep. The commit SHAs formaxtext,tunix,tpu-inference, andvllmthat were used in that successful run are listed in the logs of this step.Prior to installation, ensure that the
maxtext,tunix,vllm, andtpu-inferencerepositories are synchronized to the specific commits recorded from the CI logs. For each repository, use the following command to switch to the correct commit:git checkout <commit_id>.
Note: Clone these repositories as siblings of the maxtext directory (e.g.,
in the same parent directory). After cloning, run the build from inside the
maxtext repository so it picks up the local sources:
bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training POST_TRAINING_SOURCE=local
Upload the Docker Image#
Note: You will need the Artifact Registry Writer role to push Docker images to your project’s Artifact Registry. Contact your project administrator if you don’t have this permission.
bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}
Submit your RL workload via Pathways#
See the Troubleshooting section for concise instructions on how to retry or resume a failed workload.
Ensure you have a Pathways-ready GKE cluster (as mentioned in Prerequisites) and
submit the train_rl.py script via XPK.
Note: XPK v0.14.0+ automatically discovers your cluster’s location from GCP. You don’t need to specify
--zonein the commands below. If using an older XPK version, add--zone=<zone>to the workload commands.
Submit GRPO workload#
xpk workload create-pathways --workload $WORKLOAD \
--docker-image gcr.io/$PROJECT_ID/$CLOUD_IMAGE_NAME --cluster $TPU_CLUSTER \
--tpu-type=$TPU_TYPE --num-slices=1 \
--project=$PROJECT_ID --priority=high \
--command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
model_name=${MODEL} \
tokenizer_path=${TOKENIZER} \
load_parameters_path=${MAXTEXT_CKPT_PATH} \
run_name=${WORKLOAD} \
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
hf_access_token=${HF_TOKEN}"
Submit GSPO workload#
xpk workload create-pathways --workload $WORKLOAD \
--docker-image gcr.io/$PROJECT_ID/$CLOUD_IMAGE_NAME --cluster $TPU_CLUSTER \
--tpu-type=$TPU_TYPE --num-slices=1 \
--project=$PROJECT_ID --priority=high \
--command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
model_name=${MODEL} \
tokenizer_path=${TOKENIZER} \
load_parameters_path=${MAXTEXT_CKPT_PATH} \
run_name=${WORKLOAD} \
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
hf_access_token=${HF_TOKEN} \
loss_algo=gspo-token"
Managing Workloads#
Monitor workload status: Check Pathways job status:
kubectl get pathwaysjob. Check pod status:kubectl get pods.Delete a workload: To remove a failed or unwanted Pathways job, use XPK:
xpk workload delete \ --workload $WORKLOAD \ --cluster $TPU_CLUSTER \ --project $PROJECT_ID
In case the job still lingers on, you can use
kubectl get podsto obtain the name of the pod and then run:kubectl delete pod <pod-name>.
Troubleshooting#
Authentication Issues: Ensure your
HF_TOKENenvironment variable is set correctly and has access to the required models.Resource Quotas: Verify you have sufficient TPU quotas in your GCP project.
Docker Build Failures: Check that all dependencies are correctly installed and authentication is configured.
Workload Failures: Review the logs for specific error messages and ensure all environment variables are properly set.
Workload retry / resume:
Retry (fresh run): Use a unique workload name to avoid overwriting outputs:
export WORKLOAD=${WORKLOAD}-retry1 export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${WORKLOAD}/0/items. Then submit the XPK workload. If “workload already exists” error occurs, pick a new name or list jobs:kubectl get pathwaysjob.Resume from checkpoint: Keep the same
WORKLOADand set the checkpoint path:export load_parameters_path=${MAXTEXT_CKPT_PATH}/checkpoint-0000. Then submit the workload again.Tip: Verify the checkpoint exists in GCS with read access before resuming.
For more detailed troubleshooting, refer to the MaxText documentation and XPK documentation.