Source code for maxtext.checkpoint_conversion.utils.param_mapping

#  Copyright 2023–2026 Google LLC
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""Parameter mappings and transformation hooks for checkpoint conversion.

This module defines the necessary components to convert model checkpoints between
MaxText and Hugging Face formats for various architectures (e.g., Gemma, Qwen).
It provides two key types of mappings for each model:

1.  **Parameter Name Mappings (`PARAM_MAPPING`)**: Dictionaries that map a MaxText
    parameter key to its corresponding Hugging Face parameter(s). These mappings are
    generated by functions like `GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING`.

    **Key: MaxText parameters, with following forms:**
    - `atomic_mt_key`: A single string representing one MaxText parameter.
    - `composite_mt_key`: A tuple of strings representing multiple MaxText parameters. (e.g., GPT-OSS)

    **Value: corresponding Hugging Face parameters, with following forms:**
    - `unscanned`: A single string.
    - `scanned`: A list of strings, to be stacked along the layer axis.
    - `unscanned with expert stacking`: A list of strings, to be stacked along the expert axis.
    - `scanned with expert stacking`: A nested list of strings, to be stacked along both layer and expert axes.
    Note: Expert stacking only applies a subset of MoE models (e.g., Qwen MoE, DeepSeek, Mixtral),
      but not others (e.g., GPT-OSS).

2.  **Hook Functions (`HOOK_FNS`)**: Dictionaries that map a MaxText parameter
    name to a specific transformation function (a "hook"). These hooks handle
    the actual value conversion, which can include operations like reshaping,
    transposing, scaling, or padding tensors to match the target format's
    requirements. These are generated by functions like
    `GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN`.

