Elastic training with Pathways#

This guide shows how to run elastic training on a multi-slice TPU cluster: training that survives a slice failure in-process, without restarting the job. You launch a Qwen3 0.6B run across several TPU slices with Pathways, lose a slice mid-run, and watch training recover from the last checkpoint on the same controller.

Important

This guide is a demonstration of the elastic training mechanism, not a production recipe. It uses a small model (Qwen3 0.6B) and synthetic data so you can see recovery happen on a short run, then tear everything down. The exact slice counts, timeouts, and checkpoint cadence here are illustrative; tune them for your own model and hardware. Treat it as a starting point to understand the feature, not a configuration to copy verbatim into a long-running job.

What is elastic training?#

Large model training runs across many TPU slices. When one slice fails (a hardware fault, a preemption, a network blip), the default outcome is that the whole job crashes and restarts from scratch, losing the XLA compilation time plus everything since the last checkpoint.

Elastic training keeps the training process alive instead. Three components make that possible:

  • Pathways orchestrates training across the slices. Its Resource Manager detects when a slice goes down and reports it to the training process.

  • MaxText wraps the training loop with elastic_retry. When Pathways reports a failure, it catches the exception inside the same Python process, cleans up, and restarts training without exiting.

  • Orbax handles checkpointing. Each checkpoint writes to GCS and creates a commit_success marker only after all data is flushed, so a checkpoint interrupted mid-write has no marker and is safely discarded on recovery.

Because the controller process never exits, the expensive XLA recompile is skipped and recovery is fast.

Note

This demo shows recovery via checkpoint restore on a fixed mesh: when a slice is lost, Pathways waits for a replacement, then all slices restore from the last committed checkpoint. It does not show elastic degradation (continuing on fewer slices at reduced throughput), which requires dynamic mesh resize and is not covered here.

1. Prerequisites#

This guide assumes you already have a Pathways-enabled GKE cluster created with xpk, and a MaxText Docker image in your Artifact Registry. If you don’t:

  1. Install XPK and create a Pathways GKE cluster. Follow Running MaxText with XPK and the Pathways & XPK cluster guide. Cluster creation and management is out of scope for this page.

  2. Build and upload the MaxText Docker image. See Build MaxText.

Note

If you installed xpk inside a Python virtual environment (venv), reactivate it (e.g., source <VENV_NAME>/bin/activate) in any new terminal before running xpk commands, or you will hit a Command xpk not found error.

2. Environment configuration#

Set these environment variables in your shell. Replace the placeholders with your own values.

# Google Cloud Configuration
export PROJECT_ID=<GCP project ID>
export ZONE=<GCP location>        # e.g., 'us-central1'
export GKE_CLUSTER=<cluster name> # your Pathways-enabled cluster

# Workload Configuration
# Kubernetes requires workload names to be valid DNS labels (lowercase, no underscores/periods).
export RUN_NAME="elastic-qwen3-$(date +%Y%m%d-%H%M%S)"

# TPU type and slice count. For supported types see src/maxtext/utils/accelerator_to_spec_map.py.
export TPU_TYPE="v5litepod-16"  # one slice = 16 v5e chips
export NUM_SLICES=3             # total slices in the run

# MaxText & Storage Configuration
export BASE_OUTPUT_DIRECTORY=<gcs bucket path>  # e.g., gs://my-bucket/maxtext-runs
export DOCKER_IMAGE="gcr.io/${PROJECT_ID?}/<your maxtext image>"

3. Launch the elastic workload#

Submit the run with xpk workload create-pathways. Two sets of flags make it elastic:

  • On the xpk side, --elastic-slices tells Pathways how many slices the workload is allowed to lose and keep going, and --max-slice-restarts caps how many times a slice’s workers may be restarted.

  • On the MaxText side (inside --command), elastic_enabled=true turns on the elastic_retry wrapper, and enable_single_controller=True runs training through Pathways. checkpoint_period is kept small so a recovery rewinds only a little.

