SFT on single-host TPUs#

Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.

This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT.

We use Tunix, a JAX-based library designed for post-training tasks, to perform SFT.

In this tutorial we use a single host TPU VM such as v6e-8/v5p-8. Let’s get started!

Install MaxText and Post-Training dependencies#

For instructions on installing MaxText with post-training dependencies on your VM, please refer to the official documentation and use the maxtext[tpu-post-train] installation path to include all necessary post-training dependencies.

Note: If you have previously installed MaxText with a different option (e.g., maxtext[tpu]), we strongly recommend using a fresh virtual environment for maxtext[tpu-post-train] to avoid potential library version conflicts.

Setup environment variables#

Login to Hugging Face. Provide your access token when prompted:

hf auth login

Set up the following environment variables to configure your training run. Replace placeholders with your actual values.

# -- Model configuration --
# The MaxText model name. See `src/maxtext/configs/types.py` for `ModelName` for a
# full list of supported models.
export MODEL=<MODEL_NAME> # e.g., 'llama3.1-8b-Instruct'

# -- MaxText configuration --
# Use a GCS bucket you own to store logs and checkpoints. Ideally in the same
# region as your TPUs to minimize latency and costs.
# You can list your buckets and their locations in the
# [Cloud Console](https://console.cloud.google.com/storage/browser).
export BASE_OUTPUT_DIRECTORY=<GCS_BUCKET> # e.g., gs://my-bucket/maxtext-runs

# An arbitrary string to identify this specific run.
# We recommend to include the model, user, and timestamp.
# Note: Kubernetes requires workload names to be valid DNS labels (lowercase, no underscores or periods).
export RUN_NAME=<RUN_NAME>

export STEPS=<STEPS> # e.g., 1000
export PER_DEVICE_BATCH_SIZE=<BATCH_SIZE_PER_DEVICE> # e.g., 1

# -- Dataset configuration --
export DATASET_NAME=<DATASET_NAME> # e.g., HuggingFaceH4/ultrachat_200k
export TRAIN_SPLIT=<TRAIN_SPLIT> # e.g., train_sft
export TRAIN_DATA_COLUMNS=<DATA_COLUMNS> # e.g., ['messages']

Get your model checkpoint#

This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.

Option 1: Using an existing MaxText checkpoint#

If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.

export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items

Option 2: Converting a Hugging Face checkpoint#

Refer the steps in Hugging Face to MaxText to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.

export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items

[!IMPORTANT] Matching the scan_layers Parameter: The scan_layers setting during your fine-tuning run must match the setting used when creating the checkpoint at MAXTEXT_CKPT_PATH.

  • If the checkpoint was converted or saved with scan_layers=False (which is common for Hugging Face conversions and inference-ready models), you must also provide scan_layers=False in the MaxText command.

  • If scan_layers does not match, MaxText will raise a ValueError. See the Checkpoints concept guide for more details.

Run SFT on Hugging Face Dataset#

Now you are ready to run SFT using the following command:

python3 -m maxtext.trainers.post_train.sft.train_sft \
    run_name=${RUN_NAME?} \
    base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
    model_name=${MODEL?} \
    load_parameters_path=${MAXTEXT_CKPT_PATH?} \
    per_device_batch_size=${PER_DEVICE_BATCH_SIZE?} \
    steps=${STEPS?} \
    hf_path=${DATASET_NAME?} \
    train_split=${TRAIN_SPLIT?} \
    train_data_columns=${TRAIN_DATA_COLUMNS?} \
    profiler=xplane

Your fine-tuned model checkpoints will be saved here: $BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints.

Dataset Customization & Chat Templates#

Supervised Fine-Tuning in MaxText relies on tokenizing conversational datasets using chat templates. This requires the dataset structure and templates to be aligned.

Supported Dataset Schemas#

By default, MaxText SFT expects one of three conversational dataset structures:

  • ["messages"]: A single column containing a list of dictionaries with role and content (recommended).

  • ["prompt", "completion"]: Separated prompt and completion columns.

  • ["question", "answer"]: Question and answer columns (e.g., math datasets).

During data processing, MaxText converts these into a unified messages schema (OpenAI-like format) before feeding it to the tokenizer:

[
  {"role": "user", "content": "Hello!"},
  {"role": "assistant", "content": "Hi there!"}
]

Custom Tokenizer Chat Templates#

To customize the tokenizer’s chat formatting (e.g., adding special tokens like <start_of_turn>, <end_of_turn>, etc.), you can provide a custom chat template using the chat_template or chat_template_path configs:

  • chat_template: Use this config to specify a custom Jinja2 template string directly.

  • chat_template_path: Path to a custom Jinja2 template file (e.g., .jinja) or a JSON file containing the template.

  • use_chat_template=True: Enables chat template formatting.

Advanced: Custom Dataset Formatter (e.g., ShareGPT)#

If your dataset is in a format not natively supported—such as ShareGPT (which uses a conversations column with from and value keys)—you can write a custom Python formatting function to convert it on-the-fly.

1. Write a custom formatting function#

Create a Python file in your workspace (e.g., src/maxtext/input_pipeline/custom_formatters.py):

def format_sharegpt(example):
    """Converts ShareGPT format (from/value) to standard messages (role/content)."""
    role_map = {
        "human": "user",
        "user": "user",
        "gpt": "assistant",
        "assistant": "assistant",
        "system": "system",
    }

    messages = []
    for turn in example["conversations"]:
        role = role_map.get(turn["from"], "user")
        messages.append(
            {
                "role": role,
                "content": turn["value"],
            }
        )

    example["messages"] = messages
    return example

2. Configure MaxText to use your formatter#

When starting your SFT training, pass the following parameters:

  • train_data_columns: Point to the original column name in the raw dataset ("['conversations']").

  • formatting_func_path: Point to the python import path of your formatting function ("maxtext.input_pipeline.custom_formatters.format_sharegpt").

python3 -m maxtext.trainers.post_train.sft.train_sft \
    ... \
    train_data_columns="['conversations']" \
    formatting_func_path="maxtext.input_pipeline.custom_formatters.format_sharegpt"

Runnable Example in the Codebase#

For a complete, runnable SFT workflow that demonstrates how to configure the training loop and use a custom dataset formatter (formatting_func_path and formatting_func_kwargs), check out the sft_qwen3_demo.ipynb Jupyter notebook.