Source code for maxtext.integration.tunix.weight_mapping.deepseek3

# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mapping MaxText Deepseek (MoE) weights to vLLM/tpu-inference keys."""

from dataclasses import dataclass


[docs] @dataclass class DEEPSEEK_VLLM_MAPPING: """Mapping MaxText Deepseek-V3 weights to Tunix/vLLM NNX keys."""
[docs] @staticmethod def to_hf_hook_fns(): def flatten_3d_to_2d(val): # Converts (Rank, Heads, HeadDim) -> (Rank, Heads * HeadDim) if val.ndim == 3: return val.reshape(val.shape[0], -1) return val return { # MaxText MLA weights are 3D (Rank, Heads, HeadDim). # tpu-inference expects 2D (Rank, Heads*HeadDim) before it splits them. "base.decoder.layers.self_attention.wq_b.kernel": flatten_3d_to_2d, "base.decoder.layers.self_attention.wkv_b.kernel": flatten_3d_to_2d, "base.decoder.layers.self_attention.out.kernel": flatten_3d_to_2d, }
[docs] @staticmethod def to_hf_transpose_keys(): """Returns a list of keys for weights that need to be transposed. Returns: An empty dictionary, as no keys require transposition for this mapping. """ return {}
[docs] @staticmethod def lora_to_hf_mappings(): """Provides the mapping for LoRA (Low-Rank Adaptation) weights. Returns: None, as LoRA mappings are not defined for this model. """ return None
[docs] @staticmethod def to_hf_mapping(): """Returns the weight mapping for the model.""" mapping = { # --- Base Model Params --- # Map to HF names to be safe with loader regexes "base.token_embedder.embedding": ("model.embed_tokens.weight", ("model", None)), "base.decoder.decoder_norm.scale": ("model.norm.weight", (None,)), "base.decoder.logits_dense.kernel": ("lm_head.weight", (None, "model")), # MLA LAYERS (Map to HF Keys to trigger loader splitting logic) # Norms "base.decoder.layers.pre_self_attention_layer_norm.scale": ( "model.layers.*.input_layernorm.weight", (None, "layer"), ), "base.decoder.layers.post_self_attention_layer_norm.scale": ( "model.layers.*.post_attention_layernorm.weight", (None, "layer"), ), # MLA Norms "base.decoder.layers.self_attention.kv_norm.scale": ( "model.layers.*.self_attn.kv_a_layernorm.weight", (None, "layer"), ), "base.decoder.layers.self_attention.q_norm.scale": ( "model.layers.*.self_attn.q_a_layernorm.weight", (None, "layer"), ), # MLA Projections # We use HF names here so `DeepSeekV3WeightLoader` detects "kv_b_proj" # and performs the necessary split into k_b and v_b for the MLA kernel. "base.decoder.layers.self_attention.wq_a.kernel": ( "model.layers.*.self_attn.q_a_proj.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.wq_b.kernel": ( "model.layers.*.self_attn.q_b_proj.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.wkv_a.kernel": ( "model.layers.*.self_attn.kv_a_proj_with_mqa.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.wkv_b.kernel": ( "model.layers.*.self_attn.kv_b_proj.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.out.kernel": ( "model.layers.*.self_attn.o_proj.weight", ("model", "layer", None, None), ), # DENSE MLP LAYERS (Map to vllm keys for safety/consistency) "base.decoder.layers.mlp.wi_0.kernel": ("model.layers.*.mlp.gate_proj.weight", (None, "layer", "model")), "base.decoder.layers.mlp.wi_1.kernel": ("model.layers.*.mlp.up_proj.weight", (None, "layer", "model")), "base.decoder.layers.mlp.wo.kernel": ("model.layers.*.mlp.down_proj.weight", ("model", "layer", None)), # MOE LAYERS (Map to INTERNAL keys to bypass loader stacking) # Since MaxText experts are already fused/stacked, we map directly to the # internal `tpu-inference` param names. The loader will fail to find # "experts.{i}" in the name and fall back to loading these directly, # which is exactly what we want for performance. # Shared Experts "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel": ( "layers.*.shared_experts.kernel_gating_DF", (None, "layer", "model"), ), "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel": ( "layers.*.shared_experts.kernel_up_proj_DF", (None, "layer", "model"), ), "base.decoder.layers.DeepSeekMoeBlock_0.shared_experts.wo.kernel": ( "layers.*.shared_experts.kernel_down_proj_FD", ("model", "layer", None), ), # Router "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel": ( "layers.*.custom_module.router.kernel_DE", (None, "layer", "model"), ), "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias": ( "layers.*.custom_module.router.bias_E", (None, "layer", "model"), ), # Routed Experts (Fused) "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_0": ( "layers.*.custom_module.kernel_gating_EDF", ("expert", "layer", None, "model"), ), "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wi_1": ( "layers.*.custom_module.kernel_up_proj_EDF", ("expert", "layer", None, "model"), ), "base.decoder.layers.DeepSeekMoeBlock_0.MoeBlock_0.wo": ( "layers.*.custom_module.kernel_down_proj_EFD", ("expert", "layer", "model", None), ), # MTP BLOCK (Included for completeness, but typically skipped by current loader) "base.mtp_block.mtp_layer_1.embedding_norm.scale": ("mtp_block.layer.pre_norm.scale", (None,)), "base.mtp_block.mtp_layer_1.hidden_state_norm.scale": ("mtp_block.layer.post_norm.scale", (None,)), "base.mtp_block.mtp_layer_1.projection_layer.kernel": ("mtp_block.layer.projection.kernel", (None, "model")), } return mapping