# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for Tunix integration."""
import re
import maxtext.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
from maxtext.checkpoint_conversion.utils.param_mapping import PARAM_MAPPING
from maxtext.checkpoint_conversion.utils.param_mapping import VLLM_HOOK_FNS
STANDALONE_VLLM_WEIGHT_MAPPING = weight_mapping.StandaloneVllmWeightMapping()
# This static map provides the architectural knowledge (sharding) that is
# not present in the original HF mapping.
# Keys are the "generalized" MaxText names (e.g., base.decoder.layers...).
_SHARDING_KNOWLEDGE_MAP = {
# Non-layer parameters
"base.token_embedder.embedding": ("model", None),
"base.decoder.decoder_norm.scale": (None,),
"base.decoder.logits_dense.kernel": (None, "model"),
# --- Attention (generic for scanned/unscanned) ---
"base.decoder.layers.pre_self_attention_layer_norm.scale": (None, "layer"),
"base.decoder.layers.self_attention.query.kernel": (
None,
"layer",
"model",
None,
),
"base.decoder.layers.self_attention.key.kernel": (
None,
"layer",
"model",
None,
),
"base.decoder.layers.self_attention.value.kernel": (
None,
"layer",
"model",
None,
),
"base.decoder.layers.self_attention.query_norm.scale": (None, "layer"),
"base.decoder.layers.self_attention.key_norm.scale": (None, "layer"),
"base.decoder.layers.self_attention.out.kernel": (
"model",
"layer",
None,
None,
),
"base.decoder.layers.post_self_attention_layer_norm.scale": (None, "layer"),
# --- Dense MLP (generic for scanned/unscanned) ---
"base.decoder.layers.mlp.wi_0.kernel": (None, "layer", "model"),
"base.decoder.layers.mlp.wi_1.kernel": (None, "layer", "model"),
"base.decoder.layers.mlp.wo.kernel": ("model", "layer", None),
# --- MoE (generic for scanned/unscanned) ---
"base.decoder.layers.moe_block.gate.kernel": (None, "layer", "model"),
"base.decoder.layers.moe_block.wi_0": ("expert", "layer", None, "model"),
"base.decoder.layers.moe_block.wi_1": ("expert", "layer", None, "model"),
"base.decoder.layers.moe_block.wo": ("expert", "layer", "model", None),
# --- Deepseek Attention ---
"base.decoder.layers.self_attention.wq_a.kernel": (
None,
"layer",
"model",
None,
),
"base.decoder.layers.self_attention.wq_b.kernel": (
None,
"layer",
"model",
None,
),
"base.decoder.layers.self_attention.q_norm.scale": (None, "layer"),
"base.decoder.layers.self_attention.wkv_a.kernel": (
None,
"layer",
"model",
None,
),
"base.decoder.layers.self_attention.wkv_b.kernel": (
None,
"layer",
"model",
None,
),
"base.decoder.layers.self_attention.kv_norm.scale": (None, "layer"),
# --- Deepseek MoE ---
"base.decoder.layers.moe_block.shared_experts.wi_0.kernel": (
None,
"layer",
"model",
),
"base.decoder.layers.moe_block.shared_experts.wi_1.kernel": (
None,
"layer",
"model",
),
"base.decoder.layers.moe_block.shared_experts.wo.kernel": (
"model",
"layer",
None,
),
"base.decoder.layers.moe_block.gate.bias": (None, "layer", "model"),
}
[docs]
class VllmWeightMapping:
"""Mapping MaxText model weights to vLLM's model weights."""
def __init__(self, model_name, config=None, use_standalone_mappings=False):
self.model_name = model_name
self.config = config
self.use_standalone_mappings = use_standalone_mappings
self._sharding_knowledge_map = _SHARDING_KNOWLEDGE_MAP
[docs]
def to_hf_mapping(self):
"""Returns a mapping from MaxText parameter names to HuggingFace parameter names."""
if self.use_standalone_mappings:
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping()
config = self.config
mapping = self.convert_hf_map_to_sharding_map(
PARAM_MAPPING[self.model_name](config, maxtext_config=None, scan_layers=True)
)
return mapping
[docs]
def to_hf_transpose_keys(self):
if self.use_standalone_mappings:
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_transpose_keys()
return {}
[docs]
def to_hf_hook_fns(self):
"""Returns a mapping from MaxText parameter names to transformation functions."""
if self.use_standalone_mappings:
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_hook_fns()
model_family = self.model_name.split("-")[0]
if model_family in VLLM_HOOK_FNS:
return VLLM_HOOK_FNS[model_family]()
else:
return {}
[docs]
def lora_to_hf_mappings(self):
if self.use_standalone_mappings:
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].lora_to_hf_mappings()
return None
def _generalize_maxtext_key(self, maxtext_key):
"""Generalizes the MaxText key to a common vLLM format."""
# 'params-decoder-layers_0-mlp-...' -> 'base.decoder.layers_0.mlp....'
generic_key = maxtext_key.replace("params-", "base.").replace("-", ".")
# 'base.decoder.dense_layers.mlp....' -> 'base.decoder.layers.mlp....'
generic_key = re.sub(r"\.dense_layers\.", ".layers.", generic_key)
# 'base.decoder.moe_layers.mlp....' -> 'base.decoder.layers.mlp....'
generic_key = re.sub(r"\.moe_layers\.", ".layers.", generic_key)
# '...layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0' -> '...layers.moe_block.wi_0'
generic_key = re.sub(r"DeepSeekMoeBlock_0\.MoeBlock_0\.", "moe_block.", generic_key)
# Handle shared experts
generic_key = re.sub(
r"DeepSeekMoeBlock_0\.shared_experts\.",
"moe_block.shared_experts.",
generic_key,
)
# Keep original rule for other models
generic_key = re.sub(r"layers_(\d+)\.", "layers.", generic_key)
return generic_key
def _generalize_hf_value(self, hf_value):
"""Extracts and generalizes the Hugging Face name from the hf_value."""
first_name = ""
if isinstance(hf_value, str):
first_name = hf_value
elif isinstance(hf_value, list):
if not hf_value:
return None
if isinstance(hf_value[0], list):
first_name = hf_value[0][0] # Scanned MoE
else:
first_name = hf_value[0] # Scanned Dense / Unscanned MoE
else:
raise TypeError(f"Unknown value type in map: {type(hf_value)}")
# Replace layer and expert indices with wildcards
wildcard_name = re.sub(r"layers\.(\d+)\.", "layers.*.", first_name)
wildcard_name = re.sub(r"experts\.(\d+)\.", "experts.*.", wildcard_name)
return wildcard_name
def _correct_hf_wildcard_name(self, wildcard_name):
"""Corrects the generated Hugging Face wildcard name."""
corrected_name = wildcard_name
if "layernorm.weight" in corrected_name or "_norm.weight" in corrected_name:
# Fix all layer norms
corrected_name = corrected_name.replace(".weight", ".scale")
elif corrected_name == "model.embed_tokens.weight":
corrected_name = "model.embed.embedding"
elif corrected_name == "lm_head.weight":
corrected_name = "model.lm_head"
elif corrected_name == "model.norm.weight":
corrected_name = "model.norm.scale"
elif corrected_name.endswith(".weight"):
# Fix all other weights (MLP, Attn)
corrected_name = corrected_name.replace(".weight", ".kernel")
return corrected_name
[docs]
def convert_hf_map_to_sharding_map(self, hf_mapping):
"""Converts a MaxText-to-HF name map into a generic MaxText-to-vLLM sharding map.
Args:
hf_mapping (dict): The output from QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING.
- Keys are MaxText param names (e.g., "params-decoder-layers...").
- Values are HF param names (str) or lists of names (list).
Returns:
dict: A mapping from generalized MaxText names (e.g.,
"base.decoder.layers.mlp.wi_0.kernel") to a tuple containing:
(str: generalized HF/vLLM name, tuple: sharding specification).
"""
sharding_map = {}
for maxtext_key, hf_value in hf_mapping.items():
# 1. Generalize the MaxText key
generic_key = self._generalize_maxtext_key(maxtext_key)
# 2. Generalize the Hugging Face (HF) value name
wildcard_name = self._generalize_hf_value(hf_value)
if wildcard_name is None:
continue
# 3. Correct the generated wildcard name
corrected_name = self._correct_hf_wildcard_name(wildcard_name)
# 4. Look up the sharding tuple
sharding_tuple = self._sharding_knowledge_map.get(generic_key)
if sharding_tuple is None:
# This warning is fine if it's for unscanned layers,
# as we only want the generic "base.decoder.layers.*" key
if "layers." not in generic_key:
print(f"Warning: No sharding rule found for key: {generic_key}")
continue
# 5. Assemble the final map entry
sharding_map[generic_key] = (corrected_name, sharding_tuple)
return sharding_map