Checkpoint Conversion Utilities#
This guide provides instructions to use checkpoint conversion scripts to convert model checkpoints bidirectionally between Hugging Face and MaxText formats.
Supported models#
The following models are supported:
Model Family |
Sizes |
HF \(\to\) Orbax (scan) |
HF \(\to\) Orbax (unscan) |
Orbax (scan) \(\to\) HF |
Orbax (unscan) \(\to\) HF |
|---|---|---|---|---|---|
Gemma2 |
2B, 9B, 27B |
√ |
√ |
√ |
√ |
Gemma3 (Multimodal) |
4B, 12B, 27B |
√ |
√ |
√ |
√ |
Llama3.1 |
8B, 70B, 450B |
√ |
√ |
√ |
√ |
Qwen2.5 |
1.5B, 7B, 14B |
√ |
√ |
√ |
√ |
Qwen3 |
0.6B, 4B, 8B, 14B, 32B |
√ |
√ |
√ |
√ |
Qwen3 MoE |
30B, 235B, 480B |
√ |
√ |
√ |
√ |
Mixtral |
8x7B, 8x22B |
√ |
√ |
√ |
√ |
GPT-OSS |
20B, 120B |
√ |
√ |
√ |
√ |
DeepSeek2 |
16B |
√ |
√ |
√ |
√ |
DeepSeek3 |
671B |
√ |
√ |
√ |
√ |
DeepSeek3.2 |
671B |
√ |
√ |
- |
- |
Qwen3 Next |
80B |
√ |
√ |
√ |
√ |
Prerequisites#
MaxText must be installed in a Python virtual environment using the
maxtext[tpu]option. For instructions on installing MaxText on your VM, please refer to the official installation documentation.Hugging Face model checkpoints are cached locally at
$HOME/.cache/huggingface/hubbefore conversion. Ensure you have sufficient disk space.Authenticate via the Hugging Face CLI if using private or gated models.
Hugging Face to MaxText#
Use the to_maxtext.py script to convert a Hugging Face model checkpoint into a MaxText checkpoint. The script will automatically download the specified model from the Hugging Face Hub, perform conversion, and save converted checkpoints to the given output directory.
Note: For more information, checkout qwen3-4b example script and gemma3-4b example script.
Setup Environment#
# Install PyTorch (in MaxText virtual environment)
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
# Setup environment variables
export MODEL=<HF_MODEL> # e.g. 'llama3.1-8b-Instruct'
export BASE_OUTPUT_DIRECTORY=<CKPT_PATH> # e.g., gs://my-bucket/my-checkpoint-directory
export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX
export LAZY_LOAD_TENSORS=<LAZY_LOAD> # Set to True to save RAM
Run Conversion#
# Optional: If you run out of disk space when downloading Hugging Face safetensors,
# customize your "HF_HOME" to redirect the cache to a larger or mounted disk (e.g., on a TPU VM).
# export HF_HOME="/dev/shm/huggingface_tmp"
python3 -m maxtext.checkpoint_conversion.to_maxtext \
model_name=${MODEL?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
scan_layers=True \
use_multimodal=false \
hardware=cpu \
skip_jax_distributed_system=true \
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
--lazy_load_tensors=${LAZY_LOAD_TENSORS?} \
--save_dtype=bfloat16
You can find your converted checkpoint files under ${BASE_OUTPUT_DIRECTORY}/0/items.
Key Parameters#
model_name: The specific model identifier. It must match a supported entry in the MaxText globals.py.scan_layers: Controls whether the output uses a scanned (scan_layers=true) or unscanned (scan_layers=false) checkpoint format. Refer here for more information.use_multimodal: Indicates if multimodality is used, important for Gemma3.base_output_directory: The path where the converted Orbax checkpoint will be stored; it can be Google Cloud Storage (GCS) or local.hardware=cpu: The conversion script runs on a CPU machine.checkpoint_storage_use_zarr3andcheckpoint_storage_use_ocdbt: These storage flags enable McJAX compatibility when set to True (the default). For Pathways, these should be False.--lazy_load_tensors(Optional): Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with--lazy_load_tensors=trueuses around 200GB of RAM and completes in ~10 minutes.--hf_model_path(Optional): Specifies a customized remote directory or local directory containing the model weights. If unspecified, we use the default Hugging Face repository ID (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.--save_dtype(Optional): Specifies the data type of saved model weights. Default tobfloat16to save memory.
MaxText to Hugging Face#
Use the to_huggingface.py script to convert a MaxText checkpoint into the Hugging Face format. This is useful for sharing your models or integrating them with the Hugging Face ecosystem.
Note: For more information, checkout qwen3-4b example script.
Setup Environment#
# Install PyTorch (in MaxText virtual environment)
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
# Setup environment variables
export MODEL=<MODEL_NAME> # e.g. 'qwen3-4b'
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
export BASE_OUTPUT_DIRECTORY=<HF_CKPT_PATH> # e.g., gs://my-bucket/my-checkpoint-directory
Run Conversion#
The following command converts a MaxText checkpoint and saves it locally, to GCS (gs://), or uploads it directly to the Hugging Face Hub (hf://).
python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=${MODEL?} \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
hardware=cpu \
skip_jax_distributed_system=true \
scan_layers=false \
use_multimodal=false \
weight_dtype=bfloat16
Key Parameters#
model_name: The specific model identifier. It must match a supported entry in the MaxText globals.py.load_parameters_path: The path to the MaxText Orbax checkpoint.scan_layers: Controls whether the output uses a scanned (scan_layers=true) or unscanned (scan_layers=false) checkpoint format. Refer here for more information.use_multimodal: Indicates if multimodality is used, important for Gemma3.hardware=cpu: The conversion script runs on a CPU machine.base_output_directory: The path where the converted checkpoint will be stored; it can be Google Cloud Storage (GCS), Hugging Face Hub or local.weight_dtype: It affects the resulting Hugging Face weight dtype. Default value isfloat32. We recommend usingbfloat16to save memory and speed up conversion.
Verifying conversion correctness#
To ensure the conversion was successful, you can use the test script. It runs a forward pass on both the original and converted models and compares the output logits to verify conversion. It is used to verify the bidirectional conversion.
Note: This correctness test will only work when MaxText is installed from source by following the installation instructions here.
Setup Environment#
# Setup environment variables
export MODEL=<MODEL_NAME> # e.g. 'qwen3-4b'
export MAXTEXT_CKPT_PATH=<CKPT_PATH> # e.g., gs://my-bucket/my-model-checkpoint/0/items
export HF_CKPT_PATH=<HF_CKPT_PATH> # e.g., gs://my-bucket/my-checkpoint-directory
Run Correctness Test#
python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \
load_parameters_path=${MAXTEXT_CKPT_PATH?} \
model_name=${MODEL?} \
skip_jax_distributed_system=true \
scan_layers=false \
max_prefill_predict_length=4 \
max_target_length=8 \
use_multimodal=false \
--run_hf_model=True \
--hf_model_path=${HF_CKPT_PATH?} \
--max_kl_div=0.015
Key Parameters#
load_parameters_path: The path to the MaxText Orbax checkpoint (e.g.,gs://your-bucket/maxtext-checkpoint/0/items).model_name: The corresponding model name in the MaxText configuration (e.g.,qwen3-4b).scan_layers: Controls whether the output uses a scanned (scan_layers=true) or unscanned (scan_layers=false) checkpoint format. Refer here for more information.use_multimodal: Indicates if multimodality is used.--run_hf_model(Optional): Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits.--hf_model_path(Optional): The path to the Hugging Face checkpoint (if--run_hf_model=True).--golden_logits_path(Optional): The pre-saved golden logits. (if--run_hf_modelis not set).--max_kl_div: Max KL divergence tolerance during comparisons.
Example of Successful Conversion Verification#
Here is part of the output of forward_pass_logit_checker for the gemma2-2b.
--- Prompt: What is the ---
--- MaxText model top 10 tokens ---
| Token ID | Token | Score |
|------------|----------------------|------------|
| 5830 | difference | 27.2500 |
| 1963 | best | 26.6250 |
| 5316 | average | 26.3750 |
| 2669 | change | 26.1250 |
| 12070 | percentage | 26.1250 |
| 1618 | value | 25.8750 |
| 1546 | most | 25.7500 |
| 66202 | molar | 25.5000 |
| 3051 | total | 25.5000 |
| 1503 | name | 25.3750 |
--- HF model top 10 tokens ---
| Token ID | Token | Score |
|------------|----------------------|------------|
| 5830 | difference | 27.2500 |
| 1963 | best | 26.6250 |
| 5316 | average | 26.3750 |
| 12070 | percentage | 26.1250 |
| 2669 | change | 26.1250 |
| 1618 | value | 25.8750 |
| 1546 | most | 25.7500 |
| 66202 | molar | 25.5000 |
| 3051 | total | 25.5000 |
| 6187 | purpose | 25.3750 |
--- Similarity Metrics of Top Tokens ---
| Metric | Value |
|--------------------------------|----------------------|
| overlap_count | 9/10 |
| jaccard_similarity | 0.8181818181818182 |
| rank_agreement_percentage | 70.0 |
Average KL divergence per token (D_KL(P_golden || Q_model)): 0.000409
Max KL divergence for a single token in the set: 0.003497
Troubleshooting and Development#
Adding New Models#
To extend conversion support to a new model architecture, you must define its specific parameter and configuration mappings. The conversion logic is decoupled, so you only need to modify the mapping files.
Add parameter mappings:
In
utils/param_mapping.py, add the parameter name mappings(def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING). This is the 1-to-1 mappings of parameters names per layer.In
utils/param_mapping.py, add thehook_fnlogic (def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN). This is the transformation needed per layer.
Add Hugging Face weights Shape: In
utils/globals.py, define the tensor shape of Hugging Face format (def {MODEL}_HF_WEIGHTS_TO_SHAPE). This is used to ensure the tensor shape is matched after to_huggingface conversion.Register model key: In
utils/utils.py, add the new model key inHF_IDS.Add transformer config: In
utils/hf_model_configs.py, add thetransformers.Configobject, describing the Hugging Face model configuration (defined insrc/maxtext/configs/models). This configuration must precisely match the MaxText model’s architecture.
Here is an example PR to add support for gemma3 multi-modal model.
Common Errors#
“Type ShapeDtypeStruct is not a valid JAX type”: Usually caused by a mismatch in the
scan_layersflag.If the converted checkpoint loads without errors but produces nonsensical output, likely an error in the Q/K/V weight reshaping logic during conversion.
If the model generates repetitive text sequences, check if layer normalization parameters were mapped correctly.