xpk workload create-pathways \
  --workload=${RUN_NAME?} \
  --cluster=${GKE_CLUSTER?} \
  --project=${PROJECT_ID?} \
  --zone=${ZONE?} \
  --tpu-type=${TPU_TYPE?} \
  --num-slices=${NUM_SLICES?} \
  --docker-image=${DOCKER_IMAGE?} \
  --elastic-slices=1 \
  --max-slice-restarts=10 \
  --command="python3 -m maxtext.trainers.pre_train.train \
    src/maxtext/configs/base.yml \
    base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
    run_name=${RUN_NAME?} \
    model_name=qwen3-0.6b \
    dataset_type=synthetic \
    per_device_batch_size=1 \
    max_target_length=2048 \
    attention=flash \
    remat_policy=full \
    steps=5000 \
    enable_checkpointing=true \
    checkpoint_period=100 \
    enable_single_controller=True \
    elastic_enabled=true \
    elastic_timeout_seconds=300 \
    elastic_max_retries=10"

Note

--elastic-slices=1 means the run tolerates losing one slice at a time out of ${NUM_SLICES}. Keep --max-slice-restarts and the MaxText elastic_max_retries consistent with how many failures you want to ride out.

Warning

Do not enable profiling in an elastic run. An elastic event (a slice going down and recovering) in the middle of a profile is not supported, so this example leaves the profiler off (profiler is unset). Profile a separate, non-elastic run if you need performance traces.

Watch training start#

List the workload and follow its logs through the Cloud Console (Kubernetes Engine → Workloads → your run → Logs), or:

xpk workload list --cluster=${GKE_CLUSTER?} --project=${PROJECT_ID?} --zone=${ZONE?}

After XLA compilation (a couple of minutes) you should see elastic training enabled and a steady stream of steps:

Elastic utils: Elastic training enabled.
Elastic Retry Enabled
completed step: 8, seconds: 0.159, TFLOP/s/device: 43.430, loss: 220.774
completed step: 9, seconds: 0.166, TFLOP/s/device: 41.524, loss: 217.296

Let it run until the step counter passes the first checkpoint (here, step ~130, so checkpoint_period=100 has committed once) before you inject a failure, so there is a complete checkpoint to recover from.

4. Simulate a slice failure#

To see recovery, remove a worker on one slice. Connect to the cluster and delete a worker pod immediately (--grace-period=0 --force), so it does not drain gracefully. This mimics an abrupt hardware failure rather than a clean shutdown:

gcloud container clusters get-credentials ${GKE_CLUSTER?} --location ${ZONE?} --project ${PROJECT_ID?}

# Pick a worker pod on one slice and remove it immediately.
WORKER=$(kubectl get pods -o name | grep "${RUN_NAME?}" | grep worker | head -1)
kubectl delete ${WORKER?} --grace-period=0 --force

Warning

This deliberate pod deletion is only for observing recovery in this demo. Do not remove pods this way against a real training job.

5. Verify in-process recovery#

Recovery shows up in the same controller log you were already watching, which is the point: the controller process never exited. Within seconds of the termination you should see Pathways report the slice down and elastic_retry restore the last committed checkpoint:

Slice down event detected. Retrying.
Found commit_success file. Keeping gs://.../checkpoints/100/.
Elastic attempt 2 out of 10
Restoring checkpoint from gs://.../checkpoints/100.
completed step: 101, ...

The step counter dropping (for example 150 -> 101) is the rewind to the last committed checkpoint. Training then continues from there on the same controller, with no JobSet restart. That is the whole point of elastic training: a slice failure became a short rewind instead of a full job restart.

6. Clean up#

Delete the workload to stop the meter. TPU slices are expensive, so don’t skip this.

xpk workload delete --workload=${RUN_NAME?} --cluster=${GKE_CLUSTER?} --project=${PROJECT_ID?} --zone=${ZONE?}

If you created the cluster only for this demo, delete it too (see the XPK documentation for xpk cluster delete).

Going further#

  • The elastic flags are documented in src/maxtext/configs/base.yml: elastic_enabled, elastic_timeout_seconds, elastic_max_retries, plus enable_single_controller (runs training through Pathways) and checkpoint_period.

  • A larger model changes the checkpoint size that streams through Pathways during recovery; size the controller and adjust checkpoint_period accordingly.

  • Custom Pathways server args can be passed through xpk with --custom-pathways-proxy-server-args if you need finer control than --elastic-slices exposes.

More information#