Reinforcement Learning on Multi-Host TPUs#

This tutorial provides step-by-step instructions for setting up the environment and training the Llama3.1 70B-IT model on the GSM8K math reasoning dataset using Pathways for orchestration on multi-host TPU-VMs, such as v5p-128.

We utilize two RL algorithms, implemented via the Tunix library, to enhance the model’s reasoning capabilities:

  • Group Relative Policy Optimization (GRPO): GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group’s performance to update the policy.

  • Group Sequence Policy Optimization (GSPO): GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.

For efficient model inference and response generation during this process, we rely on the vLLM library.

Table of Contents#

Prerequisites#

Before starting, ensure you have:

  • Access to a Google Cloud Project with TPU quotas.

  • A Hugging Face account with an access token for downloading models.

  • Permissions for Google Artifact Registry (Artifact Registry Writer role).

  • Prerequisites for XPK installed (follow official documentation).

  • A Pathways-ready GKE cluster (see create GKE cluster).

  • Docker installed and configured for sudoless use. Follow the steps to configure sudoless Docker.

Build and upload MaxText Docker image#

For instructions on building and uploading the MaxText Docker image with post-training dependencies, please refer to the official documentation.

Setup Environment Variables#

Set up the following environment variables to configure your training run. Replace placeholders with your actual values.

# -- Model configuration --
# The MaxText model name. See `src/maxtext/configs/types.py` for `ModelName` for a
# full list of supported models.
export MODEL=<MODEL_NAME> # e.g. 'llama3.1-70b-Instruct'

# Your Hugging Face access token. Required to download gated models like Llama.
# You can generate one at https://huggingface.co/settings/tokens.
export HF_TOKEN=<HF_TOKEN>

# -- MaxText configuration --
# Use a GCS bucket you own to store logs and checkpoints. Ideally in the same
# region as your TPUs to minimize latency and costs.
# You can list your buckets and their locations in the
# [Cloud Console](https://console.cloud.google.com/storage/browser).
export BASE_OUTPUT_DIRECTORY=<GCS_BUCKET> # e.g., gs://my-bucket/maxtext-runs

# An arbitrary string to identify this specific run.
# We recommend to include the model, user, and timestamp.
# Note: Kubernetes requires workload names to be valid DNS labels (lowercase, no underscores or periods).
export RUN_NAME=<RUN_NAME>

# The directory containing the MaxText-compatible model checkpoint.
# If you are converting from a Hugging Face checkpoint, see:
# [Checkpoint Conversion Guide](../../guides/checkpointing_solutions/convert_checkpoint.md)
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY?}/${RUN_NAME?}/0/items

# -- Workload configuration --
# Your GCP project ID. Find it on the [Cloud Console Dashboard](https://console.cloud.google.com/home/dashboard).
# If you've already set it in your local config, you can retrieve it via:
# gcloud config get-value project
export PROJECT_ID=<PROJECT_ID>

# The GCP location (listed as "Location" in the UI) and name of your
# TPU-enabled GKE cluster. Both can be found on the
# [Cloud Console](https://console.cloud.google.com/kubernetes/list).
export ZONE=<ZONE> # e.g., 'us-central1' or 'us-central1-a'
export GKE_CLUSTER=<CLUSTER_NAME>

# For a full list of MaxText-supported TPU types, see: `src/maxtext/utils/accelerator_to_spec_map.py`. To see the TPU type
# of your cluster:

# 1. Connect to the cluster (required for kubectl commands later):
# gcloud container clusters get-credentials ${GKE_CLUSTER?} --location ${ZONE?} --project ${PROJECT_ID?}

# 2. Find your TPU type (e.g., 'v5p-128') by checking the accelerator labels on your nodes:
# kubectl get nodes -l cloud.google.com/gke-tpu-accelerator -o jsonpath='{.items[*].metadata.labels.cloud\.google\.com/gke-tpu-accelerator}' | tr ' ' '\n' | sort -u
export TPU_TYPE=<TPU_TYPE>

