Source code for maxtext.integration.vllm.torchax_converter.qwen3_moe

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Qwen3 MaxText to vLLM weight converters."""

import functools
import gc
import logging

import jax
import jax.numpy as jnp
from jaxtyping import PyTree

from maxtext.integration.vllm.torchax_converter.base import BaseMaxTextToVLLMConverter


[docs] class Qwen3MaxTextToVLLMConverter(BaseMaxTextToVLLMConverter): """Qwen3-specific MaxText to vLLM converter.""" def _convert_global(self, params): logging.info("_convert_global: embed_tokens...") self._to_embed_tokens(params) logging.info("_convert_global: final_norm...") self._to_final_norm(params) logging.info("_convert_global: lm_head...") self._to_lm_head(params) logging.info("_convert_global: done") def _convert_attn(self, params): logging.info("_convert_attn: pre_self_attention_layer_norm...") pre_ln = params["base"]["decoder"]["layers"]["pre_self_attention_layer_norm"]["scale"] convert_pre_ln = self._transpose_unstack(pre_ln) assert len(convert_pre_ln) == self.num_layers, f"Expected {self.num_layers} layers, got {len(convert_pre_ln)}" for i, layer in enumerate(convert_pre_ln): self.vllm_state[f"vllm_model.model.layers.{i}.input_layernorm.weight"] = layer del convert_pre_ln logging.info("_convert_attn: post_self_attention_layer_norm...") post_ln = params["base"]["decoder"]["layers"]["post_self_attention_layer_norm"]["scale"] converted_post_ln = self._transpose_unstack(post_ln) assert len(converted_post_ln) == self.num_layers, f"Expected {self.num_layers} layers, got {len(converted_post_ln)}" for i, layer in enumerate(converted_post_ln): self.vllm_state[f"vllm_model.model.layers.{i}.post_attention_layernorm.weight"] = layer del post_ln, converted_post_ln logging.info("_convert_attn: self_attention (qkv/o/norms)...") attn = params["base"]["decoder"]["layers"]["self_attention"] self_attn = self._to_attn(attn) for key, layers in self_attn.items(): self.vllm_state.update({f"vllm_model.model.layers.{i}.{key}": layer for i, layer in enumerate(layers)}) del attn, self_attn logging.info("_convert_attn: done") gc.collect() def _convert_moe(self, params): logging.info("_convert_moe: extracting moe_block...") moe = params["base"]["decoder"]["layers"]["moe_block"].to_pure_dict() prefix = "vllm_model.model.layers" logging.info("_convert_moe: gate weights...") self.vllm_state.update( {f"{prefix}.{i}.mlp.gate.weight": weight for i, weight in enumerate(self._to_mlp_gate(moe["gate"]["kernel"]))} ) del moe["gate"] gc.collect() logging.info("_convert_moe: expert down (w2) weights...") self.vllm_state.update( {f"{prefix}.{i}.mlp.experts.w2_weight": weight for i, weight in enumerate(self._to_mlp_expert_down(moe["wo"]))} ) del moe["wo"] gc.collect() logging.info("_convert_moe: expert gate+up (w13) weights (fuse_all jit+vmap)...") self._to_mlp_expert_gate_up( moe["wi_0"], moe["wi_1"], prefix, "mlp.experts.w13_weight", ) del moe["wi_0"], moe["wi_1"], moe logging.info("_convert_moe: done") gc.collect() def _to_final_norm(self, params): self.vllm_state["vllm_model.model.norm.weight"] = jnp.array(params["base"]["decoder"]["decoder_norm"]["scale"]) def _to_embed_tokens(self, params): self.vllm_state["vllm_model.model.embed_tokens.weight"] = jnp.array(params["base"]["token_embedder"]["embedding"]) def _to_lm_head(self, params): self.vllm_state["vllm_model.lm_head.weight"] = self._transpose_2d(params["base"]["decoder"]["logits_dense"]["kernel"]) def _to_attn(self, attn: PyTree) -> dict[str, jax.Array]: """Convert MaxText attention parameters into per-layer vLLM weights.""" tp = min(self.vllm_tp, self.config.base_num_kv_heads) compute = self._make_attn_compute(tp) return compute(attn) @staticmethod @functools.lru_cache(maxsize=8) def _make_attn_compute(tp: int): """Build the cached JIT that packs QKV and output projections for a TP size.""" @jax.jit def _compute(attn): q = jnp.transpose(attn["query"]["kernel"], (1, 0, 2, 3)) k = jnp.transpose(attn["key"]["kernel"], (1, 0, 2, 3)) v = jnp.transpose(attn["value"]["kernel"], (1, 0, 2, 3)) num_q_heads = q.shape[2] num_kv_heads = k.shape[2] head_dim = q.shape[3] num_layers, d_model = q.shape[0], q.shape[1] kv_per_tp = num_kv_heads // tp q_per_tp = num_q_heads // tp q_by_tp = q.reshape(num_layers, d_model, tp, q_per_tp, head_dim) k_by_tp = k.reshape(num_layers, d_model, tp, kv_per_tp, head_dim) v_by_tp = v.reshape(num_layers, d_model, tp, kv_per_tp, head_dim) qkv_by_tp = jnp.concatenate([q_by_tp, k_by_tp, v_by_tp], axis=3) qkv_flat = qkv_by_tp.reshape(num_layers, d_model, -1) qkv_proj = jnp.transpose(qkv_flat, (0, 2, 1)) o = jnp.transpose(attn["out"]["kernel"], (1, 3, 0, 2)) o_proj = o.reshape(o.shape[0], o.shape[1], -1) q_norm = jnp.transpose(attn["query_norm"]["scale"], (1, 0)) k_norm = jnp.transpose(attn["key_norm"]["scale"], (1, 0)) return { "self_attn.qkv_proj.weight": jnp.unstack(qkv_proj), "self_attn.o_proj.weight": jnp.unstack(o_proj), "self_attn.q_norm.weight": jnp.unstack(q_norm), "self_attn.k_norm.weight": jnp.unstack(k_norm), } return _compute def _to_mlp_gate(self, param): param = self._transpose_gate(param) return self._unstack_layer(param) def _to_mlp_expert_down(self, param): param = self._transpose_expert_down(param) param = jnp.transpose(param, (0, 1, 3, 2)) return self._unstack_layer(param) def _to_mlp_expert_gate_up(self, wi_0, wi_1, layer_key_prefix, layer_key_suffix): """Fuse MoE gate and up projections into the vLLM `w13` layout.""" fuse_all = self._make_fuse_all(self.vllm_tp) logging.info("_to_mlp_expert_gate_up: dispatching _fuse_all (single JIT+vmap)...") fused = fuse_all(wi_0, wi_1) logging.info( "_to_mlp_expert_gate_up: _fuse_all complete, shape=%s, unstacking layers...", fused.shape, ) del wi_0, wi_1 gc.collect() for i, layer_i in enumerate(jnp.unstack(fused, axis=0)): layer_i = jnp.transpose(layer_i, (0, 2, 1)) self.vllm_state[f"{layer_key_prefix}.{i}.{layer_key_suffix}"] = layer_i if i % 8 == 7: gc.collect() del fused gc.collect() @staticmethod @functools.lru_cache(maxsize=8) def _make_fuse_all(tp: int): """Build the cached JIT that fuses all expert gate and up weights.""" @jax.jit def _fuse_all(wi_0, wi_1): wi_0 = jnp.transpose(wi_0, (1, 0, 2, 3)) wi_1 = jnp.transpose(wi_1, (1, 0, 2, 3)) def _fuse_single(w0, w1): # [e, d_model, d_inner] -> [e, 2*padded_chunk_size*tp, d_model] w0 = jnp.transpose(w0, (0, 2, 1)) w1 = jnp.transpose(w1, (0, 2, 1)) num_experts, d_inner, d_model = w0.shape chunk_size = d_inner // tp # Pad each TP chunk to the next multiple of 128 for TPU GMM alignment, # matching process_w13_for_gmm in tpu_inference. # Example: d_inner=768 gives chunk=192 -> 256 with tp=4, or 96 -> 128 with tp=8. padded_chunk_size = ((chunk_size + 127) // 128) * 128 pad_amount = padded_chunk_size - chunk_size gate_chunks = w0.reshape(num_experts, tp, chunk_size, d_model) up_chunks = w1.reshape(num_experts, tp, chunk_size, d_model) if pad_amount > 0: gate_chunks = jnp.pad(gate_chunks, ((0, 0), (0, 0), (0, pad_amount), (0, 0))) up_chunks = jnp.pad(up_chunks, ((0, 0), (0, 0), (0, pad_amount), (0, 0))) combined = jnp.stack([gate_chunks, up_chunks], axis=2) return combined.reshape(num_experts, 2 * padded_chunk_size * tp, d_model) return jax.vmap(_fuse_single)(wi_0, wi_1) return _fuse_all @staticmethod @jax.jit def _unstack_layer(param): """Split a stacked layer tensor into a tuple of per-layer tensors.""" return jnp.unstack(param, axis=0) @staticmethod @jax.jit def _transpose_unstack(x): return jnp.unstack(jnp.transpose(x, (1, 0))) @staticmethod @jax.jit def _transpose_2d(x): return jnp.transpose(x, (1, 0)) @staticmethod @jax.jit def _transpose_gate(param): return jnp.transpose(param, (1, 2, 0)) @staticmethod @jax.jit def _transpose_expert_down(param): return jnp.transpose(param, (1, 0, 3, 2))