Source code for maxtext.integration.tunix.weight_mapping.gpt_oss

# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mapping MaxText GPT-OSS (MoE) weights to vLLM/tpu-inference keys."""

from dataclasses import dataclass
from typing import Dict, Optional, Tuple


[docs] @dataclass class GPT_OSS_VLLM_MAPPING: """ Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX. Supports: - Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...) """
[docs] @staticmethod def lora_to_hf_mappings(): """Provides the mapping for LoRA (Low-Rank Adaptation) weights. Returns: None, as LoRA mappings are not defined for this model. """ return None
[docs] @staticmethod def to_hf_hook_fns(): """Returns hook functions to fuse interleaved weights.""" return {}
[docs] @staticmethod def to_hf_transpose_keys(): """Returns keys that need to be transposed.""" return {}
[docs] @staticmethod def to_hf_mapping( layer_cycle_interval: int = 2, total_num_layers: int = 36, interleave_style: str = "modulo" ) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]: """Returns the weight mapping for the model. Args: layer_cycle_interval: The interval at which layers are cycled. total_num_layers: The total number of layers in the model. interleave_style: The style of interleaving used for the layers. Returns: A dictionary mapping MaxText parameter names to vLLM parameter names. """ mapping = {} # --- 1. Global Parameters --- mapping.update( { "base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)), "base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)), "base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")), } ) # --- 2. Layer Mapping Loop --- layers_per_block = total_num_layers // layer_cycle_interval for block_idx in range(layer_cycle_interval): src_block = f"base.decoder.layers.layers_{block_idx}" if interleave_style == "modulo": target_indices = range(block_idx, total_num_layers, layer_cycle_interval) else: start = block_idx * layers_per_block target_indices = range(start, start + layers_per_block) regex_indices = "|".join(map(str, target_indices)) layer_regex = rf"layers\.({regex_indices})" # --- 3. Block Mappings (Standard) --- mapping.update( { f"{src_block}.pre_self_attention_layer_norm.scale": ( f"{layer_regex}.pre_attention_norm.scale", (None, "layer"), ), f"{src_block}.post_self_attention_layer_norm.scale": (f"{layer_regex}.pre_mlp_norm.scale", (None, "layer")), f"{src_block}.GptOssAttention.query.kernel": ( f"{layer_regex}.attn.kernel_q_DNH", (None, "layer", "model", None), ), f"{src_block}.GptOssAttention.key.kernel": ( f"{layer_regex}.attn.kernel_k_DKH", (None, "layer", "model", None), ), f"{src_block}.GptOssAttention.value.kernel": ( f"{layer_regex}.attn.kernel_v_DKH", (None, "layer", "model", None), ), f"{src_block}.GptOssAttention.out.kernel": ( f"{layer_regex}.attn.kernel_o_proj_NHD", ("model", "layer", None, None), ), f"{src_block}.GptOssAttention.query.bias": (f"{layer_regex}.attn.bias_q_NH", (None, "layer", None)), f"{src_block}.GptOssAttention.key.bias": (f"{layer_regex}.attn.bias_k_KH", (None, "layer", None)), f"{src_block}.GptOssAttention.value.bias": (f"{layer_regex}.attn.bias_v_KH", (None, "layer", None)), f"{src_block}.GptOssAttention.out.bias": (f"{layer_regex}.attn.bias_o_D", (None, "layer")), f"{src_block}.GptOssAttention.sinks": (f"{layer_regex}.attn.sinks_N", (None, "layer")), } ) # MoE Router mapping.update( { f"{src_block}.GptOssMlp.gate.kernel": ( f"{layer_regex}.custom_module.router.kernel_DE", (None, "layer", "model"), ), f"{src_block}.GptOssMlp.gate.bias": (f"{layer_regex}.custom_module.router.bias_E", ("model", "layer")), } ) # --- MOE EXPERTS --- # Separate gate_proj (wi_0) and up_proj (wi_1) kernels and biases. # MLP Gate Projection (wi_0) mapping.update( { f"{src_block}.GptOssMlp.wi_0": (f"{layer_regex}.custom_module.gate_proj_kernel", ("model", "layer", None)), f"{src_block}.GptOssMlp.wi_0_bias": (f"{layer_regex}.custom_module.gate_proj_bias", ("model", "layer")), } ) # MLP Up Projection (wi_1) mapping.update( { f"{src_block}.GptOssMlp.wi_1": (f"{layer_regex}.custom_module.up_proj_kernel", ("model", "layer", None)), f"{src_block}.GptOssMlp.wi_1_bias": (f"{layer_regex}.custom_module.up_proj_bias", ("model", "layer")), } ) # MLP Down Projection (wo) mapping.update( { f"{src_block}.GptOssMlp.wo": (f"{layer_regex}.custom_module.mlp2_weight_EFD", ("model", "layer", None)), f"{src_block}.GptOssMlp.wo_bias": (f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer")), } ) return mapping