# The Docker image you pushed in the prerequisite step
export CLOUD_IMAGE_NAME=<IMAGE_NAME>
export DOCKER_IMAGE="gcr.io/${PROJECT_ID?}/${CLOUD_IMAGE_NAME?}"

Get Your Model Checkpoint#

Option 1: Using an existing MaxText checkpoint#

If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.

export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items

Option 2: Converting from a Hugging Face checkpoint#

Refer to the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.

export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items

Submit your RL workload via Pathways#

See the Troubleshooting section for concise instructions on how to retry or resume a failed workload.

Ensure you have a Pathways-ready GKE cluster (as mentioned in Prerequisites) and submit the train_rl.py script via XPK.

Note: XPK v0.14.0+ automatically discovers your cluster’s location from GCP. You don’t need to specify --zone in the commands below. If using an older XPK version, add --zone=<ZONE> to the workload commands.

Submit GRPO workload#

xpk workload create-pathways --workload ${RUN_NAME?} \
--docker-image ${DOCKER_IMAGE?} --cluster ${GKE_CLUSTER?} \
--tpu-type=${TPU_TYPE?} --num-slices=1 \
--project=${PROJECT_ID?} --priority=high \
--zone=${ZONE?} \
--command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
python3 -m maxtext.trainers.post_train.rl.train_rl \
  model_name=${MODEL?} \
  load_parameters_path=${MAXTEXT_CKPT_PATH?} \
  run_name=${RUN_NAME?} \
  base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
  hf_access_token=${HF_TOKEN?}"

Submit GSPO workload#

xpk workload create-pathways --workload ${RUN_NAME?} \
--docker-image ${DOCKER_IMAGE?} --cluster ${GKE_CLUSTER?} \
--tpu-type=${TPU_TYPE?} --num-slices=1 \
--project=${PROJECT_ID?} --priority=high \
--zone=${ZONE?} \
--command "HF_TOKEN=${HF_TOKEN?} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
python3 -m maxtext.trainers.post_train.rl.train_rl \
  model_name=${MODEL?} \
  load_parameters_path=${MAXTEXT_CKPT_PATH?} \
  run_name=${RUN_NAME?} \
  base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
  hf_access_token=${HF_TOKEN?} \
  loss_algo=gspo-token"

Managing Workloads#

  • Monitor workload status: Check Pathways job status: kubectl get pathwaysjob. Check pod status: kubectl get pods.

  • Delete a workload: To remove a failed or unwanted Pathways job, use XPK:

    xpk workload delete \
        --workload ${RUN_NAME?} \
        --cluster ${GKE_CLUSTER?} \
        --project ${PROJECT_ID?}
    

    In case the job still lingers on, you can use kubectl get pods to obtain the name of the pod and then run: kubectl delete pod <POD_NAME>.

Troubleshooting#

  • Authentication Issues: Ensure your HF_TOKEN environment variable is set correctly and has access to the required models.

  • Resource Quotas: Verify you have sufficient TPU quotas in your GCP project.

  • Docker Build Failures: Check that all dependencies are correctly installed and authentication is configured.

  • Workload Failures: Review the logs for specific error messages and ensure all environment variables are properly set.

  • Workload retry / resume:

    • Retry (fresh run): Use a unique run name to avoid overwriting outputs: export RUN_NAME=${RUN_NAME?}-retry1 export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY?}/${RUN_NAME?}/0/items. Then submit the XPK workload. If “workload already exists” error occurs, pick a new name or list jobs: kubectl get pathwaysjob.

    • Resume from checkpoint: Keep the same RUN_NAME and set the checkpoint path: export load_parameters_path=${MAXTEXT_CKPT_PATH?}/checkpoint-0000. Then submit the workload again.

    • Tip: Verify the checkpoint exists in GCS with read access before resuming.

For more detailed troubleshooting, refer to the MaxText documentation and XPK documentation.