The main conversion script uses these mappings to systematically transform each
parameter from the source checkpoint and build the target checkpoint.
"""

import warnings
import numpy as np

import jax
import jax.numpy as jnp


[docs] def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Generates a parameter mapping from MaxText to Hugging Face for Gemma3. This function creates a dictionary that maps the parameter names from a MaxText Gemma3 checkpoint to their corresponding names in the Hugging Face `Gemma3ForCausalLM` format. It handles both the text and vision components of the model. Args: config (dict): The Hugging Face model configuration dictionary, which must contain 'text_config' and 'vision_config' sub-dictionaries. scan_layers (bool, optional): If True, generates mappings for scanned layers, where multiple layers are stacked into a single tensor. If False, generates mappings for individual, unscanned layers. Defaults to False. Returns: dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values are either a single Hugging Face parameter name (unscanned form) or a list of Hugging Face parameter names (scanned form) for stacked text layers. """ tcfg = config["text_config"] vcfg = config["vision_config"] Ndec = tcfg["num_hidden_layers"] Nvis = vcfg["num_hidden_layers"] # pylint: disable=line-too-long mapping = { # Embedding & final norm "params-token_embedder-embedding": "model.language_model.embed_tokens.weight", "params-decoder-decoder_norm-scale": "model.language_model.norm.weight", # Vision embed & pos "params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-kernel": "model.vision_tower.vision_model.embeddings.patch_embedding.weight", "params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-bias": "model.vision_tower.vision_model.embeddings.patch_embedding.bias", "params-vision_encoder-Gemma3VisionEncoderLayer_0-pos_embedding": "model.vision_tower.vision_model.embeddings.position_embedding.weight", "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoder_norm-scale": "model.vision_tower.vision_model.post_layernorm.weight", "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoder_norm-bias": "model.vision_tower.vision_model.post_layernorm.bias", # Multi-modal projector "params-vision_encoder-VisionEmbedder_0-mm_input_projection-w": "model.multi_modal_projector.mm_input_projection_weight", "params-vision_encoder-VisionEmbedder_0-mm_soft_embedding_norm-scale": "model.multi_modal_projector.mm_soft_emb_norm.weight", } vision_params = [ ("LayerNorm_0-scale", "layer_norm1.weight"), ("LayerNorm_0-bias", "layer_norm1.bias"), ("LayerNorm_1-scale", "layer_norm2.weight"), ("LayerNorm_1-bias", "layer_norm2.bias"), ("MultiHeadDotProductAttention_0-query-kernel", "self_attn.q_proj.weight"), ("MultiHeadDotProductAttention_0-query-bias", "self_attn.q_proj.bias"), ("MultiHeadDotProductAttention_0-key-kernel", "self_attn.k_proj.weight"), ("MultiHeadDotProductAttention_0-key-bias", "self_attn.k_proj.bias"), ("MultiHeadDotProductAttention_0-value-kernel", "self_attn.v_proj.weight"), ("MultiHeadDotProductAttention_0-value-bias", "self_attn.v_proj.bias"), ("MultiHeadDotProductAttention_0-out-kernel", "self_attn.out_proj.weight"), ("MultiHeadDotProductAttention_0-out-bias", "self_attn.out_proj.bias"), ("MlpBlockViT_0-Dense_0-kernel", "mlp.fc1.weight"), ("MlpBlockViT_0-Dense_0-bias", "mlp.fc1.bias"), ("MlpBlockViT_0-Dense_1-kernel", "mlp.fc2.weight"), ("MlpBlockViT_0-Dense_1-bias", "mlp.fc2.bias"), ] # Vision layers mapping for i in range(Nvis): for mx, hf in vision_params: key = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-{mx}" mapping[key] = f"model.vision_tower.vision_model.encoder.layers.{i}.{hf}" # Text decoder mapping text_params = [ ("pre_self_attention_norm-scale", "input_layernorm.weight"), ("post_self_attention_norm-scale", "post_attention_layernorm.weight"), ("self_attention-query_norm-scale", "self_attn.q_norm.weight"), ("self_attention-key_norm-scale", "self_attn.k_norm.weight"), ("pre_ffw_norm-scale", "pre_feedforward_layernorm.weight"), ("post_ffw_norm-scale", "post_feedforward_layernorm.weight"), ("self_attention-query-kernel", "self_attn.q_proj.weight"), ("self_attention-key-kernel", "self_attn.k_proj.weight"), ("self_attention-value-kernel", "self_attn.v_proj.weight"), ("self_attention-out-kernel", "self_attn.o_proj.weight"), ("mlp-wi_0-kernel", "mlp.gate_proj.weight"), ("mlp-wi_1-kernel", "mlp.up_proj.weight"), ("mlp-wo-kernel", "mlp.down_proj.weight"), ] if scan_layers: # Gemma3 repeats a 6-layer attention pattern (5 local + 1 global), # scanned as layers_0..layers_5 with leftovers in layers_remainder. attention_pattern_length = 6 num_remaining = Ndec % attention_pattern_length num_scanned = Ndec - num_remaining # Main scanned blocks: params-decoder-layers-layers_{block_idx}-{param} for block_idx in range(attention_pattern_length): hf_indices = list(range(block_idx, num_scanned, attention_pattern_length)) for mx, hf in text_params: key = f"params-decoder-layers-layers_{block_idx}-{mx}" mapping[key] = [f"model.language_model.layers.{i}.{hf}" for i in hf_indices] # Remainder layers (unscanned): params-decoder-layers_remainder-layers_{rem_idx}-{param} if num_remaining > 0: for rem_idx in range(num_remaining): hf_layer_idx = num_scanned + rem_idx for mx, hf in text_params: key = f"params-decoder-layers_remainder-layers_{rem_idx}-{mx}" mapping[key] = f"model.language_model.layers.{hf_layer_idx}.{hf}" else: for i in range(Ndec): for mx, hf in text_params: key = f"params-decoder-layers_{i}-{mx}" mapping[key] = f"model.language_model.layers.{i}.{hf}" return mapping
[docs] def GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Hook functions for Gemma3 parameter conversion. This function provides a dictionary of transformation functions (hooks) for converting Gemma3 model parameters between MaxText and Hugging Face formats. It handles embedding padding/scaling, RMSNorm scaling, kernel reshaping, and vision-specific tensor manipulations. Args: config (dict): The Hugging Face model configuration dictionary. scan_layers (bool, optional): Whether the model uses scanned layers. Defaults to False. saving_to_hf (bool, optional): The direction of conversion. True for MaxText to Hugging Face, False for the reverse. Defaults to False. Returns: dict: A dictionary mapping MaxText parameter names to their corresponding transformation functions. """ hooks = {} # ---- Embedding pad & scale ---- def pad_and_scale_embedding(input_tensor, target_shape): source_vocab_size, _ = input_tensor.shape target_vocab_size, target_hidden_size = target_shape # MaxText embedding = original_embedding * sqrt(hidden_size) # HF embedding = original_embedding (HF model forward pass applies scaling) # Note: config["hidden_size"] is the HF hidden size from the HF config object normalizer = np.dtype("bfloat16").type(config["text_config"]["hidden_size"] ** 0.5) # Apply scaling first if saving_to_hf: # MaxText to HF scaled_tensor = (input_tensor / normalizer).astype(input_tensor.dtype) else: # HF to MaxText scaled_tensor = (input_tensor * normalizer).astype(input_tensor.dtype) # Handle padding/truncation if source_vocab_size > target_vocab_size: warnings.warn( f"source vocab={source_vocab_size} > target vocab={target_vocab_size}, truncate output layer for MaxText." ) output_tensor = scaled_tensor[:target_vocab_size, :] elif source_vocab_size < target_vocab_size: warnings.warn(f"source vocab={source_vocab_size} < target vocab={target_vocab_size}, pad output layer for MaxText.") padding_shape = (target_vocab_size - source_vocab_size, target_hidden_size) # Use jnp.zeros for JAX arrays, np.zeros for numpy arrays padding = ( jnp.zeros(padding_shape, dtype=scaled_tensor.dtype) if isinstance(scaled_tensor, jax.Array) else np.zeros(padding_shape, dtype=scaled_tensor.dtype) ) output_tensor = ( jnp.concatenate([scaled_tensor, padding], axis=0) if isinstance(scaled_tensor, jax.Array) else np.concatenate([scaled_tensor, padding], axis=0) ) else: # Vocab sizes match output_tensor = scaled_tensor return output_tensor # ---- RMSNorm scale ---- def scale_rmsnorm(x, target_shape): # MaxText norm = HF norm +1; HF norm = MaxText norm -1 if saving_to_hf: return (x - 1.0).reshape(target_shape) return (x + 1.0).reshape(target_shape) # ---- Generic reshape ---- def reshape_kernel(x, target_shape): if saving_to_hf: flipped = np.flip(np.array(target_shape)) return x.reshape(flipped).T else: return x.T.reshape(target_shape) # ---- Vision reshape ---- def vis_bias(x, target_shape): if saving_to_hf: return x.flatten() else: return x.reshape(target_shape) def vision_patch(x, target_shape): if saving_to_hf: return x.transpose(3, 2, 0, 1) else: return x.transpose(2, 3, 1, 0) def pos_embed(x, target_shape): if saving_to_hf: return x.squeeze(0) return x[None, :, :] # ---Embedding & final norm--- hooks["params-token_embedder-embedding"] = pad_and_scale_embedding hooks["params-decoder-decoder_norm-scale"] = scale_rmsnorm # [1, 4096, 1152] hooks["params-vision_encoder-Gemma3VisionEncoderLayer_0-embedding-kernel"] = vision_patch hooks["params-vision_encoder-Gemma3VisionEncoderLayer_0-pos_embedding"] = pos_embed hooks["params-vision_encoder-VisionEmbedder_0-mm_input_projection-w"] = lambda x, _: x hooks["params-vision_encoder-VisionEmbedder_0-mm_soft_embedding_norm-scale"] = scale_rmsnorm # Text layers tc = config.get("text_config", {}) nlayers = tc.get("num_hidden_layers", 0) if scan_layers: attention_pattern_length = 6 num_remaining = nlayers % attention_pattern_length # Scanned sub-layer prefixes prefixes = [f"params-decoder-layers-layers_{block_idx}-" for block_idx in range(attention_pattern_length)] # Remainder sub-layer prefixes if num_remaining > 0: prefixes += [f"params-decoder-layers_remainder-layers_{rem_idx}-" for rem_idx in range(num_remaining)] else: prefixes = [f"params-decoder-layers_{i}-" for i in range(nlayers)] for pref in prefixes: # Attention Q/K/V/O hooks[pref + "self_attention-query-kernel"] = reshape_kernel hooks[pref + "self_attention-key-kernel"] = reshape_kernel hooks[pref + "self_attention-value-kernel"] = reshape_kernel hooks[pref + "self_attention-out-kernel"] = reshape_kernel # Norm scales for nm in [ "pre_self_attention_norm-scale", "post_self_attention_norm-scale", "self_attention-query_norm-scale", "self_attention-key_norm-scale", "pre_ffw_norm-scale", "post_ffw_norm-scale", ]: hooks[pref + nm] = scale_rmsnorm # MLP hooks[pref + "mlp-wi_0-kernel"] = reshape_kernel hooks[pref + "mlp-wi_1-kernel"] = reshape_kernel hooks[pref + "mlp-wo-kernel"] = reshape_kernel # Vision layers vc = config.get("vision_config", {}) nvis = vc.get("num_hidden_layers", 0) for i in range(nvis): base = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-" # Attention kernels & biases for qkv in ["query", "key", "value"]: hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-kernel"] = reshape_kernel hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-bias"] = vis_bias # [1152, 1152] -> [16, 72, 1152] hooks[base + "MultiHeadDotProductAttention_0-out-kernel"] = reshape_kernel hooks[base + "MultiHeadDotProductAttention_0-out-bias"] = vis_bias # MLP ViT kernels & biases for dense in ["Dense_0", "Dense_1"]: hooks[base + f"MlpBlockViT_0-{dense}-kernel"] = reshape_kernel return hooks
[docs] def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Returns mapping between MaxText and HuggingFace Gemma2 weight paths. Args: config (dict): Model configuration dictionary containing at least 'num_hidden_layers'. scan_layers (bool, optional): Whether the MaxText model uses layer scanning optimization. When True, decoder layers are stacked into a single tensor. Defaults to False. Returns: dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter name). Values are either a single string (unscanned form) or a list of strings (scanned form) for stacked layers when `scan_layers=True`. Notes: - MaxText uses a paired layer approach where two HF decoder layers are treated as one MaxText decoder layer. - MaxText layer `i` corresponds to HF layers `2i` and `2i+1`. - Local components map to even-numbered HF decoder layers (0, 2, 4...). - Global components map to odd-numbered HF decoder layers (1, 3, 5...). """ nlayers = config["num_hidden_layers"] mapping = { "params-token_embedder-embedding": "model.embed_tokens.weight", "params-decoder-decoder_norm-scale": "model.norm.weight", } if scan_layers: mapping = { **mapping, "params-decoder-layers-pre_self_attention_norm_global-scale": [ f"model.layers.{i}.input_layernorm.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-mlp_global-wo-kernel": [ f"model.layers.{i}.mlp.down_proj.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-mlp_global-wi_1-kernel": [ f"model.layers.{i}.mlp.up_proj.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-mlp_global-wi_0-kernel": [ f"model.layers.{i}.mlp.gate_proj.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-post_self_attention_norm_global-scale": [ f"model.layers.{i}.post_attention_layernorm.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-post_ffw_norm_global-scale": [ f"model.layers.{i}.post_feedforward_layernorm.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-pre_ffw_norm_global-scale": [ f"model.layers.{i}.pre_feedforward_layernorm.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-self_attention_global-key-kernel": [ f"model.layers.{i}.self_attn.k_proj.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-self_attention_global-out-kernel": [ f"model.layers.{i}.self_attn.o_proj.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-self_attention_global-query-kernel": [ f"model.layers.{i}.self_attn.q_proj.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-self_attention_global-value-kernel": [ f"model.layers.{i}.self_attn.v_proj.weight" for i in range(1, nlayers, 2) ], "params-decoder-layers-pre_self_attention_norm_local-scale": [ f"model.layers.{i}.input_layernorm.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-mlp_local-wo-kernel": [ f"model.layers.{i}.mlp.down_proj.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-mlp_local-wi_1-kernel": [ f"model.layers.{i}.mlp.up_proj.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-mlp_local-wi_0-kernel": [ f"model.layers.{i}.mlp.gate_proj.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-post_self_attention_norm_local-scale": [ f"model.layers.{i}.post_attention_layernorm.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-post_ffw_norm_local-scale": [ f"model.layers.{i}.post_feedforward_layernorm.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-pre_ffw_norm_local-scale": [ f"model.layers.{i}.pre_feedforward_layernorm.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-self_attention_local-key-kernel": [ f"model.layers.{i}.self_attn.k_proj.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-self_attention_local-out-kernel": [ f"model.layers.{i}.self_attn.o_proj.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-self_attention_local-query-kernel": [ f"model.layers.{i}.self_attn.q_proj.weight" for i in range(0, nlayers, 2) ], "params-decoder-layers-self_attention_local-value-kernel": [ f"model.layers.{i}.self_attn.v_proj.weight" for i in range(0, nlayers, 2) ], } # Case 2: scan_layer=False else: for maxtext_layer_idx in range(0, nlayers // 2): local_layer_idx = maxtext_layer_idx * 2 global_layer_idx = maxtext_layer_idx * 2 + 1 # pylint: disable=line-too-long layer_mapping = { f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": f"model.layers.{global_layer_idx}.input_layernorm.weight", f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wo-kernel": f"model.layers.{global_layer_idx}.mlp.down_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_1-kernel": f"model.layers.{global_layer_idx}.mlp.up_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_0-kernel": f"model.layers.{global_layer_idx}.mlp.gate_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_global-scale": f"model.layers.{global_layer_idx}.post_attention_layernorm.weight", f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_global-scale": f"model.layers.{global_layer_idx}.post_feedforward_layernorm.weight", f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_global-scale": f"model.layers.{global_layer_idx}.pre_feedforward_layernorm.weight", f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": f"model.layers.{global_layer_idx}.self_attn.k_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-out-kernel": f"model.layers.{global_layer_idx}.self_attn.o_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": f"model.layers.{global_layer_idx}.self_attn.q_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-value-kernel": f"model.layers.{global_layer_idx}.self_attn.v_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_local-scale": f"model.layers.{local_layer_idx}.input_layernorm.weight", f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wo-kernel": f"model.layers.{local_layer_idx}.mlp.down_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_1-kernel": f"model.layers.{local_layer_idx}.mlp.up_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_0-kernel": f"model.layers.{local_layer_idx}.mlp.gate_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_local-scale": f"model.layers.{local_layer_idx}.post_attention_layernorm.weight", f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_local-scale": f"model.layers.{local_layer_idx}.post_feedforward_layernorm.weight", f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_local-scale": f"model.layers.{local_layer_idx}.pre_feedforward_layernorm.weight", f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-key-kernel": f"model.layers.{local_layer_idx}.self_attn.k_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-out-kernel": f"model.layers.{local_layer_idx}.self_attn.o_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": f"model.layers.{local_layer_idx}.self_attn.q_proj.weight", f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-value-kernel": f"model.layers.{local_layer_idx}.self_attn.v_proj.weight", } mapping = {**mapping, **layer_mapping} return mapping
[docs] def GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Creates parameter transformation functions for Gemma2 conversion. This function generates a mapping of transformation functions that handle the necessary conversions between MaxText and HuggingFace parameter formats for Gemma2, including operations like padding, reshaping, and scaling. Args: config (dict): Model configuration dictionary that must contain: - num_hidden_layers (int): Number of layers in the model. - head_dim (int): Dimension of attention heads. - hidden_size (int): Model's hidden dimension size. scan_layers (bool, optional): Controls the output format for layer parameters. True for batched, False for individual. Defaults to False. saving_to_hf (bool, optional): Determines the direction of transformation. True for MaxText to HuggingFace, False for the reverse. Defaults to False. Returns: dict: A mapping from MaxText parameter names to transformation functions. The value can be a single function or a list of functions to be applied sequentially. """ nlayers = config["num_hidden_layers"] def pad_hf_embedding_layer(input_tensor, target_shape): """Pads/unpads and scales the embedding layer. Note: HF embedding weights shape = [256000, d_model] MaxText embedding weights shape = [256128, d_model] MaxText pads Gemma2 embedding to 256128 for better performance. """ # TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype normalizer = np.dtype("float32").type(config["hidden_size"] ** 0.5) if saving_to_hf: target_tensor = input_tensor[: target_shape[0], : target_shape[1]] target_tensor = target_tensor / normalizer target_tensor = target_tensor.astype(input_tensor.dtype) return target_tensor else: target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype) target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor target_tensor = target_tensor * normalizer target_tensor = target_tensor.astype(input_tensor.dtype) return target_tensor def reshape_kernel(input_tensor, target_shape): if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) def scale_rmsnorm_layer(input_tensor, target_shape): if saving_to_hf: return (input_tensor - 1.0).reshape(target_shape) else: return (input_tensor + 1.0).reshape(target_shape) def scale_query_layer(input_tensor, target_shape): if saving_to_hf: depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"])) return (input_tensor * depth_scale).astype(input_tensor.dtype) else: depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"])) return (input_tensor * depth_scale).astype(input_tensor.dtype) # hook order does not affect result query_hook_chain = [reshape_kernel, scale_query_layer] mapping = { "params-token_embedder-embedding": pad_hf_embedding_layer, "params-decoder-decoder_norm-scale": scale_rmsnorm_layer, } if scan_layers: mapping = { **mapping, "params-decoder-layers-self_attention_global-query-kernel": query_hook_chain, "params-decoder-layers-self_attention_local-query-kernel": query_hook_chain, "params-decoder-layers-self_attention_global-key-kernel": reshape_kernel, "params-decoder-layers-self_attention_local-key-kernel": reshape_kernel, "params-decoder-layers-self_attention_global-value-kernel": reshape_kernel, "params-decoder-layers-self_attention_local-value-kernel": reshape_kernel, "params-decoder-layers-mlp_global-wo-kernel": reshape_kernel, "params-decoder-layers-mlp_global-wi_1-kernel": reshape_kernel, "params-decoder-layers-mlp_global-wi_0-kernel": reshape_kernel, "params-decoder-layers-self_attention_global-out-kernel": reshape_kernel, "params-decoder-layers-mlp_local-wo-kernel": reshape_kernel, "params-decoder-layers-mlp_local-wi_1-kernel": reshape_kernel, "params-decoder-layers-mlp_local-wi_0-kernel": reshape_kernel, "params-decoder-layers-self_attention_local-out-kernel": reshape_kernel, "params-decoder-layers-pre_self_attention_norm_global-scale": scale_rmsnorm_layer, "params-decoder-layers-post_self_attention_norm_global-scale": scale_rmsnorm_layer, "params-decoder-layers-post_ffw_norm_global-scale": scale_rmsnorm_layer, "params-decoder-layers-pre_ffw_norm_global-scale": scale_rmsnorm_layer, "params-decoder-layers-pre_self_attention_norm_local-scale": scale_rmsnorm_layer, "params-decoder-layers-post_self_attention_norm_local-scale": scale_rmsnorm_layer, "params-decoder-layers-post_ffw_norm_local-scale": scale_rmsnorm_layer, "params-decoder-layers-pre_ffw_norm_local-scale": scale_rmsnorm_layer, } else: for maxtext_layer_idx in range(nlayers // 2): mapping = { **mapping, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": query_hook_chain, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": query_hook_chain, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-key-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-value-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-value-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wo-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_1-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-mlp_global-wi_0-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-out-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wo-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_1-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-mlp_local-wi_0-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-out-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": scale_rmsnorm_layer, f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_global-scale": scale_rmsnorm_layer, f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_global-scale": scale_rmsnorm_layer, f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_global-scale": scale_rmsnorm_layer, f"params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_local-scale": scale_rmsnorm_layer, f"params-decoder-layers_{maxtext_layer_idx}-post_self_attention_norm_local-scale": scale_rmsnorm_layer, f"params-decoder-layers_{maxtext_layer_idx}-post_ffw_norm_local-scale": scale_rmsnorm_layer, f"params-decoder-layers_{maxtext_layer_idx}-pre_ffw_norm_local-scale": scale_rmsnorm_layer, } return mapping
[docs] def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Returns mapping from MaxText to HuggingFace Qwen weight paths. This function generates a dictionary that maps parameter names from a MaxText Qwen checkpoint to their corresponding names in the Hugging Face format. It handles both dense and Mixture-of-Experts (MoE) model variants. Args: config (dict): Model configuration dictionary, including 'num_hidden_layers' and optionally 'num_experts'. scan_layers (bool, optional): Whether the MaxText model uses scanned layers. Defaults to False. Returns: dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values are Hugging Face parameter names in one of four forms: unscanned (string), scanned (list of strings), unscanned with expert stacking (list of strings), or scanned with expert stacking (nested list of strings). """ n_layers = config["num_hidden_layers"] num_experts = config.get("num_experts", 0) mapping = { "params-token_embedder-embedding": "model.embed_tokens.weight", "params-decoder-decoder_norm-scale": "model.norm.weight", "params-decoder-logits_dense-kernel": "lm_head.weight", } if scan_layers: # This block handles scanned layers for both dense and MoE models. mapping.update( { "params-decoder-layers-pre_self_attention_layer_norm-scale": [ f"model.layers.{i}.input_layernorm.weight" for i in range(n_layers) ], "params-decoder-layers-self_attention-query-kernel": [ f"model.layers.{i}.self_attn.q_proj.weight" for i in range(n_layers) ], "params-decoder-layers-self_attention-key-kernel": [ f"model.layers.{i}.self_attn.k_proj.weight" for i in range(n_layers) ], "params-decoder-layers-self_attention-value-kernel": [ f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers) ], "params-decoder-layers-self_attention-query-bias": [ f"model.layers.{i}.self_attn.q_proj.bias" for i in range(n_layers) ], "params-decoder-layers-self_attention-key-bias": [ f"model.layers.{i}.self_attn.k_proj.bias" for i in range(n_layers) ], "params-decoder-layers-self_attention-value-bias": [ f"model.layers.{i}.self_attn.v_proj.bias" for i in range(n_layers) ], "params-decoder-layers-self_attention-out-kernel": [ f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers) ], "params-decoder-layers-self_attention-query_norm-scale": [ f"model.layers.{i}.self_attn.q_norm.weight" for i in range(n_layers) ], "params-decoder-layers-self_attention-key_norm-scale": [ f"model.layers.{i}.self_attn.k_norm.weight" for i in range(n_layers) ], "params-decoder-layers-post_self_attention_layer_norm-scale": [ f"model.layers.{i}.post_attention_layernorm.weight" for i in range(n_layers) ], } ) if num_experts > 1: # For scanned MoE, we create a nested list: [[e0_l0, e0_l1..], [e1_l0, e1_l1..]..] # This follows the (experts, layers, ...) tensor layout. mapping.update( { "params-decoder-layers-moe_block-gate-kernel": [ f"model.layers.{i}.mlp.gate.weight" for i in range(n_layers) ], "params-decoder-layers-moe_block-wi_0": [ [f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight" for l in range(n_layers)] for e in range(num_experts) ], "params-decoder-layers-moe_block-wi_1": [ [f"model.layers.{l}.mlp.experts.{e}.up_proj.weight" for l in range(n_layers)] for e in range(num_experts) ], "params-decoder-layers-moe_block-wo": [ [f"model.layers.{l}.mlp.experts.{e}.down_proj.weight" for l in range(n_layers)] for e in range(num_experts) ], } ) else: # Dense MLP mapping.update( { "params-decoder-layers-mlp-wi_0-kernel": [ f"model.layers.{i}.mlp.gate_proj.weight" for i in range(n_layers) ], "params-decoder-layers-mlp-wi_1-kernel": [f"model.layers.{i}.mlp.up_proj.weight" for i in range(n_layers)], "params-decoder-layers-mlp-wo-kernel": [f"model.layers.{i}.mlp.down_proj.weight" for i in range(n_layers)], } ) else: # unscanned layers for i in range(n_layers): # Common Attention and Norms # pylint: disable=line-too-long mapping.update( { f"params-decoder-layers_{i}-pre_self_attention_layer_norm-scale": f"model.layers.{i}.input_layernorm.weight", f"params-decoder-layers_{i}-self_attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight", f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", f"params-decoder-layers_{i}-self_attention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias", f"params-decoder-layers_{i}-self_attention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias", f"params-decoder-layers_{i}-self_attention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias", f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight", f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight", f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight", f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight", } ) if num_experts > 1: # For each unscanned MoE layer, map the MaxText parameter to a 1D list of all expert weights for that layer. mapping.update( { f"params-decoder-layers_{i}-moe_block-gate-kernel": f"model.layers.{i}.mlp.gate.weight", f"params-decoder-layers_{i}-moe_block-wi_0": [ f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight" for j in range(num_experts) ], f"params-decoder-layers_{i}-moe_block-wi_1": [ f"model.layers.{i}.mlp.experts.{j}.up_proj.weight" for j in range(num_experts) ], f"params-decoder-layers_{i}-moe_block-wo": [ f"model.layers.{i}.mlp.experts.{j}.down_proj.weight" for j in range(num_experts) ], } ) else: # Dense MLP mapping.update( { f"params-decoder-layers_{i}-mlp-wi_0-kernel": f"model.layers.{i}.mlp.gate_proj.weight", f"params-decoder-layers_{i}-mlp-wi_1-kernel": f"model.layers.{i}.mlp.up_proj.weight", f"params-decoder-layers_{i}-mlp-wo-kernel": f"model.layers.{i}.mlp.down_proj.weight", } ) return mapping
[docs] def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Creates parameter transformation functions for Qwen. This function provides a dictionary of transformation functions (hooks) for converting Qwen model parameters between MaxText and Hugging Face formats. It handles embedding padding and kernel reshaping. Args: config (dict): Model configuration dictionary, including 'num_hidden_layers' and optionally 'num_experts'. scan_layers (bool, optional): Whether the model uses scanned layers. Defaults to False. saving_to_hf (bool, optional): The direction of conversion. True for MaxText to Hugging Face, False for the reverse. Defaults to False. Returns: dict: A dictionary mapping MaxText parameter names to their corresponding transformation functions. """ n_layers = config["num_hidden_layers"] num_experts = config.get("num_experts", 0) def pad_embedding_layer(input_tensor, target_shape): """Pads or truncates embedding layer to match target vocab size.""" source_vocab_size = input_tensor.shape[0] target_vocab_size = target_shape[0] if source_vocab_size == target_vocab_size: return input_tensor if saving_to_hf: # MaxText to HF, truncate return input_tensor[:target_vocab_size, :] else: # HF to MaxText, pad padded_tensor = np.zeros(target_shape, dtype=input_tensor.dtype) padded_tensor[:source_vocab_size, :] = input_tensor return padded_tensor def reshape_kernel(input_tensor, target_shape): """Reshapes and transposes kernel weights between MaxText and HF.""" if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) def reshape_bias(input_tensor, target_shape=None): """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden).""" # saving_to_hf: MaxText [heads, head_dim] -> HF [hidden_dim] (flatten) # loading_to_maxtext: HF [hidden_dim] -> MaxText [heads, head_dim] return input_tensor.reshape(target_shape) mapping = { "params-token_embedder-embedding": pad_embedding_layer, "params-decoder-logits_dense-kernel": reshape_kernel, } kernel_hooks = [ "self_attention-query-kernel", "self_attention-key-kernel", "self_attention-value-kernel", "self_attention-out-kernel", "mlp-wi_0-kernel", "mlp-wi_1-kernel", "mlp-wo-kernel", ] bias_hooks = [ "self_attention-query-bias", "self_attention-key-bias", "self_attention-value-bias", ] moe_kernel_hooks = [ "moe_block-gate-kernel", "moe_block-wi_0-kernel", "moe_block-wi_1-kernel", "moe_block-wo-kernel", "moe_block-wi_0", "moe_block-wi_1", "moe_block-wo", ] if scan_layers: for key in kernel_hooks: mapping[f"params-decoder-layers-{key}"] = reshape_kernel for key in bias_hooks: mapping[f"params-decoder-layers-{key}"] = reshape_bias if num_experts > 1: for key in moe_kernel_hooks: mapping[f"params-decoder-layers-{key}"] = reshape_kernel else: for i in range(n_layers): for key in kernel_hooks: mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel for key in bias_hooks: mapping[f"params-decoder-layers_{i}-{key}"] = reshape_bias if num_experts > 1: for key in moe_kernel_hooks: mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel return mapping
[docs] def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """ Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths. All MaxText keys start with 'params-' and use '-' separators for scanned layers. """ num_main_layers = config["num_hidden_layers"] num_experts = config["num_experts"] layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval # 1. Non-layer specific weight mappings mapping = { "params-token_embedder-embedding": "model.embed_tokens.weight", "params-decoder-decoder_norm-scale": "model.norm.weight", "params-decoder-logits_dense-kernel": "lm_head.weight", } if scan_layers: # 2. Scan over block cycles for block_idx in range(layer_cycle_interval): hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval)) prefix = f"params-decoder-layers-layer_{block_idx}" # Layer norms mapping[f"{prefix}-input_layernorm-scale"] = [f"model.layers.{i}.input_layernorm.weight" for i in hf_indices] mapping[f"{prefix}-post_attention_layernorm-scale"] = [ f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices ] # Handle Interleaved Attention (Linear vs Full) is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 if is_full_attention_layer: mapping.update( { f"{prefix}-attention-attention-query-kernel": [ f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices ], f"{prefix}-attention-attention-key-kernel": [ f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices ], f"{prefix}-attention-attention-value-kernel": [ f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices ], f"{prefix}-attention-attention-out-kernel": [ f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices ], f"{prefix}-attention-attention-query_norm-scale": [ f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices ], f"{prefix}-attention-attention-key_norm-scale": [ f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices ], } ) else: # Linear/Hybrid Attention Block mapping.update( { f"{prefix}-attention-in_proj_qkvz-kernel": [ f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices ], f"{prefix}-attention-in_proj_ba-kernel": [ f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices ], f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices], f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices], f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices], f"{prefix}-attention-norm-rms_norm-scale": [ f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices ], f"{prefix}-attention-out_proj-kernel": [ f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices ], } ) # 3. Handle MLP: Gates and Shared Experts mapping.update( { f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices], f"{prefix}-mlp-shared_expert-wi_0-kernel": [ f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices ], f"{prefix}-mlp-shared_expert-wi_1-kernel": [ f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices ], f"{prefix}-mlp-shared_expert-wo-kernel": [ f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices ], f"{prefix}-mlp-shared_expert_gate-kernel": [ f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices ], } ) # 4. Handle MoE Routed Experts mapping.update( { f"{prefix}-mlp-routed_experts-wi_0": [ [f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts) ], f"{prefix}-mlp-routed_experts-wi_1": [ [f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts) ], f"{prefix}-mlp-routed_experts-wo": [ [f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts) ], } ) else: # Unscanned layer mapping for i in range(num_main_layers): prefix = f"params-decoder-layers_{i}" # Layer Norms mapping[f"{prefix}-input_layernorm-scale"] = f"model.layers.{i}.input_layernorm.weight" mapping[f"{prefix}-post_attention_layernorm-scale"] = f"model.layers.{i}.post_attention_layernorm.weight" # Determine layer type based on cycle interval # Assuming block logic: layer i corresponds to block_idx = i % interval block_idx = i % layer_cycle_interval is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 if is_full_attention_layer: mapping.update( { f"{prefix}-attention-attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight", f"{prefix}-attention-attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", f"{prefix}-attention-attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", f"{prefix}-attention-attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", f"{prefix}-attention-attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight", f"{prefix}-attention-attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight", } ) else: # Linear/Hybrid Attention Block mapping.update( { f"{prefix}-attention-in_proj_qkvz-kernel": f"model.layers.{i}.linear_attn.in_proj_qkvz.weight", f"{prefix}-attention-in_proj_ba-kernel": f"model.layers.{i}.linear_attn.in_proj_ba.weight", f"{prefix}-attention-conv1d-kernel": f"model.layers.{i}.linear_attn.conv1d.weight", f"{prefix}-attention-A_log": f"model.layers.{i}.linear_attn.A_log", f"{prefix}-attention-dt_bias": f"model.layers.{i}.linear_attn.dt_bias", f"{prefix}-attention-norm-rms_norm-scale": f"model.layers.{i}.linear_attn.norm.weight", f"{prefix}-attention-out_proj-kernel": f"model.layers.{i}.linear_attn.out_proj.weight", } ) # MLP: Gates and Shared Experts mapping.update( { f"{prefix}-mlp-routed_experts-gate-kernel": f"model.layers.{i}.mlp.gate.weight", f"{prefix}-mlp-shared_expert-wi_0-kernel": f"model.layers.{i}.mlp.shared_expert.gate_proj.weight", f"{prefix}-mlp-shared_expert-wi_1-kernel": f"model.layers.{i}.mlp.shared_expert.up_proj.weight", f"{prefix}-mlp-shared_expert-wo-kernel": f"model.layers.{i}.mlp.shared_expert.down_proj.weight", f"{prefix}-mlp-shared_expert_gate-kernel": f"model.layers.{i}.mlp.shared_expert_gate.weight", } ) # MoE Routed Experts (List of expert weights for this specific layer) mapping.update( { f"{prefix}-mlp-routed_experts-wi_0": [ f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for e in range(num_experts) ], f"{prefix}-mlp-routed_experts-wi_1": [ f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for e in range(num_experts) ], f"{prefix}-mlp-routed_experts-wo": [ f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for e in range(num_experts) ], } ) return mapping
[docs] def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """ Transformation hooks for parameters using hyphenated 'params-' MaxText keys. """ def transpose(input_tensor, target_shape=None): return input_tensor.T def reshape_kernel(input_tensor, target_shape): if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) def permute_conv(input_tensor, target_shape=None): # MT: [K, 1, C] <-> HF: [C, 1, K] return input_tensor.transpose(2, 1, 0) # Initialize Hooks hooks = { "params-decoder-logits_dense-kernel": transpose, } layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval num_main_layers = config["num_hidden_layers"] loop_indices = range(layer_cycle_interval) if scan_layers else range(num_main_layers) for i in loop_indices: if scan_layers: prefix = f"params-decoder-layers-layer_{i}" block_idx = i else: prefix = f"params-decoder-layers_{i}" block_idx = i % layer_cycle_interval is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 if is_full_attention_layer: for key in ["query", "key", "value", "out"]: hooks[f"{prefix}-attention-attention-{key}-kernel"] = reshape_kernel else: hooks[f"{prefix}-attention-in_proj_qkvz-kernel"] = transpose hooks[f"{prefix}-attention-in_proj_ba-kernel"] = transpose hooks[f"{prefix}-attention-out_proj-kernel"] = transpose hooks[f"{prefix}-attention-conv1d-kernel"] = permute_conv mlp_prefix = f"{prefix}-mlp" hooks[f"{mlp_prefix}-routed_experts-gate-kernel"] = transpose hooks[f"{mlp_prefix}-shared_expert-wi_0-kernel"] = transpose hooks[f"{mlp_prefix}-shared_expert-wi_1-kernel"] = transpose hooks[f"{mlp_prefix}-shared_expert-wo-kernel"] = transpose hooks[f"{mlp_prefix}-shared_expert_gate-kernel"] = transpose hooks[f"{mlp_prefix}-routed_experts-wi_0"] = transpose hooks[f"{mlp_prefix}-routed_experts-wi_1"] = transpose hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose return hooks
[docs] def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Generates a parameter mapping from MaxText to HuggingFace Deepseek weight paths. Returns: dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values are Hugging Face parameter names in one of four forms: unscanned (string), scanned (list of strings), unscanned with expert stacking (list of strings), or scanned with expert stacking (nested list of strings). """ # Extract hf configuration parameters, without mtp num_main_layers = config["num_hidden_layers"] first_num_dense_layers = config["first_k_dense_replace"] num_experts = config.get("n_routed_experts", 0) # Mapping for non-layer-specific weights mapping = { "params-token_embedder-embedding": "model.embed_tokens.weight", "params-decoder-decoder_norm-scale": "model.norm.weight", "params-decoder-logits_dense-kernel": "lm_head.weight", } # Attention keys are shared by both dense and MoE attention_keys = { "pre_self_attention_layer_norm-scale": "input_layernorm.weight", "post_self_attention_layer_norm-scale": "post_attention_layernorm.weight", "self_attention-kv_norm-scale": "self_attn.kv_a_layernorm.weight", "self_attention-wkv_a-kernel": "self_attn.kv_a_proj_with_mqa.weight", "self_attention-wkv_b-kernel": "self_attn.kv_b_proj.weight", "self_attention-out-kernel": "self_attn.o_proj.weight", # v2 "self_attention-query-kernel": "self_attn.q_proj.weight", # v3 "self_attention-q_norm-scale": "self_attn.q_a_layernorm.weight", "self_attention-wq_a-kernel": "self_attn.q_a_proj.weight", "self_attention-wq_b-kernel": "self_attn.q_b_proj.weight", # v3.2 "self_attention-indexer-k_norm-bias": "self_attn.indexer.k_norm.bias", "self_attention-indexer-k_norm-scale": "self_attn.indexer.k_norm.weight", "self_attention-indexer-weights_proj-kernel": "self_attn.indexer.weights_proj.weight", "self_attention-indexer-wk-kernel": "self_attn.indexer.wk.weight", "self_attention-indexer-wq_b-kernel": "self_attn.indexer.wq_b.weight", } # Dense Layers dense_layer_keys = attention_keys | { "mlp-wi_0-kernel": "mlp.gate_proj.weight", "mlp-wi_1-kernel": "mlp.up_proj.weight", "mlp-wo-kernel": "mlp.down_proj.weight", } # MoE Layers moe_layer_keys = attention_keys | { "DeepSeekMoeBlock_0-shared_experts-wi_0-kernel": "mlp.shared_experts.gate_proj.weight", "DeepSeekMoeBlock_0-shared_experts-wi_1-kernel": "mlp.shared_experts.up_proj.weight", "DeepSeekMoeBlock_0-shared_experts-wo-kernel": "mlp.shared_experts.down_proj.weight", "DeepSeekMoeBlock_0-MoeBlock_0-gate-kernel": "mlp.gate.weight", # v3 "DeepSeekMoeBlock_0-MoeBlock_0-gate-bias": "mlp.gate.e_score_correction_bias", } # MoE Experts (nested list mapping: [[e0_l0, e0_l1..], [e1_l0, e1_l1..]..]) moe_expert_keys = { "DeepSeekMoeBlock_0-MoeBlock_0-wi_0": "gate_proj.weight", "DeepSeekMoeBlock_0-MoeBlock_0-wi_1": "up_proj.weight", "DeepSeekMoeBlock_0-MoeBlock_0-wo": "down_proj.weight", } # scan if scan_layers: for maxtext_key, hf_key in dense_layer_keys.items(): mapping[f"params-decoder-dense_layers-{maxtext_key}"] = [ f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers) ] for maxtext_key, hf_key in moe_layer_keys.items(): mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [ f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers, num_main_layers) ] for maxtext_key, hf_key in moe_expert_keys.items(): mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [ [f"model.layers.{i}.mlp.experts.{e}.{hf_key}" for i in range(first_num_dense_layers, num_main_layers)] for e in range(num_experts) ] # unscan else: for i in range(first_num_dense_layers): for maxtext_key, hf_key in dense_layer_keys.items(): mapping[f"params-decoder-dense_layers_{i}-{maxtext_key}"] = f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers, num_main_layers): moe_layer_idx = i - first_num_dense_layers for maxtext_key, hf_key in moe_layer_keys.items(): mapping[f"params-decoder-moe_layers_{moe_layer_idx}-{maxtext_key}"] = f"model.layers.{i}.{hf_key}" for maxtext_key, hf_key in moe_expert_keys.items(): mapping[f"params-decoder-moe_layers_{moe_layer_idx}-{maxtext_key}"] = [ f"model.layers.{i}.mlp.experts.{e}.{hf_key}" for e in range(num_experts) ] return mapping
[docs] def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Creates parameter transformation functions for Deepseek.""" def reshape_kernel(input_tensor, target_shape): """Reshapes and transposes kernel weights between MaxText and HF.""" if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) num_main_layers = config["num_hidden_layers"] first_num_dense_layers = config["first_k_dense_replace"] mapping = { "params-decoder-logits_dense-kernel": reshape_kernel, } attention_need_reshape = { "self_attention-wkv_a-kernel", # transpose "self_attention-wkv_b-kernel", "self_attention-out-kernel", # v2 "self_attention-query-kernel", # v3 "self_attention-wq_a-kernel", # transpose "self_attention-wq_b-kernel", # v3.2 "self_attention-indexer-weights_proj-kernel", # transpose "self_attention-indexer-wk-kernel", # transpose "self_attention-indexer-wq_b-kernel", } dense_need_reshape = attention_need_reshape | { "mlp-wi_0-kernel", # transpose "mlp-wi_1-kernel", # transpose "mlp-wo-kernel", # transpose } moe_need_reshape = attention_need_reshape | { "DeepSeekMoeBlock_0-shared_experts-wi_0-kernel", # transpose "DeepSeekMoeBlock_0-shared_experts-wi_1-kernel", # transpose "DeepSeekMoeBlock_0-shared_experts-wo-kernel", # transpose "DeepSeekMoeBlock_0-MoeBlock_0-gate-kernel", # transpose "DeepSeekMoeBlock_0-MoeBlock_0-wi_0", # transpose "DeepSeekMoeBlock_0-MoeBlock_0-wi_1", # transpose "DeepSeekMoeBlock_0-MoeBlock_0-wo", # transpose } # scan if scan_layers: for key in dense_need_reshape: mapping[f"params-decoder-dense_layers-{key}"] = reshape_kernel for key in moe_need_reshape: mapping[f"params-decoder-moe_layers-{key}"] = reshape_kernel # unscan else: for i in range(first_num_dense_layers): for key in dense_need_reshape: mapping[f"params-decoder-dense_layers_{i}-{key}"] = reshape_kernel for i in range(first_num_dense_layers, num_main_layers): moe_layer_idx = i - first_num_dense_layers for key in moe_need_reshape: mapping[f"params-decoder-moe_layers_{moe_layer_idx}-{key}"] = reshape_kernel return mapping
[docs] def DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN(): """Creates parameter transformation functions for Deepseek.""" return {}
[docs] def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Generates mapping from MaxText gpt-oss to Hugging Face weight paths. Returns: dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter) or `composite_mt_key` (a tuple of MaxText parameters). Values are Hugging Face parameter names either a single string (unscanned form) or a list of strings (scanned form). Notes: - Handles the inhomogeneous scan block structure, based on `inhomogeneous_layer_cycle_interval` - Handles `composite_mt_key`: multiple MaxText keys map to HF key(s) - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj - (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias """ n_layers = config["num_hidden_layers"] # hf config layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval # Base mapping for non-layer parameters (targeting standard HF keys) mapping = { "params-token_embedder-embedding": "model.embed_tokens.weight", "params-decoder-decoder_norm-scale": "model.norm.weight", "params-decoder-logits_dense-kernel": "lm_head.weight", } if scan_layers: # Scan over blocks for block_idx in range(layer_cycle_interval): # Identify all original HF layer indices that collapse into this block hf_indices = range(block_idx, n_layers, layer_cycle_interval) prefix = f"params-decoder-layers-layers_{block_idx}" block_mapping = { # Layer Norms f"{prefix}-pre_self_attention_layer_norm-scale": [ f"model.layers.{i}.input_layernorm.weight" for i in hf_indices ], f"{prefix}-post_self_attention_layer_norm-scale": [ f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices ], # GptOssAttention f"{prefix}-GptOssAttention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices], f"{prefix}-GptOssAttention-query-bias": [f"model.layers.{i}.self_attn.q_proj.bias" for i in hf_indices], f"{prefix}-GptOssAttention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices], f"{prefix}-GptOssAttention-key-bias": [f"model.layers.{i}.self_attn.k_proj.bias" for i in hf_indices], f"{prefix}-GptOssAttention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices], f"{prefix}-GptOssAttention-value-bias": [f"model.layers.{i}.self_attn.v_proj.bias" for i in hf_indices], f"{prefix}-GptOssAttention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices], f"{prefix}-GptOssAttention-out-bias": [f"model.layers.{i}.self_attn.o_proj.bias" for i in hf_indices], f"{prefix}-GptOssAttention-sinks": [f"model.layers.{i}.self_attn.sinks" for i in hf_indices], # GptOssMlp # 1. Gate/Router f"{prefix}-GptOssMlp-gate-kernel": [f"model.layers.{i}.mlp.router.weight" for i in hf_indices], f"{prefix}-GptOssMlp-gate-bias": [f"model.layers.{i}.mlp.router.bias" for i in hf_indices], # 2. Experts (Down Projection) f"{prefix}-GptOssMlp-wo": [f"model.layers.{i}.mlp.experts.down_proj" for i in hf_indices], f"{prefix}-GptOssMlp-wo_bias": [f"model.layers.{i}.mlp.experts.down_proj_bias" for i in hf_indices], # 3. Experts (Gate/Up Fused Projection) # `composite_mt_key`: Multiple MaxText keys map to HF key(s). (f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1"): [ f"model.layers.{i}.mlp.experts.gate_up_proj" for i in hf_indices ], (f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias"): [ f"model.layers.{i}.mlp.experts.gate_up_proj_bias" for i in hf_indices ], } mapping.update(block_mapping) else: # Unscan for i in range(n_layers): prefix = f"params-decoder-layers_{i}" layer_mapping = { # Layer Norms f"{prefix}-pre_self_attention_layer_norm-scale": f"model.layers.{i}.input_layernorm.weight", f"{prefix}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight", # GptOssAttention f"{prefix}-GptOssAttention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight", f"{prefix}-GptOssAttention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias", f"{prefix}-GptOssAttention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", f"{prefix}-GptOssAttention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias", f"{prefix}-GptOssAttention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", f"{prefix}-GptOssAttention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias", f"{prefix}-GptOssAttention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", f"{prefix}-GptOssAttention-out-bias": f"model.layers.{i}.self_attn.o_proj.bias", f"{prefix}-GptOssAttention-sinks": f"model.layers.{i}.self_attn.sinks", # GptOssMlp # 1. Gate/Router f"{prefix}-GptOssMlp-gate-kernel": f"model.layers.{i}.mlp.router.weight", f"{prefix}-GptOssMlp-gate-bias": f"model.layers.{i}.mlp.router.bias", # 2. Experts (Down Projection) f"{prefix}-GptOssMlp-wo": f"model.layers.{i}.mlp.experts.down_proj", f"{prefix}-GptOssMlp-wo_bias": f"model.layers.{i}.mlp.experts.down_proj_bias", # 3. Experts (Gate/Up Fused Projection) # `composite_mt_key`: Multiple MaxText keys map to HF key(s). (f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1"): f"model.layers.{i}.mlp.experts.gate_up_proj", ( f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias", ): f"model.layers.{i}.mlp.experts.gate_up_proj_bias", } mapping.update(layer_mapping) return mapping
[docs] def GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Transformation hooks for gpt-oss parameters. Notes: - Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval) - Handles `composite_mt_key` where multiple MaxText keys map to HF key(s) - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj - (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias - The composite keys are transformed via `interleave` function """ def transpose(input_tensor, target_shape=None): return input_tensor.T def reshape_kernel(input_tensor, target_shape): """Reshapes and transposes kernel weights between MaxText and HF.""" if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) def reshape_bias(input_tensor, target_shape=None): """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden).""" if saving_to_hf: # MaxText [heads, head_dim] -> HF [hidden_dim] (flatten) return input_tensor.reshape(target_shape) else: # HF [hidden_dim] -> MaxText [heads, head_dim] return input_tensor.reshape(target_shape) def interleave(input_tensor, target_shape=None): """ Handles `composite_mt_key`: maxtext (wi_0, wi_1) <-> hf (wi_0_1) - if saving_to_hf: (wi_0, wi_1) -> wi_0_1 - input_tensor is a list of two tensors, tensor ORDER must be same as key order - return a single tensor - otherwise: wi_0_1 -> (wi_0, wi_1) - input_tensor is a single tensor - return two tensors stack at LAST index -1, tensor ORDER must be same as key order """ if saving_to_hf: wi_0, wi_1 = input_tensor wi_0_1 = np.empty(target_shape, dtype=wi_0.dtype) wi_0_1[..., ::2] = wi_0 wi_0_1[..., 1::2] = wi_1 return wi_0_1 else: wi_0_1 = input_tensor wi_0 = wi_0_1[..., ::2] wi_1 = wi_0_1[..., 1::2] return np.stack([wi_0, wi_1], axis=-1) n_layers = config["num_hidden_layers"] # hf config layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval hooks = {"params-decoder-logits_dense-kernel": transpose} indices = range(layer_cycle_interval) if scan_layers else range(n_layers) for idx in indices: prefix = f"params-decoder-layers-layers_{idx}" if scan_layers else f"params-decoder-layers_{idx}" # Attention Kernels & Biases for key in ["query", "key", "value"]: hooks[f"{prefix}-GptOssAttention-{key}-kernel"] = reshape_kernel hooks[f"{prefix}-GptOssAttention-{key}-bias"] = reshape_bias hooks[f"{prefix}-GptOssAttention-out-kernel"] = reshape_kernel # MLP Kernels & Biases hooks[f"{prefix}-GptOssMlp-gate-kernel"] = transpose # `composite_mt_key`: A hook for combining multiple MaxText params. hooks[(f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1")] = interleave hooks[(f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias")] = interleave return hooks
[docs] def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Returns mapping from MaxText to HuggingFace Qwen3-Omni weight paths. This function combines mappings from different modalities (text, vision, audio, etc.) into a unified parameter mapping for the multi-modal Qwen3-Omni model. Args: config (dict): Model configuration dictionary containing modality-specific configs. scan_layers (bool, optional): Whether the model uses scanned layers. Defaults to False. Returns: dict: Combined mapping from all modalities. """ # Collect all modality mappings mapping = {} # Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0) n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"] text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING( config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text}, maxtext_config=maxtext_config, scan_layers=scan_layers, ) # Add "thinker." prefix to text mapping values def add_prefix_recursive(value): """Recursively add 'thinker.' prefix to strings, handling nested lists.""" if isinstance(value, list): return [add_prefix_recursive(v) for v in value] else: return f"thinker.{value}" for key, value in text_mapping.items(): text_mapping[key] = add_prefix_recursive(value) mapping.update(text_mapping) # Vision mapping vision_config = config["thinker_config"]["vision_config"] n_vision_layers = vision_config["depth"] # Vision patch embedding mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-kernel"] = ( "thinker.visual.patch_embed.proj.weight" ) mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-bias"] = ( "thinker.visual.patch_embed.proj.bias" ) # Vision positional embedding mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-pos_embed_interpolate-pos_embed"] = ( "thinker.visual.pos_embed.weight" ) # Vision blocks (27 layers) for i in range(n_vision_layers): prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-blocks_{i}" hf_prefix = f"thinker.visual.blocks.{i}" # Layer norms mapping[f"{prefix}-ln1-scale"] = f"{hf_prefix}.norm1.weight" mapping[f"{prefix}-ln1-bias"] = f"{hf_prefix}.norm1.bias" mapping[f"{prefix}-ln2-scale"] = f"{hf_prefix}.norm2.weight" mapping[f"{prefix}-ln2-bias"] = f"{hf_prefix}.norm2.bias" # Attention (HF has fused QKV, MaxText has separate Q/K/V) # We'll handle the split/fusion in the hook functions mapping[f"{prefix}-attn-attn-query-kernel"] = f"{hf_prefix}.attn.qkv.weight" mapping[f"{prefix}-attn-attn-query-bias"] = f"{hf_prefix}.attn.qkv.bias" mapping[f"{prefix}-attn-attn-key-kernel"] = f"{hf_prefix}.attn.qkv.weight" mapping[f"{prefix}-attn-attn-key-bias"] = f"{hf_prefix}.attn.qkv.bias" mapping[f"{prefix}-attn-attn-value-kernel"] = f"{hf_prefix}.attn.qkv.weight" mapping[f"{prefix}-attn-attn-value-bias"] = f"{hf_prefix}.attn.qkv.bias" mapping[f"{prefix}-attn-attn-out-kernel"] = f"{hf_prefix}.attn.proj.weight" mapping[f"{prefix}-attn-attn-out-bias"] = f"{hf_prefix}.attn.proj.bias" # MLP mapping[f"{prefix}-mlp-kernel"] = f"{hf_prefix}.mlp.linear_fc1.weight" mapping[f"{prefix}-mlp-bias"] = f"{hf_prefix}.mlp.linear_fc1.bias" mapping[f"{prefix}-mlp_out-kernel"] = f"{hf_prefix}.mlp.linear_fc2.weight" mapping[f"{prefix}-mlp_out-bias"] = f"{hf_prefix}.mlp.linear_fc2.bias" # Vision merger_list (deep mergers at layers 8, 16, 24) deepstack_indexes = vision_config.get("deepstack_visual_indexes", [8, 16, 24]) for merger_idx, _ in enumerate(deepstack_indexes): prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-merger_{merger_idx}" hf_prefix = f"thinker.visual.merger_list.{merger_idx}" mapping[f"{prefix}-ln_q-scale"] = f"{hf_prefix}.ln_q.weight" mapping[f"{prefix}-ln_q-bias"] = f"{hf_prefix}.ln_q.bias" mapping[f"{prefix}-mlp_0-kernel"] = f"{hf_prefix}.mlp.0.weight" mapping[f"{prefix}-mlp_0-bias"] = f"{hf_prefix}.mlp.0.bias" mapping[f"{prefix}-mlp_2-kernel"] = f"{hf_prefix}.mlp.2.weight" mapping[f"{prefix}-mlp_2-bias"] = f"{hf_prefix}.mlp.2.bias" # Vision projector (final merger) mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-scale"] = "thinker.visual.merger.ln_q.weight" mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-ln_q-bias"] = "thinker.visual.merger.ln_q.bias" mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = ( "thinker.visual.merger.mlp.0.weight" ) mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-bias"] = "thinker.visual.merger.mlp.0.bias" mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = ( "thinker.visual.merger.mlp.2.weight" ) mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-bias"] = "thinker.visual.merger.mlp.2.bias" # Audio mapping audio_config = config["thinker_config"]["audio_config"] n_audio_layers = audio_config["encoder_layers"] # Audio conv layers (3 Conv2D layers for downsampling) mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-kernel"] = "thinker.audio_tower.conv2d1.weight" mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-bias"] = "thinker.audio_tower.conv2d1.bias" mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-kernel"] = "thinker.audio_tower.conv2d2.weight" mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-bias"] = "thinker.audio_tower.conv2d2.bias" mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-kernel"] = "thinker.audio_tower.conv2d3.weight" mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-bias"] = "thinker.audio_tower.conv2d3.bias" # Audio conv output projection mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv_out-kernel"] = "thinker.audio_tower.conv_out.weight" # Audio encoder layers (32 layers) for i in range(n_audio_layers): prefix = f"params-audio_encoder-Qwen3OmniAudioEncoder_0-layers_{i}" hf_prefix = f"thinker.audio_tower.layers.{i}" # Layer norms mapping[f"{prefix}-input_layer_norm-scale"] = f"{hf_prefix}.self_attn_layer_norm.weight" mapping[f"{prefix}-input_layer_norm-bias"] = f"{hf_prefix}.self_attn_layer_norm.bias" mapping[f"{prefix}-post_attention_layer_norm-scale"] = f"{hf_prefix}.final_layer_norm.weight" mapping[f"{prefix}-post_attention_layer_norm-bias"] = f"{hf_prefix}.final_layer_norm.bias" # Attention (separate Q/K/V) mapping[f"{prefix}-self_attention_audio-query-kernel"] = f"{hf_prefix}.self_attn.q_proj.weight" mapping[f"{prefix}-self_attention_audio-query-bias"] = f"{hf_prefix}.self_attn.q_proj.bias" mapping[f"{prefix}-self_attention_audio-key-kernel"] = f"{hf_prefix}.self_attn.k_proj.weight" mapping[f"{prefix}-self_attention_audio-key-bias"] = f"{hf_prefix}.self_attn.k_proj.bias" mapping[f"{prefix}-self_attention_audio-value-kernel"] = f"{hf_prefix}.self_attn.v_proj.weight" mapping[f"{prefix}-self_attention_audio-value-bias"] = f"{hf_prefix}.self_attn.v_proj.bias" mapping[f"{prefix}-self_attention_audio-out-kernel"] = f"{hf_prefix}.self_attn.out_proj.weight" mapping[f"{prefix}-self_attention_audio-out-bias"] = f"{hf_prefix}.self_attn.out_proj.bias" # MLP (AudioMLP has 2 linear layers: fc1 and fc2) mapping[f"{prefix}-AudioMLP-wi-kernel"] = f"{hf_prefix}.fc1.weight" mapping[f"{prefix}-AudioMLP-wi-bias"] = f"{hf_prefix}.fc1.bias" mapping[f"{prefix}-AudioMLP-wo-kernel"] = f"{hf_prefix}.fc2.weight" mapping[f"{prefix}-AudioMLP-wo-bias"] = f"{hf_prefix}.fc2.bias" # Audio post layer norm mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-layernorm_post-scale"] = "thinker.audio_tower.ln_post.weight" mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-layernorm_post-bias"] = "thinker.audio_tower.ln_post.bias" # Audio projector (2 linear layers) mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-kernel"] = "thinker.audio_tower.proj1.weight" mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-bias"] = "thinker.audio_tower.proj1.bias" mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-kernel"] = "thinker.audio_tower.proj2.weight" mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-bias"] = "thinker.audio_tower.proj2.bias" return mapping
[docs] def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Creates parameter transformation functions for Qwen3-Omni. This function provides a dictionary of transformation functions (hooks) for converting Qwen3-Omni model parameters between MaxText and Hugging Face formats. It handles embedding padding and kernel reshaping. Args: config (dict): Model configuration dictionary, including 'num_hidden_layers' and optionally 'num_experts'. scan_layers (bool, optional): Whether the model uses scanned layers. Defaults to False. saving_to_hf (bool, optional): The direction of conversion. True for MaxText to Hugging Face, False for the reverse. Defaults to False. Returns: dict: A dictionary mapping MaxText parameter names to their corresponding transformation functions. """ # Collect all modality hooks mapping = {} # Text hooks, reusing QWEN3-MOE hook function num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0) n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"] text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN( config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text}, maxtext_config=maxtext_config, scan_layers=scan_layers, saving_to_hf=saving_to_hf, ) mapping.update(text_hooks) # Vision hooks vision_config = config["thinker_config"]["vision_config"] n_vision_layers = vision_config["depth"] hidden_size = vision_config["hidden_size"] def reshape_kernel_vision(input_tensor, target_shape): """Reshape kernel for vision layers.""" if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) def reshape_conv3d_patch_embed(input_tensor, target_shape): """Reshape 3D conv patch embedding weight. HF: (out_channels, in_channels, temporal, height, width) MaxText: (temporal, height, width, in_channels, out_channels) """ if saving_to_hf: # MaxText -> HF: (T, H, W, C_in, C_out) -> (C_out, C_in, T, H, W) return input_tensor.transpose(4, 3, 0, 1, 2) else: # HF -> MaxText: (C_out, C_in, T, H, W) -> (T, H, W, C_in, C_out) return input_tensor.transpose(2, 3, 4, 1, 0) def split_qkv_query(input_tensor, target_shape): """Extract Q from fused QKV for HF->MaxText conversion. HF has fused QKV: (3*hidden_size, hidden_size) MaxText Q: (hidden_size, num_heads, head_dim) """ if saving_to_hf: # MaxText -> HF: will be handled by fusion hook raise NotImplementedError("Use fusion hook for MaxText->HF") else: # HF -> MaxText: Extract Q from fused QKV # input_tensor shape: (3*hidden_size, hidden_size) q_weight = input_tensor[:hidden_size, :] # (hidden_size, hidden_size) return q_weight.T.reshape(target_shape) # (hidden_size, num_heads, head_dim) def split_qkv_key(input_tensor, target_shape): """Extract K from fused QKV for HF->MaxText conversion.""" if saving_to_hf: raise NotImplementedError("Use fusion hook for MaxText->HF") else: # Extract K from fused QKV k_weight = input_tensor[hidden_size : 2 * hidden_size, :] return k_weight.T.reshape(target_shape) def split_qkv_value(input_tensor, target_shape): """Extract V from fused QKV for HF->MaxText conversion.""" if saving_to_hf: raise NotImplementedError("Use fusion hook for MaxText->HF") else: # Extract V from fused QKV v_weight = input_tensor[2 * hidden_size :, :] return v_weight.T.reshape(target_shape) def split_qkv_bias_query(input_tensor, target_shape): """Extract Q bias from fused QKV bias.""" if saving_to_hf: raise NotImplementedError("Use fusion hook for MaxText->HF") else: q_bias = input_tensor[:hidden_size] return q_bias.reshape(target_shape) # (num_heads, head_dim) def split_qkv_bias_key(input_tensor, target_shape): """Extract K bias from fused QKV bias.""" if saving_to_hf: raise NotImplementedError("Use fusion hook for MaxText->HF") else: k_bias = input_tensor[hidden_size : 2 * hidden_size] return k_bias.reshape(target_shape) def split_qkv_bias_value(input_tensor, target_shape): """Extract V bias from fused QKV bias.""" if saving_to_hf: raise NotImplementedError("Use fusion hook for MaxText->HF") else: v_bias = input_tensor[2 * hidden_size :] return v_bias.reshape(target_shape) def reshape_vision_attn_out(input_tensor, target_shape): """Reshape vision attention output projection. HF: (hidden_size, hidden_size) MaxText: (num_heads, head_dim, hidden_size) """ if saving_to_hf: # MaxText -> HF: (num_heads, head_dim, hidden_size) -> (hidden_size, hidden_size) return input_tensor.reshape(hidden_size, hidden_size).T else: # HF -> MaxText: (hidden_size, hidden_size) -> (num_heads, head_dim, hidden_size) return input_tensor.T.reshape(target_shape) # Vision patch embedding mapping["params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-patch_embed-proj-kernel"] = reshape_conv3d_patch_embed # Vision blocks for i in range(n_vision_layers): prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-blocks_{i}" # Attention Q/K/V - split from fused QKV mapping[f"{prefix}-attn-attn-query-kernel"] = split_qkv_query mapping[f"{prefix}-attn-attn-query-bias"] = split_qkv_bias_query mapping[f"{prefix}-attn-attn-key-kernel"] = split_qkv_key mapping[f"{prefix}-attn-attn-key-bias"] = split_qkv_bias_key mapping[f"{prefix}-attn-attn-value-kernel"] = split_qkv_value mapping[f"{prefix}-attn-attn-value-bias"] = split_qkv_bias_value # Attention output mapping[f"{prefix}-attn-attn-out-kernel"] = reshape_vision_attn_out # attn-attn-out-bias doesn't need a hook (no reshape needed) # MLP mapping[f"{prefix}-mlp-kernel"] = reshape_kernel_vision mapping[f"{prefix}-mlp_out-kernel"] = reshape_kernel_vision # Vision merger_list and projector MLPs deepstack_indexes = vision_config.get("deepstack_visual_indexes", [8, 16, 24]) for merger_idx, _ in enumerate(deepstack_indexes): prefix = f"params-vision_encoder-Qwen3OmniMoeVisionEncoder_0-merger_{merger_idx}" mapping[f"{prefix}-mlp_0-kernel"] = reshape_kernel_vision mapping[f"{prefix}-mlp_2-kernel"] = reshape_kernel_vision # Vision projector (final merger) mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_0-kernel"] = reshape_kernel_vision mapping["params-vision_encoder-Qwen3OmniMoeVisionProjector_0-merger-mlp_2-kernel"] = reshape_kernel_vision # Audio hooks audio_config = config["thinker_config"]["audio_config"] n_audio_layers = audio_config["encoder_layers"] hidden_size_audio = audio_config["d_model"] def reshape_kernel_audio(input_tensor, target_shape): """Reshape kernel for audio layers.""" if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) def reshape_conv2d_audio(input_tensor, target_shape): """Reshape Conv2D weight for audio. HF: (out_channels, in_channels, height, width) MaxText: (height, width, in_channels, out_channels) """ if saving_to_hf: # MaxText -> HF: (H, W, C_in, C_out) -> (C_out, C_in, H, W) return input_tensor.transpose(3, 2, 0, 1) else: # HF -> MaxText: (C_out, C_in, H, W) -> (H, W, C_in, C_out) return input_tensor.transpose(2, 3, 1, 0) def reshape_audio_attn_qkv(input_tensor, target_shape): """Reshape audio attention Q/K/V projection. HF: (hidden_size, hidden_size) MaxText: (hidden_size, num_heads, head_dim) """ if saving_to_hf: # MaxText -> HF: (hidden_size, num_heads, head_dim) -> (hidden_size, hidden_size) return input_tensor.reshape(hidden_size_audio, hidden_size_audio).T else: # HF -> MaxText: (hidden_size, hidden_size) -> (hidden_size, num_heads, head_dim) return input_tensor.T.reshape(target_shape) def reshape_audio_attn_out(input_tensor, target_shape): """Reshape audio attention output projection. HF: (hidden_size, hidden_size) MaxText: (num_heads, head_dim, hidden_size) """ if saving_to_hf: # MaxText -> HF: (num_heads, head_dim, hidden_size) -> (hidden_size, hidden_size) return input_tensor.reshape(hidden_size_audio, hidden_size_audio).T else: # HF -> MaxText: (hidden_size, hidden_size) -> (num_heads, head_dim, hidden_size) return input_tensor.T.reshape(target_shape) def reshape_audio_attn_qkv_bias(input_tensor, target_shape): """Reshape audio attention Q/K/V bias. HF: (hidden_size,) MaxText: (num_heads, head_dim) """ if saving_to_hf: # MaxText -> HF: (num_heads, head_dim) -> (hidden_size,) return input_tensor.reshape(hidden_size_audio) else: # HF -> MaxText: (hidden_size,) -> (num_heads, head_dim) return input_tensor.reshape(target_shape) # Audio conv layers mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d1-kernel"] = reshape_conv2d_audio mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d2-kernel"] = reshape_conv2d_audio mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv2d3-kernel"] = reshape_conv2d_audio # Audio conv output projection mapping["params-audio_encoder-Qwen3OmniAudioEncoder_0-conv_out-kernel"] = reshape_kernel_audio # Audio encoder layers for i in range(n_audio_layers): prefix = f"params-audio_encoder-Qwen3OmniAudioEncoder_0-layers_{i}" # Attention Q/K/V mapping[f"{prefix}-self_attention_audio-query-kernel"] = reshape_audio_attn_qkv mapping[f"{prefix}-self_attention_audio-query-bias"] = reshape_audio_attn_qkv_bias mapping[f"{prefix}-self_attention_audio-key-kernel"] = reshape_audio_attn_qkv mapping[f"{prefix}-self_attention_audio-key-bias"] = reshape_audio_attn_qkv_bias mapping[f"{prefix}-self_attention_audio-value-kernel"] = reshape_audio_attn_qkv mapping[f"{prefix}-self_attention_audio-value-bias"] = reshape_audio_attn_qkv_bias # Attention output mapping[f"{prefix}-self_attention_audio-out-kernel"] = reshape_audio_attn_out # MLP mapping[f"{prefix}-AudioMLP-wi-kernel"] = reshape_kernel_audio mapping[f"{prefix}-AudioMLP-wo-kernel"] = reshape_kernel_audio # Audio projector mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj1-kernel"] = reshape_kernel_audio mapping["params-audio_encoder-Qwen3OmniAudioProjector_0-proj2-kernel"] = reshape_kernel_audio return mapping
[docs] def QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN(target_shape=None): """Creates parameter transformation functions for Qwen3. This function provides a dictionary of transformation functions (hooks) for converting Qwen3 model parameters between NNX and vLLM formats. Returns: dict: A dictionary mapping NNX parameter names to their corresponding transformation functions. """ return {}
[docs] def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """ Returns a dictionary mapping from MaxText parameter names to HuggingFace LLaMA3.1 parameter names. Args: config (dict): Model configuration dictionary containing: - num_hidden_layers (int): The number of decoder layers. scan_layers (bool, optional): If True, MaxText layers are 'stacked' into a single param. Defaults to False. Returns: dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values are either a single string (unscanned form) or a list of strings (scanned form) for stacked layers when `scan_layers=True`. """ n_layers = config["num_hidden_layers"] mapping = { "params-token_embedder-embedding": "model.embed_tokens.weight", "params-decoder-logits_dense-kernel": "lm_head.weight", "params-decoder-decoder_norm-scale": "model.norm.weight", } if scan_layers: mapping["params-decoder-layers-self_attention-query-kernel"] = [ f"model.layers.{layer_idx}.self_attn.q_proj.weight" for layer_idx in range(n_layers) ] mapping["params-decoder-layers-self_attention-key-kernel"] = [ f"model.layers.{layer_idx}.self_attn.k_proj.weight" for layer_idx in range(n_layers) ] mapping["params-decoder-layers-self_attention-value-kernel"] = [ f"model.layers.{layer_idx}.self_attn.v_proj.weight" for layer_idx in range(n_layers) ] mapping["params-decoder-layers-self_attention-out-kernel"] = [ f"model.layers.{layer_idx}.self_attn.o_proj.weight" for layer_idx in range(n_layers) ] mapping["params-decoder-layers-mlp-wi_0-kernel"] = [ f"model.layers.{layer_idx}.mlp.gate_proj.weight" for layer_idx in range(n_layers) ] mapping["params-decoder-layers-mlp-wi_1-kernel"] = [ f"model.layers.{layer_idx}.mlp.up_proj.weight" for layer_idx in range(n_layers) ] mapping["params-decoder-layers-mlp-wo-kernel"] = [ f"model.layers.{layer_idx}.mlp.down_proj.weight" for layer_idx in range(n_layers) ] mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"] = [ f"model.layers.{layer_idx}.input_layernorm.weight" for layer_idx in range(n_layers) ] mapping["params-decoder-layers-post_self_attention_layer_norm-scale"] = [ f"model.layers.{layer_idx}.post_attention_layernorm.weight" for layer_idx in range(n_layers) ] else: for layer_idx in range(n_layers): mapping[f"params-decoder-layers_{layer_idx}-self_attention-query-kernel"] = ( f"model.layers.{layer_idx}.self_attn.q_proj.weight" ) mapping[f"params-decoder-layers_{layer_idx}-self_attention-key-kernel"] = ( f"model.layers.{layer_idx}.self_attn.k_proj.weight" ) mapping[f"params-decoder-layers_{layer_idx}-self_attention-value-kernel"] = ( f"model.layers.{layer_idx}.self_attn.v_proj.weight" ) mapping[f"params-decoder-layers_{layer_idx}-self_attention-out-kernel"] = ( f"model.layers.{layer_idx}.self_attn.o_proj.weight" ) mapping[f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel"] = f"model.layers.{layer_idx}.mlp.gate_proj.weight" mapping[f"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel"] = f"model.layers.{layer_idx}.mlp.up_proj.weight" mapping[f"params-decoder-layers_{layer_idx}-mlp-wo-kernel"] = f"model.layers.{layer_idx}.mlp.down_proj.weight" mapping[f"params-decoder-layers_{layer_idx}-pre_self_attention_layer_norm-scale"] = ( f"model.layers.{layer_idx}.input_layernorm.weight" ) mapping[f"params-decoder-layers_{layer_idx}-post_self_attention_layer_norm-scale"] = ( f"model.layers.{layer_idx}.post_attention_layernorm.weight" ) return mapping
[docs] def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Creates parameter transformation functions for converting between MaxText and HuggingFace formats. This function generates a mapping of transformation functions that handle the necessary conversions between MaxText and HuggingFace parameter formats, including operations like reshaping. """ nlayers = config["num_hidden_layers"] def scale_query_layer(input_tensor, target_shape): if saving_to_hf: depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"])) original_dtype = input_tensor.dtype output_tensor = input_tensor.astype(np.float32) * depth_scale return output_tensor.astype(original_dtype) else: depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"])) original_dtype = input_tensor.dtype output_tensor = input_tensor.astype(np.float32) * depth_scale return output_tensor.astype(original_dtype) def adjust_rope(input_tensor, target_shape): arr = input_tensor if saving_to_hf: # Convert from MaxText's interleaved layout to HF's concatenated layout evens = arr[..., ::2] odds = arr[..., 1::2] return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) else: # Convert from HF's concatenated layout to MaxText's interleaved layout half_dim = arr.shape[-1] // 2 first_half = arr[..., :half_dim] second_half = arr[..., half_dim:] return jax.numpy.stack([first_half, second_half], axis=-1).reshape(arr.shape) def reshape_kernel(input_tensor, target_shape): if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).transpose() else: return input_tensor.transpose().reshape(target_shape) # caveat: hook order does affect result # to_huggingface query_hook_chain = [scale_query_layer, adjust_rope, reshape_kernel] key_hook_chain = [adjust_rope, reshape_kernel] # to_maxtext if not saving_to_hf: query_hook_chain.reverse() key_hook_chain.reverse() hook_fns = {} hook_fns["params-decoder-logits_dense-kernel"] = reshape_kernel if scan_layers: hook_fns = { **hook_fns, "params-decoder-layers-self_attention-query-kernel": query_hook_chain, "params-decoder-layers-self_attention-key-kernel": key_hook_chain, "params-decoder-layers-self_attention-value-kernel": reshape_kernel, "params-decoder-layers-self_attention-out-kernel": reshape_kernel, "params-decoder-layers-mlp-wi_0-kernel": reshape_kernel, "params-decoder-layers-mlp-wi_1-kernel": reshape_kernel, "params-decoder-layers-mlp-wo-kernel": reshape_kernel, } else: for layer_idx in range(nlayers): hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-query-kernel"] = query_hook_chain hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-key-kernel"] = key_hook_chain hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-value-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-out-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wo-kernel"] = reshape_kernel return hook_fns
[docs] def LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN(): """Defines and returns hook functions for weight transformations. These hooks are applied to specific weights during the conversion from MaxText to a HuggingFace-compatible format. They handle transformations like RoPE reordering and query scaling that are not simple re-mappings. Returns: A dictionary where keys are MaxText parameter names and values are the corresponding transformation functions. """ def reorder_rope(arr): """Reorders Rotary Position Embedding (RoPE) weights. This function is necessary because MaxText and HuggingFace's vLLM implementations may have different orderings for RoPE dimensions. It splits the last dimension into even and odd indices and concatenates them. Args: arr: The input weight array. Returns: The reordered weight array. """ evens = arr[..., ::2] odds = arr[..., 1::2] return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) def transform_query_kernel(arr): """Transforms the query kernel. This involves scaling the kernel by the square root of the head dimension and then applying RoPE reordering. Args: arr: The query kernel weight array. Returns: The transformed query kernel array. """ head_dim = arr.shape[-1] depth_scale = np.dtype("float32").type(np.sqrt(head_dim)) arr = arr * depth_scale return reorder_rope(arr) hook_fns = { "base.decoder.layers.self_attention.query.kernel": transform_query_kernel, "base.decoder.layers.self_attention.key.kernel": reorder_rope, } return hook_fns
[docs] def MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """ Generates the mapping of parameter names from MaxText to Hugging Face for Mixtral. Returns: dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values are Hugging Face parameter names in one of four forms: unscanned string, scanned list of strings, unscanned with expert stacking (list of strings), or scanned with expert stacking (nested list of strings). """ mapping = {} # Top-level, non-layer-specific parameters mapping["params-token_embedder-embedding"] = "model.embed_tokens.weight" mapping["params-decoder-decoder_norm-scale"] = "model.norm.weight" mapping["params-decoder-logits_dense-kernel"] = "lm_head.weight" num_experts = maxtext_config.num_experts if scan_layers: # Initialize lists for scanned layer weights mapping.update( { "params-decoder-layers-self_attention-query-kernel": [], "params-decoder-layers-self_attention-key-kernel": [], "params-decoder-layers-self_attention-value-kernel": [], "params-decoder-layers-self_attention-out-kernel": [], "params-decoder-layers-pre_self_attention_layer_norm-scale": [], "params-decoder-layers-post_self_attention_layer_norm-scale": [], "params-decoder-layers-MoeBlock_0-gate-kernel": [], "params-decoder-layers-MoeBlock_0-wi_0": [], "params-decoder-layers-MoeBlock_0-wi_1": [], "params-decoder-layers-MoeBlock_0-wo": [], } ) for i in range(config["num_hidden_layers"]): hf_prefix = f"model.layers.{i}" # Attention weights mapping["params-decoder-layers-self_attention-query-kernel"].append(f"{hf_prefix}.self_attn.q_proj.weight") mapping["params-decoder-layers-self_attention-key-kernel"].append(f"{hf_prefix}.self_attn.k_proj.weight") mapping["params-decoder-layers-self_attention-value-kernel"].append(f"{hf_prefix}.self_attn.v_proj.weight") mapping["params-decoder-layers-self_attention-out-kernel"].append(f"{hf_prefix}.self_attn.o_proj.weight") # RMSNorm weights mapping["params-decoder-layers-pre_self_attention_layer_norm-scale"].append(f"{hf_prefix}.input_layernorm.weight") mapping["params-decoder-layers-post_self_attention_layer_norm-scale"].append( f"{hf_prefix}.post_attention_layernorm.weight" ) # MoE gate mapping["params-decoder-layers-MoeBlock_0-gate-kernel"].append(f"{hf_prefix}.block_sparse_moe.gate.weight") # Outer loop as experts and inner loop as layers to align with logic in _build_multi_axis_stacked_tensor() for j in range(num_experts): w1_layers = [] w3_layers = [] w2_layers = [] for i in range(config["num_hidden_layers"]): hf_prefix = f"model.layers.{i}" w1_layers.append(f"{hf_prefix}.block_sparse_moe.experts.{j}.w1.weight") w3_layers.append(f"{hf_prefix}.block_sparse_moe.experts.{j}.w3.weight") w2_layers.append(f"{hf_prefix}.block_sparse_moe.experts.{j}.w2.weight") mapping["params-decoder-layers-MoeBlock_0-wi_0"].append(w1_layers) mapping["params-decoder-layers-MoeBlock_0-wi_1"].append(w3_layers) mapping["params-decoder-layers-MoeBlock_0-wo"].append(w2_layers) else: for i in range(config["num_hidden_layers"]): maxtext_prefix = f"params-decoder-layers_{i}" hf_prefix = f"model.layers.{i}" # Attention weights mapping[f"{maxtext_prefix}-self_attention-query-kernel"] = f"{hf_prefix}.self_attn.q_proj.weight" mapping[f"{maxtext_prefix}-self_attention-key-kernel"] = f"{hf_prefix}.self_attn.k_proj.weight" mapping[f"{maxtext_prefix}-self_attention-value-kernel"] = f"{hf_prefix}.self_attn.v_proj.weight" mapping[f"{maxtext_prefix}-self_attention-out-kernel"] = f"{hf_prefix}.self_attn.o_proj.weight" # RMSNorm weights mapping[f"{maxtext_prefix}-pre_self_attention_layer_norm-scale"] = f"{hf_prefix}.input_layernorm.weight" mapping[f"{maxtext_prefix}-post_self_attention_layer_norm-scale"] = f"{hf_prefix}.post_attention_layernorm.weight" # MoE gate mapping[f"{maxtext_prefix}-MoeBlock_0-gate-kernel"] = f"{hf_prefix}.block_sparse_moe.gate.weight" # MoE expert weights (1 MaxText param -> 8 HF params) w1_experts = [f"{hf_prefix}.block_sparse_moe.experts.{j}.w1.weight" for j in range(num_experts)] w3_experts = [f"{hf_prefix}.block_sparse_moe.experts.{j}.w3.weight" for j in range(num_experts)] w2_experts = [f"{hf_prefix}.block_sparse_moe.experts.{j}.w2.weight" for j in range(num_experts)] mapping[f"{maxtext_prefix}-MoeBlock_0-wi_0"] = w1_experts mapping[f"{maxtext_prefix}-MoeBlock_0-wi_1"] = w3_experts mapping[f"{maxtext_prefix}-MoeBlock_0-wo"] = w2_experts return mapping
[docs] def MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """ Generates parameter conversion hooks for Mixtral between MaxText and Hugging Face. """ hooks = {} def reshape_and_transpose_attention(x, target_shape): """MaxText: [hidden, n_heads, h_dim] <-> HF: [n_heads * h_dim, hidden]""" if saving_to_hf: # (H, N, D) -> (H, N*D) -> (N*D, H) return x.reshape(config["hidden_size"], -1).transpose() else: # (N*D, H) -> (H, N*D) -> (H, N, D) return x.transpose().reshape(target_shape) def reshape_kernel(x, target_shape): return x.transpose() def scale_query_layer(input_tensor, target_shape): if saving_to_hf: depth_scale = np.dtype("float32").type(np.sqrt(maxtext_config.head_dim)) return (input_tensor * depth_scale).astype(input_tensor.dtype) else: depth_scale = np.dtype("float32").type(1 / np.sqrt(maxtext_config.head_dim)) return (input_tensor * depth_scale).astype(input_tensor.dtype) # hook order does not affect result query_hook_chain = [reshape_and_transpose_attention, scale_query_layer] if scan_layers: plan = [ ("params-decoder-layers-self_attention-query-kernel", query_hook_chain), ("params-decoder-layers-self_attention-key-kernel", reshape_and_transpose_attention), ("params-decoder-layers-self_attention-value-kernel", reshape_and_transpose_attention), ("params-decoder-layers-self_attention-out-kernel", reshape_and_transpose_attention), ("params-decoder-layers-MoeBlock_0-wi_0", reshape_kernel), ("params-decoder-layers-MoeBlock_0-wi_1", reshape_kernel), ("params-decoder-layers-MoeBlock_0-wo", reshape_kernel), ("params-decoder-layers-MoeBlock_0-gate-kernel", reshape_kernel), ] else: plan = [ ("params-decoder-layers_{i}-self_attention-query-kernel", query_hook_chain), ("params-decoder-layers_{i}-self_attention-key-kernel", reshape_and_transpose_attention), ("params-decoder-layers_{i}-self_attention-value-kernel", reshape_and_transpose_attention), ("params-decoder-layers_{i}-self_attention-out-kernel", reshape_and_transpose_attention), ("params-decoder-layers_{i}-MoeBlock_0-wi_0", reshape_kernel), ("params-decoder-layers_{i}-MoeBlock_0-wi_1", reshape_kernel), ("params-decoder-layers_{i}-MoeBlock_0-wo", reshape_kernel), ("params-decoder-layers_{i}-MoeBlock_0-gate-kernel", reshape_kernel), ] plan.append(("params-decoder-logits_dense-kernel", reshape_kernel)) for maxtext_pattern, op_func in plan: if "{i}" in maxtext_pattern: for i in range(config["num_hidden_layers"]): hooks[maxtext_pattern.format(i=i)] = op_func else: hooks[maxtext_pattern] = op_func return hooks
[docs] def GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Returns mapping between MaxText and HuggingFace Gemma4 weight paths.""" tcfg = config.get("text_config", config) nlayers = tcfg["num_hidden_layers"] share_kv_projections = maxtext_config.share_kv_projections # Gemma 4 uses a block pattern of length 6: 5 local, 1 global vcfg = config.get("vision_config", {}) text_base = "model.language_model" if vcfg else "model" mapping = { "params-token_embedder-embedding": f"{text_base}.embed_tokens.weight", "params-decoder-decoder_norm-scale": f"{text_base}.norm.weight", } if vcfg: nvis = vcfg.get("num_hidden_layers", 0) mapping.update( { "params-vision_encoder-Gemma4VisionEncoderLayer_0-vision_entry-input_projection-kernel": ( "model.vision_tower.patch_embedder.input_proj.weight" ), "params-vision_encoder-Gemma4VisionEncoderLayer_0-vision_entry-pos_emb_param": ( "model.vision_tower.patch_embedder.position_embedding_table" ), "params-vision_encoder-Gemma4VisionProjector_0-projection-kernel": ( "model.embed_vision.embedding_projection.weight" ), "params-vision_encoder-Gemma4VisionEncoderLayer_0-std_scale": ("model.vision_tower.std_scale"), "params-vision_encoder-Gemma4VisionEncoderLayer_0-std_bias": ("model.vision_tower.std_bias"), } ) for i in range(nvis): prefix = f"params-vision_encoder-Gemma4VisionEncoderLayer_0-layer_{i}" hf_prefix = f"model.vision_tower.encoder.layers.{i}" mapping.update( { f"{prefix}-attention-query-kernel": f"{hf_prefix}.self_attn.q_proj.linear.weight", f"{prefix}-attention-key-kernel": f"{hf_prefix}.self_attn.k_proj.linear.weight", f"{prefix}-attention-value-kernel": f"{hf_prefix}.self_attn.v_proj.linear.weight", f"{prefix}-attention-out-kernel": f"{hf_prefix}.self_attn.o_proj.linear.weight", f"{prefix}-attention-query_norm-scale": f"{hf_prefix}.self_attn.q_norm.weight", f"{prefix}-attention-key_norm-scale": f"{hf_prefix}.self_attn.k_norm.weight", f"{prefix}-pre_attention_norm-scale": f"{hf_prefix}.input_layernorm.weight", f"{prefix}-post_attention_norm-scale": f"{hf_prefix}.post_attention_layernorm.weight", f"{prefix}-pre_ffw_norm-scale": f"{hf_prefix}.pre_feedforward_layernorm.weight", f"{prefix}-post_ffw_norm-scale": f"{hf_prefix}.post_feedforward_layernorm.weight", f"{prefix}-mlp-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.linear.weight", f"{prefix}-mlp-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.linear.weight", f"{prefix}-mlp-wo-kernel": f"{hf_prefix}.mlp.down_proj.linear.weight", } ) if scan_layers: attention_pattern_length = 6 num_remaining = nlayers % attention_pattern_length num_scanned = nlayers - num_remaining num_experts = tcfg.get("num_experts") num_experts = num_experts if num_experts is not None else 1 # Main scanned blocks for layer_in_block in range(attention_pattern_length): hf_indices = list(range(layer_in_block, num_scanned, attention_pattern_length)) prefix = f"params-decoder-scanned_blocks-layers_{layer_in_block}" mapping.update( { f"{prefix}-self_attention-query-kernel": [ f"{text_base}.layers.{i}.self_attn.q_proj.weight" for i in hf_indices ], f"{prefix}-self_attention-key-kernel": [ f"{text_base}.layers.{i}.self_attn.k_proj.weight" for i in hf_indices ], f"{prefix}-self_attention-value-kernel": ( None if share_kv_projections and layer_in_block == 5 else [f"{text_base}.layers.{i}.self_attn.v_proj.weight" for i in hf_indices] ), f"{prefix}-self_attention-out-kernel": [ f"{text_base}.layers.{i}.self_attn.o_proj.weight" for i in hf_indices ], f"{prefix}-self_attention-query_norm-scale": [ f"{text_base}.layers.{i}.self_attn.q_norm.weight" for i in hf_indices ], f"{prefix}-self_attention-key_norm-scale": [ f"{text_base}.layers.{i}.self_attn.k_norm.weight" for i in hf_indices ], } ) if maxtext_config.v_norm_with_scale: mapping.update( { f"{prefix}-self_attention-value_norm-scale": [ f"{text_base}.layers.{i}.self_attn.v_norm.weight" for i in hf_indices ] } ) mapping.update( { f"{prefix}-pre_self_attention_norm-scale": [ f"{text_base}.layers.{i}.input_layernorm.weight" for i in hf_indices ], f"{prefix}-post_self_attention_norm-scale": [ f"{text_base}.layers.{i}.post_attention_layernorm.weight" for i in hf_indices ], f"{prefix}-pre_ffw_norm-scale": [ f"{text_base}.layers.{i}.pre_feedforward_layernorm.weight" for i in hf_indices ], f"{prefix}-post_ffw_norm-scale": [ f"{text_base}.layers.{i}.post_feedforward_layernorm.weight" for i in hf_indices ], f"{prefix}-mlp-pre_feedforward_layernorm_2-scale": [ f"{text_base}.layers.{i}.pre_feedforward_layernorm_2.weight" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-post_feedforward_layernorm_1-scale": [ f"{text_base}.layers.{i}.post_feedforward_layernorm_1.weight" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-post_feedforward_layernorm_2-scale": [ f"{text_base}.layers.{i}.post_feedforward_layernorm_2.weight" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-pre_forward_scale_2": [ f"{text_base}.layers.{i}.router.scale" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-wi_0-kernel": [ f"{text_base}.layers.{i}.mlp.gate_proj.weight" if num_experts == 1 else None for i in hf_indices ], f"{prefix}-mlp-wi_1-kernel": [ f"{text_base}.layers.{i}.mlp.up_proj.weight" if num_experts == 1 else None for i in hf_indices ], f"{prefix}-mlp-wo-kernel": [ f"{text_base}.layers.{i}.mlp.down_proj.weight" if num_experts == 1 else None for i in hf_indices ], f"{prefix}-mlp-moe_block-MoeBlock_0-gate-kernel": [ f"{text_base}.layers.{i}.router.proj.weight" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": [ f"{text_base}.layers.{i}.experts.gate_up_proj" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": [ f"{text_base}.layers.{i}.experts.gate_up_proj" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-moe_block-MoeBlock_0-wo": [ f"{text_base}.layers.{i}.experts.down_proj" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": [ f"{text_base}.layers.{i}.router.per_expert_scale" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-moe_block-shared_experts-wi_0-kernel": [ f"{text_base}.layers.{i}.mlp.gate_proj.weight" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-moe_block-shared_experts-wi_1-kernel": [ f"{text_base}.layers.{i}.mlp.up_proj.weight" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-mlp-moe_block-shared_experts-wo-kernel": [ f"{text_base}.layers.{i}.mlp.down_proj.weight" if num_experts > 1 else None for i in hf_indices ], f"{prefix}-layer_scalar": [f"{text_base}.layers.{i}.layer_scalar" for i in hf_indices], } ) mapping = { k: v for k, v in mapping.items() if (isinstance(v, list) and len(v) > 0 and v[0] is not None) or (not isinstance(v, list) and v is not None) } # Remainder layers if num_remaining > 0: for rem_idx in range(num_remaining): hf_layer_idx = num_scanned + rem_idx # Remaining layers use local attention type logic is_global = False # For gemma 4 it is unlikely the remainder is global but safe to determine layer_in_block = hf_layer_idx % 6 is_global = layer_in_block == 5 prefix = f"params-decoder-layers_remainder-layers_{rem_idx}" hf_prefix = f"{text_base}.layers.{hf_layer_idx}" mapping.update( { f"{prefix}-self_attention-query-kernel": (f"{hf_prefix}.self_attn.q_proj.weight"), f"{prefix}-self_attention-key-kernel": (f"{hf_prefix}.self_attn.k_proj.weight"), f"{prefix}-self_attention-value-kernel": ( f"{hf_prefix}.self_attn.k_proj.weight" if share_kv_projections and is_global else f"{hf_prefix}.self_attn.v_proj.weight" ), f"{prefix}-self_attention-out-kernel": (f"{hf_prefix}.self_attn.o_proj.weight"), f"{prefix}-self_attention-query_norm-scale": (f"{hf_prefix}.self_attn.q_norm.weight"), f"{prefix}-self_attention-key_norm-scale": (f"{hf_prefix}.self_attn.k_norm.weight"), } ) if maxtext_config.v_norm_with_scale: mapping.update({f"{prefix}-self_attention-value_norm-scale": (f"{hf_prefix}.self_attn.v_norm.weight")}) mapping.update( { f"{prefix}-pre_self_attention_norm-scale": (f"{hf_prefix}.input_layernorm.weight"), f"{prefix}-post_self_attention_norm-scale": (f"{hf_prefix}.post_attention_layernorm.weight"), f"{prefix}-pre_ffw_norm-scale": (f"{hf_prefix}.pre_feedforward_layernorm.weight"), f"{prefix}-post_ffw_norm-scale": (f"{hf_prefix}.post_feedforward_layernorm.weight"), f"{prefix}-mlp-pre_feedforward_layernorm_2-scale": ( f"{hf_prefix}.pre_feedforward_layernorm_2.weight" if num_experts > 1 else None ), f"{prefix}-mlp-post_feedforward_layernorm_1-scale": ( f"{hf_prefix}.post_feedforward_layernorm_1.weight" if num_experts > 1 else None ), f"{prefix}-mlp-post_feedforward_layernorm_2-scale": ( f"{hf_prefix}.post_feedforward_layernorm_2.weight" if num_experts > 1 else None ), f"{prefix}-mlp-pre_forward_scale_2": (f"{hf_prefix}.router.scale" if num_experts > 1 else None), f"{prefix}-mlp-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight" if num_experts == 1 else None, f"{prefix}-mlp-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight" if num_experts == 1 else None, f"{prefix}-mlp-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight" if num_experts == 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-gate-kernel": f"{hf_prefix}.router.proj.weight" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": f"{hf_prefix}.experts.gate_up_proj" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": f"{hf_prefix}.experts.gate_up_proj" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-wo": f"{hf_prefix}.experts.down_proj" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": f"{hf_prefix}.router.per_expert_scale" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-shared_experts-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-shared_experts-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-shared_experts-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight" if num_experts > 1 else None, f"{prefix}-layer_scalar": f"{hf_prefix}.layer_scalar", } ) mapping = {k: v for k, v in mapping.items() if v is not None} else: for i in range(nlayers): prefix = f"params-decoder-layers_{i}" hf_prefix = f"{text_base}.layers.{i}" is_global = i % 6 == 5 num_experts = tcfg.get("num_experts") num_experts = num_experts if num_experts is not None else 1 mapping.update( { f"{prefix}-self_attention-query-kernel": (f"{hf_prefix}.self_attn.q_proj.weight"), f"{prefix}-self_attention-key-kernel": (f"{hf_prefix}.self_attn.k_proj.weight"), f"{prefix}-self_attention-value-kernel": ( None if share_kv_projections and is_global else f"{hf_prefix}.self_attn.v_proj.weight" ), f"{prefix}-self_attention-out-kernel": (f"{hf_prefix}.self_attn.o_proj.weight"), f"{prefix}-self_attention-query_norm-scale": (f"{hf_prefix}.self_attn.q_norm.weight"), f"{prefix}-self_attention-key_norm-scale": (f"{hf_prefix}.self_attn.k_norm.weight"), } ) if maxtext_config.v_norm_with_scale: mapping.update({f"{prefix}-self_attention-value_norm-scale": (f"{hf_prefix}.self_attn.v_norm.weight")}) mapping.update( { f"{prefix}-pre_self_attention_norm-scale": (f"{hf_prefix}.input_layernorm.weight"), f"{prefix}-post_self_attention_norm-scale": (f"{hf_prefix}.post_attention_layernorm.weight"), f"{prefix}-pre_ffw_norm-scale": (f"{hf_prefix}.pre_feedforward_layernorm.weight"), f"{prefix}-post_ffw_norm-scale": (f"{hf_prefix}.post_feedforward_layernorm.weight"), f"{prefix}-mlp-pre_feedforward_layernorm_2-scale": ( f"{hf_prefix}.pre_feedforward_layernorm_2.weight" if num_experts > 1 else None ), f"{prefix}-mlp-post_feedforward_layernorm_1-scale": ( f"{hf_prefix}.post_feedforward_layernorm_1.weight" if num_experts > 1 else None ), f"{prefix}-mlp-post_feedforward_layernorm_2-scale": ( f"{hf_prefix}.post_feedforward_layernorm_2.weight" if num_experts > 1 else None ), f"{prefix}-mlp-pre_forward_scale_2": (f"{hf_prefix}.router.scale" if num_experts > 1 else None), f"{prefix}-mlp-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight" if num_experts == 1 else None, f"{prefix}-mlp-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight" if num_experts == 1 else None, f"{prefix}-mlp-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight" if num_experts == 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-gate-kernel": f"{hf_prefix}.router.proj.weight" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0": f"{hf_prefix}.experts.gate_up_proj" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1": f"{hf_prefix}.experts.gate_up_proj" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-wo": f"{hf_prefix}.experts.down_proj" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-MoeBlock_0-per_expert_scale": f"{hf_prefix}.router.per_expert_scale" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-shared_experts-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-shared_experts-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight" if num_experts > 1 else None, f"{prefix}-mlp-moe_block-shared_experts-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight" if num_experts > 1 else None, f"{prefix}-layer_scalar": f"{hf_prefix}.layer_scalar", } ) mapping = {k: v for k, v in mapping.items() if v is not None} return mapping
[docs] def GEMMA4_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Creates parameter transformation functions for Gemma4.""" tcfg = config.get("text_config", config) nlayers = tcfg["num_hidden_layers"] share_kv_projections = maxtext_config.share_kv_projections hooks = {} def pad_hf_embedding_layer(input_tensor, target_shape): normalizer = np.dtype("float32").type(tcfg["hidden_size"] ** 0.5) if saving_to_hf: target_tensor = input_tensor[: target_shape[0], : target_shape[1]] target_tensor = target_tensor / normalizer return target_tensor.astype(input_tensor.dtype) else: target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype) target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor target_tensor = target_tensor * normalizer return target_tensor.astype(input_tensor.dtype) def reshape_kernel(input_tensor, target_shape): if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) def scale_rmsnorm_layer(input_tensor, target_shape): # Shift of 1.0 is now folded into Gemma 4 text and vision checkpoint weights return input_tensor.reshape(target_shape) def split_moe_wi0(input_tensor, target_shape): if saving_to_hf: raise NotImplementedError("Saving to HF for fused gate_up_proj requires custom concat hook.") # input_tensor: [E, 2*FF, H], target: [E, H, FF] _, two_FF, _ = input_tensor.shape FF = two_FF // 2 return input_tensor[:, :FF, :].transpose(0, 2, 1) def split_moe_wi1(input_tensor, target_shape): if saving_to_hf: raise NotImplementedError("Saving to HF for fused gate_up_proj requires custom concat hook.") _, two_FF, _ = input_tensor.shape FF = two_FF // 2 return input_tensor[:, FF:, :].transpose(0, 2, 1) def reshape_moe_wo(input_tensor, target_shape): # input_tensor: [E, H, FF], target: [E, FF, H] return input_tensor.transpose(0, 2, 1) def moe_gate_up_hook(weight_list, target_shape): # Inverse of split_moe_wi0/wi1: fuse MaxText wi_0, wi_1 → HF experts.gate_up_proj. # weight_list: [wi_0, wi_1], each [..., H, FF] # Returns: [..., 2*FF, H] wi_0 = jnp.asarray(weight_list[0]) wi_1 = jnp.asarray(weight_list[1]) return jnp.swapaxes(jnp.concatenate([wi_0, wi_1], axis=-1), -2, -1) hooks["params-token_embedder-embedding"] = pad_hf_embedding_layer hooks["params-decoder-decoder_norm-scale"] = scale_rmsnorm_layer # REMOVED: logits_dense-kernel hook (handled by logits_via_embedding: True) kernel_keys = [ "self_attention-query-kernel", "self_attention-key-kernel", "self_attention-value-kernel", "self_attention-out-kernel", "mlp-wi_0-kernel", "mlp-wi_1-kernel", "mlp-wo-kernel", "mlp-moe_block-shared_experts-wi_0-kernel", "mlp-moe_block-shared_experts-wi_1-kernel", "mlp-moe_block-shared_experts-wo-kernel", ] moe_kernel_keys = [ # `gate-kernel` (router) has shape [feature, num_experts] in MaxText, but [num_experts, feature] in HF "mlp-moe_block-MoeBlock_0-gate-kernel", ] norm_keys = [ "self_attention-query_norm-scale", "self_attention-key_norm-scale", ] if maxtext_config.v_norm_with_scale: norm_keys.append("self_attention-value_norm-scale") norm_keys.extend( [ "pre_self_attention_norm-scale", "post_self_attention_norm-scale", "pre_ffw_norm-scale", "post_ffw_norm-scale", ] ) num_experts = tcfg.get("num_experts") num_experts = num_experts if num_experts is not None else 1 if num_experts > 1: norm_keys.extend( [ "mlp-pre_feedforward_layernorm_2-scale", "mlp-post_feedforward_layernorm_1-scale", "mlp-post_feedforward_layernorm_2-scale", ] ) # Note: `pre_forward_scale_2`, `per_expert_scale`, and `layer_scalar` # are standard tensors being multiplied, not typical RMSNorms. They # do not need the `scale_rmsnorm_layer` hook, so leaving them out # of norm_keys means they perfectly default to the identity mapping. vcfg = config.get("vision_config", {}) if vcfg: nvis = vcfg.get("num_hidden_layers", 0) def reshape_vision_patch(x, target_shape): # HF and MaxText both use (H, W, C) patch flattening now. # Just transpose between out_features/in_features. return x.T def reshape_pos_emb(x, target_shape): return x.transpose(1, 0, 2) hooks["params-vision_encoder-Gemma4VisionEncoderLayer_0-vision_entry-input_projection-kernel"] = reshape_vision_patch hooks["params-vision_encoder-Gemma4VisionEncoderLayer_0-vision_entry-pos_emb_param"] = reshape_pos_emb hooks["params-vision_encoder-Gemma4VisionProjector_0-projection-kernel"] = reshape_kernel for i in range(nvis): prefix = f"params-vision_encoder-Gemma4VisionEncoderLayer_0-layer_{i}" hooks[f"{prefix}-attention-query-kernel"] = reshape_kernel hooks[f"{prefix}-attention-key-kernel"] = reshape_kernel hooks[f"{prefix}-attention-value-kernel"] = reshape_kernel hooks[f"{prefix}-attention-out-kernel"] = reshape_kernel hooks[f"{prefix}-attention-query_norm-scale"] = scale_rmsnorm_layer hooks[f"{prefix}-attention-key_norm-scale"] = scale_rmsnorm_layer hooks[f"{prefix}-pre_attention_norm-scale"] = scale_rmsnorm_layer hooks[f"{prefix}-post_attention_norm-scale"] = scale_rmsnorm_layer hooks[f"{prefix}-pre_ffw_norm-scale"] = scale_rmsnorm_layer hooks[f"{prefix}-post_ffw_norm-scale"] = scale_rmsnorm_layer hooks[f"{prefix}-mlp-wi_0-kernel"] = reshape_kernel hooks[f"{prefix}-mlp-wi_1-kernel"] = reshape_kernel hooks[f"{prefix}-mlp-wo-kernel"] = reshape_kernel if scan_layers: attention_pattern_length = 6 num_remaining = nlayers % attention_pattern_length # Scanned sub-layer prefixes for layer_in_block in range(attention_pattern_length): is_global = layer_in_block % 6 == 5 prefix = f"params-decoder-scanned_blocks-layers_{layer_in_block}" for key in kernel_keys: if share_kv_projections and is_global and key == "self_attention-value-kernel": continue hooks[f"{prefix}-{key}"] = reshape_kernel for key in moe_kernel_keys: hooks[f"{prefix}-{key}"] = reshape_kernel for key in norm_keys: hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer # Add these specialized 3D tensor hooks inside the loop if saving_to_hf and num_experts > 1: wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0" wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1" hooks[(wi0_key, wi1_key)] = moe_gate_up_hook else: hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0 hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1 hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo # Remainder sub-layer prefixes if num_remaining > 0: for rem_idx in range(num_remaining): prefix = f"params-decoder-layers_remainder-layers_{rem_idx}" real_layer_idx = attention_pattern_length * (nlayers // attention_pattern_length) + rem_idx is_global = real_layer_idx % 6 == 5 for key in kernel_keys: if share_kv_projections and is_global and key == "self_attention-value-kernel": continue hooks[f"{prefix}-{key}"] = reshape_kernel for key in moe_kernel_keys: hooks[f"{prefix}-{key}"] = reshape_kernel for key in norm_keys: hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer # Add these specialized 3D tensor hooks inside the loop if saving_to_hf and num_experts > 1: wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0" wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1" hooks[(wi0_key, wi1_key)] = moe_gate_up_hook else: hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0 hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1 hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo else: for i in range(nlayers): is_global = i % 6 == 5 prefix = f"params-decoder-layers_{i}" for key in kernel_keys: if share_kv_projections and is_global and key == "self_attention-value-kernel": continue hooks[f"{prefix}-{key}"] = reshape_kernel for key in moe_kernel_keys: hooks[f"{prefix}-{key}"] = reshape_kernel for key in norm_keys: hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer # Add these specialized 3D tensor hooks inside the loop if saving_to_hf and num_experts > 1: wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0" wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1" hooks[(wi0_key, wi1_key)] = moe_gate_up_hook else: hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0 hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1 hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo return hooks
[docs] def OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Returns mapping from MaxText to HuggingFace Olmo3 weight paths.""" # Olmo3 uses an inhomogeneous layer cycle (4 layers: 3 sliding, 1 global). # MaxText handles this by defining sub-layers (layers_0, layers_1...) within a block. n_layers = config["num_hidden_layers"] # Default Olmo3 cycle length is 4 if not specified in config layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval # Base mapping for embeddings and global norms mapping = { "params-token_embedder-embedding": "model.embed_tokens.weight", "params-decoder-decoder_norm-scale": "model.norm.weight", "params-decoder-logits_dense-kernel": "lm_head.weight", } if scan_layers: # Scanned: Map MaxText 'layers_k' to HF layers [k, k+cycle, k+2*cycle, ...] for cycle_idx in range(layer_cycle_interval): hf_indices = range(cycle_idx, n_layers, layer_cycle_interval) prefix = f"params-decoder-layers-layers_{cycle_idx}" mapping.update( { # Attention Projections f"{prefix}-attention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices], f"{prefix}-attention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices], f"{prefix}-attention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices], f"{prefix}-attention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices], # QK Norms (Olmo3 Specific) f"{prefix}-attention-query_norm-scale": [f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices], f"{prefix}-attention-key_norm-scale": [f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices], # MLP (wi_0=gate, wi_1=up, wo=down) f"{prefix}-mlp-wi_0-kernel": [f"model.layers.{i}.mlp.gate_proj.weight" for i in hf_indices], f"{prefix}-mlp-wi_1-kernel": [f"model.layers.{i}.mlp.up_proj.weight" for i in hf_indices], f"{prefix}-mlp-wo-kernel": [f"model.layers.{i}.mlp.down_proj.weight" for i in hf_indices], # Layer Norms f"{prefix}-post_self_attention_layer_norm-scale": [ f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices ], f"{prefix}-post_mlp_layer_norm-scale": [ f"model.layers.{i}.post_feedforward_layernorm.weight" for i in hf_indices ], } ) else: # Unscanned: Direct 1-to-1 mapping for i in range(n_layers): prefix = f"params-decoder-layers_{i}" hf_prefix = f"model.layers.{i}" mapping.update( { f"{prefix}-attention-query-kernel": f"{hf_prefix}.self_attn.q_proj.weight", f"{prefix}-attention-key-kernel": f"{hf_prefix}.self_attn.k_proj.weight", f"{prefix}-attention-value-kernel": f"{hf_prefix}.self_attn.v_proj.weight", f"{prefix}-attention-out-kernel": f"{hf_prefix}.self_attn.o_proj.weight", f"{prefix}-attention-query_norm-scale": f"{hf_prefix}.self_attn.q_norm.weight", f"{prefix}-attention-key_norm-scale": f"{hf_prefix}.self_attn.k_norm.weight", f"{prefix}-mlp-wi_0-kernel": f"{hf_prefix}.mlp.gate_proj.weight", f"{prefix}-mlp-wi_1-kernel": f"{hf_prefix}.mlp.up_proj.weight", f"{prefix}-mlp-wo-kernel": f"{hf_prefix}.mlp.down_proj.weight", f"{prefix}-post_self_attention_layer_norm-scale": f"{hf_prefix}.post_attention_layernorm.weight", f"{prefix}-post_mlp_layer_norm-scale": f"{hf_prefix}.post_feedforward_layernorm.weight", } ) return mapping
[docs] def OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Creates parameter transformation functions for Olmo3.""" # Standard Transpose for Kernels (HF: [Out, In] <-> MaxText: [In, Out]) def reshape_kernel(input_tensor, target_shape): if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T else: return input_tensor.T.reshape(target_shape) # Identity mapping for Norms # Olmo3 checkpoints typically have weights ~1.0. # If MaxText RMSNorm adds 1.0 (x * (1+w)), we might need w-1.0. # However, if weights are zeroed/mismatched, identity is safer to restore logic flow. def scale_rmsnorm_layer(input_tensor, target_shape): return input_tensor.reshape(target_shape) # Identity mapping for QK Norms (assuming MaxText attentions.py was patched to use global norm) def adapt_olmo3_qk_norm(input_tensor, target_shape): return input_tensor.reshape(target_shape) # Padding for Embedding layer (Vocab size adjustments) def pad_hf_embedding_layer(input_tensor, target_shape): source_vocab_size = input_tensor.shape[0] target_vocab_size = target_shape[0] if source_vocab_size == target_vocab_size: return input_tensor if saving_to_hf: return input_tensor[:target_vocab_size, :] else: padded_tensor = np.zeros(target_shape, dtype=input_tensor.dtype) padded_tensor[:source_vocab_size, :] = input_tensor return padded_tensor hooks = { "params-token_embedder-embedding": pad_hf_embedding_layer, "params-decoder-logits_dense-kernel": reshape_kernel, "params-decoder-decoder_norm-scale": scale_rmsnorm_layer, } kernel_keys = [ "attention-query-kernel", "attention-key-kernel", "attention-value-kernel", "attention-out-kernel", "mlp-wi_0-kernel", "mlp-wi_1-kernel", "mlp-wo-kernel", ] norm_keys = [ "post_self_attention_layer_norm-scale", "post_mlp_layer_norm-scale", "attention-query_norm-scale", "attention-key_norm-scale", ] cycle_len = getattr(maxtext_config, "inhomogeneous_layer_cycle_interval", 4) n_layers = config["num_hidden_layers"] if scan_layers: for cycle_idx in range(cycle_len): prefix = f"params-decoder-layers-layers_{cycle_idx}" for key in kernel_keys: hooks[f"{prefix}-{key}"] = reshape_kernel for key in norm_keys: # For QK norm, we use the specific adaptor (which is currently Identity # but separates logic if we need to revert to averaging later) if "attention-" in key and "_norm-scale" in key: hooks[f"{prefix}-{key}"] = adapt_olmo3_qk_norm else: hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer else: for i in range(n_layers): prefix = f"params-decoder-layers_{i}" for key in kernel_keys: hooks[f"{prefix}-{key}"] = reshape_kernel for key in norm_keys: if "attention-" in key and "_norm-scale" in key: hooks[f"{prefix}-{key}"] = adapt_olmo3_qk_norm else: hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer return hooks
# {maxtext model name: {maxtext weight name: hf weight name}} PARAM_MAPPING = { "gemma2-2b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma2-9b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma2-27b": GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma4-26b": GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma4-31b": GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "deepseek2-16b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING, "deepseek3.2-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING, "mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING, "mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING, "olmo3-7b": OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING, "olmo3-7b-pt": OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING, "olmo3-32b": OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING, } # {maxtext model name: {maxtext weight name: bi-directional transform}} HOOK_FNS = { "gemma2-2b": GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma2-9b": GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma2-27b": GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma4-26b": GEMMA4_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma4-31b": GEMMA4_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-1.7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-1.7b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-4b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-8b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "deepseek2-16b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN, "deepseek3.2-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN, "mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "olmo3-7b": OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "olmo3-7b-pt": OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "olmo3-32b": OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN, } VLLM_HOOK_FNS = { "qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN, "llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN, "deepseek3": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN, }