# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""vLLM adapter for MaxText models."""
import os
import jax
from flax import nnx
import flax.linen as nn
from jax import numpy as jnp
from jax.experimental.pallas import tpu as pltpu
from jax.sharding import Mesh
from maxtext.configs import pyconfig
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE
from maxtext.utils import max_logging
from maxtext.utils import model_creation_utils
try:
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
except ImportError:
# Mock for documentation build or environments without tpu_inference
class AttentionMetadata:
input_positions: jax.Array
from vllm.config import VllmConfig
[docs]
def next_power_of_two(x: int) -> int:
"""Finds the smallest power of 2 >= x using bit manipulation.
Args:
x: The input number (should be an integer).
Returns:
The smallest integer power of 2 that is >= x.
"""
assert x > 0
if x == 1:
return 1
return 1 << (x - 1).bit_length()
[docs]
def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.HyperParameters:
"""Generates a MaxText configuration from a vLLM configuration.
This function takes a vLLM configuration object and translates relevant
parameters into a MaxText `HyperParameters` object. It handles loading
paths and model names from the vLLM config, and applies a base MaxText
vLLM configuration file.
Args:
vllm_config: The vLLM configuration object containing model and load
parameters.
mesh: The JAX mesh device for model sharding.
Returns:
A `pyconfig.HyperParameters` object configured for MaxText.
Raises:
ValueError: If `hf_config_path` is not provided in the vLLM model config.
"""
if "maxtext_config" in vllm_config.additional_config:
overrides = vllm_config.additional_config["maxtext_config"]
else:
overrides = {}
if vllm_config.load_config.load_format == "dummy":
if overrides.get("load_parameters_path") is not None:
max_logging.log(
"Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped."
)
overrides["load_parameters_path"] = None
# Add base config path to positional args
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
argv_list = ["", str(base_config_path)]
# Gather sharding information from vLLM config to determine transformations to apply
sharding_config = vllm_config.sharding_config
tp = sharding_config.tp_size
ep = sharding_config.expert_size
attn_dp = sharding_config.attn_dp_size
# Calculate the maximum TP size across attention and MLP dimensions
kv_tp_size = tp * ep
moe_mlp_tp_size = tp * attn_dp
# Gather information on the hidden size of MoE models to determine if padding is needed
# to meet MLP MoE requirements for tpu-inference GMM_v2 kernel.
hf_config = (
vllm_config.model_config.hf_config.text_config
if hasattr(vllm_config.model_config.hf_config, "text_config")
else vllm_config.model_config.hf_config
)
hidden_size = getattr(hf_config, "moe_intermediate_size", None)
num_lanes = pltpu.get_tpu_info().num_lanes
use_global_kv_heads = hasattr(hf_config, "num_global_key_value_heads")
num_kv_heads = hf_config.num_key_value_heads
# Get the number KV heads used in global attention layers if specified.
num_global_kv_heads = hf_config.num_global_key_value_heads if use_global_kv_heads else None
max_logging.log(
f"vLLM sharding config: hidden_size={hidden_size}, kv_heads={num_kv_heads}, global_kv_heads={num_global_kv_heads}, "
f"num_lanes={num_lanes}, tp={tp}, attn_dp={attn_dp}, ep={ep}, moe_mlp_tp_size={moe_mlp_tp_size}"
)
# Replicate the number of KV heads if its less than the total degree of model parallelism
if kv_tp_size % num_kv_heads == 0 and num_kv_heads < kv_tp_size:
max_logging.log(
f"Padding num_kv_heads from {num_kv_heads} to {kv_tp_size} to match the degree of tensor parallelism."
)
overrides["base_num_kv_heads"] = kv_tp_size
# Replicate the number of global KV heads if its less than the total degree of model parallelism
if use_global_kv_heads and kv_tp_size % num_global_kv_heads == 0 and num_global_kv_heads < kv_tp_size:
max_logging.log(
f"Padding num_global_kv_heads from {num_global_kv_heads} "
f"to {kv_tp_size} to match the degree of tensor parallelism."
)
overrides["global_num_kv_heads"] = kv_tp_size
# Pad the hidden size of MoE models if the MLP dimension is less than expected by the GMM_v2 kernel in tpu-inference.
# The GMM_v2 kernel requires the MLP dimension per expert to be at least 2x the number of TPU lanes
# to ensure efficient execution. See the validate_inputs() method in the following file for more details:
# https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/megablox/gmm_v2.py
if hidden_size is not None and (hidden_size // moe_mlp_tp_size) % (2 * num_lanes) != 0:
padded_hidden_size = next_power_of_two(hidden_size)
while (padded_hidden_size // moe_mlp_tp_size) < (2 * num_lanes):
padded_hidden_size = next_power_of_two(padded_hidden_size + 1)
max_logging.log(
f"Padding moe_intermediate_size from {hidden_size} to {padded_hidden_size} to match MLP MoE requirements."
)
overrides["padded_base_moe_mlp_dim"] = padded_hidden_size
maxtext_config = pyconfig.initialize(argv_list, **overrides)
return maxtext_config
[docs]
class MaxTextForCausalLM(nnx.Module):
"""A vLLM-compatible causal language model wrapper for MaxText.
This class serves as the primary interface for integrating MaxText models
into the vLLM serving framework, specifically for causal language modeling
tasks. It handles configuration generation, model initialization, and execution
of the decoding step.
"""
# Signal to tpu-inference model_loader that this class manages its own
# JIT-sharded initialization (via create_nnx_model with out_shardings).
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
_self_manages_sharding: bool = True
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
"""Initializes the MaxTextForCausalLM model.
Args:
vllm_config: The vLLM configuration object.
rng_key: A JAX random key for model initialization.
mesh: The JAX mesh device for model sharding.
"""
self.vllm_config = vllm_config
self.cfg = vllm_config.model_config
self.maxtext_config = generate_maxtext_config(vllm_config, mesh)
# Model configuration
self.mesh = mesh
self.model_mode = MODEL_MODE_AUTOREGRESSIVE
self.is_text_generation_model = True
# Model creation
self.model: nnx.Module | None = None
# Indicates that the model handles its own sharding logic
self._self_manages_sharding = True
# Handle dummy weight loading during initialization
if vllm_config.load_config.load_format == "dummy":
self.load_weights(rng_key)
elif self.maxtext_config.load_parameters_path is None:
max_logging.log("Warning: No load_parameters_path provided. The model will be initialized with random weights.")
def __call__(
self,
kv_caches: list[jax.Array],
input_ids: jax.Array,
attention_metadata: AttentionMetadata,
*args,
**kwargs,
) -> tuple[list[jax.Array], jax.Array, list[jax.Array], list[jax.Array] | None]:
"""Performs a forward pass through the causal language model.
Args:
kv_caches: A list of JAX arrays representing the KV caches.
input_ids: A JAX array of input token IDs.
attention_metadata: Attention metadata for the decoding process.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
A tuple containing:
- updated_kv_caches: A list of updated KV caches.
- hidden: The hidden states.
- aux_hidden_states: A list of auxiliary hidden states.
- expert_indices: A list of expert indices or None.
Raises:
ValueError: If the model is not an instance of `nnx.Module`.
"""
if not isinstance(self.model, nnx.Module):
raise ValueError("Model must be an instance of type nnx.Module.")
# Ensure inputs are at least 2D with a batch dimension
input_ids = jnp.expand_dims(input_ids, axis=1)
input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1)
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
aux_hidden_states = []
expert_indices = None
hidden, kv_caches = self.model(
decoder_input_tokens=input_ids,
decoder_positions=input_positions,
kv_caches=kv_caches,
attention_metadata=attention_metadata,
model_mode=self.model_mode,
**kwargs,
)
# To be compatible with vLLM, we reshape to (batch * seq, dim).
hidden = hidden.reshape((-1, hidden.shape[-1]))
return kv_caches, hidden, aux_hidden_states, expert_indices
[docs]
def forward(self, *args, **kwargs):
"""Alias for __call__ for compatibility.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
The result of the `__call__` method.
"""
return self(*args, **kwargs)
[docs]
def get_input_embeddings(self) -> jax.Array:
"""Returns the input embeddings of the model.
Returns:
A JAX array representing the input embeddings.
"""
if not isinstance(self.model, nnx.Module):
raise ValueError("Model is not initialized.")
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
return self.model.token_embedder.embedding
[docs]
def embed_input_ids(self, input_ids: jax.Array) -> jax.Array:
"""Embeds the input token IDs using the model's token embedder.
Args:
input_ids: A JAX array of input token IDs.
Returns:
A JAX array of embedded input tokens.
"""
if not isinstance(self.model, nnx.Module):
raise ValueError("Model is not initialized.")
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
return self.model.token_embedder(input_ids)
[docs]
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
"""Computes the logits from the hidden states using the underlying decoder model.
Args:
hidden_states: A JAX array of hidden states.
Returns:
A JAX array of logits.
"""
if not isinstance(self.model, nnx.Module):
raise ValueError("Model is not initialized.")
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
# Reshape to (num_tokens, 1, hidden_dim) for decoder output head
y = jnp.expand_dims(hidden_states, axis=1)
# Compute logits using the MaxText decoder's output head
logits = self.model.decoder.apply_output_head(self.model.token_embedder, y, True, self.model_mode)
# Reshape back to (num_tokens, vocab_size)
return logits.squeeze(1)
[docs]
def load_weights(self, rng_key: jax.Array) -> None:
"""Loads model weights using the underlying decoder model.
Args:
rng_key: A JAX random key for model initialization.
"""
if self.model is not None:
return
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
model = model_creation_utils.from_pretrained(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)