maxtext.integration.vllm.torchax_converter.gemma4_moe module#
Gemma4 MaxText to vLLM weight converter.
Supports gemma4-26b (MoE: 128 routed + 1 shared expert).
- MaxText Gemma4 stores layers in a scanned-block structure:
state[‘base’][‘decoder’][‘scanned_blocks’][‘layers_{slot}’]
where slot ∈ [0..5]. Slots 0–4 are local-sliding-window attention layers and slot 5 is a global attention layer. The ‘L’ dimension (axis 1 of each weight tensor) holds ‘num_reps = num_layers // 6’ repetitions of each slot. Final vLLM layer index = rep * 6 + slot.
Global attention (slot 5) uses a shared KV projection — ‘key’ serves as both K and V; there is no separate ‘value’ tensor.
Key names and tensor transformations are derived from the MaxText HF param mapping at src/maxtext/checkpoint_conversion/utils/param_mapping.py.
Attention: Gemma4 uses SEPARATE q/k/v proj weights (not fused QKV). MoE (26B): gate+up proj are fused into experts.gate_up_proj (E, 2*d_inner, d_model). Embedding: MaxText stores embedding * sqrt(d_model); divide out before writing to vLLM.
- class maxtext.integration.vllm.torchax_converter.gemma4_moe.Gemma4MaxTextToVLLMConverter(config, mesh)[source]#
Bases:
BaseMaxTextToVLLMConverterConverts MaxText Gemma4 weights to the layout expected by a vLLM Gemma4 model.
- NUM_SLOTS = 6#