Knowledge distillation#
Overview#
Knowledge Distillation is a compression technique that transfers knowledge from a larger (teacher) model to a smaller (student) model. This allows the smaller model to achieve performance levels closer to the larger one, but with significantly fewer parameters and computational resources.
This tutorial focuses on response-based knowledge distillation, a technique where the student model is trained to replicate the outputs and behaviors of the teacher model. Within response-based knowledge distillation, two primary methods are often employed:
Offline Distillation (Dataset Generation):
The pre-trained teacher model (running in vLLM) generates a new dataset of input-output pairs.
The student model is then trained on this teacher-generated dataset using standard fine-tuning techniques in MaxText.
Online Distillation (Logit Matching):
During the training process, both the teacher model (which is typically frozen) and the student model process the same input data simultaneously.
The student model is trained by minimizing a loss function that encourages its output logits to match the logits produced by the teacher model for the same inputs.
Running Offline Distillation with MaxText#
The following recipe demonstrates the process of offline distillation using Qwen/Qwen3-32B as the teacher model and Llama-3.1-8B as the student model. Since this recipe fine-tunes the student model using Supervised Fine-Tuning (SFT), it’s crucial to use the conversational variant for both the teacher and student models. Here’s a step-by-step tutorial:
Prerequisites#
a. Setup environment variables#
export HF_TOKEN=<your-hf-token> # e.g., hf_BA6...
export RUN_NAME=<your-run-name> # e.g., distill-20260115
b. Install dependencies#
To install MaxText and its dependencies for post-training (including vLLM for the teacher), run the following:
Follow the MaxText installation instructions.
Install the additional dependencies for post-training:
bash tools/setup/setup_post_training_requirements.sh
c. Setup storage with Hyperdisk#
To store large models and datasets, attach a Hyperdisk to your TPU VM. Refer to the Google Cloud Hyperdisk documentation and TPU VM management for detailed instructions.
First, create a Hyperdisk:
export ZONE=<your-tpu-zone> # e.g., us-central1-a
export TPU_VM_NAME=<your-tpu-vm-name>
export DISK_NAME=<your-disk-name> # e.g., my-hyperdisk
export DISK_SIZE=<disk-size> # e.g., 500GB
gcloud compute disks create ${DISK_NAME} \
--size=${DISK_SIZE} \
--type=hyperdisk-balanced \
--zone=${ZONE}
Then, attach the disk to your TPU VM:
gcloud compute instances attach-disk ${TPU_VM_NAME} \
--disk=${DISK_NAME} \
--zone=${ZONE}
Inside the TPU VM, format and mount the disk (if not already mounted):
# Assuming the disk is /dev/sdb, check with lsblk
sudo mkfs.ext4 /dev/sdb
sudo mkdir -p /mnt/hyperdisk
sudo mount /dev/sdb /mnt/hyperdisk
Update the BASE_DIRECTORY to point to the mounted disk and create the directory:
export BASE_NAME=<your-base-directory> # e.g., knowledge-distillation
export BASE_DIRECTORY=/mnt/hyperdisk/${BASE_NAME}
mkdir -p ${BASE_DIRECTORY}
Note: This tutorial uses a mounted Hyperdisk for performance and reproducibility, because writing large model files and many small I/O operations directly to
gs://can be significantly slower.
Obtain and prepare the teacher model#
For the teacher model, we will use vLLM to run inference. vLLM can load Hugging Face checkpoints directly, so no conversion to MaxText format is needed for the teacher. Ensure the teacher model is supported on TPU vLLM (refer to the vLLM TPU recommended models for the latest list).
You can simply download the model from Hugging Face to your local directory:
huggingface-cli login --token $HF_TOKEN
huggingface-cli download Qwen/Qwen3-32B --repo-type model --local-dir ${BASE_DIRECTORY}/qwen3-32b
Obtain and prepare the student model#
The student model will be trained in MaxText, which uses the Orbax checkpointing format. You will download the Hugging Face weights to your mounted bucket and convert them for training.
Convert checkpoint to MaxText format#
The following command downloads the Hugging Face weights and converts them to the MaxText format.
Note: This conversion script requires PyTorch.
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
# Set the checkpoint directory
export PRE_TRAINED_MODEL_CKPT_DIRECTORY=${BASE_DIRECTORY}/llama3.1-8b-ckpt
# Convert to MaxText format
python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
model_name=llama3.1-8b \
hf_access_token=${HF_TOKEN} \
base_output_directory=${PRE_TRAINED_MODEL_CKPT_DIRECTORY} \
scan_layers=True skip_jax_distributed_system=True
Generate dataset using vLLM (Teacher Step)#
Use the provided script generate_distillation_data_vllm.py to generate the dataset from the teacher model. This script writes a Parquet dataset compatible with MaxText SFT.
Run the generation script:
export OUTPUT_DATASET=${BASE_DIRECTORY}/datasets/distillation_data.parquet
python3 -m tools.data_generation.generate_distillation_data_vllm \
--dataset-path HuggingFaceH4/ultrachat_200k \
--data-split train_sft \
--data-columns messages \
--hf-access-token $HF_TOKEN \
--teacher-model ${BASE_DIRECTORY}/qwen3-32b \
--use-chat-template \
--num-prompts 5120 \
--num-generations 2 \
--output-file ${OUTPUT_DATASET}
Fine-tune the student model using Supervised Fine Tuning (SFT)#
You can now fine-tune your smaller student model using supervised fine-tuning technique in MaxText.
Fine-tune the student model using the generated dataset#
Example command to run fine-tuning on a TPU v6e-8:
python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/configs/post_train/sft.yml \
run_name=${RUN_NAME} \
base_output_directory=${BASE_DIRECTORY}/distillation/qwen3-32b-distill-llama3.1-8b \
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct tokenizer_type=huggingface \
dataset_type=hf \
hf_path=parquet \
hf_train_files=${OUTPUT_DATASET} \
train_split='train' \
train_data_columns=['messages'] \
load_parameters_path=${PRE_TRAINED_MODEL_CKPT_DIRECTORY}/0/items \
model_name=llama3.1-8b \
per_device_batch_size=2 \
steps=200 \
ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
max_target_length=2048 \
hf_access_token=$HF_TOKEN \
profiler=xplane
[OPTIONAL] Fine-tune the student model using the original dataset#
The checkpoint from the student model’s fine-tuning (on the teacher-generated dataset) can be used for a subsequent fine-tuning stage. In this step, the student model is fine-tuned on the original dataset that was initially provided to the teacher model for generating the dataset.
# Get the latest checkpoint for fine-tuned student model
CHECKPOINTS_PATH=${BASE_DIRECTORY}/distillation/qwen3-32b-distill-llama3.1-8b/${RUN_NAME}/checkpoints
checkpoints=$(ls $CHECKPOINTS_PATH)
integer_dirs=()
for dir in $checkpoints; do
dir_name=$(basename "$dir")
if [[ "$dir_name" =~ ^[0-9]+$ ]]; then
integer_dirs+=("$dir_name")
fi
done
sorted_dirs=($(printf '%s\n' "${integer_dirs[@]}" | sort -n))
largest_dir="${sorted_dirs[-1]}"
FINE_TUNED_MODEL_CKPT_PATH=${CHECKPOINTS_PATH}/${largest_dir}/model_params
# Fine-tune student model on original dataset
python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \
run_name=${RUN_NAME}_stage2 \
base_output_directory=${BASE_DIRECTORY}/distillation/qwen3-32b-distill-llama3.1-8b \
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct tokenizer_type=huggingface \
dataset_type=hf \
hf_path='HuggingFaceH4/ultrachat_200k' \
train_split='train_sft' \
train_data_columns=['messages'] \
load_parameters_path=${FINE_TUNED_MODEL_CKPT_PATH} \
model_name=llama3.1-8b \
per_device_batch_size=2 \
steps=200 \
ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
max_target_length=2048 \
hf_access_token=$HF_TOKEN \
profiler=xplane