Via localhost or single-host VM#

Objective#

This guide provides comprehensive instructions for setting up MaxText on a local machine or single-host environment, covering everything from cloning the repo and dependency installation to building with Docker. By walking through the process of pre-training a small model, you will gain the foundational knowledge to run jobs on TPUs/GPUs.

Prerequisites#

Before you can begin a training run, you need to configure your storage environment and set up the basic MaxText configuration.

Setup Google Cloud storage bucket#

You’ll need a GCS bucket to store all your training artifacts, such as logs, metrics, and model checkpoints.

  1. In your Google Cloud project, create a new storage bucket.

  2. Your TPU or GPU VMs require read/write access to this bucket. The simplest way to grant this is by assigning the Storage Admin (roles/storage.admin) role to the service account associated with your VMs.

Setup MaxText#

MaxText uses a primary YAML file, configs/base.yml, to manage its settings. This default configuration sets up a llama2 style decoder-only model with approximately 1 billion parameters.

  • Before running your first model, take a moment to review this file. Pay special attention to these core settings:

    • run_name: The name for your experiment.

    • per_device_batch_size: Controls how many examples are processed per chip. You may need to lower this for larger models to avoid running out of memory.

    • max_target_length: The maximum sequence length for the model.

    • learning_rate: The core hyperparameter for the optimizer.

    • Mode shape parameters: base_num_decoder_layers, base_emb_dim, base_num_query_heads, base_num_kv_heads, and head_dim.

  • Override settings (optional): You can modify training parameters in two ways: by editing configs/base.yml directly or by passing them as command-line arguments to the training script which is the recommended method. For example, to change the number of training steps, you can pass --steps=500 when running train.py.

  • Note: You must update the variable base_output_directory which is initialized in configs/base.yml to point to a folder within the GCS bucket you just created (e.g., gs://your-bucket-name/maxtext-output).

Development#

Local development on a single host TPU/GPU VM is a convenient way to run MaxText on a single host. It doesn’t scale to multiple hosts but is a good way to learn about MaxText. The following describes how to run Maxtext on TPU/GPU VMs.

Run MaxText on single host VM#

  1. Create and SSH to the single host VM of your choice. You can use any available single host TPU, such as v5litepod-8, v5p-8, or v4-8. For GPUs, you can use nvidia-h100-mega-80gb, nvidia-h200-141gb, or nvidia-b200. For setting up a TPU VM, use the Cloud TPU documentation available at https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm. For a GPU setup, refer to the guide at https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus.

  2. For instructions on installing MaxText on your VM, please refer to the official documentation.

Run a Test Training Job#

After the installation is complete, run a short training job using synthetic data to confirm everything is working correctly. This command trains a model for just 10 steps. Remember to replace $YOUR_JOB_NAME with a unique name for your run and gs://<my-bucket> with the path to the GCS bucket you configured in the prerequisites.

python3 -m maxtext.trainers.pre_train.train \
  run_name=${YOUR_JOB_NAME?} \
  base_output_directory=gs://<my-bucket> \
  dataset_type=synthetic \
  steps=10

Optional: If you want to try training on a real dataset, see Data Input Pipeline for data input options from sources like HuggingFace, Grain, and TFDS.

Generate sample output (decoding)#

To demonstrate model output, run the following command:

python3 -m maxtext.inference.decode \
  run_name=${YOUR_JOB_NAME?} \
  base_output_directory=gs://<my-bucket> \
  per_device_batch_size=1

Note: Because the model hasn’t been properly trained, the output text will be random. To generate meaningful output, you need to load a trained checkpoint using the load_parameters_path argument.

Running models using provided configs#

MaxText provides many OSS model configs that you can use directly to run training jobs on those model-specific architectures. These model-specific YAML files are located in src/maxtext/configs/models for TPU-oriented defaults, and src/maxtext/configs/models/gpu for GPU-oriented defaults.

Training on TPUs#

To use a pre-configured model for TPUs, you override the model_name parameter, and MaxText will automatically load the corresponding configuration from the src/maxtext/configs/models directory and merge it with the settings from src/maxtext/configs/base.yml.

llama3-8b (TPU)
python3 -m maxtext.trainers.pre_train.train \
  model_name=llama3-8b \
  run_name=${YOUR_JOB_NAME?} \
  base_output_directory=gs://<my-bucket> \
  dataset_type=synthetic \
  steps=10
qwen3-4b (TPU)
python3 -m maxtext.trainers.pre_train.train \
  model_name=qwen3-4b \
  run_name=${YOUR_JOB_NAME?} \
  base_output_directory=gs://<my-bucket> \
  dataset_type=synthetic \
  steps=10

Training on GPUs#

To use a GPU-optimized configuration, you should specify the path to the model’s YAML file within the src/maxtext/configs/models/gpu directory as the main config file in the command. These files typically inherit from base.yml and set the appropriate model_name internally, as well as GPU-specific settings.

mixtral-8x7b (GPU)
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/gpu/models/mixtral_8x7b.yml \
  run_name=${YOUR_JOB_NAME?} \
  base_output_directory=gs://<my-bucket> \
  dataset_type=synthetic \
  steps=10

This will load gpu/mixtral_8x7b.yml, which inherits from base.yml.

llama3-8b (GPU)
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/gpu/models/llama3-8b.yml \
  run_name=${YOUR_JOB_NAME?} \
  base_output_directory=gs://<my-bucket> \
  dataset_type=synthetic \
  steps=10