Source code for maxtext.models.models
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer models."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module
from typing import Any
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from flax import linen as nn
from flax import nnx
from maxtext.common.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN, MultimodalInput
from maxtext.inference import page_manager
from maxtext.layers.nnx_decoders import NNXDecoder
from maxtext.layers import initializers
from maxtext.layers import nnx_wrappers
from maxtext.layers.decoders import Decoder
from maxtext.layers.embeddings import Embed, embed_as_linen
from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen
from maxtext.layers.multi_token_prediction import MultiTokenPredictionBlock, multi_token_prediction_block_as_linen
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.multimodal import processor as mm_processor
from maxtext.utils import max_utils
# ------------------------------------------------------------------------------
# The network: Transformer Definitions
# ------------------------------------------------------------------------------
[docs]
class TransformerLinenPure(nn.Module):
"""An autoregressive transformer model."""
# Make new attributes required, so that all Transformer dependencies (train, decode,
# compile, etc) will error instead of silently use defaults.
# pylint: disable=attribute-defined-outside-init
config: Config
mesh: Mesh
quant: Quant
# Possible model_mode values can be found in maxtext.common.common_types.
# We generally use maxtext.common.common_types.MODEL_MODE_TRAIN or
# maxtext.common.common_types.MODEL_MODE_PREFILL for initializations here.
# TODO: Make model_mode required after confirming no users are affected.
model_mode: str = MODEL_MODE_TRAIN # May be different than the model_mode passed to __call__
# pylint: enable=attribute-defined-outside-init
[docs]
def init(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs):
"""Initializes the model."""
module = self.clone(model_mode=model_mode)
kwargs["model_mode"] = model_mode
return nn.Module.init(module, *args, **kwargs)
[docs]
def apply(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs):
"""Applies the model."""
module = self.clone(model_mode=model_mode)
kwargs["model_mode"] = model_mode
return nn.Module.apply(module, *args, **kwargs)
[docs]
def setup(self):
"""Initialize shared_embedding & decoder layers."""
cfg = self.config
mesh = self.mesh
self.shared_embedding = embed_as_linen(
num_embeddings=cfg.vocab_size,
num_features=cfg.emb_dim,
dtype=cfg.dtype,
attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
embedding_init=nn.initializers.normal(stddev=1.0),
name="token_embedder",
config=cfg,
mesh=self.mesh,
)
self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None
self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None
self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
# If MTP is enabled via config, set up the MTP block.
if self.config.mtp_num_layers > 0:
# Get the list of layer blueprints for the current model.
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
# By convention, this is the last layer in the list.
layer_types = self.decoder.get_decoder_layers()
mtp_layer_linen = layer_types[-1]
# UNWRAP: The MTP block is pure NNX. If the decoder returned a Linen wrapper,
# extract the native NNX class to preserve parameter tracing/scoping.
mtp_layer_nnx = getattr(mtp_layer_linen, "module_class", mtp_layer_linen)
self.mtp_block = multi_token_prediction_block_as_linen(
config=self.config,
mesh=self.mesh,
transformer_layer_module=mtp_layer_nnx,
decoder=self.decoder,
rngs=self.make_rng("mtp_block"),
)
def __call__(
self,
decoder_input_tokens: jnp.ndarray,
decoder_positions: jnp.ndarray,
decoder_segment_ids=None,
encoder_images: None | jnp.ndarray = None,
encoder_image_masks: None | jnp.ndarray = None,
encoder_audios: None | jnp.ndarray = None,
enable_dropout=True,
model_mode=MODEL_MODE_TRAIN,
previous_chunk=None,
true_length: None | int = None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
decoder_target_tokens: None | jnp.ndarray = None,
decoder_target_mask: None | jnp.ndarray = None,
nnx_method=None,
kv_caches: list[jax.Array] | None = None,
attention_metadata: dict[str, Any] | None = None,
):
"""Applies Transformer decoder-branch on encoded-input and target.
Args:
true_length: (Optional) Prompt length before padding
slot: (Optional) An integer representing the decode batch index selected
for this request.
"""
if decoder_segment_ids is not None and model_mode == MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
f"During autoregressive decoding we assume the tokens are in the active sequence"
f" which is always {DECODING_ACTIVE_SEQUENCE_INDICATOR}."
)
bidirectional_mask = None
image_embeddings = None
audio_embeddings = None
deepstack_visual_embeds = None
if self.config.use_multimodal and encoder_images is not None:
image_embeddings, deepstack_visual_embeds = self.vision_encoder(
input_images=encoder_images, deterministic=not enable_dropout
)
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout)
# Create audio mask for placeholder tokens (qwen3-omni models)
audio_masks = None
if audio_embeddings is not None:
audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens)
multimodal_input = None
if image_embeddings is not None or audio_embeddings is not None:
multimodal_input = MultimodalInput(
image_embeddings=image_embeddings,
image_masks=encoder_image_masks,
audio_embeddings=audio_embeddings,
audio_masks=audio_masks,
bidirectional_mask=bidirectional_mask,
)
logits, hidden_state, kv_caches = self.decoder(
shared_embedding=self.shared_embedding,
decoder_input_tokens=decoder_input_tokens,
decoder_positions=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=not enable_dropout,
model_mode=model_mode,
previous_chunk=previous_chunk,
slot=slot,
page_state=page_state,
multimodal_input=multimodal_input,
kv_caches=kv_caches,
attention_metadata=attention_metadata,
deepstack_visual_embeds=deepstack_visual_embeds,
) # pytype: disable=wrong-keyword-args
# If we are initializing the model AND MTP is enabled, we must create
# dummy target tensors. This allows Flax to trace the MTPBlock and create
# all its necessary parameters, without requiring the main training pipeline
# to be aware of this initialization detail.
if self.is_initializing() and self.config.mtp_num_layers > 0:
if decoder_target_tokens is None:
dummy_shape = decoder_input_tokens.shape
decoder_target_tokens = jnp.ones(dummy_shape, dtype=jnp.int32)
decoder_target_mask = jnp.ones(dummy_shape, dtype=jnp.int32)
decoder_segment_ids = jnp.ones(dummy_shape, dtype=jnp.int32)
# The Multi-Token Prediction (MTP) block functions as a "side-car" to the main
# model, active only during training. It computes an auxiliary loss based on
# predicting multiple future tokens, as described in the DeepSeek-V3 paper.
# To ensure architectural consistency, it uses two key components from the parent Transformer:
# 1. The same `DecoderLayer` blueprint for its internal transformer blocks.
# 2. The `shared_embedding` for both embedding future tokens and for its final
# logit projection.
# Its only effect is to "sow" these losses; it does not alter the primary logits output.
if self.config.mtp_num_layers > 0:
self.mtp_block(
shared_embedding=self.shared_embedding,
main_hidden_state=hidden_state,
input_ids=decoder_input_tokens,
target_ids=decoder_target_tokens,
target_mask=decoder_target_mask,
position_ids=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=not enable_dropout,
model_mode=model_mode,
)
if self.config.attention == "vllm_rpa":
# In vLLM, logits are computed separately after updating the KV cache.
return hidden_state, kv_caches
return logits
[docs]
def transformer_as_linen(
config: Config,
mesh: Mesh,
quant: Quant,
model_mode: str = MODEL_MODE_TRAIN,
*,
name: str | None = None,
) -> nnx_wrappers.ToLinen | TransformerLinenPure:
"""Constructs a Transformer model as a Linen or NNX module.
This function returns an autoregressive Transformer model as either a Linen module
or an NNX-wrapped module, depending on the `config.enable_nnx` flag. The returned module
is suitable for training, evaluation, or decoding.
If `config.enable_nnx` is True, returns a `TransformerLinen` that wraps the NNX-style
Transformer for integration with NNX-specific APIs and workflows.
Otherwise, returns a pure Flax Linen implementation (`TransformerLinenPure`).
Args:
config (Config): The configuration object specifying model hyperparameters and options.
mesh (Mesh): The JAX sharding mesh for device partitioning.
quant (Quant): The quantization module or configuration to use.
model_mode (str, optional): The operational mode for the model, e.g.
training, prefill, or autoregressive. Defaults to `MODEL_MODE_TRAIN`.
name (str, optional): Optional module name for Linen/NNX construction.
Returns:
nnx_wrappers.ToLinen | TransformerLinenPure:
A constructed Transformer model compatible with the specified framework (Linen or NNX).
"""
if config.enable_nnx:
return TransformerLinen(
Transformer,
args=(),
kwargs=nn.FrozenDict(
{
"mesh": mesh,
"config": config,
"quant": quant,
"model_mode": model_mode,
}
),
metadata_fn=initializers.variable_to_logically_partitioned,
name=name,
)
else:
return TransformerLinenPure(config, mesh, quant, model_mode=model_mode, name=name)
[docs]
class TransformerLinen(nnx_wrappers.ToLinen):
"""Transformer model as a linen module."""
[docs]
def init(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs):
"""Initializes the model."""
model_kwargs = self.kwargs.copy({"model_mode": model_mode}) # type: ignore[wrong-arg-types]
module = self.clone(kwargs=model_kwargs)
kwargs["model_mode"] = model_mode
return nnx_wrappers.ToLinen.init(module, *args, **kwargs)
[docs]
def apply(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs):
"""Applies the model."""
model_kwargs = self.kwargs.copy({"model_mode": model_mode}) # type: ignore[wrong-arg-types]
module = self.clone(kwargs=model_kwargs)
kwargs["model_mode"] = model_mode
return nnx_wrappers.ToLinen.apply(module, *args, **kwargs)
[docs]
class Transformer(nnx.Module):
"""An autoregressive transformer model."""
# Make new attributes required, so that all Transformer dependencies (train, decode,
# compile, etc) will error instead of silently use defaults.
# pylint: disable=attribute-defined-outside-init
def __init__(
self,
config: Config,
mesh: Mesh,
quant: Quant,
*,
model_mode: str = MODEL_MODE_TRAIN,
rngs: nnx.Rngs,
):
"""Initialize shared_embedding & decoder layers."""
self.config = config
self.mesh = mesh
self.quant = quant
self.model_mode = model_mode
cfg = self.config
mesh = self.mesh
self.token_embedder = Embed(
mesh=self.mesh,
num_embeddings=cfg.vocab_size,
num_features=cfg.emb_dim,
dtype=cfg.dtype,
attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
embedding_init=nn.initializers.normal(stddev=1.0),
config=cfg,
rngs=rngs,
)
self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None
self.audio_encoder = AudioEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_audio else None
if cfg.pure_nnx_decoder:
self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs)
else:
decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode)
self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs)
self.hidden_states = None
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode)
dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
dummy_decoder_positions = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
if self.config.attention == "vllm_rpa":
try:
# pylint: disable=import-outside-toplevel
from tpu_inference.layers.common.attention_metadata import AttentionMetadata # pytype: disable=import-error
except ImportError as e:
raise ImportError(
"vLLM RPA attention requires the vllm-tpu package. Please install it with `pip install vllm-tpu`."
) from e
dummy_attention_metadata = AttentionMetadata(
input_positions=jnp.ones((batch_size * seq_len,), dtype=jnp.int32),
block_tables=jnp.ones((seq_len,), dtype=jnp.int32),
seq_lens=jnp.ones((1), dtype=jnp.int32),
query_start_loc=jnp.ones((2), dtype=jnp.int32),
request_distribution=jnp.ones((3), dtype=jnp.int32),
)
else:
dummy_attention_metadata = None
if not cfg.pure_nnx_decoder:
self.decoder.lazy_init(
shared_embedding=self.token_embedder,
decoder_input_tokens=dummy_decoder_input_tokens,
decoder_positions=dummy_decoder_positions,
attention_metadata=dummy_attention_metadata,
)
# If MTP is enabled via config, set up the MTP block.
if self.config.mtp_num_layers > 0:
# Get the list of layer blueprints for the current model.
layer_types = self.decoder.get_decoder_layers()
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
# By convention, this is the last layer in the list.
mtp_layer = layer_types[-1]
self.mtp_block = MultiTokenPredictionBlock(
config=self.config,
mesh=self.mesh,
transformer_layer_module=mtp_layer,
decoder=self.decoder,
rngs=rngs,
)
[docs]
def no_op(self, *args, **kwargs):
"""A no-op method to allow the model to be used in a lazy context."""
return
[docs]
def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32):
"""Initializes the KV cache for the Transformer.
Args:
cache_size: The maximum size of the KV cache.
batch_size: The batch size for which the cache is initialized.
dtype: Data type for the cache. Defaults to `jnp.float32`.
Returns:
True if the cache is successfully initialized.
"""
return True
def __call__(
self,
decoder_input_tokens: jnp.ndarray,
decoder_positions: jnp.ndarray,
decoder_segment_ids=None,
cache=None,
encoder_images: jax.Array | None = None,
encoder_image_masks: jax.Array | None = None,
encoder_audios: jax.Array | None = None,
enable_dropout=True,
model_mode=MODEL_MODE_TRAIN,
previous_chunk=None,
true_length: int | None = None,
slot: int | None = None,
page_state: page_manager.PageState | None = None,
decoder_target_tokens: jax.Array | None = None,
decoder_target_mask: jax.Array | None = None,
kv_caches: list[jax.Array] | None = None,
attention_metadata: dict[str, Any] | None = None,
):
"""Applies the Zero-1 FSDP wrapped Transformer model.
This method handles the all-gather operation for model weights before
applying the underlying Transformer model, and then releases them.
Args:
decoder_input_tokens: Input tokens for the decoder.
decoder_positions: Positional encodings for the decoder inputs.
decoder_segment_ids: Segment IDs for the decoder inputs (optional).
encoder_images: Encoder images for multimodal models (optional).
enable_dropout: Whether to enable dropout. Defaults to True.
previous_chunk: Previous chunk for incremental decoding (optional).
true_length: True length of the prompt before padding (optional).
slot: An integer representing the decode batch index selected for this request (optional).
page_state: Page state for paged attention (optional).
partition_spec: Partition specification for FSDP all-gather.
decoder_target_tokens: Target tokens for the decoder (optional, used in MTP).
decoder_target_mask: Target mask for the decoder (optional, used in MTP).
nnx_method: Method to call on the NNX module (optional).
kv_caches: List of KV caches for each attention layer, used when invoking from vLLM (optional).
attention_metadata: Mapping to store attention metadata, used when invoking from vLLM (optional).
Returns:
Logits from the Transformer model. Logits, hidden_state, kv_caches if called by vLLM.
"""
if decoder_segment_ids is not None and model_mode == MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
f"During autoregressive decoding we assume the tokens are in the active sequence"
f" which is always {DECODING_ACTIVE_SEQUENCE_INDICATOR}."
)
bidirectional_mask = None
image_embeddings = None
deepstack_visual_embeds = None
if self.config.use_multimodal and encoder_images is not None:
image_embeddings, deepstack_visual_embeds = self.vision_encoder(
input_images=encoder_images, deterministic=not enable_dropout
)
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
audio_embeddings = None
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout)
# Create audio mask for placeholder tokens (qwen3-omni models)
audio_masks = None
if audio_embeddings is not None:
audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens)
multimodal_input = None
if image_embeddings is not None or audio_embeddings is not None:
multimodal_input = MultimodalInput(
image_embeddings=image_embeddings,
image_masks=encoder_image_masks,
audio_embeddings=audio_embeddings,
audio_masks=audio_masks,
bidirectional_mask=bidirectional_mask,
)
mutable_collections = []
if self.config.record_internal_nn_metrics:
mutable_collections.append("intermediates")
if self.config.distill_beta > 0.0 and "intermediates" not in mutable_collections:
mutable_collections.append("intermediates")
if self.config.load_balance_loss_weight > 0.0 and "intermediates" not in mutable_collections:
mutable_collections.append("intermediates")
if self.config.pure_nnx_decoder:
logits, hidden_state, kv_caches = self.decoder(
shared_embedding=self.token_embedder,
decoder_input_tokens=decoder_input_tokens,
decoder_positions=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=not enable_dropout,
model_mode=model_mode,
previous_chunk=previous_chunk,
slot=slot,
page_state=page_state,
multimodal_input=multimodal_input,
kv_caches=kv_caches,
attention_metadata=attention_metadata,
deepstack_visual_embeds=deepstack_visual_embeds,
) # pytype: disable=wrong-keyword-args
else:
logits, hidden_state, kv_caches = self.decoder(
shared_embedding=self.token_embedder,
decoder_input_tokens=decoder_input_tokens,
decoder_positions=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=not enable_dropout,
model_mode=model_mode,
previous_chunk=previous_chunk,
slot=slot,
page_state=page_state,
multimodal_input=multimodal_input,
kv_caches=kv_caches,
attention_metadata=attention_metadata,
deepstack_visual_embeds=deepstack_visual_embeds,
mutable=mutable_collections,
) # pytype: disable=wrong-keyword-args
# Materialize hidden state when vocab tiling is enabled
if self.config.num_vocab_tiling > 1:
self.hidden_states = hidden_state
# If we are initializing the model AND MTP is enabled, we must create
# dummy target tensors. This allows Flax to trace the MTPBlock and create
# all its necessary parameters, without requiring the main training pipeline
# to be aware of this initialization detail.
# if self.is_initializing() and self.config.mtp_num_layers > 0:
# if decoder_target_tokens is None:
# dummy_shape = decoder_input_tokens.shape
# decoder_target_tokens = jnp.ones(dummy_shape, dtype=jnp.int32)
# decoder_target_mask = jnp.ones(dummy_shape, dtype=jnp.int32)
# decoder_segment_ids = jnp.ones(dummy_shape, dtype=jnp.int32)
# The Multi-Token Prediction (MTP) block functions as a "side-car" to the main
# model, active only during training. It computes an auxiliary loss based on
# predicting multiple future tokens, as described in the DeepSeek-V3 paper.
# To ensure architectural consistency, it uses two key components from the parent Transformer:
# 1. The same `DecoderLayer` blueprint for its internal transformer blocks.
# 2. The `shared_embedding` for both embedding future tokens and for its final
# logit projection.
# Its only effect is to "sow" these losses; it does not alter the primary logits output.
if self.config.mtp_num_layers > 0:
self.mtp_block(
shared_embedding=self.token_embedder,
main_hidden_state=hidden_state,
input_ids=decoder_input_tokens,
target_ids=decoder_target_tokens,
target_mask=decoder_target_mask,
position_ids=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=not enable_dropout,
model_mode=model_mode,
)
if self.config.attention == "vllm_rpa":
# In vLLM, logits are computed separately after updating the KV cache.
return hidden_state, kv_caches
return logits