maxtext.integration.tunix.weight_mapping.gpt_oss module#

Mapping MaxText GPT-OSS (MoE) weights to vLLM/tpu-inference keys.

class maxtext.integration.tunix.weight_mapping.gpt_oss.GPT_OSS_VLLM_MAPPING[source]#

Bases: object

Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX. Supports: - Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4…)

static lora_to_hf_mappings()[source]#

Provides the mapping for LoRA (Low-Rank Adaptation) weights. :returns: None, as LoRA mappings are not defined for this model.

static to_hf_hook_fns()[source]#

Returns hook functions to fuse interleaved weights.

static to_hf_transpose_keys()[source]#

Returns keys that need to be transposed.

static to_hf_mapping(layer_cycle_interval=2, total_num_layers=36, interleave_style='modulo')[source]#

Returns the weight mapping for the model. :param layer_cycle_interval: The interval at which layers are cycled. :param total_num_layers: The total number of layers in the model. :param interleave_style: The style of interleaving used for the layers.

Returns:

A dictionary mapping MaxText parameter names to vLLM parameter names.

Parameters:
  • layer_cycle_interval (int)

  • total_num_layers (int)

  • interleave_style (str)

Return type:

Dict[str, Tuple[str, Tuple[str | None, …]]]