maxtext.checkpoint_conversion.utils.hf_shape module#

Hugging Face shape checkpoint conversion utils.

maxtext.checkpoint_conversion.utils.hf_shape.GEMMA3_HF_WEIGHTS_TO_SHAPE(config)[source]#

Generates a shape mapping for Hugging Face Gemma3 parameters.

This function computes the expected shapes for all parameters in a Hugging Face Gemma3 model, including both the text and vision components. The shapes are derived from the provided model configuration.

Parameters:

config (dict) – The Hugging Face model configuration dictionary. It must contain ‘text_config’ and ‘vision_config’ sub-dictionaries with all necessary architectural details (e.g., hidden_size, num_layers).

Returns:

A dictionary where keys are Hugging Face parameter names (e.g., ‘model.language_model.embed_tokens.weight’) and values are lists of integers representing the tensor’s shape.

Return type:

dict

maxtext.checkpoint_conversion.utils.hf_shape.GEMMA4_HF_WEIGHTS_TO_SHAPE(config)[source]#

Generates shape mapping for Hugging Face Gemma4 parameters.

Handles both multimodal (with vision tower) and text-only variants, as well as MoE (26B) and dense (31B) text configurations. Shapes are per-layer aware: local (sliding) attention layers use head_dim, while global (full) attention layers use global_head_dim and num_global_key_value_heads.

Parameters:

config (dict) – The Hugging Face model configuration dictionary. Must contain ‘text_config’ with architectural details. May contain ‘vision_config’ for multimodal models.

Returns:

A dictionary mapping Hugging Face parameter names to their shapes.

Return type:

dict

maxtext.checkpoint_conversion.utils.hf_shape.GEMMA2_HF_WEIGHTS_TO_SHAPE(config)[source]#

Returns mapping between HuggingFace weights path and weights shape.

Parameters:

config (dict) – Model configuration dictionary, defined in model_configs.py

Returns:

A mapping where:
  • Keys are HuggingFace model parameter paths

  • Values are parameter shape as a list

Return type:

dict

maxtext.checkpoint_conversion.utils.hf_shape.DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config)[source]#

Returns mapping between HuggingFace weights path and their shape derived from HF config.

Parameters:

config (dict) – HF configuration dictionary e.g., https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/config.json

Returns:

A mapping where:
  • Keys are HuggingFace model parameter paths

  • Values are parameter shape as a list

Return type:

dict

To check expected mapping:

from transformers import AutoModelForCausalLM model_name = “deepseek-ai/DeepSeek-V2-Lite” model = AutoModelForCausalLM.from_pretrained(model_name, dtype=”auto”) for name, val in model.named_parameters():

print(name, val.shape)

maxtext.checkpoint_conversion.utils.hf_shape.QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config)[source]#

Returns mapping between HuggingFace Qwen3-Next weights path and their shape.

maxtext.checkpoint_conversion.utils.hf_shape.GPT_OSS_HF_WEIGHTS_TO_SHAPE(config)[source]#

Returns mapping between HuggingFace GptOss weights path and their shape.

maxtext.checkpoint_conversion.utils.hf_shape.QWEN_HF_WEIGHTS_TO_SHAPE(config)[source]#

Returns mapping between HuggingFace Qwen weights path and the HuggingFace weights shape.

Parameters:

config (dict) – HF configuration dictionary (from Qwen3TextConfig.to_dict()) e.g., https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json

Returns:

A mapping where:
  • Keys are HuggingFace model parameter paths

  • Values are parameter shape as a list

Return type:

dict

To check expected mapping:

from transformers import AutoModelForCausalLM model_name = “Qwen/Qwen3-0.6B” model = AutoModelForCausalLM.from_pretrained(model_name, dtype=”auto”) for name, val in model.named_parameters():

print(name, val.shape)

maxtext.checkpoint_conversion.utils.hf_shape.LLAMA31_HF_WEIGHTS_TO_SHAPE(config)[source]#

Returns mapping between HuggingFace weights path and weights shape.

Parameters:

config (dict) – Model configuration dictionary, defined in model_configs.py

Returns:

A mapping where:
  • Keys are HuggingFace model parameter paths

  • Values are parameter shape as a List

Return type:

dict

maxtext.checkpoint_conversion.utils.hf_shape.MIXTRAL_HF_WEIGHTS_TO_SHAPE(config)[source]#

Returns a mapping of Hugging Face parameter names to their tensor shapes.

Parameters:

config (dict) – The model configuration dictionary.

Returns:

A dictionary mapping Hugging Face parameter paths to their tensor shapes.