Source code for maxtext.integration.vllm.torchax_converter.gemma4_moe

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Gemma4 MaxText to vLLM weight converter.

Supports gemma4-26b (MoE: 128 routed + 1 shared expert).

MaxText Gemma4 stores layers in a scanned-block structure:
  state['base']['decoder']['scanned_blocks']['layers_{slot}']
where slot ∈ [0..5].  Slots 0–4 are local-sliding-window attention layers
and slot 5 is a global attention layer.  The 'L' dimension (axis 1 of each
weight tensor) holds 'num_reps = num_layers // 6' repetitions of each slot.
Final vLLM layer index = rep * 6 + slot.

Global attention (slot 5) uses a shared KV projection — 'key' serves as
both K and V; there is no separate 'value' tensor.

Key names and tensor transformations are derived from the MaxText HF param mapping
at src/maxtext/checkpoint_conversion/utils/param_mapping.py.

Attention: Gemma4 uses SEPARATE q/k/v proj weights (not fused QKV).
MoE (26B): gate+up proj are fused into experts.gate_up_proj (E, 2*d_inner, d_model).
Embedding: MaxText stores embedding * sqrt(d_model); divide out before writing to vLLM.
"""

import gc
import logging

import jax
import jax.numpy as jnp
from tpu_inference.layers.common.moe import MoEBackend
from tpu_inference.layers.common.process_weights.moe_weights import FusedMoEWeights
from tpu_inference.layers.common.process_weights.moe_weights import process_moe_weights

from maxtext.integration.vllm.torchax_converter.base import BaseMaxTextToVLLMConverter
from maxtext.integration.vllm.torchax_converter.base import timer
from maxtext.integration.vllm.torchax_converter.base import GREEN
from maxtext.integration.vllm.torchax_converter.base import RESET


[docs] class Gemma4MaxTextToVLLMConverter(BaseMaxTextToVLLMConverter): """Converts MaxText Gemma4 weights to the layout expected by a vLLM Gemma4 model.""" NUM_SLOTS = 6 # 5 local + 1 global def __init__(self, config, mesh): super().__init__(config, mesh) assert self.num_layers % self.NUM_SLOTS == 0, f"num_layers {self.num_layers} must be divisible by {self.NUM_SLOTS}" self.num_reps = self.num_layers // self.NUM_SLOTS self.is_moe = config.model_name == "gemma4-26b" self.d_model = config.base_emb_dim # --- 1. Top-Level Entry Point ---
[docs] def convert(self, model_state: dict): """Convert a MaxText Gemma4 model state into vLLM weight tensors.""" logging.info( "\n%sStarting Gemma4 Conversion (is_moe=%s, num_layers=%d, num_reps=%d)...%s", GREEN, self.is_moe, self.num_layers, self.num_reps, RESET, ) self.vllm_state = {} blocks = model_state["base"]["decoder"]["scanned_blocks"] prefix = "vllm_model.language_model.model.layers" with timer("Convert Global Weights"): self._convert_global(model_state) with timer("Convert Layer Norms"): self._convert_norms(blocks, prefix) with timer("Convert Attention Weights"): self._convert_attn_weights(blocks, prefix) if self.is_moe: with timer("Convert MoE Weights"): self._convert_moe_weights(blocks, prefix) else: with timer("Convert Dense MLP Weights"): self._convert_dense_mlp_weights(blocks, prefix) return self.vllm_state
# --- Abstract method implementations (delegate to Gemma4-specific methods) --- def _convert_global(self, params): """Convert non-layered weights (embed_tokens, lm_head, final norm).""" # Gemma4 uses tied embeddings: no logits_dense; lm_head.weight = embed_tokens.weight. # MaxText stores embedding pre-multiplied by sqrt(hidden_size) (applied during HF->MaxText # conversion in param_mapping.py). vLLM/tpu-inference apply sqrt(hidden_size) at runtime, # so divide out the pre-multiplied factor to give vLLM the raw embedding. logging.info("_convert_global: embed_tokens (de-normalize) + lm_head (tied) + final_norm...") normalizer = self.d_model**0.5 @jax.jit def _denorm_embed(x): return (x / normalizer).astype(x.dtype) raw_embedding = _denorm_embed(params["base"]["token_embedder"]["embedding"]) self.vllm_state["vllm_model.language_model.model.embed_tokens.weight"] = raw_embedding self.vllm_state["vllm_model.language_model.lm_head.weight"] = raw_embedding # tied self.vllm_state["vllm_model.language_model.model.norm.weight"] = params["base"]["decoder"]["decoder_norm"]["scale"] logging.info("_convert_global: done") def _convert_attn(self, params): """Satisfy abstract interface; Gemma4 uses _convert_attn_weights instead.""" blocks = params["base"]["decoder"]["scanned_blocks"] prefix = "vllm_model.language_model.model.layers" self._convert_attn_weights(blocks, prefix) def _convert_moe(self, params): """Satisfy abstract interface; Gemma4 uses _convert_moe_weights/_convert_dense_mlp_weights.""" blocks = params["base"]["decoder"]["scanned_blocks"] prefix = "vllm_model.language_model.model.layers" if self.is_moe: self._convert_moe_weights(blocks, prefix) else: self._convert_dense_mlp_weights(blocks, prefix) # --- 2. Static JIT helper --- @staticmethod @jax.jit def _pack_attn(q, k, v, o, qnorm, knorm): """Prepares separate q/k/v, o, and norms for all layers in a slot. Input shapes (MaxText scanned, scan axis at index 1): q/k/v: (d_model, L, nH, D) o: (nH, L, D, d_model) # scan axis is 1 norms: (d_model, L) Returns: L × (nH*D, d_model) for q/k/v, L × (d_model, nH*D) for o. """ # q/k/v: (d_model, L, nH, D) -> (L, nH, D, d_model) -> (L, nH*D, d_model) q = jnp.transpose(q, (1, 2, 3, 0)).reshape(q.shape[1], -1, q.shape[0]) k = jnp.transpose(k, (1, 2, 3, 0)).reshape(k.shape[1], -1, k.shape[0]) v = jnp.transpose(v, (1, 2, 3, 0)).reshape(v.shape[1], -1, v.shape[0]) # o: (nH, L, D, d_model) -> (L, d_model, nH, D) -> (L, d_model, nH*D) o = jnp.transpose(o, (1, 3, 0, 2)).reshape(o.shape[1], o.shape[3], -1) # norms: (D, L) -> (L, D) qnorm = jnp.transpose(qnorm, (1, 0)) knorm = jnp.transpose(knorm, (1, 0)) return ( jnp.unstack(q), jnp.unstack(k), jnp.unstack(v), jnp.unstack(o), jnp.unstack(qnorm), jnp.unstack(knorm), ) # --- 3. Per-layer norms --- def _convert_norms(self, blocks, prefix): """Converts all 4 per-layer norm vectors across all layers.""" @jax.jit def _unstack_norm(x): # x: (d_model, L) -> L tensors of (d_model,) return jnp.unstack(x, axis=1) for slot in range(self.NUM_SLOTS): slot_data = blocks[f"layers_{slot}"] pre_attn = _unstack_norm(slot_data["pre_self_attention_norm"]["scale"]) post_attn = _unstack_norm(slot_data["post_self_attention_norm"]["scale"]) pre_ffw = _unstack_norm(slot_data["pre_ffw_norm"]["scale"]) post_ffw = _unstack_norm(slot_data["post_ffw_norm"]["scale"]) for rep in range(self.num_reps): i = rep * self.NUM_SLOTS + slot self.vllm_state[f"{prefix}.{i}.input_layernorm.weight"] = pre_attn[rep] self.vllm_state[f"{prefix}.{i}.post_attention_layernorm.weight"] = post_attn[rep] self.vllm_state[f"{prefix}.{i}.pre_feedforward_layernorm.weight"] = pre_ffw[rep] self.vllm_state[f"{prefix}.{i}.post_feedforward_layernorm.weight"] = post_ffw[rep] del pre_attn, post_attn, pre_ffw, post_ffw gc.collect() # --- 4. Per-layer attention weights --- def _convert_attn_weights(self, blocks, prefix): """Converts separate q/k/v proj, o proj, q-norm, k-norm for all layers. HF/vLLM Gemma4 uses separate projections (not fused QKV). Global attention layers (slot 5) have no 'value' tensor; vLLM sets v_proj = k_proj. Tensor transformations (MaxText → HF): q/k/v kernel: (d_model, nH, D) → (nH*D, d_model) [reshape then transpose] out kernel: (nH, D, d_model) → (d_model, nH*D) [reshape then transpose] norms: (D,) → (D,) [identity] """ @jax.jit def _pack_local(attn): q = attn["query"]["kernel"] k = attn["key"]["kernel"] v = attn["value"]["kernel"] return Gemma4MaxTextToVLLMConverter._pack_attn( q, k, v, attn["out"]["kernel"], attn["query_norm"]["scale"], attn["key_norm"]["scale"], ) @jax.jit def _pack_global(attn): # Global: no 'value'; key used as both K and V (shared KV projection). q = attn["query"]["kernel"] k = attn["key"]["kernel"] return Gemma4MaxTextToVLLMConverter._pack_attn( q, k, k, attn["out"]["kernel"], attn["query_norm"]["scale"], attn["key_norm"]["scale"], ) for slot in range(self.NUM_SLOTS): is_global = slot == self.NUM_SLOTS - 1 attn = blocks[f"layers_{slot}"]["self_attention"] pack_fn = _pack_global if is_global else _pack_local q_layers, k_layers, v_layers, o_layers, qnorm_layers, knorm_layers = pack_fn(attn) num_kv_heads = self.config.global_num_kv_heads if is_global else self.config.base_num_kv_heads tp = min(self.vllm_tp, num_kv_heads) for rep in range(self.num_reps): i = rep * self.NUM_SLOTS + slot q, k, v = q_layers[rep], k_layers[rep], v_layers[rep] # QKVParallelLinear (vLLM) expects TP-interleaved layout: # [q_tp0, k_tp0, v_tp0, q_tp1, k_tp1, v_tp1, ...] q_per_tp = q.shape[0] // tp kv_per_tp = k.shape[0] // tp qkv = jnp.concatenate( [ q.reshape(tp, q_per_tp, q.shape[1]), k.reshape(tp, kv_per_tp, k.shape[1]), v.reshape(tp, kv_per_tp, v.shape[1]), ], axis=1, ).reshape(-1, q.shape[1]) self.vllm_state[f"{prefix}.{i}.self_attn.qkv_proj.weight"] = qkv self.vllm_state[f"{prefix}.{i}.self_attn.o_proj.weight"] = o_layers[rep] self.vllm_state[f"{prefix}.{i}.self_attn.q_norm.weight"] = qnorm_layers[rep] self.vllm_state[f"{prefix}.{i}.self_attn.k_norm.weight"] = knorm_layers[rep] del q_layers, k_layers, v_layers, o_layers, qnorm_layers, knorm_layers gc.collect() # --- 5a. MoE weights (gemma4-26b only) --- def _convert_moe_weights(self, blocks, prefix): """Converts router, routed experts (fused gate_up_proj), shared expert, MoE norms (26B). Tensor transformations: router.proj.weight: gate.kernel (d_model, L, E) → (E, d_model) router.scale: pre_forward_scale_2 (d_model, L) → (d_model,) router.per_expert_scale: per_expert_scale (E, L) → (E,) experts.gate_up_proj: fuse wi_0+wi_1 (E, L, d_model, d_inner) → (E, 2*d_inner, d_model) experts.down_proj: wo (E, L, d_inner, d_model) → (E, d_model, d_inner) shared mlp.*: (d_model, L, d_sh) or (d_sh, L, d_model) → HF convention extra norms: (d_model, L) → (d_model,) """ def _pack_moe(routed, shared, extra): # Router proj: (d_model, L, E) -> L × (E, d_model) router_proj = jnp.unstack(jnp.transpose(routed["gate"]["kernel"], (1, 2, 0)), axis=0) # Router scale: (d_model, L) -> L × (d_model,) router_scale = jnp.unstack(extra["pre_forward_scale_2"], axis=1) # Per-expert scale: (E, L) -> L × (E,) per_expert_scale = jnp.unstack(routed["per_expert_scale"], axis=1) # Fused gate+up proj for routed experts (HF format): # wi_0 (gate): (E, L, d_model, d_inner) -> (L, E, d_inner, d_model) # wi_1 (up): (E, L, d_model, d_inner) -> (L, E, d_inner, d_model) # concat along axis 2: (L, E, 2*d_inner, d_model) = gate_up_proj w0 = jnp.transpose(routed["wi_0"], (1, 0, 3, 2)) # (L, E, d_inner, d_model) w1 = jnp.transpose(routed["wi_1"], (1, 0, 3, 2)) # (L, E, d_inner, d_model) gate_up = jnp.concatenate([w0, w1], axis=2) # (L, E, 2*d_inner, d_model) gate_up_proj = jnp.unstack(gate_up, axis=0) # Down proj: (E, L, d_inner, d_model) -> L × (E, d_model, d_inner) down_proj = jnp.unstack(jnp.transpose(routed["wo"], (1, 0, 3, 2)), axis=0) # Shared expert: # wi_0/wi_1: (d_model, L, d_sh) -> L × (d_sh, d_model) # wo: (d_sh, L, d_model) -> L × (d_model, d_sh) sh_gate = jnp.unstack(jnp.transpose(shared["wi_0"]["kernel"], (1, 2, 0)), axis=0) sh_up = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0) sh_down = jnp.unstack(jnp.transpose(shared["wo"]["kernel"], (1, 2, 0)), axis=0) # Extra MoE norms: (d_model, L) -> L × (d_model,) pre_ln_2 = jnp.unstack(extra["pre_feedforward_layernorm_2"]["scale"], axis=1) post_ln_1 = jnp.unstack(extra["post_feedforward_layernorm_1"]["scale"], axis=1) post_ln_2 = jnp.unstack(extra["post_feedforward_layernorm_2"]["scale"], axis=1) return ( router_proj, router_scale, per_expert_scale, gate_up_proj, down_proj, sh_gate, sh_up, sh_down, pre_ln_2, post_ln_1, post_ln_2, ) for slot in range(self.NUM_SLOTS): moe_block = blocks[f"layers_{slot}"]["mlp"]["moe_block"] routed = moe_block["MoeBlock_0"] shared = moe_block["shared_experts"] extra = blocks[f"layers_{slot}"]["mlp"] ( router_proj, router_scale, per_expert_scale, gate_up_proj, down_proj, sh_gate, sh_up, sh_down, pre_ln_2, post_ln_1, post_ln_2, ) = _pack_moe(routed, shared, extra) for rep in range(self.num_reps): i = rep * self.NUM_SLOTS + slot p = f"{prefix}.{i}" # Router self.vllm_state[f"{p}.router.proj.weight"] = router_proj[rep] self.vllm_state[f"{p}.router.scale"] = router_scale[rep] self.vllm_state[f"{p}.moe.per_expert_scale"] = per_expert_scale[rep] # Routed experts: apply process_moe_weights (GMM_TP: swapaxes + pad + TP reorder) # to produce the post-processed format that llm_state holds after model init. processed = process_moe_weights( FusedMoEWeights( w13_weight=gate_up_proj[rep], w13_weight_scale=None, w13_bias=None, w2_weight=down_proj[rep], w2_weight_scale=None, w2_bias=None, ), moe_backend=MoEBackend.GMM_TP, w13_reorder_size=self.vllm_tp, w13_interleave=False, # Gemma4 uses gelu, not swiglu ) self.vllm_state[f"{p}.moe.experts.w13_weight"] = processed.w13_weight self.vllm_state[f"{p}.moe.experts.w2_weight"] = processed.w2_weight # Shared expert: gate+up fused, TP-interleaved (MergedColumnParallelLinear, # spec=P('model', None)): [gate_tp0, up_tp0, gate_tp1, up_tp1, ...] sh_g, sh_u = sh_gate[rep], sh_up[rep] # each (d_sh, d_model) sh_per_tp = sh_g.shape[0] // self.vllm_tp shared_gate_up = jnp.concatenate( [ sh_g.reshape(self.vllm_tp, sh_per_tp, sh_g.shape[1]), sh_u.reshape(self.vllm_tp, sh_per_tp, sh_u.shape[1]), ], axis=1, ).reshape(-1, sh_g.shape[1]) self.vllm_state[f"{p}.mlp.gate_up_proj.weight"] = shared_gate_up self.vllm_state[f"{p}.mlp.down_proj.weight"] = sh_down[rep] # Extra MoE norms self.vllm_state[f"{p}.pre_feedforward_layernorm_2.weight"] = pre_ln_2[rep] self.vllm_state[f"{p}.post_feedforward_layernorm_1.weight"] = post_ln_1[rep] self.vllm_state[f"{p}.post_feedforward_layernorm_2.weight"] = post_ln_2[rep] del router_proj, router_scale, per_expert_scale, gate_up_proj, down_proj del sh_gate, sh_up, sh_down, pre_ln_2, post_ln_1, post_ln_2 gc.collect() # --- 5b. Dense MLP weights (gemma4-31b only) --- def _convert_dense_mlp_weights(self, blocks, prefix): """Converts gate/up/down projections for all layers (31B only). Tensor transformations: wi_0 (gate): (d_model, L, d_mlp) → L × (d_mlp, d_model) wi_1 (up): (d_model, L, d_mlp) → L × (d_mlp, d_model) wo (down): (d_mlp, L, d_model) → L × (d_model, d_mlp) """ @jax.jit def _pack_mlp(mlp): gate = jnp.unstack(jnp.transpose(mlp["wi_0"]["kernel"], (1, 2, 0)), axis=0) up = jnp.unstack(jnp.transpose(mlp["wi_1"]["kernel"], (1, 2, 0)), axis=0) down = jnp.unstack(jnp.transpose(mlp["wo"]["kernel"], (1, 2, 0)), axis=0) return gate, up, down for slot in range(self.NUM_SLOTS): mlp = blocks[f"layers_{slot}"]["mlp"] gate_layers, up_layers, down_layers = _pack_mlp(mlp) for rep in range(self.num_reps): i = rep * self.NUM_SLOTS + slot p = f"{prefix}.{i}" self.vllm_state[f"{p}.mlp.gate_proj.weight"] = gate_layers[rep] self.vllm_state[f"{p}.mlp.up_proj.weight"] = up_layers[rep] self.vllm_state[f"{p}.mlp.down_proj.weight"] = down_layers[rep] del gate_layers, up_layers, down_layers gc.collect()