SFT on single-host TPUs#
Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT.
We use Tunix, a JAX-based library designed for post-training tasks, to perform SFT.
In this tutorial we use a single host TPU VM such as v6e-8/v5p-8. Let’s get started!
Install MaxText and Post-Training dependencies#
For instructions on installing MaxText with post-training dependencies on your VM, please refer to the official documentation and use the maxtext[tpu-post-train] installation path to include all necessary post-training dependencies.
Note: If you have previously installed MaxText with a different option (e.g.,
maxtext[tpu]), we strongly recommend using a fresh virtual environment formaxtext[tpu-post-train]to avoid potential library version conflicts.
Setup environment variables#
Login to Hugging Face. Provide your access token when prompted:
hf auth login
Set up the following environment variables to configure your training run. Replace placeholders with your actual values.
# -- Model configuration --
# The MaxText model name. See `src/maxtext/configs/types.py` for `ModelName` for a
# full list of supported models.
export MODEL=<MODEL_NAME> # e.g., 'llama3.1-8b-Instruct'
# -- MaxText configuration --
# Use a GCS bucket you own to store logs and checkpoints. Ideally in the same
# region as your TPUs to minimize latency and costs.
# You can list your buckets and their locations in the
# [Cloud Console](https://console.cloud.google.com/storage/browser).
export BASE_OUTPUT_DIRECTORY=<GCS_BUCKET> # e.g., gs://my-bucket/maxtext-runs
# An arbitrary string to identify this specific run.
# We recommend to include the model, user, and timestamp.
# Note: Kubernetes requires workload names to be valid DNS labels (lowercase, no underscores or periods).
export RUN_NAME=<RUN_NAME>
export STEPS=<STEPS> # e.g., 1000
export PER_DEVICE_BATCH_SIZE=<BATCH_SIZE_PER_DEVICE> # e.g., 1
# -- Dataset configuration --
export DATASET_NAME=<DATASET_NAME> # e.g., HuggingFaceH4/ultrachat_200k
export TRAIN_SPLIT=<TRAIN_SPLIT> # e.g., train_sft
export TRAIN_DATA_COLUMNS=<DATA_COLUMNS> # e.g., ['messages']
Get your model checkpoint#
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
Option 1: Using an existing MaxText checkpoint#
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
Option 2: Converting a Hugging Face checkpoint#
Refer the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
[!IMPORTANT] Matching the
scan_layersParameter: Thescan_layerssetting during your fine-tuning run must match the setting used when creating the checkpoint atMAXTEXT_CKPT_PATH.
If the checkpoint was converted or saved with
scan_layers=False(which is common for Hugging Face conversions and inference-ready models), you must also providescan_layers=Falsein the MaxText command.If
scan_layersdoes not match, MaxText will raise aValueError. See the Checkpoints concept guide for more details.
Run SFT on Hugging Face Dataset#
Now you are ready to run SFT using the following command:
python3 -m maxtext.trainers.post_train.sft.train_sft \
run_name=${RUN_NAME?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
model_name=${MODEL?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} \
steps=${STEPS?} \
hf_path=${DATASET_NAME?} \
train_split=${TRAIN_SPLIT?} \
train_data_columns=${TRAIN_DATA_COLUMNS?} \
profiler=xplane
Your fine-tuned model checkpoints will be saved here: $BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints.
Dataset Customization & Chat Templates#
Supervised Fine-Tuning in MaxText relies on tokenizing conversational datasets using chat templates. This requires the dataset structure and templates to be aligned.
Supported Dataset Schemas#
By default, MaxText SFT expects one of three conversational dataset structures:
["messages"]: A single column containing a list of dictionaries withroleandcontent(recommended).["prompt", "completion"]: Separated prompt and completion columns.["question", "answer"]: Question and answer columns (e.g., math datasets).
During data processing, MaxText converts these into a unified messages schema (OpenAI-like format) before feeding it to the tokenizer:
[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"}
]
Custom Tokenizer Chat Templates#
To customize the tokenizer’s chat formatting (e.g., adding special tokens like <start_of_turn>, <end_of_turn>, etc.), you can provide a custom chat template using the chat_template or chat_template_path configs:
chat_template: Use this config to specify a custom Jinja2 template string directly.chat_template_path: Path to a custom Jinja2 template file (e.g.,.jinja) or a JSON file containing the template.use_chat_template=True: Enables chat template formatting.
Runnable Example in the Codebase#
For a complete, runnable SFT workflow that demonstrates how to configure the training loop and use a custom dataset formatter (formatting_func_path and formatting_func_kwargs), check out the sft_qwen3_demo.ipynb Jupyter notebook.