# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parameter mappings and transformation hooks for checkpoint conversion.
This module defines the necessary components to convert model checkpoints between
MaxText and Hugging Face formats for various architectures (e.g., Gemma, Qwen).
It provides two key types of mappings for each model:
1. **Parameter Name Mappings (`PARAM_MAPPING`)**: Dictionaries that map a MaxText
parameter key to its corresponding Hugging Face parameter(s). These mappings are
generated by functions like `GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING`.
**Key: MaxText parameters, with following forms:**
- `atomic_mt_key`: A single string representing one MaxText parameter.
- `composite_mt_key`: A tuple of strings representing multiple MaxText parameters. (e.g., GPT-OSS)
**Value: corresponding Hugging Face parameters, with following forms:**
- `unscanned`: A single string.
- `scanned`: A list of strings, to be stacked along the layer axis.
- `unscanned with expert stacking`: A list of strings, to be stacked along the expert axis.
- `scanned with expert stacking`: A nested list of strings, to be stacked along both layer and expert axes.
Note: Expert stacking only applies a subset of MoE models (e.g., Qwen MoE, DeepSeek, Mixtral),
but not others (e.g., GPT-OSS).
2. **Hook Functions (`HOOK_FNS`)**: Dictionaries that map a MaxText parameter
name to a specific transformation function (a "hook"). These hooks handle
the actual value conversion, which can include operations like reshaping,
transposing, scaling, or padding tensors to match the target format's
requirements. These are generated by functions like
`GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN`.
The main conversion script uses these mappings to systematically transform each
parameter from the source checkpoint and build the target checkpoint.
"""
import warnings
import numpy as np
import jax
import jax.numpy as jnp
[docs]
def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Generates a parameter mapping from MaxText to Hugging Face for Gemma3.
This function creates a dictionary that maps the parameter names from a
MaxText Gemma3 checkpoint to their corresponding names in the Hugging Face
`Gemma3ForCausalLM` format. It handles both the text and vision components
of the model.
Args:
config (dict): The Hugging Face model configuration dictionary, which must
contain 'text_config' and 'vision_config' sub-dictionaries.
scan_layers (bool, optional): If True, generates mappings for scanned
layers, where multiple layers are stacked into a single tensor. If False,
generates mappings for individual, unscanned layers. Defaults to False.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values
are either a single Hugging Face parameter name (unscanned form) or a list of
Hugging Face parameter names (scanned form) for stacked text layers.
"""
tcfg = config["text_config"]
vcfg = config["vision_config"]
Ndec = tcfg["num_hidden_layers"]
Nvis = vcfg["num_hidden_layers"]
# pylint: disable=line-too-long
mapping = {
# Embedding & final norm
"params-token_embedder-embedding": "model.language_model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.language_model.norm.weight",
# Vision embed & pos
"params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-kernel": "model.vision_tower.vision_model.embeddings.patch_embedding.weight",
"params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-bias": "model.vision_tower.vision_model.embeddings.patch_embedding.bias",
"params-vision_encoder-Gemma3VisionEncoderLayer_0-pos_embedding": "model.vision_tower.vision_model.embeddings.position_embedding.weight",
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoder_norm-scale": "model.vision_tower.vision_model.post_layernorm.weight",
"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoder_norm-bias": "model.vision_tower.vision_model.post_layernorm.bias",
# Multi-modal projector
"params-vision_encoder-VisionEmbedder_0-mm_input_projection-w": "model.multi_modal_projector.mm_input_projection_weight",
"params-vision_encoder-VisionEmbedder_0-mm_soft_embedding_norm-scale": "model.multi_modal_projector.mm_soft_emb_norm.weight",
}
vision_params = [
("LayerNorm_0-scale", "layer_norm1.weight"),
("LayerNorm_0-bias", "layer_norm1.bias"),
("LayerNorm_1-scale", "layer_norm2.weight"),
("LayerNorm_1-bias", "layer_norm2.bias"),
("MultiHeadDotProductAttention_0-query-kernel", "self_attn.q_proj.weight"),
("MultiHeadDotProductAttention_0-query-bias", "self_attn.q_proj.bias"),
("MultiHeadDotProductAttention_0-key-kernel", "self_attn.k_proj.weight"),
("MultiHeadDotProductAttention_0-key-bias", "self_attn.k_proj.bias"),
("MultiHeadDotProductAttention_0-value-kernel", "self_attn.v_proj.weight"),
("MultiHeadDotProductAttention_0-value-bias", "self_attn.v_proj.bias"),
("MultiHeadDotProductAttention_0-out-kernel", "self_attn.out_proj.weight"),
("MultiHeadDotProductAttention_0-out-bias", "self_attn.out_proj.bias"),
("MlpBlockViT_0-Dense_0-kernel", "mlp.fc1.weight"),
("MlpBlockViT_0-Dense_0-bias", "mlp.fc1.bias"),
("MlpBlockViT_0-Dense_1-kernel", "mlp.fc2.weight"),
("MlpBlockViT_0-Dense_1-bias", "mlp.fc2.bias"),
]
# Vision layers mapping
for i in range(Nvis):
for mx, hf in vision_params:
key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-{mx}"
mapping[key] = f"model.vision_tower.vision_model.encoder.layers.{i}.{hf}"
# Text decoder mapping
text_params = [
("pre_self_attention_norm-scale", "input_layernorm.weight"),
("post_self_attention_norm-scale", "post_attention_layernorm.weight"),
("self_attention-query_norm-scale", "self_attn.q_norm.weight"),
("self_attention-key_norm-scale", "self_attn.k_norm.weight"),
("pre_ffw_norm-scale", "pre_feedforward_layernorm.weight"),
("post_ffw_norm-scale", "post_feedforward_layernorm.weight"),
("self_attention-query-kernel", "self_attn.q_proj.weight"),
("self_attention-key-kernel", "self_attn.k_proj.weight"),
("self_attention-value-kernel", "self_attn.v_proj.weight"),
("self_attention-out-kernel", "self_attn.o_proj.weight"),
("mlp-wi_0-kernel", "mlp.gate_proj.weight"),
("mlp-wi_1-kernel", "mlp.up_proj.weight"),
("mlp-wo-kernel", "mlp.down_proj.weight"),
]
if scan_layers:
# Gemma3 repeats a 6-layer attention pattern (5 local + 1 global),
# scanned as layers_0..layers_5 with leftovers in layers_remainder.
attention_pattern_length = 6
num_remaining = Ndec % attention_pattern_length
num_scanned = Ndec - num_remaining
# Main scanned blocks: params-decoder-layers-layers_{block_idx}-{param}
for block_idx in range(attention_pattern_length):
hf_indices = list(range(block_idx, num_scanned, attention_pattern_length))
for mx, hf in text_params:
key = f"params-decoder-layers-layers_{block_idx}-{mx}"
mapping[key] = [f"model.language_model.layers.{i}.{hf}" for i in hf_indices]
# Remainder layers (unscanned): params-decoder-layers_remainder-layers_{rem_idx}-{param}
if num_remaining > 0:
for rem_idx in range(num_remaining):
hf_layer_idx = num_scanned + rem_idx
for mx, hf in text_params:
key = f"params-decoder-layers_remainder-layers_{rem_idx}-{mx}"
mapping[key] = f"model.language_model.layers.{hf_layer_idx}.{hf}"
else:
for i in range(Ndec):
for mx, hf in text_params:
key = f"params-decoder-layers_{i}-{mx}"
mapping[key] = f"model.language_model.layers.{i}.{hf}"
return mapping
[docs]
def GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Hook functions for Gemma3 parameter conversion.
This function provides a dictionary of transformation functions (hooks) for
converting Gemma3 model parameters between MaxText and Hugging Face formats.
It handles embedding padding/scaling, RMSNorm scaling, kernel reshaping, and
vision-specific tensor manipulations.
Args:
config (dict): The Hugging Face model configuration dictionary.
scan_layers (bool, optional): Whether the model uses scanned layers.
Defaults to False.
saving_to_hf (bool, optional): The direction of conversion. True for
MaxText to Hugging Face, False for the reverse. Defaults to False.
Returns:
dict: A dictionary mapping MaxText parameter names to their corresponding
transformation functions.
"""
hooks = {}
# ---- Embedding pad & scale ----
def pad_and_scale_embedding(input_tensor, target_shape):
source_vocab_size, _ = input_tensor.shape
target_vocab_size, target_hidden_size = target_shape
# MaxText embedding = original_embedding * sqrt(hidden_size)
# HF embedding = original_embedding (HF model forward pass applies scaling)
# Note: config["hidden_size"] is the HF hidden size from the HF config object
normalizer = np.dtype("bfloat16").type(config["text_config"]["hidden_size"] ** 0.5)
# Apply scaling first
if saving_to_hf: # MaxText to HF
scaled_tensor = (input_tensor / normalizer).astype(input_tensor.dtype)
else: # HF to MaxText
scaled_tensor = (input_tensor * normalizer).astype(input_tensor.dtype)
# Handle padding/truncation
if source_vocab_size > target_vocab_size:
warnings.warn(
f"source vocab={source_vocab_size} > target vocab={target_vocab_size}, truncate output layer for MaxText."
)
output_tensor = scaled_tensor[:target_vocab_size, :]
elif source_vocab_size < target_vocab_size:
warnings.warn(f"source vocab={source_vocab_size} < target vocab={target_vocab_size}, pad output layer for MaxText.")
padding_shape = (target_vocab_size - source_vocab_size, target_hidden_size)
# Use jnp.zeros for JAX arrays, np.zeros for numpy arrays
padding = (
jnp.zeros(padding_shape, dtype=scaled_tensor.dtype)
if isinstance(scaled_tensor, jax.Array)
else np.zeros(padding_shape, dtype=scaled_tensor.dtype)
)
output_tensor = (
jnp.concatenate([scaled_tensor, padding], axis=0)
if isinstance(scaled_tensor, jax.Array)
else np.concatenate([scaled_tensor, padding], axis=0)
)
else: # Vocab sizes match
output_tensor = scaled_tensor
return output_tensor
# ---- RMSNorm scale ----
def scale_rmsnorm(x, target_shape):
# MaxText norm = HF norm +1; HF norm = MaxText norm -1
if saving_to_hf:
return (x - 1.0).reshape(target_shape)
return (x + 1.0).reshape(target_shape)
# ---- Generic reshape ----
def reshape_kernel(x, target_shape):
if saving_to_hf:
flipped = np.flip(np.array(target_shape))
return x.reshape(flipped).T
else:
return x.T.reshape(target_shape)
# ---- Vision reshape ----
def vis_bias(x, target_shape):
if saving_to_hf:
return x.flatten()
else:
return x.reshape(target_shape)
def vision_patch(x, target_shape):
if saving_to_hf:
return x.transpose(3, 2, 0, 1)
else:
return x.transpose(2, 3, 1, 0)
def pos_embed(x, target_shape):
if saving_to_hf:
return x.squeeze(0)
return x[None, :, :]
# ---Embedding & final norm---
hooks["params-token_embedder-embedding"] = pad_and_scale_embedding
hooks["params-decoder-decoder_norm-scale"] = scale_rmsnorm
# [1, 4096, 1152]
hooks["params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-kernel"] = vision_patch
hooks["params-vision_encoder-Gemma3VisionEncoderLayer_0-pos_embedding"] = pos_embed
hooks["params-vision_encoder-VisionEmbedder_0-mm_input_projection-w"] = lambda x, _: x
hooks["params-vision_encoder-VisionEmbedder_0-mm_soft_embedding_norm-scale"] = scale_rmsnorm
# Text layers
tc = config.get("text_config", {})
nlayers = tc.get("num_hidden_layers", 0)
if scan_layers:
attention_pattern_length = 6
num_remaining = nlayers % attention_pattern_length
# Scanned sub-layer prefixes
prefixes = [f"params-decoder-layers-layers_{block_idx}-" for block_idx in range(attention_pattern_length)]
# Remainder sub-layer prefixes
if num_remaining > 0:
prefixes += [f"params-decoder-layers_remainder-layers_{rem_idx}-" for rem_idx in range(num_remaining)]
else:
prefixes = [f"params-decoder-layers_{i}-" for i in range(nlayers)]
for pref in prefixes:
# Attention Q/K/V/O
hooks[pref + "self_attention-query-kernel"] = reshape_kernel
hooks[pref + "self_attention-key-kernel"] = reshape_kernel
hooks[pref + "self_attention-value-kernel"] = reshape_kernel
hooks[pref + "self_attention-out-kernel"] = reshape_kernel
# Norm scales
for nm in [
"pre_self_attention_norm-scale",
"post_self_attention_norm-scale",
"self_attention-query_norm-scale",
"self_attention-key_norm-scale",
"pre_ffw_norm-scale",
"post_ffw_norm-scale",
]:
hooks[pref + nm] = scale_rmsnorm
# MLP
hooks[pref + "mlp-wi_0-kernel"] = reshape_kernel
hooks[pref + "mlp-wi_1-kernel"] = reshape_kernel
hooks[pref + "mlp-wo-kernel"] = reshape_kernel
# Vision layers
vc = config.get("vision_config", {})
nvis = vc.get("num_hidden_layers", 0)
for i in range(nvis):
base = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-"
# Attention kernels & biases
for qkv in ["query", "key", "value"]:
hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-kernel"] = reshape_kernel
hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-bias"] = vis_bias
# [1152, 1152] -> [16, 72, 1152]
hooks[base + "MultiHeadDotProductAttention_0-out-kernel"] = reshape_kernel
hooks[base + "MultiHeadDotProductAttention_0-out-bias"] = vis_bias
# MLP ViT kernels & biases
for dense in ["Dense_0", "Dense_1"]:
hooks[base + f"MlpBlockViT_0-{dense}-kernel"] = reshape_kernel
return hooks
[docs]
def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping between MaxText and HuggingFace Gemma2 weight paths.
Args:
config (dict): Model configuration dictionary containing at least
'num_hidden_layers'.
scan_layers (bool, optional): Whether the MaxText model uses layer
scanning optimization. When True, decoder layers are stacked into a
single tensor. Defaults to False.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter name).
Values are either a single string (unscanned form) or a list of strings
(scanned form) for stacked layers when `scan_layers=True`.
Notes:
- MaxText uses a paired layer approach where two HF decoder layers are
treated as one MaxText decoder layer.
- MaxText layer `i` corresponds to HF layers `2i` and `2i+1`.
- Local components map to even-numbered HF decoder layers (0, 2, 4...).
- Global components map to odd-numbered HF decoder layers (1, 3, 5...).
"""
nlayers = config["num_hidden_layers"]
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
}
if scan_layers:
mapping = {
**mapping,
"params-decoder-layers-pre_self_attention_norm_global-scale": [
f"model.layers.{i}.input_layernorm.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-mlp_global-wo-kernel": [
f"model.layers.{i}.mlp.down_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-mlp_global-wi_1-kernel": [
f"model.layers.{i}.mlp.up_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-mlp_global-wi_0-kernel": [
f"model.layers.{i}.mlp.gate_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-post_self_attention_norm_global-scale": [
f"model.layers.{i}.post_attention_layernorm.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-post_ffw_norm_global-scale": [
f"model.layers.{i}.post_feedforward_layernorm.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-pre_ffw_norm_global-scale": [
f"model.layers.{i}.pre_feedforward_layernorm.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-self_attention_global-key-kernel": [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-self_attention_global-out-kernel": [
f"model.layers.{i}.self_attn.o_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-self_attention_global-query-kernel": [
f"model.layers.{i}.self_attn.q_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-self_attention_global-value-kernel": [
f"model.layers.{i}.self_attn.v_proj.weight" for i in range(1, nlayers, 2)
],
"params-decoder-layers-pre_self_attention_norm_local-scale": [
f"model.layers.{i}.input_layernorm.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-mlp_local-wo-kernel": [
f"model.layers.{i}.mlp.down_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-mlp_local-wi_1-kernel": [
f"model.layers.{i}.mlp.up_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-mlp_local-wi_0-kernel": [
f"model.layers.{i}.mlp.gate_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-post_self_attention_norm_local-scale": [
f"model.layers.{i}.post_attention_layernorm.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-post_ffw_norm_local-scale": [
f"model.layers.{i}.post_feedforward_layernorm.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-pre_ffw_norm_local-scale": [
f"model.layers.{i}.pre_feedforward_layernorm.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-self_attention_local-key-kernel": [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-self_attention_local-out-kernel": [
f"model.layers.{i}.self_attn.o_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-self_attention_local-query-kernel": [
f"model.layers.{i}.self_attn.q_proj.weight" for i in range(0, nlayers, 2)
],
"params-decoder-layers-self_attention_local-value-kernel": [
f"model.layers.{i}.self_attn.v_proj.weight" for i in range(0, nlayers, 2)
],
}
# Case 2: scan_layer=False
else:
for maxtext_layer_idx in range(0, nlayers // 2):
local_layer_idx = maxtext_layer_idx * 2
global_layer_idx = maxtext_layer_idx * 2 + 1
# pylint: disable=line-too-long
layer_mapping = {
f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": f"model.layers.{global_layer_idx}.input_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wo-kernel": f"model.layers.{global_layer_idx}.mlp.down_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_1-kernel": f"model.layers.{global_layer_idx}.mlp.up_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_0-kernel": f"model.layers.{global_layer_idx}.mlp.gate_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_global-scale": f"model.layers.{global_layer_idx}.post_attention_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_global-scale": f"model.layers.{global_layer_idx}.post_feedforward_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_global-scale": f"model.layers.{global_layer_idx}.pre_feedforward_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": f"model.layers.{global_layer_idx}.self_attn.k_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-out-kernel": f"model.layers.{global_layer_idx}.self_attn.o_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": f"model.layers.{global_layer_idx}.self_attn.q_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-value-kernel": f"model.layers.{global_layer_idx}.self_attn.v_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_local-scale": f"model.layers.{local_layer_idx}.input_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wo-kernel": f"model.layers.{local_layer_idx}.mlp.down_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_1-kernel": f"model.layers.{local_layer_idx}.mlp.up_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_0-kernel": f"model.layers.{local_layer_idx}.mlp.gate_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_local-scale": f"model.layers.{local_layer_idx}.post_attention_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_local-scale": f"model.layers.{local_layer_idx}.post_feedforward_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_local-scale": f"model.layers.{local_layer_idx}.pre_feedforward_layernorm.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-key-kernel": f"model.layers.{local_layer_idx}.self_attn.k_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-out-kernel": f"model.layers.{local_layer_idx}.self_attn.o_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": f"model.layers.{local_layer_idx}.self_attn.q_proj.weight",
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-value-kernel": f"model.layers.{local_layer_idx}.self_attn.v_proj.weight",
}
mapping = {**mapping, **layer_mapping}
return mapping
[docs]
def GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Gemma2 conversion.
This function generates a mapping of transformation functions that handle the
necessary conversions between MaxText and HuggingFace parameter formats for
Gemma2, including operations like padding, reshaping, and scaling.
Args:
config (dict): Model configuration dictionary that must contain:
- num_hidden_layers (int): Number of layers in the model.
- head_dim (int): Dimension of attention heads.
- hidden_size (int): Model's hidden dimension size.
scan_layers (bool, optional): Controls the output format for layer
parameters. True for batched, False for individual. Defaults to False.
saving_to_hf (bool, optional): Determines the direction of transformation.
True for MaxText to HuggingFace, False for the reverse. Defaults to
False.
Returns:
dict: A mapping from MaxText parameter names to transformation functions.
The value can be a single function or a list of functions to be
applied sequentially.
"""
nlayers = config["num_hidden_layers"]
def pad_hf_embedding_layer(input_tensor, target_shape):
"""Pads/unpads and scales the embedding layer.
Note:
HF embedding weights shape = [256000, d_model]
MaxText embedding weights shape = [256128, d_model]
MaxText pads Gemma2 embedding to 256128 for better performance.
"""
# TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype
normalizer = np.dtype("float32").type(config["hidden_size"] ** 0.5)
if saving_to_hf:
target_tensor = input_tensor[: target_shape[0], : target_shape[1]]
target_tensor = target_tensor / normalizer
target_tensor = target_tensor.astype(input_tensor.dtype)
return target_tensor
else:
target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)
target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor
target_tensor = target_tensor * normalizer
target_tensor = target_tensor.astype(input_tensor.dtype)
return target_tensor
def reshape_kernel(input_tensor, target_shape):
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
def scale_rmsnorm_layer(input_tensor, target_shape):
if saving_to_hf:
return (input_tensor - 1.0).reshape(target_shape)
else:
return (input_tensor + 1.0).reshape(target_shape)
def scale_query_layer(input_tensor, target_shape):
if saving_to_hf:
depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"]))
return (input_tensor * depth_scale).astype(input_tensor.dtype)
else:
depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"]))
return (input_tensor * depth_scale).astype(input_tensor.dtype)
# hook order does not affect result
query_hook_chain = [reshape_kernel, scale_query_layer]
mapping = {
"params-token_embedder-embedding": pad_hf_embedding_layer,
"params-decoder-decoder_norm-scale": scale_rmsnorm_layer,
}
if scan_layers:
mapping = {
**mapping,
"params-decoder-layers-self_attention_global-query-kernel": query_hook_chain,
"params-decoder-layers-self_attention_local-query-kernel": query_hook_chain,
"params-decoder-layers-self_attention_global-key-kernel": reshape_kernel,
"params-decoder-layers-self_attention_local-key-kernel": reshape_kernel,
"params-decoder-layers-self_attention_global-value-kernel": reshape_kernel,
"params-decoder-layers-self_attention_local-value-kernel": reshape_kernel,
"params-decoder-layers-mlp_global-wo-kernel": reshape_kernel,
"params-decoder-layers-mlp_global-wi_1-kernel": reshape_kernel,
"params-decoder-layers-mlp_global-wi_0-kernel": reshape_kernel,
"params-decoder-layers-self_attention_global-out-kernel": reshape_kernel,
"params-decoder-layers-mlp_local-wo-kernel": reshape_kernel,
"params-decoder-layers-mlp_local-wi_1-kernel": reshape_kernel,
"params-decoder-layers-mlp_local-wi_0-kernel": reshape_kernel,
"params-decoder-layers-self_attention_local-out-kernel": reshape_kernel,
"params-decoder-layers-pre_self_attention_norm_global-scale": scale_rmsnorm_layer,
"params-decoder-layers-post_self_attention_norm_global-scale": scale_rmsnorm_layer,
"params-decoder-layers-post_ffw_norm_global-scale": scale_rmsnorm_layer,
"params-decoder-layers-pre_ffw_norm_global-scale": scale_rmsnorm_layer,
"params-decoder-layers-pre_self_attention_norm_local-scale": scale_rmsnorm_layer,
"params-decoder-layers-post_self_attention_norm_local-scale": scale_rmsnorm_layer,
"params-decoder-layers-post_ffw_norm_local-scale": scale_rmsnorm_layer,
"params-decoder-layers-pre_ffw_norm_local-scale": scale_rmsnorm_layer,
}
else:
for maxtext_layer_idx in range(nlayers // 2):
mapping = {
**mapping,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": query_hook_chain,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": query_hook_chain,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-key-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-value-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-value-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wo-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_1-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_0-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-out-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wo-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_1-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_0-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-out-kernel": reshape_kernel,
f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_global-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_global-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_global-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_local-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_local-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_local-scale": scale_rmsnorm_layer,
f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_local-scale": scale_rmsnorm_layer,
}
return mapping
[docs]
def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping from MaxText to HuggingFace Qwen weight paths.
This function generates a dictionary that maps parameter names from a MaxText
Qwen checkpoint to their corresponding names in the Hugging Face format.
It handles both dense and Mixture-of-Experts (MoE) model variants.
Args:
config (dict): Model configuration dictionary, including
'num_hidden_layers' and optionally 'num_experts'.
scan_layers (bool, optional): Whether the MaxText model uses scanned
layers. Defaults to False.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names).
Values are Hugging Face parameter names in one of four forms: unscanned (string),
scanned (list of strings), unscanned with expert stacking (list of strings),
or scanned with expert stacking (nested list of strings).
"""
n_layers = config["num_hidden_layers"]
num_experts = config.get("num_experts", 0)
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
}
if scan_layers:
# This block handles scanned layers for both dense and MoE models.
mapping.update(
{
"params-decoder-layers-pre_self_attention_layer_norm-scale": [
f"model.layers.{i}.input_layernorm.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-query-kernel": [
f"model.layers.{i}.self_attn.q_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-key-kernel": [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-value-kernel": [
f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-query-bias": [
f"model.layers.{i}.self_attn.q_proj.bias" for i in range(n_layers)
],
"params-decoder-layers-self_attention-key-bias": [
f"model.layers.{i}.self_attn.k_proj.bias" for i in range(n_layers)
],
"params-decoder-layers-self_attention-value-bias": [
f"model.layers.{i}.self_attn.v_proj.bias" for i in range(n_layers)
],
"params-decoder-layers-self_attention-out-kernel": [
f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-query_norm-scale": [
f"model.layers.{i}.self_attn.q_norm.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-key_norm-scale": [
f"model.layers.{i}.self_attn.k_norm.weight" for i in range(n_layers)
],
"params-decoder-layers-post_self_attention_layer_norm-scale": [
f"model.layers.{i}.post_attention_layernorm.weight" for i in range(n_layers)
],
}
)
if num_experts > 1:
# For scanned MoE, we create a nested list: [[e0_l0, e0_l1..], [e1_l0, e1_l1..]..]
# This follows the (experts, layers, ...) tensor layout.
mapping.update(
{
"params-decoder-layers-moe_block-gate-kernel": [
f"model.layers.{i}.mlp.gate.weight" for i in range(n_layers)
],
"params-decoder-layers-moe_block-wi_0": [
[f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight" for l in range(n_layers)]
for e in range(num_experts)
],
"params-decoder-layers-moe_block-wi_1": [
[f"model.layers.{l}.mlp.experts.{e}.up_proj.weight" for l in range(n_layers)]
for e in range(num_experts)
],
"params-decoder-layers-moe_block-wo": [
[f"model.layers.{l}.mlp.experts.{e}.down_proj.weight" for l in range(n_layers)]
for e in range(num_experts)
],
}
)
else: # Dense MLP
mapping.update(
{
"params-decoder-layers-mlp-wi_0-kernel": [
f"model.layers.{i}.mlp.gate_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-mlp-wi_1-kernel": [f"model.layers.{i}.mlp.up_proj.weight" for i in range(n_layers)],
"params-decoder-layers-mlp-wo-kernel": [f"model.layers.{i}.mlp.down_proj.weight" for i in range(n_layers)],
}
)
else: # unscanned layers
for i in range(n_layers):
# Common Attention and Norms
# pylint: disable=line-too-long
mapping.update(
{
f"params-decoder-layers_{i}-pre_self_attention_layer_norm-scale": f"model.layers.{i}.input_layernorm.weight",
f"params-decoder-layers_{i}-self_attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight",
f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
f"params-decoder-layers_{i}-self_attention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias",
f"params-decoder-layers_{i}-self_attention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias",
f"params-decoder-layers_{i}-self_attention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias",
f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
}
)
if num_experts > 1:
# For each unscanned MoE layer, map the MaxText parameter to a 1D list of all expert weights for that layer.
mapping.update(
{
f"params-decoder-layers_{i}-moe_block-gate-kernel": f"model.layers.{i}.mlp.gate.weight",
f"params-decoder-layers_{i}-moe_block-wi_0": [
f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight" for j in range(num_experts)
],
f"params-decoder-layers_{i}-moe_block-wi_1": [
f"model.layers.{i}.mlp.experts.{j}.up_proj.weight" for j in range(num_experts)
],
f"params-decoder-layers_{i}-moe_block-wo": [
f"model.layers.{i}.mlp.experts.{j}.down_proj.weight" for j in range(num_experts)
],
}
)
else: # Dense MLP
mapping.update(
{
f"params-decoder-layers_{i}-mlp-wi_0-kernel": f"model.layers.{i}.mlp.gate_proj.weight",
f"params-decoder-layers_{i}-mlp-wi_1-kernel": f"model.layers.{i}.mlp.up_proj.weight",
f"params-decoder-layers_{i}-mlp-wo-kernel": f"model.layers.{i}.mlp.down_proj.weight",
}
)
return mapping
[docs]
def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Qwen.
This function provides a dictionary of transformation functions (hooks) for
converting Qwen model parameters between MaxText and Hugging Face formats.
It handles embedding padding and kernel reshaping.
Args:
config (dict): Model configuration dictionary, including
'num_hidden_layers' and optionally 'num_experts'.
scan_layers (bool, optional): Whether the model uses scanned layers.
Defaults to False.
saving_to_hf (bool, optional): The direction of conversion. True for
MaxText to Hugging Face, False for the reverse. Defaults to False.
Returns:
dict: A dictionary mapping MaxText parameter names to their corresponding
transformation functions.
"""
n_layers = config["num_hidden_layers"]
num_experts = config.get("num_experts", 0)
def pad_embedding_layer(input_tensor, target_shape):
"""Pads or truncates embedding layer to match target vocab size."""
source_vocab_size = input_tensor.shape[0]
target_vocab_size = target_shape[0]
if source_vocab_size == target_vocab_size:
return input_tensor
if saving_to_hf: # MaxText to HF, truncate
return input_tensor[:target_vocab_size, :]
else: # HF to MaxText, pad
padded_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)
padded_tensor[:source_vocab_size, :] = input_tensor
return padded_tensor
def reshape_kernel(input_tensor, target_shape):
"""Reshapes and transposes kernel weights between MaxText and HF."""
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
def reshape_bias(input_tensor, target_shape=None):
"""Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
# saving_to_hf: MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
# loading_to_maxtext: HF [hidden_dim] -> MaxText [heads, head_dim]
return input_tensor.reshape(target_shape)
mapping = {
"params-token_embedder-embedding": pad_embedding_layer,
"params-decoder-logits_dense-kernel": reshape_kernel,
}
kernel_hooks = [
"self_attention-query-kernel",
"self_attention-key-kernel",
"self_attention-value-kernel",
"self_attention-out-kernel",
"mlp-wi_0-kernel",
"mlp-wi_1-kernel",
"mlp-wo-kernel",
]
bias_hooks = [
"self_attention-query-bias",
"self_attention-key-bias",
"self_attention-value-bias",
]
moe_kernel_hooks = [
"moe_block-gate-kernel",
"moe_block-wi_0-kernel",
"moe_block-wi_1-kernel",
"moe_block-wo-kernel",
"moe_block-wi_0",
"moe_block-wi_1",
"moe_block-wo",
]
if scan_layers:
for key in kernel_hooks:
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
for key in bias_hooks:
mapping[f"params-decoder-layers-{key}"] = reshape_bias
if num_experts > 1:
for key in moe_kernel_hooks:
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
else:
for i in range(n_layers):
for key in kernel_hooks:
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
for key in bias_hooks:
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_bias
if num_experts > 1:
for key in moe_kernel_hooks:
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
return mapping
[docs]
def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""
Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths.
All MaxText keys start with 'params-' and use '-' separators for scanned layers.
"""
num_main_layers = config["num_hidden_layers"]
num_experts = config["num_experts"]
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
# 1. Non-layer specific weight mappings
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
}
if scan_layers:
# 2. Scan over block cycles
for block_idx in range(layer_cycle_interval):
hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval))
prefix = f"params-decoder-layers-layer_{block_idx}"
# Layer norms
mapping[f"{prefix}-input_layernorm-scale"] = [f"model.layers.{i}.input_layernorm.weight" for i in hf_indices]
mapping[f"{prefix}-post_attention_layernorm-scale"] = [
f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices
]
# Handle Interleaved Attention (Linear vs Full)
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0
if is_full_attention_layer:
mapping.update(
{
f"{prefix}-attention-attention-query-kernel": [
f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices
],
f"{prefix}-attention-attention-key-kernel": [
f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices
],
f"{prefix}-attention-attention-value-kernel": [
f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices
],
f"{prefix}-attention-attention-out-kernel": [
f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices
],
f"{prefix}-attention-attention-query_norm-scale": [
f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices
],
f"{prefix}-attention-attention-key_norm-scale": [
f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices
],
}
)
else:
# Linear/Hybrid Attention Block
mapping.update(
{
f"{prefix}-attention-in_proj_qkvz-kernel": [
f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices
],
f"{prefix}-attention-in_proj_ba-kernel": [
f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices
],
f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices],
f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices],
f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices],
f"{prefix}-attention-norm-rms_norm-scale": [
f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices
],
f"{prefix}-attention-out_proj-kernel": [
f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices
],
}
)
# 3. Handle MLP: Gates and Shared Experts
mapping.update(
{
f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert-wi_0-kernel": [
f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices
],
f"{prefix}-mlp-shared_expert-wi_1-kernel": [
f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices
],
f"{prefix}-mlp-shared_expert-wo-kernel": [
f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices
],
f"{prefix}-mlp-shared_expert_gate-kernel": [
f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices
],
}
)
# 4. Handle MoE Routed Experts
mapping.update(
{
f"{prefix}-mlp-routed_experts-wi_0": [
[f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts)
],
f"{prefix}-mlp-routed_experts-wi_1": [
[f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts)
],
f"{prefix}-mlp-routed_experts-wo": [
[f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts)
],
}
)
else:
# Unscanned layer mapping
for i in range(num_main_layers):
prefix = f"params-decoder-layers_{i}"
# Layer Norms
mapping[f"{prefix}-input_layernorm-scale"] = f"model.layers.{i}.input_layernorm.weight"
mapping[f"{prefix}-post_attention_layernorm-scale"] = f"model.layers.{i}.post_attention_layernorm.weight"
# Determine layer type based on cycle interval
# Assuming block logic: layer i corresponds to block_idx = i % interval
block_idx = i % layer_cycle_interval
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0
if is_full_attention_layer:
mapping.update(
{
f"{prefix}-attention-attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight",
f"{prefix}-attention-attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
f"{prefix}-attention-attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
f"{prefix}-attention-attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
f"{prefix}-attention-attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
f"{prefix}-attention-attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
}
)
else:
# Linear/Hybrid Attention Block
mapping.update(
{
f"{prefix}-attention-in_proj_qkvz-kernel": f"model.layers.{i}.linear_attn.in_proj_qkvz.weight",
f"{prefix}-attention-in_proj_ba-kernel": f"model.layers.{i}.linear_attn.in_proj_ba.weight",
f"{prefix}-attention-conv1d-kernel": f"model.layers.{i}.linear_attn.conv1d.weight",
f"{prefix}-attention-A_log": f"model.layers.{i}.linear_attn.A_log",
f"{prefix}-attention-dt_bias": f"model.layers.{i}.linear_attn.dt_bias",
f"{prefix}-attention-norm-rms_norm-scale": f"model.layers.{i}.linear_attn.norm.weight",
f"{prefix}-attention-out_proj-kernel": f"model.layers.{i}.linear_attn.out_proj.weight",
}
)
# MLP: Gates and Shared Experts
mapping.update(
{
f"{prefix}-mlp-routed_experts-gate-kernel": f"model.layers.{i}.mlp.gate.weight",
f"{prefix}-mlp-shared_expert-wi_0-kernel": f"model.layers.{i}.mlp.shared_expert.gate_proj.weight",
f"{prefix}-mlp-shared_expert-wi_1-kernel": f"model.layers.{i}.mlp.shared_expert.up_proj.weight",
f"{prefix}-mlp-shared_expert-wo-kernel": f"model.layers.{i}.mlp.shared_expert.down_proj.weight",
f"{prefix}-mlp-shared_expert_gate-kernel": f"model.layers.{i}.mlp.shared_expert_gate.weight",
}
)
# MoE Routed Experts (List of expert weights for this specific layer)
mapping.update(
{
f"{prefix}-mlp-routed_experts-wi_0": [
f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for e in range(num_experts)
],
f"{prefix}-mlp-routed_experts-wi_1": [
f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for e in range(num_experts)
],
f"{prefix}-mlp-routed_experts-wo": [
f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for e in range(num_experts)
],
}
)
return mapping
[docs]
def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""
Transformation hooks for parameters using hyphenated 'params-' MaxText keys.
"""
def transpose(input_tensor, target_shape=None):
return input_tensor.T
def reshape_kernel(input_tensor, target_shape):
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
def permute_conv(input_tensor, target_shape=None):
# MT: [K, 1, C] <-> HF: [C, 1, K]
return input_tensor.transpose(2, 1, 0)
# Initialize Hooks
hooks = {
"params-decoder-logits_dense-kernel": transpose,
}
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
num_main_layers = config["num_hidden_layers"]
loop_indices = range(layer_cycle_interval) if scan_layers else range(num_main_layers)
for i in loop_indices:
if scan_layers:
prefix = f"params-decoder-layers-layer_{i}"
block_idx = i
else:
prefix = f"params-decoder-layers_{i}"
block_idx = i % layer_cycle_interval
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0
if is_full_attention_layer:
for key in ["query", "key", "value", "out"]:
hooks[f"{prefix}-attention-attention-{key}-kernel"] = reshape_kernel
else:
hooks[f"{prefix}-attention-in_proj_qkvz-kernel"] = transpose
hooks[f"{prefix}-attention-in_proj_ba-kernel"] = transpose
hooks[f"{prefix}-attention-out_proj-kernel"] = transpose
hooks[f"{prefix}-attention-conv1d-kernel"] = permute_conv
mlp_prefix = f"{prefix}-mlp"
hooks[f"{mlp_prefix}-routed_experts-gate-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert-wi_0-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert-wi_1-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert-wo-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert_gate-kernel"] = transpose
hooks[f"{mlp_prefix}-routed_experts-wi_0"] = transpose
hooks[f"{mlp_prefix}-routed_experts-wi_1"] = transpose
hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose
return hooks
[docs]
def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Generates a parameter mapping from MaxText to HuggingFace Deepseek weight paths.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names).
Values are Hugging Face parameter names in one of four forms: unscanned (string),
scanned (list of strings), unscanned with expert stacking (list of strings),
or scanned with expert stacking (nested list of strings).
"""
# Extract hf configuration parameters, without mtp
num_main_layers = config["num_hidden_layers"]
first_num_dense_layers = config["first_k_dense_replace"]
num_experts = config.get("n_routed_experts", 0)
# Mapping for non-layer-specific weights
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
}
# Attention keys are shared by both dense and MoE
attention_keys = {
"pre_self_attention_layer_norm-scale": "input_layernorm.weight",
"post_self_attention_layer_norm-scale": "post_attention_layernorm.weight",
"self_attention-kv_norm-scale": "self_attn.kv_a_layernorm.weight",
"self_attention-wkv_a-kernel": "self_attn.kv_a_proj_with_mqa.weight",
"self_attention-wkv_b-kernel": "self_attn.kv_b_proj.weight",
"self_attention-out-kernel": "self_attn.o_proj.weight",
# v2
"self_attention-query-kernel": "self_attn.q_proj.weight",
# v3
"self_attention-q_norm-scale": "self_attn.q_a_layernorm.weight",
"self_attention-wq_a-kernel": "self_attn.q_a_proj.weight",
"self_attention-wq_b-kernel": "self_attn.q_b_proj.weight",
# v3.2
"self_attention-indexer-k_norm-bias": "self_attn.indexer.k_norm.bias",
"self_attention-indexer-k_norm-scale": "self_attn.indexer.k_norm.weight",
"self_attention-indexer-weights_proj-kernel": "self_attn.indexer.weights_proj.weight",
"self_attention-indexer-wk-kernel": "self_attn.indexer.wk.weight",
"self_attention-indexer-wq_b-kernel": "self_attn.indexer.wq_b.weight",
}
# Dense Layers
dense_layer_keys = attention_keys | {
"mlp-wi_0-kernel": "mlp.gate_proj.weight",
"mlp-wi_1-kernel": "mlp.up_proj.weight",
"mlp-wo-kernel": "mlp.down_proj.weight",
}
# MoE Layers
moe_layer_keys = attention_keys | {
"DeepSeekMoeBlock_0-shared_experts-wi_0-kernel": "mlp.shared_experts.gate_proj.weight",
"DeepSeekMoeBlock_0-shared_experts-wi_1-kernel": "mlp.shared_experts.up_proj.weight",
"DeepSeekMoeBlock_0-shared_experts-wo-kernel": "mlp.shared_experts.down_proj.weight",
"DeepSeekMoeBlock_0-MoeBlock_0-gate-kernel": "mlp.gate.weight",
# v3
"DeepSeekMoeBlock_0-MoeBlock_0-gate-bias": "mlp.gate.e_score_correction_bias",
}
# MoE Experts (nested list mapping: [[e0_l0, e0_l1..], [e1_l0, e1_l1..]..])
moe_expert_keys = {
"DeepSeekMoeBlock_0-MoeBlock_0-wi_0": "gate_proj.weight",
"DeepSeekMoeBlock_0-MoeBlock_0-wi_1": "up_proj.weight",
"DeepSeekMoeBlock_0-MoeBlock_0-wo": "down_proj.weight",
}
# scan
if scan_layers:
for maxtext_key, hf_key in dense_layer_keys.items():
mapping[f"params-decoder-dense_layers-{maxtext_key}"] = [
f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers)
]
for maxtext_key, hf_key in moe_layer_keys.items():
mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [
f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers, num_main_layers)
]
for maxtext_key, hf_key in moe_expert_keys.items():
mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [
[f"model.layers.{i}.mlp.experts.{e}.{hf_key}" for i in range(first_num_dense_layers, num_main_layers)]
for e in range(num_experts)
]
# unscan
else:
for i in range(first_num_dense_layers):
for maxtext_key, hf_key in dense_layer_keys.items():
mapping[f"params-decoder-dense_layers_{i}-{maxtext_key}"] = f"model.layers.{i}.{hf_key}"
for i in range(first_num_dense_layers, num_main_layers):
moe_layer_idx = i - first_num_dense_layers
for maxtext_key, hf_key in moe_layer_keys.items():
mapping[f"params-decoder-moe_layers_{moe_layer_idx}-{maxtext_key}"] = f"model.layers.{i}.{hf_key}"
for maxtext_key, hf_key in moe_expert_keys.items():
mapping[f"params-decoder-moe_layers_{moe_layer_idx}-{maxtext_key}"] = [
f"model.layers.{i}.mlp.experts.{e}.{hf_key}" for e in range(num_experts)
]
return mapping
[docs]
def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Deepseek."""
def reshape_kernel(input_tensor, target_shape):
"""Reshapes and transposes kernel weights between MaxText and HF."""
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
num_main_layers = config["num_hidden_layers"]
first_num_dense_layers = config["first_k_dense_replace"]
mapping = {
"params-decoder-logits_dense-kernel": reshape_kernel,
}
attention_need_reshape = {
"self_attention-wkv_a-kernel", # transpose
"self_attention-wkv_b-kernel",
"self_attention-out-kernel",
# v2
"self_attention-query-kernel",
# v3
"self_attention-wq_a-kernel", # transpose
"self_attention-wq_b-kernel",
# v3.2
"self_attention-indexer-weights_proj-kernel", # transpose
"self_attention-indexer-wk-kernel", # transpose
"self_attention-indexer-wq_b-kernel",
}
dense_need_reshape = attention_need_reshape | {
"mlp-wi_0-kernel", # transpose
"mlp-wi_1-kernel", # transpose
"mlp-wo-kernel", # transpose
}
moe_need_reshape = attention_need_reshape | {
"DeepSeekMoeBlock_0-shared_experts-wi_0-kernel", # transpose
"DeepSeekMoeBlock_0-shared_experts-wi_1-kernel", # transpose
"DeepSeekMoeBlock_0-shared_experts-wo-kernel", # transpose
"DeepSeekMoeBlock_0-MoeBlock_0-gate-kernel", # transpose
"DeepSeekMoeBlock_0-MoeBlock_0-wi_0", # transpose
"DeepSeekMoeBlock_0-MoeBlock_0-wi_1", # transpose
"DeepSeekMoeBlock_0-MoeBlock_0-wo", # transpose
}
# scan
if scan_layers:
for key in dense_need_reshape:
mapping[f"params-decoder-dense_layers-{key}"] = reshape_kernel
for key in moe_need_reshape:
mapping[f"params-decoder-moe_layers-{key}"] = reshape_kernel
# unscan
else:
for i in range(first_num_dense_layers):
for key in dense_need_reshape:
mapping[f"params-decoder-dense_layers_{i}-{key}"] = reshape_kernel
for i in range(first_num_dense_layers, num_main_layers):
moe_layer_idx = i - first_num_dense_layers
for key in moe_need_reshape:
mapping[f"params-decoder-moe_layers_{moe_layer_idx}-{key}"] = reshape_kernel
return mapping
[docs]
def DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN():
"""Creates parameter transformation functions for Deepseek."""
return {}
[docs]
def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Generates mapping from MaxText gpt-oss to Hugging Face weight paths.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter) or
`composite_mt_key` (a tuple of MaxText parameters). Values are Hugging Face parameter
names either a single string (unscanned form) or a list of strings (scanned form).
Notes:
- Handles the inhomogeneous scan block structure, based on `inhomogeneous_layer_cycle_interval`
- Handles `composite_mt_key`: multiple MaxText keys map to HF key(s)
- (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj
- (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias
"""
n_layers = config["num_hidden_layers"] # hf config
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
# Base mapping for non-layer parameters (targeting standard HF keys)
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
}
if scan_layers:
# Scan over blocks
for block_idx in range(layer_cycle_interval):
# Identify all original HF layer indices that collapse into this block
hf_indices = range(block_idx, n_layers, layer_cycle_interval)
prefix = f"params-decoder-layers-layers_{block_idx}"
block_mapping = {
# Layer Norms
f"{prefix}-pre_self_attention_layer_norm-scale": [
f"model.layers.{i}.input_layernorm.weight" for i in hf_indices
],
f"{prefix}-post_self_attention_layer_norm-scale": [
f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices
],
# GptOssAttention
f"{prefix}-GptOssAttention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices],
f"{prefix}-GptOssAttention-query-bias": [f"model.layers.{i}.self_attn.q_proj.bias" for i in hf_indices],
f"{prefix}-GptOssAttention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices],
f"{prefix}-GptOssAttention-key-bias": [f"model.layers.{i}.self_attn.k_proj.bias" for i in hf_indices],
f"{prefix}-GptOssAttention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices],
f"{prefix}-GptOssAttention-value-bias": [f"model.layers.{i}.self_attn.v_proj.bias" for i in hf_indices],
f"{prefix}-GptOssAttention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices],
f"{prefix}-GptOssAttention-out-bias": [f"model.layers.{i}.self_attn.o_proj.bias" for i in hf_indices],
f"{prefix}-GptOssAttention-sinks": [f"model.layers.{i}.self_attn.sinks" for i in hf_indices],
# GptOssMlp
# 1. Gate/Router
f"{prefix}-GptOssMlp-gate-kernel": [f"model.layers.{i}.mlp.router.weight" for i in hf_indices],
f"{prefix}-GptOssMlp-gate-bias": [f"model.layers.{i}.mlp.router.bias" for i in hf_indices],
# 2. Experts (Down Projection)
f"{prefix}-GptOssMlp-wo": [f"model.layers.{i}.mlp.experts.down_proj" for i in hf_indices],
f"{prefix}-GptOssMlp-wo_bias": [f"model.layers.{i}.mlp.experts.down_proj_bias" for i in hf_indices],
# 3. Experts (Gate/Up Fused Projection)
# `composite_mt_key`: Multiple MaxText keys map to HF key(s).
(f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1"): [
f"model.layers.{i}.mlp.experts.gate_up_proj" for i in hf_indices
],
(f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias"): [
f"model.layers.{i}.mlp.experts.gate_up_proj_bias" for i in hf_indices
],
}
mapping.update(block_mapping)
else:
# Unscan
for i in range(n_layers):
prefix = f"params-decoder-layers_{i}"
layer_mapping = {
# Layer Norms
f"{prefix}-pre_self_attention_layer_norm-scale": f"model.layers.{i}.input_layernorm.weight",
f"{prefix}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
# GptOssAttention
f"{prefix}-GptOssAttention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight",
f"{prefix}-GptOssAttention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias",
f"{prefix}-GptOssAttention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
f"{prefix}-GptOssAttention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias",
f"{prefix}-GptOssAttention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
f"{prefix}-GptOssAttention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias",
f"{prefix}-GptOssAttention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
f"{prefix}-GptOssAttention-out-bias": f"model.layers.{i}.self_attn.o_proj.bias",
f"{prefix}-GptOssAttention-sinks": f"model.layers.{i}.self_attn.sinks",
# GptOssMlp
# 1. Gate/Router
f"{prefix}-GptOssMlp-gate-kernel": f"model.layers.{i}.mlp.router.weight",
f"{prefix}-GptOssMlp-gate-bias": f"model.layers.{i}.mlp.router.bias",
# 2. Experts (Down Projection)
f"{prefix}-GptOssMlp-wo": f"model.layers.{i}.mlp.experts.down_proj",
f"{prefix}-GptOssMlp-wo_bias": f"model.layers.{i}.mlp.experts.down_proj_bias",
# 3. Experts (Gate/Up Fused Projection)
# `composite_mt_key`: Multiple MaxText keys map to HF key(s).
(f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1"): f"model.layers.{i}.mlp.experts.gate_up_proj",
(
f"{prefix}-GptOssMlp-wi_0_bias",
f"{prefix}-GptOssMlp-wi_1_bias",
): f"model.layers.{i}.mlp.experts.gate_up_proj_bias",
}
mapping.update(layer_mapping)
return mapping
[docs]
def GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Transformation hooks for gpt-oss parameters.
Notes:
- Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval)
- Handles `composite_mt_key` where multiple MaxText keys map to HF key(s)
- (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj
- (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias
- The composite keys are transformed via `interleave` function
"""
def transpose(input_tensor, target_shape=None):
return input_tensor.T
def reshape_kernel(input_tensor, target_shape):
"""Reshapes and transposes kernel weights between MaxText and HF."""
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
def reshape_bias(input_tensor, target_shape=None):
"""Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
if saving_to_hf:
# MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
return input_tensor.reshape(target_shape)
else:
# HF [hidden_dim] -> MaxText [heads, head_dim]
return input_tensor.reshape(target_shape)
def interleave(input_tensor, target_shape=None):
"""
Handles `composite_mt_key`: maxtext (wi_0, wi_1) <-> hf (wi_0_1)
- if saving_to_hf: (wi_0, wi_1) -> wi_0_1
- input_tensor is a list of two tensors, tensor ORDER must be same as key order
- return a single tensor
- otherwise: wi_0_1 -> (wi_0, wi_1)
- input_tensor is a single tensor
- return two tensors stack at LAST index -1, tensor ORDER must be same as key order
"""
if saving_to_hf:
wi_0, wi_1 = input_tensor
wi_0_1 = np.empty(target_shape, dtype=wi_0.dtype)
wi_0_1[..., ::2] = wi_0
wi_0_1[..., 1::2] = wi_1
return wi_0_1
else:
wi_0_1 = input_tensor
wi_0 = wi_0_1[..., ::2]
wi_1 = wi_0_1[..., 1::2]
return np.stack([wi_0, wi_1], axis=-1)
n_layers = config["num_hidden_layers"] # hf config
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
hooks = {"params-decoder-logits_dense-kernel": transpose}
indices = range(layer_cycle_interval) if scan_layers else range(n_layers)
for idx in indices:
prefix = f"params-decoder-layers-layers_{idx}" if scan_layers else f"params-decoder-layers_{idx}"
# Attention Kernels & Biases
for key in ["query", "key", "value"]:
hooks[f"{prefix}-GptOssAttention-{key}-kernel"] = reshape_kernel
hooks[f"{prefix}-GptOssAttention-{key}-bias"] = reshape_bias
hooks[f"{prefix}-GptOssAttention-out-kernel"] = reshape_kernel
# MLP Kernels & Biases
hooks[f"{prefix}-GptOssMlp-gate-kernel"] = transpose
# `composite_mt_key`: A hook for combining multiple MaxText params.
hooks[(f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1")] = interleave
hooks[(f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias")] = interleave
return hooks
[docs]
def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping from MaxText to HuggingFace Qwen3-Omni weight paths.
This function combines mappings from different modalities (text, vision, audio, etc.)
into a unified parameter mapping for the multi-modal Qwen3-Omni model.
Args:
config (dict): Model configuration dictionary containing modality-specific configs.
scan_layers (bool, optional): Whether the model uses scanned layers. Defaults to False.
Returns:
dict: Combined mapping from all modalities.
"""
# Collect all modality mappings
mapping = {}
# Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
maxtext_config=maxtext_config,
scan_layers=scan_layers,
)
# Add "thinker." prefix to text mapping values
def add_prefix_recursive(value):
"""Recursively add 'thinker.' prefix to strings, handling nested lists."""
if isinstance(value, list):
return [add_prefix_recursive(v) for v in value]
else:
return f"thinker.{value}"
for key, value in text_mapping.items():
text_mapping[key] = add_prefix_recursive(value)
mapping.update(text_mapping)
# Vision mapping
vision_config = config["thinker_config"]["vision_config"]
n_vision_layers = vision_config["depth"]
# Vision patch embedding
mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-kernel"] = (
"thinker.visual.patch_embed.proj.weight"
)
mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-bias"] = (
"thinker.visual.patch_embed.proj.bias"
)
# Vision positional embedding
mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-pos_embed_interpolate-pos_embed"] = (
"thinker.visual.pos_embed.weight"
)
# Vision blocks (27 layers)
for i in range(n_vision_layers):
prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-blocks_{i}"
hf_prefix = f"thinker.visual.blocks.{i}"
# Layer norms
mapping[f"{prefix}-ln1-scale"] = f"{hf_prefix}.norm1.weight"
mapping[f"{prefix}-ln1-bias"] = f"{hf_prefix}.norm1.bias"
mapping[f"{prefix}-ln2-scale"] = f"{hf_prefix}.norm2.weight"
mapping[f"{prefix}-ln2-bias"] = f"{hf_prefix}.norm2.bias"
# Attention (HF has fused QKV, MaxText has separate Q/K/V)
# We'll handle the split/fusion in the hook functions
mapping[f"{prefix}-attn-attn-query-kernel"] = f"{hf_prefix}.attn.qkv.weight"
mapping[f"{prefix}-attn-attn-query-bias"] = f"{hf_prefix}.attn.qkv.bias"
mapping[f"{prefix}-attn-attn-key-kernel"] = f"{hf_prefix}.attn.qkv.weight"
mapping[f"{prefix}-attn-attn-key-bias"] = f"{hf_prefix}.attn.qkv.bias"
mapping[f"{prefix}-attn-attn-value-kernel"] = f"{hf_prefix}.attn.qkv.weight"
mapping[f"{prefix}-attn-attn-value-bias"] = f"{hf_prefix}.attn.qkv.bias"
mapping[f"{prefix}-attn-attn-out-kernel"] = f"{hf_prefix}.attn.proj.weight"
mapping[f"{prefix}-attn-attn-out-bias"] = f"{hf_prefix}.attn.proj.bias"
# MLP
mapping[f"{prefix}-mlp-kernel"] = f"{hf_prefix}.mlp.linear_fc1.weight"
mapping[f"{prefix}-mlp-bias"] = f"{hf_prefix}.mlp.linear_fc1.bias"
mapping[f"{prefix}-mlp_out-kernel"] = f"{hf_prefix}.mlp.linear_fc2.weight"
mapping[f"{prefix}-mlp_out-bias"] = f"{hf_prefix}.mlp.linear_fc2.bias"
# Vision merger_list (deep mergers at layers 8, 16, 24)
deepstack_indexes = vision_config.get("deepstack_visual_indexes", [8, 16, 24])
for merger_idx, _ in enumerate(deepstack_indexes):
prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-merger_{merger_idx}"
hf_prefix = f"thinker.visual.merger_list.{merger_idx}"
mapping[f"{prefix}-ln_q-scale"] = f"{hf_prefix}.ln_q.weight"
mapping[f"{prefix}-ln_q-bias"] = f"{hf_prefix}.ln_q.bias"
mapping[f"{prefix}-mlp_0-kernel"] = f"{hf_prefix}.mlp.0.weight"
mapping[f"{prefix}-mlp_0-bias"] = f"{hf_prefix}.mlp.0.bias"
mapping[f"{prefix}-mlp_2-kernel"] = f"{hf_prefix}.mlp.2.weight"
mapping[f"{prefix}-mlp_2-bias"] = f"{hf_prefix}.mlp.2.bias"
# Vision projector (final merger)
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-scale"] = "thinker.visual.merger.ln_q.weight"
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-bias"] = "thinker.visual.merger.ln_q.bias"
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = (
"thinker.visual.merger.mlp.0.weight"
)
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-bias"] = "thinker.visual.merger.mlp.0.bias"
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = (
"thinker.visual.merger.mlp.2.weight"
)
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-bias"] = "thinker.visual.merger.mlp.2.bias"
# Audio mapping
audio_config = config["thinker_config"]["audio_config"]
n_audio_layers = audio_config["encoder_layers"]
# Audio conv layers (3 Conv2D layers for downsampling)
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-kernel"] = "thinker.audio_tower.conv2d1.weight"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-bias"] = "thinker.audio_tower.conv2d1.bias"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-kernel"] = "thinker.audio_tower.conv2d2.weight"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-bias"] = "thinker.audio_tower.conv2d2.bias"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-kernel"] = "thinker.audio_tower.conv2d3.weight"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-bias"] = "thinker.audio_tower.conv2d3.bias"
# Audio conv output projection
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv_out-kernel"] = "thinker.audio_tower.conv_out.weight"
# Audio encoder layers (32 layers)
for i in range(n_audio_layers):
prefix = f"params-audio_encoder-Qwen3OmniAudioEncoder_0-layers_{i}"
hf_prefix = f"thinker.audio_tower.layers.{i}"
# Layer norms
mapping[f"{prefix}-input_layer_norm-scale"] = f"{hf_prefix}.self_attn_layer_norm.weight"
mapping[f"{prefix}-input_layer_norm-bias"] = f"{hf_prefix}.self_attn_layer_norm.bias"
mapping[f"{prefix}-post_attention_layer_norm-scale"] = f"{hf_prefix}.final_layer_norm.weight"
mapping[f"{prefix}-post_attention_layer_norm-bias"] = f"{hf_prefix}.final_layer_norm.bias"
# Attention (separate Q/K/V)
mapping[f"{prefix}-self_attention_audio-query-kernel"] = f"{hf_prefix}.self_attn.q_proj.weight"
mapping[f"{prefix}-self_attention_audio-query-bias"] = f"{hf_prefix}.self_attn.q_proj.bias"
mapping[f"{prefix}-self_attention_audio-key-kernel"] = f"{hf_prefix}.self_attn.k_proj.weight"
mapping[f"{prefix}-self_attention_audio-key-bias"] = f"{hf_prefix}.self_attn.k_proj.bias"
mapping[f"{prefix}-self_attention_audio-value-kernel"] = f"{hf_prefix}.self_attn.v_proj.weight"
mapping[f"{prefix}-self_attention_audio-value-bias"] = f"{hf_prefix}.self_attn.v_proj.bias"
mapping[f"{prefix}-self_attention_audio-out-kernel"] = f"{hf_prefix}.self_attn.out_proj.weight"
mapping[f"{prefix}-self_attention_audio-out-bias"] = f"{hf_prefix}.self_attn.out_proj.bias"
# MLP (AudioMLP has 2 linear layers: fc1 and fc2)
mapping[f"{prefix}-AudioMLP-wi-kernel"] = f"{hf_prefix}.fc1.weight"
mapping[f"{prefix}-AudioMLP-wi-bias"] = f"{hf_prefix}.fc1.bias"
mapping[f"{prefix}-AudioMLP-wo-kernel"] = f"{hf_prefix}.fc2.weight"
mapping[f"{prefix}-AudioMLP-wo-bias"] = f"{hf_prefix}.fc2.bias"
# Audio post layer norm
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-layernorm_post-scale"] = "thinker.audio_tower.ln_post.weight"
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-layernorm_post-bias"] = "thinker.audio_tower.ln_post.bias"
# Audio projector (2 linear layers)
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-kernel"] = "thinker.audio_tower.proj1.weight"
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-bias"] = "thinker.audio_tower.proj1.bias"
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-kernel"] = "thinker.audio_tower.proj2.weight"
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-bias"] = "thinker.audio_tower.proj2.bias"
return mapping
[docs]
def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Qwen3-Omni.
This function provides a dictionary of transformation functions (hooks) for
converting Qwen3-Omni model parameters between MaxText and Hugging Face formats.
It handles embedding padding and kernel reshaping.
Args:
config (dict): Model configuration dictionary, including
'num_hidden_layers' and optionally 'num_experts'.
scan_layers (bool, optional): Whether the model uses scanned layers.
Defaults to False.
saving_to_hf (bool, optional): The direction of conversion. True for
MaxText to Hugging Face, False for the reverse. Defaults to False.
Returns:
dict: A dictionary mapping MaxText parameter names to their corresponding
transformation functions.
"""
# Collect all modality hooks
mapping = {}
# Text hooks, reusing QWEN3-MOE hook function
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
maxtext_config=maxtext_config,
scan_layers=scan_layers,
saving_to_hf=saving_to_hf,
)
mapping.update(text_hooks)
# Vision hooks
vision_config = config["thinker_config"]["vision_config"]
n_vision_layers = vision_config["depth"]
hidden_size = vision_config["hidden_size"]
def reshape_kernel_vision(input_tensor, target_shape):
"""Reshape kernel for vision layers."""
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
def reshape_conv3d_patch_embed(input_tensor, target_shape):
"""Reshape 3D conv patch embedding weight.
HF: (out_channels, in_channels, temporal, height, width)
MaxText: (temporal, height, width, in_channels, out_channels)
"""
if saving_to_hf:
# MaxText -> HF: (T, H, W, C_in, C_out) -> (C_out, C_in, T, H, W)
return input_tensor.transpose(4, 3, 0, 1, 2)
else:
# HF -> MaxText: (C_out, C_in, T, H, W) -> (T, H, W, C_in, C_out)
return input_tensor.transpose(2, 3, 4, 1, 0)
def split_qkv_query(input_tensor, target_shape):
"""Extract Q from fused QKV for HF->MaxText conversion.
HF has fused QKV: (3*hidden_size, hidden_size)
MaxText Q: (hidden_size, num_heads, head_dim)
"""
if saving_to_hf:
# MaxText -> HF: will be handled by fusion hook
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
# HF -> MaxText: Extract Q from fused QKV
# input_tensor shape: (3*hidden_size, hidden_size)
q_weight = input_tensor[:hidden_size, :] # (hidden_size, hidden_size)
return q_weight.T.reshape(target_shape) # (hidden_size, num_heads, head_dim)
def split_qkv_key(input_tensor, target_shape):
"""Extract K from fused QKV for HF->MaxText conversion."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
# Extract K from fused QKV
k_weight = input_tensor[hidden_size : 2 * hidden_size, :]
return k_weight.T.reshape(target_shape)
def split_qkv_value(input_tensor, target_shape):
"""Extract V from fused QKV for HF->MaxText conversion."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
# Extract V from fused QKV
v_weight = input_tensor[2 * hidden_size :, :]
return v_weight.T.reshape(target_shape)
def split_qkv_bias_query(input_tensor, target_shape):
"""Extract Q bias from fused QKV bias."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
q_bias = input_tensor[:hidden_size]
return q_bias.reshape(target_shape) # (num_heads, head_dim)
def split_qkv_bias_key(input_tensor, target_shape):
"""Extract K bias from fused QKV bias."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
k_bias = input_tensor[hidden_size : 2 * hidden_size]
return k_bias.reshape(target_shape)
def split_qkv_bias_value(input_tensor, target_shape):
"""Extract V bias from fused QKV bias."""
if saving_to_hf:
raise NotImplementedError("Use fusion hook for MaxText->HF")
else:
v_bias = input_tensor[2 * hidden_size :]
return v_bias.reshape(target_shape)
def reshape_vision_attn_out(input_tensor, target_shape):
"""Reshape vision attention output projection.
HF: (hidden_size, hidden_size)
MaxText: (num_heads, head_dim, hidden_size)
"""
if saving_to_hf:
# MaxText -> HF: (num_heads, head_dim, hidden_size) -> (hidden_size, hidden_size)
return input_tensor.reshape(hidden_size, hidden_size).T
else:
# HF -> MaxText: (hidden_size, hidden_size) -> (num_heads, head_dim, hidden_size)
return input_tensor.T.reshape(target_shape)
# Vision patch embedding
mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-kernel"] = reshape_conv3d_patch_embed
# Vision blocks
for i in range(n_vision_layers):
prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-blocks_{i}"
# Attention Q/K/V - split from fused QKV
mapping[f"{prefix}-attn-attn-query-kernel"] = split_qkv_query
mapping[f"{prefix}-attn-attn-query-bias"] = split_qkv_bias_query
mapping[f"{prefix}-attn-attn-key-kernel"] = split_qkv_key
mapping[f"{prefix}-attn-attn-key-bias"] = split_qkv_bias_key
mapping[f"{prefix}-attn-attn-value-kernel"] = split_qkv_value
mapping[f"{prefix}-attn-attn-value-bias"] = split_qkv_bias_value
# Attention output
mapping[f"{prefix}-attn-attn-out-kernel"] = reshape_vision_attn_out
# attn-attn-out-bias doesn't need a hook (no reshape needed)
# MLP
mapping[f"{prefix}-mlp-kernel"] = reshape_kernel_vision
mapping[f"{prefix}-mlp_out-kernel"] = reshape_kernel_vision
# Vision merger_list and projector MLPs
deepstack_indexes = vision_config.get("deepstack_visual_indexes", [8, 16, 24])
for merger_idx, _ in enumerate(deepstack_indexes):
prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-merger_{merger_idx}"
mapping[f"{prefix}-mlp_0-kernel"] = reshape_kernel_vision
mapping[f"{prefix}-mlp_2-kernel"] = reshape_kernel_vision
# Vision projector (final merger)
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = reshape_kernel_vision
mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = reshape_kernel_vision
# Audio hooks
audio_config = config["thinker_config"]["audio_config"]
n_audio_layers = audio_config["encoder_layers"]
hidden_size_audio = audio_config["d_model"]
def reshape_kernel_audio(input_tensor, target_shape):
"""Reshape kernel for audio layers."""
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
def reshape_conv2d_audio(input_tensor, target_shape):
"""Reshape Conv2D weight for audio.
HF: (out_channels, in_channels, height, width)
MaxText: (height, width, in_channels, out_channels)
"""
if saving_to_hf:
# MaxText -> HF: (H, W, C_in, C_out) -> (C_out, C_in, H, W)
return input_tensor.transpose(3, 2, 0, 1)
else:
# HF -> MaxText: (C_out, C_in, H, W) -> (H, W, C_in, C_out)
return input_tensor.transpose(2, 3, 1, 0)
def reshape_audio_attn_qkv(input_tensor, target_shape):
"""Reshape audio attention Q/K/V projection.
HF: (hidden_size, hidden_size)
MaxText: (hidden_size, num_heads, head_dim)
"""
if saving_to_hf:
# MaxText -> HF: (hidden_size, num_heads, head_dim) -> (hidden_size, hidden_size)
return input_tensor.reshape(hidden_size_audio, hidden_size_audio).T
else:
# HF -> MaxText: (hidden_size, hidden_size) -> (hidden_size, num_heads, head_dim)
return input_tensor.T.reshape(target_shape)
def reshape_audio_attn_out(input_tensor, target_shape):
"""Reshape audio attention output projection.
HF: (hidden_size, hidden_size)
MaxText: (num_heads, head_dim, hidden_size)
"""
if saving_to_hf:
# MaxText -> HF: (num_heads, head_dim, hidden_size) -> (hidden_size, hidden_size)
return input_tensor.reshape(hidden_size_audio, hidden_size_audio).T
else:
# HF -> MaxText: (hidden_size, hidden_size) -> (num_heads, head_dim, hidden_size)
return input_tensor.T.reshape(target_shape)
def reshape_audio_attn_qkv_bias(input_tensor, target_shape):
"""Reshape audio attention Q/K/V bias.
HF: (hidden_size,)
MaxText: (num_heads, head_dim)
"""
if saving_to_hf:
# MaxText -> HF: (num_heads, head_dim) -> (hidden_size,)
return input_tensor.reshape(hidden_size_audio)
else:
# HF -> MaxText: (hidden_size,) -> (num_heads, head_dim)
return input_tensor.reshape(target_shape)
# Audio conv layers
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-kernel"] = reshape_conv2d_audio
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-kernel"] = reshape_conv2d_audio
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-kernel"] = reshape_conv2d_audio
# Audio conv output projection
mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv_out-kernel"] = reshape_kernel_audio
# Audio encoder layers
for i in range(n_audio_layers):
prefix = f"params-audio_encoder-Qwen3OmniAudioEncoder_0-layers_{i}"
# Attention Q/K/V
mapping[f"{prefix}-self_attention_audio-query-kernel"] = reshape_audio_attn_qkv
mapping[f"{prefix}-self_attention_audio-query-bias"] = reshape_audio_attn_qkv_bias
mapping[f"{prefix}-self_attention_audio-key-kernel"] = reshape_audio_attn_qkv
mapping[f"{prefix}-self_attention_audio-key-bias"] = reshape_audio_attn_qkv_bias
mapping[f"{prefix}-self_attention_audio-value-kernel"] = reshape_audio_attn_qkv
mapping[f"{prefix}-self_attention_audio-value-bias"] = reshape_audio_attn_qkv_bias
# Attention output
mapping[f"{prefix}-self_attention_audio-out-kernel"] = reshape_audio_attn_out
# MLP
mapping[f"{prefix}-AudioMLP-wi-kernel"] = reshape_kernel_audio
mapping[f"{prefix}-AudioMLP-wo-kernel"] = reshape_kernel_audio
# Audio projector
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-kernel"] = reshape_kernel_audio
mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-kernel"] = reshape_kernel_audio
return mapping
[docs]
def QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN(target_shape=None):
"""Creates parameter transformation functions for Qwen3.
This function provides a dictionary of transformation functions (hooks) for
converting Qwen3 model parameters between NNX and vLLM formats.
Returns:
dict: A dictionary mapping NNX parameter names to their corresponding
transformation functions.
"""
return {}
[docs]
def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""
Returns a dictionary mapping from MaxText parameter names to
HuggingFace LLaMA3.1 parameter names.
Args:
config (dict): Model configuration dictionary containing:
- num_hidden_layers (int): The number of decoder layers.
scan_layers (bool, optional): If True, MaxText layers are 'stacked'
into a single param. Defaults to False.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names).
Values are either a single string (unscanned form) or a list of strings
(scanned form) for stacked layers when `scan_layers=True`.
"""
n_layers = config["num_hidden_layers"]
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
}
if scan_layers:
mapping["params-decoder-layers-self_attention-query-kernel"] = [
f"model.layers.{layer_idx}.self_attn.q_proj.weight" for layer_idx in range(n_layers)
]
mapping["params-decoder-layers-self_attention-key-kernel"] = [
f"model.layers.{layer_idx}.self_attn.k_proj.weight" for layer_idx in range(n_layers)
]
mapping["params-decoder-layers-self_attention-value-kernel"] = [
f"model.layers.{layer_idx}.self_attn.v_proj.weight" for layer_idx in range(n_layers)
]
mapping["params-decoder-layers-self_attention-out-kernel"] = [
f"model.layers.{layer_idx}.self_attn.o_proj.weight" for layer_idx in range(n_layers)
]
mapping["params-decoder-layers-mlp-wi_0-kernel"] = [
f"model.layers.{layer_idx}.mlp.gate_proj.weight" for layer_idx in range(n_layers)
]
mapping["params-decoder-layers-mlp-wi_1-kernel"] = [
f"model.layers.{layer_idx}.mlp.up_proj.weight" for layer_idx in range(n_layers)
]
mapping["params-decoder-layers-mlp-wo-kernel"] = [
f"model.layers.{layer_idx}.mlp.down_proj.weight" for layer_idx in range(n_layers)
]
mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"] = [
f"model.layers.{layer_idx}.input_layernorm.weight" for layer_idx in range(n_layers)
]
mapping["params-decoder-layers-post_self_attention_layer_norm-scale"] = [
f"model.layers.{layer_idx}.post_attention_layernorm.weight" for layer_idx in range(n_layers)
]
else:
for layer_idx in range(n_layers):
mapping[f"params-decoder-layers_{layer_idx}-self_attention-query-kernel"] = (
f"model.layers.{layer_idx}.self_attn.q_proj.weight"
)
mapping[f"params-decoder-layers_{layer_idx}-self_attention-key-kernel"] = (
f"model.layers.{layer_idx}.self_attn.k_proj.weight"
)
mapping[f"params-decoder-layers_{layer_idx}-self_attention-value-kernel"] = (
f"model.layers.{layer_idx}.self_attn.v_proj.weight"
)
mapping[f"params-decoder-layers_{layer_idx}-self_attention-out-kernel"] = (
f"model.layers.{layer_idx}.self_attn.o_proj.weight"
)
mapping[f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel"] = f"model.layers.{layer_idx}.mlp.gate_proj.weight"
mapping[f"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel"] = f"model.layers.{layer_idx}.mlp.up_proj.weight"
mapping[f"params-decoder-layers_{layer_idx}-mlp-wo-kernel"] = f"model.layers.{layer_idx}.mlp.down_proj.weight"
mapping[f"params-decoder-layers_{layer_idx}-pre_self_attention_layer_norm-scale"] = (
f"model.layers.{layer_idx}.input_layernorm.weight"
)
mapping[f"params-decoder-layers_{layer_idx}-post_self_attention_layer_norm-scale"] = (
f"model.layers.{layer_idx}.post_attention_layernorm.weight"
)
return mapping
[docs]
def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for converting between MaxText and
HuggingFace formats.
This function generates a mapping of transformation functions that handle the necessary
conversions between MaxText and HuggingFace parameter formats, including operations like
reshaping.
"""
nlayers = config["num_hidden_layers"]
def scale_query_layer(input_tensor, target_shape):
if saving_to_hf:
depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"]))
original_dtype = input_tensor.dtype
output_tensor = input_tensor.astype(np.float32) * depth_scale
return output_tensor.astype(original_dtype)
else:
depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"]))
original_dtype = input_tensor.dtype
output_tensor = input_tensor.astype(np.float32) * depth_scale
return output_tensor.astype(original_dtype)
def adjust_rope(input_tensor, target_shape):
arr = input_tensor
if saving_to_hf:
# Convert from MaxText's interleaved layout to HF's concatenated layout
evens = arr[..., ::2]
odds = arr[..., 1::2]
return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)
else:
# Convert from HF's concatenated layout to MaxText's interleaved layout
half_dim = arr.shape[-1] // 2
first_half = arr[..., :half_dim]
second_half = arr[..., half_dim:]
return jax.numpy.stack([first_half, second_half], axis=-1).reshape(arr.shape)
def reshape_kernel(input_tensor, target_shape):
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).transpose()
else:
return input_tensor.transpose().reshape(target_shape)
# caveat: hook order does affect result
# to_huggingface
query_hook_chain = [scale_query_layer, adjust_rope, reshape_kernel]
key_hook_chain = [adjust_rope, reshape_kernel]
# to_maxtext
if not saving_to_hf:
query_hook_chain.reverse()
key_hook_chain.reverse()
hook_fns = {}
hook_fns["params-decoder-logits_dense-kernel"] = reshape_kernel
if scan_layers:
hook_fns = {
**hook_fns,
"params-decoder-layers-self_attention-query-kernel": query_hook_chain,
"params-decoder-layers-self_attention-key-kernel": key_hook_chain,
"params-decoder-layers-self_attention-value-kernel": reshape_kernel,
"params-decoder-layers-self_attention-out-kernel": reshape_kernel,
"params-decoder-layers-mlp-wi_0-kernel": reshape_kernel,
"params-decoder-layers-mlp-wi_1-kernel": reshape_kernel,
"params-decoder-layers-mlp-wo-kernel": reshape_kernel,
}
else:
for layer_idx in range(nlayers):
hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-query-kernel"] = query_hook_chain
hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-key-kernel"] = key_hook_chain
hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-value-kernel"] = reshape_kernel
hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-out-kernel"] = reshape_kernel
hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel"] = reshape_kernel
hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel"] = reshape_kernel
hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wo-kernel"] = reshape_kernel
return hook_fns
[docs]
def LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN():
"""Defines and returns hook functions for weight transformations.
These hooks are applied to specific weights during the conversion
from MaxText to a HuggingFace-compatible format. They handle
transformations like RoPE reordering and query scaling that are not
simple re-mappings.
Returns:
A dictionary where keys are MaxText parameter names and values are
the corresponding transformation functions.
"""
def reorder_rope(arr):
"""Reorders Rotary Position Embedding (RoPE) weights.
This function is necessary because MaxText and HuggingFace's vLLM
implementations may have different orderings for RoPE dimensions.
It splits the last dimension into even and odd indices and
concatenates them.
Args:
arr: The input weight array.
Returns:
The reordered weight array.
"""
evens = arr[..., ::2]
odds = arr[..., 1::2]
return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)
def transform_query_kernel(arr):
"""Transforms the query kernel.
This involves scaling the kernel by the square root of the head
dimension and then applying RoPE reordering.
Args:
arr: The query kernel weight array.
Returns:
The transformed query kernel array.
"""
head_dim = arr.shape[-1]
depth_scale = np.dtype("float32").type(np.sqrt(head_dim))
arr = arr * depth_scale
return reorder_rope(arr)
hook_fns = {
"base.decoder.layers.self_attention.query.kernel": transform_query_kernel,
"base.decoder.layers.self_attention.key.kernel": reorder_rope,
}
return hook_fns
[docs]
def MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""
Generates the mapping of parameter names from MaxText to Hugging Face for Mixtral.
Returns:
dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values
are Hugging Face parameter names in one of four forms: unscanned string,
scanned list of strings, unscanned with expert stacking (list of strings),
or scanned with expert stacking (nested list of strings).
"""
mapping = {}
# Top-level, non-layer-specific parameters
mapping["params-token_embedder-embedding"] = "model.embed_tokens.weight"
mapping["params-decoder-decoder_norm-scale"] = "model.norm.weight"
mapping["params-decoder-logits_dense-kernel"] = "lm_head.weight"
num_experts = maxtext_config.num_experts
if scan_layers:
# Initialize lists for scanned layer weights
mapping.update(
{
"params-decoder-layers-self_attention-query-kernel": [],
"params-decoder-layers-self_attention-key-kernel": [],
"params-decoder-layers-self_attention-value-kernel": [],
"params-decoder-layers-self_attention-out-kernel": [],
"params-decoder-layers-pre_self_attention_layer_norm-scale": [],
"params-decoder-layers-post_self_attention_layer_norm-scale": [],
"params-decoder-layers-MoeBlock_0-gate-kernel": [],
"params-decoder-layers-MoeBlock_0-wi_0": [],
"params-decoder-layers-MoeBlock_0-wi_1": [],
"params-decoder-layers-MoeBlock_0-wo": [],
}
)
for i in range(config["num_hidden_layers"]):
hf_prefix = f"model.layers.{i}"
# Attention weights
mapping["params-decoder-layers-self_attention-query-kernel"].append(f"{hf_prefix}.self_attn.q_proj.weight")
mapping["params-decoder-layers-self_attention-key-kernel"].append(f"{hf_prefix}.self_attn.k_proj.weight")
mapping["params-decoder-layers-self_attention-value-kernel"].append(f"{hf_prefix}.self_attn.v_proj.weight")
mapping["params-decoder-layers-self_attention-out-kernel"].append(f"{hf_prefix}.self_attn.o_proj.weight")
# RMSNorm weights
mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"].append(f"{hf_prefix}.input_layernorm.weight")
mapping["params-decoder-layers-post_self_attention_layer_norm-scale"].append(
f"{hf_prefix}.post_attention_layernorm.weight"
)
# MoE gate
mapping["params-decoder-layers-MoeBlock_0-gate-kernel"].append(f"{hf_prefix}.block_sparse_moe.gate.weight")
# Outer loop as experts and inner loop as layers to align with logic in _build_multi_axis_stacked_tensor()
for j in range(num_experts):
w1_layers = []
w3_layers = []
w2_layers = []
for i in range(config["num_hidden_layers"]):
hf_prefix = f"model.layers.{i}"
w1_layers.append(f"{hf_prefix}.block_sparse_moe.experts.{j}.w1.weight")
w3_layers.append(f"{hf_prefix}.block_sparse_moe.experts.{j}.w3.weight")
w2_layers.append(f"{hf_prefix}.block_sparse_moe.experts.{j}.w2.weight")
mapping["params-decoder-layers-MoeBlock_0-wi_0"].append(w1_layers)
mapping["params-decoder-layers-MoeBlock_0-wi_1"].append(w3_layers)
mapping["params-decoder-layers-MoeBlock_0-wo"].append(w2_layers)
else:
for i in range(config["num_hidden_layers"]):
maxtext_prefix = f"params-decoder-layers_{i}"
hf_prefix = f"model.layers.{i}"
# Attention weights
mapping[f"{maxtext_prefix}-self_attention-query-kernel"] = f"{hf_prefix}.self_attn.q_proj.weight"
mapping[f"{maxtext_prefix}-self_attention-key-kernel"] = f"{hf_prefix}.self_attn.k_proj.weight"
mapping[f"{maxtext_prefix}-self_attention-value-kernel"] = f"{hf_prefix}.self_attn.v_proj.weight"
mapping[f"{maxtext_prefix}-self_attention-out-kernel"] = f"{hf_prefix}.self_attn.o_proj.weight"
# RMSNorm weights
mapping[f"{maxtext_prefix}-pre_self_attention_layer_norm-scale"] = f"{hf_prefix}.input_layernorm.weight"
mapping[f"{maxtext_prefix}-post_self_attention_layer_norm-scale"] = f"{hf_prefix}.post_attention_layernorm.weight"
# MoE gate
mapping[f"{maxtext_prefix}-MoeBlock_0-gate-kernel"] = f"{hf_prefix}.block_sparse_moe.gate.weight"
# MoE expert weights (1 MaxText param -> 8 HF params)
w1_experts = [f"{hf_prefix}.block_sparse_moe.experts.{j}.w1.weight" for j in range(num_experts)]
w3_experts = [f"{hf_prefix}.block_sparse_moe.experts.{j}.w3.weight" for j in range(num_experts)]
w2_experts = [f"{hf_prefix}.block_sparse_moe.experts.{j}.w2.weight" for j in range(num_experts)]
mapping[f"{maxtext_prefix}-MoeBlock_0-wi_0"] = w1_experts
mapping[f"{maxtext_prefix}-MoeBlock_0-wi_1"] = w3_experts
mapping[f"{maxtext_prefix}-MoeBlock_0-wo"] = w2_experts
return mapping
[docs]
def MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""
Generates parameter conversion hooks for Mixtral between MaxText and Hugging Face.
"""
hooks = {}
def reshape_and_transpose_attention(x, target_shape):
"""MaxText: [hidden, n_heads, h_dim] <-> HF: [n_heads * h_dim, hidden]"""
if saving_to_hf:
# (H, N, D) -> (H, N*D) -> (N*D, H)
return x.reshape(config["hidden_size"], -1).transpose()
else:
# (N*D, H) -> (H, N*D) -> (H, N, D)
return x.transpose().reshape(target_shape)
def reshape_kernel(x, target_shape):
return x.transpose()
def scale_query_layer(input_tensor, target_shape):
if saving_to_hf:
depth_scale = np.dtype("float32").type(np.sqrt(maxtext_config.head_dim))
return (input_tensor * depth_scale).astype(input_tensor.dtype)
else:
depth_scale = np.dtype("float32").type(1 / np.sqrt(maxtext_config.head_dim))
return (input_tensor * depth_scale).astype(input_tensor.dtype)
# hook order does not affect result
query_hook_chain = [reshape_and_transpose_attention, scale_query_layer]
if scan_layers:
plan = [
("params-decoder-layers-self_attention-query-kernel", query_hook_chain),
("params-decoder-layers-self_attention-key-kernel", reshape_and_transpose_attention),
("params-decoder-layers-self_attention-value-kernel", reshape_and_transpose_attention),
("params-decoder-layers-self_attention-out-kernel", reshape_and_transpose_attention),
("params-decoder-layers-MoeBlock_0-wi_0", reshape_kernel),
("params-decoder-layers-MoeBlock_0-wi_1", reshape_kernel),
("params-decoder-layers-MoeBlock_0-wo", reshape_kernel),
("params-decoder-layers-MoeBlock_0-gate-kernel", reshape_kernel),
]
else:
plan = [
("params-decoder-layers_{i}-self_attention-query-kernel", query_hook_chain),
("params-decoder-layers_{i}-self_attention-key-kernel", reshape_and_transpose_attention),
("params-decoder-layers_{i}-self_attention-value-kernel", reshape_and_transpose_attention),
("params-decoder-layers_{i}-self_attention-out-kernel", reshape_and_transpose_attention),
("params-decoder-layers_{i}-MoeBlock_0-wi_0", reshape_kernel),
("params-decoder-layers_{i}-MoeBlock_0-wi_1", reshape_kernel),
("params-decoder-layers_{i}-MoeBlock_0-wo", reshape_kernel),
("params-decoder-layers_{i}-MoeBlock_0-gate-kernel", reshape_kernel),
]
plan.append(("params-decoder-logits_dense-kernel", reshape_kernel))
for maxtext_pattern, op_func in plan:
if "{i}" in maxtext_pattern:
for i in range(config["num_hidden_layers"]):
hooks[maxtext_pattern.format(i=i)] = op_func
else:
hooks[maxtext_pattern] = op_func
return hooks
[docs]
def GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping between MaxText and HuggingFace Gemma4 weight paths."""
tcfg = config.get("text_config", config)
nlayers = tcfg["num_hidden_layers"]
share_kv_projections = maxtext_config.share_kv_projections
# Gemma 4 uses a block pattern of length 6: 5 local, 1 global
vcfg = config.get("vision_config", {})
text_base = "model.language_model" if vcfg else "model"
mapping = {
"params-token_embedder-embedding": f"{text_base}.embed_tokens.weight",
"params-decoder-decoder_norm-scale": f"{text_base}.norm.weight",
}
if vcfg:
nvis = vcfg.get("num_hidden_layers", 0)
mapping.update(
{
"params-vision_encoder-Gemma4VisionEncoderLayer_0-vision_entry-input_projection-kernel": (
"model.vision_tower.patch_embedder.input_proj.weight"
),
"params-vision_encoder-Gemma4VisionEncoderLayer_0-vision_entry-pos_emb_param": (
"model.vision_tower.patch_embedder.position_embedding_table"
),
"params-vision_encoder-Gemma4VisionProjector_0-projection-kernel": (
"model.embed_vision.embedding_projection.weight"
),
"params-vision_encoder-Gemma4VisionEncoderLayer_0-std_scale": ("model.vision_tower.std_scale"),
"params-vision_encoder-Gemma4VisionEncoderLayer_0-std_bias": ("model.vision_tower.std_bias"),
}
)
for i in range(nvis):
prefix = f"params-vision_encoder-Gemma4VisionEncoderLayer_0-layer_{i}"
hf_prefix = f"model.vision_tower.encoder.layers.{i}"
mapping.update(
{
f"{prefix}-attention-query-kernel": f"{hf_prefix}.self_attn.q_proj.linear.weight",
f"{prefix}-attention-key-kernel": f"{hf_prefix}.self_attn.k_proj.linear.weight",
f"{prefix}-attention-value-kernel": f"{hf_prefix}.self_attn.v_proj.linear.weight",
f"{prefix}-attention-out-kernel": f"{hf_prefix}.self_attn.o_proj.linear.weight",
f"{prefix}-attention-query_norm-scale": f"{hf_prefix}.self_attn.q_norm.weight",
f"{prefix}-attention-key_norm-scale": f"{hf_prefix}.self_attn.k_norm.weight",
f"{prefix}-pre_attention_norm-scale": f"{hf_prefix}.input_layernorm.weight",
f"{prefix}-post_attention_norm-scale": f"{hf_prefix}.post_attention_layernorm.weight",
f"{prefix}-pre_ffw_norm-scale": f"{hf_prefix}.pre_feedforward_layernorm.weight",
f"{prefix}-post_ffw_norm-scale": f"{hf_prefix}.post_feedforward_layernorm.weight",
f"{prefix}-mlp-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.linear.weight",
f"{prefix}-mlp-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.linear.weight",
f"{prefix}-mlp-wo-kernel": f"{hf_prefix}.mlp.down_proj.linear.weight",
}
)
if scan_layers:
attention_pattern_length = 6
num_remaining = nlayers % attention_pattern_length
num_scanned = nlayers - num_remaining
num_experts = tcfg.get("num_experts")
num_experts = num_experts if num_experts is not None else 1
# Main scanned blocks
for layer_in_block in range(attention_pattern_length):
hf_indices = list(range(layer_in_block, num_scanned, attention_pattern_length))
prefix = f"params-decoder-scanned_blocks-layers_{layer_in_block}"
mapping.update(
{
f"{prefix}-self_attention-query-kernel": [
f"{text_base}.layers.{i}.self_attn.q_proj.weight" for i in hf_indices
],
f"{prefix}-self_attention-key-kernel": [
f"{text_base}.layers.{i}.self_attn.k_proj.weight" for i in hf_indices
],
f"{prefix}-self_attention-value-kernel": (
None
if share_kv_projections and layer_in_block == 5
else [f"{text_base}.layers.{i}.self_attn.v_proj.weight" for i in hf_indices]
),
f"{prefix}-self_attention-out-kernel": [
f"{text_base}.layers.{i}.self_attn.o_proj.weight" for i in hf_indices
],
f"{prefix}-self_attention-query_norm-scale": [
f"{text_base}.layers.{i}.self_attn.q_norm.weight" for i in hf_indices
],
f"{prefix}-self_attention-key_norm-scale": [
f"{text_base}.layers.{i}.self_attn.k_norm.weight" for i in hf_indices
],
}
)
if maxtext_config.v_norm_with_scale:
mapping.update(
{
f"{prefix}-self_attention-value_norm-scale": [
f"{text_base}.layers.{i}.self_attn.v_norm.weight" for i in hf_indices
]
}
)
mapping.update(
{
f"{prefix}-pre_self_attention_norm-scale": [
f"{text_base}.layers.{i}.input_layernorm.weight" for i in hf_indices
],
f"{prefix}-post_self_attention_norm-scale": [
f"{text_base}.layers.{i}.post_attention_layernorm.weight" for i in hf_indices
],
f"{prefix}-pre_ffw_norm-scale": [
f"{text_base}.layers.{i}.pre_feedforward_layernorm.weight" for i in hf_indices
],
f"{prefix}-post_ffw_norm-scale": [
f"{text_base}.layers.{i}.post_feedforward_layernorm.weight" for i in hf_indices
],
f"{prefix}-mlp-pre_feedforward_layernorm_2-scale": [
f"{text_base}.layers.{i}.pre_feedforward_layernorm_2.weight" if num_experts > 1 else None
for i in hf_indices
],
f"{prefix}-mlp-post_feedforward_layernorm_1-scale": [
f"{text_base}.layers.{i}.post_feedforward_layernorm_1.weight" if num_experts > 1 else None
for i in hf_indices
],
f"{prefix}-mlp-post_feedforward_layernorm_2-scale": [
f"{text_base}.layers.{i}.post_feedforward_layernorm_2.weight" if num_experts > 1 else None
for i in hf_indices
],
f"{prefix}-mlp-pre_forward_scale_2": [
f"{text_base}.layers.{i}.router.scale" if num_experts > 1 else None for i in hf_indices
],
f"{prefix}-mlp-wi_0-kernel": [
f"{text_base}.layers.{i}.mlp.gate_proj.weight" if num_experts == 1 else None for i in hf_indices
],
f"{prefix}-mlp-wi_1-kernel": [
f"{text_base}.layers.{i}.mlp.up_proj.weight" if num_experts == 1 else None for i in hf_indices
],
f"{prefix}-mlp-wo-kernel": [
f"{text_base}.layers.{i}.mlp.down_proj.weight" if num_experts == 1 else None for i in hf_indices
],
f"{prefix}-mlp-moe_block-MoeBlock_0-gate-kernel": [
f"{text_base}.layers.{i}.router.proj.weight" if num_experts > 1 else None for i in hf_indices
],
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": [
f"{text_base}.layers.{i}.experts.gate_up_proj" if num_experts > 1 else None for i in hf_indices
],
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": [
f"{text_base}.layers.{i}.experts.gate_up_proj" if num_experts > 1 else None for i in hf_indices
],
f"{prefix}-mlp-moe_block-MoeBlock_0-wo": [
f"{text_base}.layers.{i}.experts.down_proj" if num_experts > 1 else None for i in hf_indices
],
f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": [
f"{text_base}.layers.{i}.router.per_expert_scale" if num_experts > 1 else None for i in hf_indices
],
f"{prefix}-mlp-moe_block-shared_experts-wi_0-kernel": [
f"{text_base}.layers.{i}.mlp.gate_proj.weight" if num_experts > 1 else None for i in hf_indices
],
f"{prefix}-mlp-moe_block-shared_experts-wi_1-kernel": [
f"{text_base}.layers.{i}.mlp.up_proj.weight" if num_experts > 1 else None for i in hf_indices
],
f"{prefix}-mlp-moe_block-shared_experts-wo-kernel": [
f"{text_base}.layers.{i}.mlp.down_proj.weight" if num_experts > 1 else None for i in hf_indices
],
f"{prefix}-layer_scalar": [f"{text_base}.layers.{i}.layer_scalar" for i in hf_indices],
}
)
mapping = {
k: v
for k, v in mapping.items()
if (isinstance(v, list) and len(v) > 0 and v[0] is not None) or (not isinstance(v, list) and v is not None)
}
# Remainder layers
if num_remaining > 0:
for rem_idx in range(num_remaining):
hf_layer_idx = num_scanned + rem_idx
# Remaining layers use local attention type logic
is_global = False # For gemma 4 it is unlikely the remainder is global but safe to determine
layer_in_block = hf_layer_idx % 6
is_global = layer_in_block == 5
prefix = f"params-decoder-layers_remainder-layers_{rem_idx}"
hf_prefix = f"{text_base}.layers.{hf_layer_idx}"
mapping.update(
{
f"{prefix}-self_attention-query-kernel": (f"{hf_prefix}.self_attn.q_proj.weight"),
f"{prefix}-self_attention-key-kernel": (f"{hf_prefix}.self_attn.k_proj.weight"),
f"{prefix}-self_attention-value-kernel": (
f"{hf_prefix}.self_attn.k_proj.weight"
if share_kv_projections and is_global
else f"{hf_prefix}.self_attn.v_proj.weight"
),
f"{prefix}-self_attention-out-kernel": (f"{hf_prefix}.self_attn.o_proj.weight"),
f"{prefix}-self_attention-query_norm-scale": (f"{hf_prefix}.self_attn.q_norm.weight"),
f"{prefix}-self_attention-key_norm-scale": (f"{hf_prefix}.self_attn.k_norm.weight"),
}
)
if maxtext_config.v_norm_with_scale:
mapping.update({f"{prefix}-self_attention-value_norm-scale": (f"{hf_prefix}.self_attn.v_norm.weight")})
mapping.update(
{
f"{prefix}-pre_self_attention_norm-scale": (f"{hf_prefix}.input_layernorm.weight"),
f"{prefix}-post_self_attention_norm-scale": (f"{hf_prefix}.post_attention_layernorm.weight"),
f"{prefix}-pre_ffw_norm-scale": (f"{hf_prefix}.pre_feedforward_layernorm.weight"),
f"{prefix}-post_ffw_norm-scale": (f"{hf_prefix}.post_feedforward_layernorm.weight"),
f"{prefix}-mlp-pre_feedforward_layernorm_2-scale": (
f"{hf_prefix}.pre_feedforward_layernorm_2.weight" if num_experts > 1 else None
),
f"{prefix}-mlp-post_feedforward_layernorm_1-scale": (
f"{hf_prefix}.post_feedforward_layernorm_1.weight" if num_experts > 1 else None
),
f"{prefix}-mlp-post_feedforward_layernorm_2-scale": (
f"{hf_prefix}.post_feedforward_layernorm_2.weight" if num_experts > 1 else None
),
f"{prefix}-mlp-pre_forward_scale_2": (f"{hf_prefix}.router.scale" if num_experts > 1 else None),
f"{prefix}-mlp-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight" if num_experts == 1 else None,
f"{prefix}-mlp-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight" if num_experts == 1 else None,
f"{prefix}-mlp-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight" if num_experts == 1 else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-gate-kernel": f"{hf_prefix}.router.proj.weight"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": f"{hf_prefix}.experts.gate_up_proj"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": f"{hf_prefix}.experts.gate_up_proj"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-wo": f"{hf_prefix}.experts.down_proj" if num_experts > 1 else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": f"{hf_prefix}.router.per_expert_scale"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-shared_experts-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-shared_experts-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-shared_experts-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight"
if num_experts > 1
else None,
f"{prefix}-layer_scalar": f"{hf_prefix}.layer_scalar",
}
)
mapping = {k: v for k, v in mapping.items() if v is not None}
else:
for i in range(nlayers):
prefix = f"params-decoder-layers_{i}"
hf_prefix = f"{text_base}.layers.{i}"
is_global = i % 6 == 5
num_experts = tcfg.get("num_experts")
num_experts = num_experts if num_experts is not None else 1
mapping.update(
{
f"{prefix}-self_attention-query-kernel": (f"{hf_prefix}.self_attn.q_proj.weight"),
f"{prefix}-self_attention-key-kernel": (f"{hf_prefix}.self_attn.k_proj.weight"),
f"{prefix}-self_attention-value-kernel": (
None if share_kv_projections and is_global else f"{hf_prefix}.self_attn.v_proj.weight"
),
f"{prefix}-self_attention-out-kernel": (f"{hf_prefix}.self_attn.o_proj.weight"),
f"{prefix}-self_attention-query_norm-scale": (f"{hf_prefix}.self_attn.q_norm.weight"),
f"{prefix}-self_attention-key_norm-scale": (f"{hf_prefix}.self_attn.k_norm.weight"),
}
)
if maxtext_config.v_norm_with_scale:
mapping.update({f"{prefix}-self_attention-value_norm-scale": (f"{hf_prefix}.self_attn.v_norm.weight")})
mapping.update(
{
f"{prefix}-pre_self_attention_norm-scale": (f"{hf_prefix}.input_layernorm.weight"),
f"{prefix}-post_self_attention_norm-scale": (f"{hf_prefix}.post_attention_layernorm.weight"),
f"{prefix}-pre_ffw_norm-scale": (f"{hf_prefix}.pre_feedforward_layernorm.weight"),
f"{prefix}-post_ffw_norm-scale": (f"{hf_prefix}.post_feedforward_layernorm.weight"),
f"{prefix}-mlp-pre_feedforward_layernorm_2-scale": (
f"{hf_prefix}.pre_feedforward_layernorm_2.weight" if num_experts > 1 else None
),
f"{prefix}-mlp-post_feedforward_layernorm_1-scale": (
f"{hf_prefix}.post_feedforward_layernorm_1.weight" if num_experts > 1 else None
),
f"{prefix}-mlp-post_feedforward_layernorm_2-scale": (
f"{hf_prefix}.post_feedforward_layernorm_2.weight" if num_experts > 1 else None
),
f"{prefix}-mlp-pre_forward_scale_2": (f"{hf_prefix}.router.scale" if num_experts > 1 else None),
f"{prefix}-mlp-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight" if num_experts == 1 else None,
f"{prefix}-mlp-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight" if num_experts == 1 else None,
f"{prefix}-mlp-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight" if num_experts == 1 else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-gate-kernel": f"{hf_prefix}.router.proj.weight"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": f"{hf_prefix}.experts.gate_up_proj" if num_experts > 1 else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": f"{hf_prefix}.experts.gate_up_proj" if num_experts > 1 else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-wo": f"{hf_prefix}.experts.down_proj" if num_experts > 1 else None,
f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": f"{hf_prefix}.router.per_expert_scale"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-shared_experts-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-shared_experts-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight"
if num_experts > 1
else None,
f"{prefix}-mlp-moe_block-shared_experts-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight"
if num_experts > 1
else None,
f"{prefix}-layer_scalar": f"{hf_prefix}.layer_scalar",
}
)
mapping = {k: v for k, v in mapping.items() if v is not None}
return mapping
[docs]
def GEMMA4_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Gemma4."""
tcfg = config.get("text_config", config)
nlayers = tcfg["num_hidden_layers"]
share_kv_projections = maxtext_config.share_kv_projections
hooks = {}
def pad_hf_embedding_layer(input_tensor, target_shape):
normalizer = np.dtype("float32").type(tcfg["hidden_size"] ** 0.5)
if saving_to_hf:
target_tensor = input_tensor[: target_shape[0], : target_shape[1]]
target_tensor = target_tensor / normalizer
return target_tensor.astype(input_tensor.dtype)
else:
target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)
target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor
target_tensor = target_tensor * normalizer
return target_tensor.astype(input_tensor.dtype)
def reshape_kernel(input_tensor, target_shape):
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
def scale_rmsnorm_layer(input_tensor, target_shape):
# Shift of 1.0 is now folded into Gemma 4 text and vision checkpoint weights
return input_tensor.reshape(target_shape)
def split_moe_wi0(input_tensor, target_shape):
if saving_to_hf:
raise NotImplementedError("Saving to HF for fused gate_up_proj requires custom concat hook.")
# input_tensor: [E, 2*FF, H], target: [E, H, FF]
_, two_FF, _ = input_tensor.shape
FF = two_FF // 2
return input_tensor[:, :FF, :].transpose(0, 2, 1)
def split_moe_wi1(input_tensor, target_shape):
if saving_to_hf:
raise NotImplementedError("Saving to HF for fused gate_up_proj requires custom concat hook.")
_, two_FF, _ = input_tensor.shape
FF = two_FF // 2
return input_tensor[:, FF:, :].transpose(0, 2, 1)
def reshape_moe_wo(input_tensor, target_shape):
# input_tensor: [E, H, FF], target: [E, FF, H]
return input_tensor.transpose(0, 2, 1)
def moe_gate_up_hook(weight_list, target_shape):
# Inverse of split_moe_wi0/wi1: fuse MaxText wi_0, wi_1 → HF experts.gate_up_proj.
# weight_list: [wi_0, wi_1], each [..., H, FF]
# Returns: [..., 2*FF, H]
wi_0 = jnp.asarray(weight_list[0])
wi_1 = jnp.asarray(weight_list[1])
return jnp.swapaxes(jnp.concatenate([wi_0, wi_1], axis=-1), -2, -1)
hooks["params-token_embedder-embedding"] = pad_hf_embedding_layer
hooks["params-decoder-decoder_norm-scale"] = scale_rmsnorm_layer
# REMOVED: logits_dense-kernel hook (handled by logits_via_embedding: True)
kernel_keys = [
"self_attention-query-kernel",
"self_attention-key-kernel",
"self_attention-value-kernel",
"self_attention-out-kernel",
"mlp-wi_0-kernel",
"mlp-wi_1-kernel",
"mlp-wo-kernel",
"mlp-moe_block-shared_experts-wi_0-kernel",
"mlp-moe_block-shared_experts-wi_1-kernel",
"mlp-moe_block-shared_experts-wo-kernel",
]
moe_kernel_keys = [
# `gate-kernel` (router) has shape [feature, num_experts] in MaxText, but [num_experts, feature] in HF
"mlp-moe_block-MoeBlock_0-gate-kernel",
]
norm_keys = [
"self_attention-query_norm-scale",
"self_attention-key_norm-scale",
]
if maxtext_config.v_norm_with_scale:
norm_keys.append("self_attention-value_norm-scale")
norm_keys.extend(
[
"pre_self_attention_norm-scale",
"post_self_attention_norm-scale",
"pre_ffw_norm-scale",
"post_ffw_norm-scale",
]
)
num_experts = tcfg.get("num_experts")
num_experts = num_experts if num_experts is not None else 1
if num_experts > 1:
norm_keys.extend(
[
"mlp-pre_feedforward_layernorm_2-scale",
"mlp-post_feedforward_layernorm_1-scale",
"mlp-post_feedforward_layernorm_2-scale",
]
)
# Note: `pre_forward_scale_2`, `per_expert_scale`, and `layer_scalar`
# are standard tensors being multiplied, not typical RMSNorms. They
# do not need the `scale_rmsnorm_layer` hook, so leaving them out
# of norm_keys means they perfectly default to the identity mapping.
vcfg = config.get("vision_config", {})
if vcfg:
nvis = vcfg.get("num_hidden_layers", 0)
def reshape_vision_patch(x, target_shape):
# HF and MaxText both use (H, W, C) patch flattening now.
# Just transpose between out_features/in_features.
return x.T
def reshape_pos_emb(x, target_shape):
return x.transpose(1, 0, 2)
hooks["params-vision_encoder-Gemma4VisionEncoderLayer_0-vision_entry-input_projection-kernel"] = reshape_vision_patch
hooks["params-vision_encoder-Gemma4VisionEncoderLayer_0-vision_entry-pos_emb_param"] = reshape_pos_emb
hooks["params-vision_encoder-Gemma4VisionProjector_0-projection-kernel"] = reshape_kernel
for i in range(nvis):
prefix = f"params-vision_encoder-Gemma4VisionEncoderLayer_0-layer_{i}"
hooks[f"{prefix}-attention-query-kernel"] = reshape_kernel
hooks[f"{prefix}-attention-key-kernel"] = reshape_kernel
hooks[f"{prefix}-attention-value-kernel"] = reshape_kernel
hooks[f"{prefix}-attention-out-kernel"] = reshape_kernel
hooks[f"{prefix}-attention-query_norm-scale"] = scale_rmsnorm_layer
hooks[f"{prefix}-attention-key_norm-scale"] = scale_rmsnorm_layer
hooks[f"{prefix}-pre_attention_norm-scale"] = scale_rmsnorm_layer
hooks[f"{prefix}-post_attention_norm-scale"] = scale_rmsnorm_layer
hooks[f"{prefix}-pre_ffw_norm-scale"] = scale_rmsnorm_layer
hooks[f"{prefix}-post_ffw_norm-scale"] = scale_rmsnorm_layer
hooks[f"{prefix}-mlp-wi_0-kernel"] = reshape_kernel
hooks[f"{prefix}-mlp-wi_1-kernel"] = reshape_kernel
hooks[f"{prefix}-mlp-wo-kernel"] = reshape_kernel
if scan_layers:
attention_pattern_length = 6
num_remaining = nlayers % attention_pattern_length
# Scanned sub-layer prefixes
for layer_in_block in range(attention_pattern_length):
is_global = layer_in_block % 6 == 5
prefix = f"params-decoder-scanned_blocks-layers_{layer_in_block}"
for key in kernel_keys:
if share_kv_projections and is_global and key == "self_attention-value-kernel":
continue
hooks[f"{prefix}-{key}"] = reshape_kernel
for key in moe_kernel_keys:
hooks[f"{prefix}-{key}"] = reshape_kernel
for key in norm_keys:
hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer
# Add these specialized 3D tensor hooks inside the loop
if saving_to_hf and num_experts > 1:
wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"
wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"
hooks[(wi0_key, wi1_key)] = moe_gate_up_hook
else:
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo
# Remainder sub-layer prefixes
if num_remaining > 0:
for rem_idx in range(num_remaining):
prefix = f"params-decoder-layers_remainder-layers_{rem_idx}"
real_layer_idx = attention_pattern_length * (nlayers // attention_pattern_length) + rem_idx
is_global = real_layer_idx % 6 == 5
for key in kernel_keys:
if share_kv_projections and is_global and key == "self_attention-value-kernel":
continue
hooks[f"{prefix}-{key}"] = reshape_kernel
for key in moe_kernel_keys:
hooks[f"{prefix}-{key}"] = reshape_kernel
for key in norm_keys:
hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer
# Add these specialized 3D tensor hooks inside the loop
if saving_to_hf and num_experts > 1:
wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"
wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"
hooks[(wi0_key, wi1_key)] = moe_gate_up_hook
else:
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo
else:
for i in range(nlayers):
is_global = i % 6 == 5
prefix = f"params-decoder-layers_{i}"
for key in kernel_keys:
if share_kv_projections and is_global and key == "self_attention-value-kernel":
continue
hooks[f"{prefix}-{key}"] = reshape_kernel
for key in moe_kernel_keys:
hooks[f"{prefix}-{key}"] = reshape_kernel
for key in norm_keys:
hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer
# Add these specialized 3D tensor hooks inside the loop
if saving_to_hf and num_experts > 1:
wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"
wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"
hooks[(wi0_key, wi1_key)] = moe_gate_up_hook
else:
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo
return hooks
[docs]
def OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping from MaxText to HuggingFace Olmo3 weight paths."""
# Olmo3 uses an inhomogeneous layer cycle (4 layers: 3 sliding, 1 global).
# MaxText handles this by defining sub-layers (layers_0, layers_1...) within a block.
n_layers = config["num_hidden_layers"]
# Default Olmo3 cycle length is 4 if not specified in config
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
# Base mapping for embeddings and global norms
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
}
if scan_layers:
# Scanned: Map MaxText 'layers_k' to HF layers [k, k+cycle, k+2*cycle, ...]
for cycle_idx in range(layer_cycle_interval):
hf_indices = range(cycle_idx, n_layers, layer_cycle_interval)
prefix = f"params-decoder-layers-layers_{cycle_idx}"
mapping.update(
{
# Attention Projections
f"{prefix}-attention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices],
f"{prefix}-attention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices],
f"{prefix}-attention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices],
f"{prefix}-attention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices],
# QK Norms (Olmo3 Specific)
f"{prefix}-attention-query_norm-scale": [f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices],
f"{prefix}-attention-key_norm-scale": [f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices],
# MLP (wi_0=gate, wi_1=up, wo=down)
f"{prefix}-mlp-wi_0-kernel": [f"model.layers.{i}.mlp.gate_proj.weight" for i in hf_indices],
f"{prefix}-mlp-wi_1-kernel": [f"model.layers.{i}.mlp.up_proj.weight" for i in hf_indices],
f"{prefix}-mlp-wo-kernel": [f"model.layers.{i}.mlp.down_proj.weight" for i in hf_indices],
# Layer Norms
f"{prefix}-post_self_attention_layer_norm-scale": [
f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices
],
f"{prefix}-post_mlp_layer_norm-scale": [
f"model.layers.{i}.post_feedforward_layernorm.weight" for i in hf_indices
],
}
)
else:
# Unscanned: Direct 1-to-1 mapping
for i in range(n_layers):
prefix = f"params-decoder-layers_{i}"
hf_prefix = f"model.layers.{i}"
mapping.update(
{
f"{prefix}-attention-query-kernel": f"{hf_prefix}.self_attn.q_proj.weight",
f"{prefix}-attention-key-kernel": f"{hf_prefix}.self_attn.k_proj.weight",
f"{prefix}-attention-value-kernel": f"{hf_prefix}.self_attn.v_proj.weight",
f"{prefix}-attention-out-kernel": f"{hf_prefix}.self_attn.o_proj.weight",
f"{prefix}-attention-query_norm-scale": f"{hf_prefix}.self_attn.q_norm.weight",
f"{prefix}-attention-key_norm-scale": f"{hf_prefix}.self_attn.k_norm.weight",
f"{prefix}-mlp-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight",
f"{prefix}-mlp-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight",
f"{prefix}-mlp-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight",
f"{prefix}-post_self_attention_layer_norm-scale": f"{hf_prefix}.post_attention_layernorm.weight",
f"{prefix}-post_mlp_layer_norm-scale": f"{hf_prefix}.post_feedforward_layernorm.weight",
}
)
return mapping
[docs]
def OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Olmo3."""
# Standard Transpose for Kernels (HF: [Out, In] <-> MaxText: [In, Out])
def reshape_kernel(input_tensor, target_shape):
if saving_to_hf:
flipped_target_shape = np.flip(np.array(target_shape))
return input_tensor.reshape(flipped_target_shape).T
else:
return input_tensor.T.reshape(target_shape)
# Identity mapping for Norms
# Olmo3 checkpoints typically have weights ~1.0.
# If MaxText RMSNorm adds 1.0 (x * (1+w)), we might need w-1.0.
# However, if weights are zeroed/mismatched, identity is safer to restore logic flow.
def scale_rmsnorm_layer(input_tensor, target_shape):
return input_tensor.reshape(target_shape)
# Identity mapping for QK Norms (assuming MaxText attentions.py was patched to use global norm)
def adapt_olmo3_qk_norm(input_tensor, target_shape):
return input_tensor.reshape(target_shape)
# Padding for Embedding layer (Vocab size adjustments)
def pad_hf_embedding_layer(input_tensor, target_shape):
source_vocab_size = input_tensor.shape[0]
target_vocab_size = target_shape[0]
if source_vocab_size == target_vocab_size:
return input_tensor
if saving_to_hf:
return input_tensor[:target_vocab_size, :]
else:
padded_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)
padded_tensor[:source_vocab_size, :] = input_tensor
return padded_tensor
hooks = {
"params-token_embedder-embedding": pad_hf_embedding_layer,
"params-decoder-logits_dense-kernel": reshape_kernel,
"params-decoder-decoder_norm-scale": scale_rmsnorm_layer,
}
kernel_keys = [
"attention-query-kernel",
"attention-key-kernel",
"attention-value-kernel",
"attention-out-kernel",
"mlp-wi_0-kernel",
"mlp-wi_1-kernel",
"mlp-wo-kernel",
]
norm_keys = [
"post_self_attention_layer_norm-scale",
"post_mlp_layer_norm-scale",
"attention-query_norm-scale",
"attention-key_norm-scale",
]
cycle_len = getattr(maxtext_config, "inhomogeneous_layer_cycle_interval", 4)
n_layers = config["num_hidden_layers"]
if scan_layers:
for cycle_idx in range(cycle_len):
prefix = f"params-decoder-layers-layers_{cycle_idx}"
for key in kernel_keys:
hooks[f"{prefix}-{key}"] = reshape_kernel
for key in norm_keys:
# For QK norm, we use the specific adaptor (which is currently Identity
# but separates logic if we need to revert to averaging later)
if "attention-" in key and "_norm-scale" in key:
hooks[f"{prefix}-{key}"] = adapt_olmo3_qk_norm
else:
hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer
else:
for i in range(n_layers):
prefix = f"params-decoder-layers_{i}"
for key in kernel_keys:
hooks[f"{prefix}-{key}"] = reshape_kernel
for key in norm_keys:
if "attention-" in key and "_norm-scale" in key:
hooks[f"{prefix}-{key}"] = adapt_olmo3_qk_norm
else:
hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer
return hooks
# {maxtext model name: {maxtext weight name: hf weight name}}
PARAM_MAPPING = {
"gemma2-2b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma2-9b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma2-27b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma4-26b": GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma4-31b": GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"deepseek2-16b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING,
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING,
"deepseek3.2-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING,
"gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
"gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING,
"mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING,
"mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING,
"olmo3-7b": OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING,
"olmo3-7b-pt": OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING,
"olmo3-32b": OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING,
}
# {maxtext model name: {maxtext weight name: bi-directional transform}}
HOOK_FNS = {
"gemma2-2b": GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma2-9b": GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma2-27b": GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma4-26b": GEMMA4_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma4-31b": GEMMA4_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"deepseek2-16b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"deepseek3.2-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
"gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"olmo3-7b": OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"olmo3-7b-pt": OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"olmo3-32b": OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
}
VLLM_HOOK_FNS = {
"qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN,
"llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN,
"deepseek3": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN,
}