Source code for maxtext.models.models

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer models."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

from typing import Any

import jax
import jax.numpy as jnp
from jax.sharding import Mesh

from flax import linen as nn
from flax import nnx

from maxtext.common.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN, MultimodalInput
from maxtext.inference import page_manager
from maxtext.layers.nnx_decoders import NNXDecoder
from maxtext.layers import initializers
from maxtext.layers import nnx_wrappers
from maxtext.layers.decoders import Decoder
from maxtext.layers.embeddings import Embed, embed_as_linen
from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen
from maxtext.layers.multi_token_prediction import MultiTokenPredictionBlock, multi_token_prediction_block_as_linen
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.multimodal import processor as mm_processor
from maxtext.utils import max_utils

# ------------------------------------------------------------------------------
# The network: Transformer Definitions
# ------------------------------------------------------------------------------


[docs] class TransformerLinenPure(nn.Module): """An autoregressive transformer model.""" # Make new attributes required, so that all Transformer dependencies (train, decode, # compile, etc) will error instead of silently use defaults. # pylint: disable=attribute-defined-outside-init config: Config mesh: Mesh quant: Quant # Possible model_mode values can be found in maxtext.common.common_types. # We generally use maxtext.common.common_types.MODEL_MODE_TRAIN or # maxtext.common.common_types.MODEL_MODE_PREFILL for initializations here. # TODO: Make model_mode required after confirming no users are affected. model_mode: str = MODEL_MODE_TRAIN # May be different than the model_mode passed to __call__ # pylint: enable=attribute-defined-outside-init
[docs] def init(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): """Initializes the model.""" module = self.clone(model_mode=model_mode) kwargs["model_mode"] = model_mode return nn.Module.init(module, *args, **kwargs)
[docs] def apply(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): """Applies the model.""" module = self.clone(model_mode=model_mode) kwargs["model_mode"] = model_mode return nn.Module.apply(module, *args, **kwargs)
[docs] def setup(self): """Initialize shared_embedding & decoder layers.""" cfg = self.config mesh = self.mesh self.shared_embedding = embed_as_linen( num_embeddings=cfg.vocab_size, num_features=cfg.emb_dim, dtype=cfg.dtype, attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability embedding_init=nn.initializers.normal(stddev=1.0), name="token_embedder", config=cfg, mesh=self.mesh, ) self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: # Get the list of layer blueprints for the current model. # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. # By convention, this is the last layer in the list. layer_types = self.decoder.get_decoder_layers() mtp_layer_linen = layer_types[-1] # UNWRAP: The MTP block is pure NNX. If the decoder returned a Linen wrapper, # extract the native NNX class to preserve parameter tracing/scoping. mtp_layer_nnx = getattr(mtp_layer_linen, "module_class", mtp_layer_linen) self.mtp_block = multi_token_prediction_block_as_linen( config=self.config, mesh=self.mesh, transformer_layer_module=mtp_layer_nnx, decoder=self.decoder, rngs=self.make_rng("mtp_block"), )
[docs] def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): """ Compute logits from hidden states (wrapping decoder.apply_output_head). This function is only used for vocabulary tiling. """ logits = self.decoder.apply_output_head( shared_embedding=self.shared_embedding, y=hidden_states, deterministic=deterministic, model_mode=model_mode, ) return logits
def __call__( self, decoder_input_tokens: jnp.ndarray, decoder_positions: jnp.ndarray, decoder_segment_ids=None, encoder_images: None | jnp.ndarray = None, encoder_image_masks: None | jnp.ndarray = None, encoder_audios: None | jnp.ndarray = None, enable_dropout=True, model_mode=MODEL_MODE_TRAIN, previous_chunk=None, true_length: None | int = None, slot: None | int = None, page_state: None | page_manager.PageState = None, decoder_target_tokens: None | jnp.ndarray = None, decoder_target_mask: None | jnp.ndarray = None, nnx_method=None, kv_caches: list[jax.Array] | None = None, attention_metadata: dict[str, Any] | None = None, ): """Applies Transformer decoder-branch on encoded-input and target. Args: true_length: (Optional) Prompt length before padding slot: (Optional) An integer representing the decode batch index selected for this request. """ if decoder_segment_ids is not None and model_mode == MODEL_MODE_AUTOREGRESSIVE: raise ValueError( f"During autoregressive decoding we assume the tokens are in the active sequence" f" which is always {DECODING_ACTIVE_SEQUENCE_INDICATOR}." ) bidirectional_mask = None image_embeddings = None audio_embeddings = None deepstack_visual_embeds = None if self.config.use_multimodal and encoder_images is not None: image_embeddings, deepstack_visual_embeds = self.vision_encoder( input_images=encoder_images, deterministic=not enable_dropout ) bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens) if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None: audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout) # Create audio mask for placeholder tokens (qwen3-omni models) audio_masks = None if audio_embeddings is not None: audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens) multimodal_input = None if image_embeddings is not None or audio_embeddings is not None: multimodal_input = MultimodalInput( image_embeddings=image_embeddings, image_masks=encoder_image_masks, audio_embeddings=audio_embeddings, audio_masks=audio_masks, bidirectional_mask=bidirectional_mask, ) logits, hidden_state, kv_caches = self.decoder( shared_embedding=self.shared_embedding, decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=not enable_dropout, model_mode=model_mode, previous_chunk=previous_chunk, slot=slot, page_state=page_state, multimodal_input=multimodal_input, kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, ) # pytype: disable=wrong-keyword-args # If we are initializing the model AND MTP is enabled, we must create # dummy target tensors. This allows Flax to trace the MTPBlock and create # all its necessary parameters, without requiring the main training pipeline # to be aware of this initialization detail. if self.is_initializing() and self.config.mtp_num_layers > 0: if decoder_target_tokens is None: dummy_shape = decoder_input_tokens.shape decoder_target_tokens = jnp.ones(dummy_shape, dtype=jnp.int32) decoder_target_mask = jnp.ones(dummy_shape, dtype=jnp.int32) decoder_segment_ids = jnp.ones(dummy_shape, dtype=jnp.int32) # The Multi-Token Prediction (MTP) block functions as a "side-car" to the main # model, active only during training. It computes an auxiliary loss based on # predicting multiple future tokens, as described in the DeepSeek-V3 paper. # To ensure architectural consistency, it uses two key components from the parent Transformer: # 1. The same `DecoderLayer` blueprint for its internal transformer blocks. # 2. The `shared_embedding` for both embedding future tokens and for its final # logit projection. # Its only effect is to "sow" these losses; it does not alter the primary logits output. if self.config.mtp_num_layers > 0: self.mtp_block( shared_embedding=self.shared_embedding, main_hidden_state=hidden_state, input_ids=decoder_input_tokens, target_ids=decoder_target_tokens, target_mask=decoder_target_mask, position_ids=decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=not enable_dropout, model_mode=model_mode, ) if self.config.attention == "vllm_rpa": # In vLLM, logits are computed separately after updating the KV cache. return hidden_state, kv_caches return logits
[docs] def transformer_as_linen( config: Config, mesh: Mesh, quant: Quant, model_mode: str = MODEL_MODE_TRAIN, *, name: str | None = None, ) -> nnx_wrappers.ToLinen | TransformerLinenPure: """Constructs a Transformer model as a Linen or NNX module. This function returns an autoregressive Transformer model as either a Linen module or an NNX-wrapped module, depending on the `config.enable_nnx` flag. The returned module is suitable for training, evaluation, or decoding. If `config.enable_nnx` is True, returns a `TransformerLinen` that wraps the NNX-style Transformer for integration with NNX-specific APIs and workflows. Otherwise, returns a pure Flax Linen implementation (`TransformerLinenPure`). Args: config (Config): The configuration object specifying model hyperparameters and options. mesh (Mesh): The JAX sharding mesh for device partitioning. quant (Quant): The quantization module or configuration to use. model_mode (str, optional): The operational mode for the model, e.g. training, prefill, or autoregressive. Defaults to `MODEL_MODE_TRAIN`. name (str, optional): Optional module name for Linen/NNX construction. Returns: nnx_wrappers.ToLinen | TransformerLinenPure: A constructed Transformer model compatible with the specified framework (Linen or NNX). """ if config.enable_nnx: return TransformerLinen( Transformer, args=(), kwargs=nn.FrozenDict( { "mesh": mesh, "config": config, "quant": quant, "model_mode": model_mode, } ), metadata_fn=initializers.variable_to_logically_partitioned, name=name, ) else: return TransformerLinenPure(config, mesh, quant, model_mode=model_mode, name=name)
[docs] class TransformerLinen(nnx_wrappers.ToLinen): """Transformer model as a linen module."""
[docs] def init(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): """Initializes the model.""" model_kwargs = self.kwargs.copy({"model_mode": model_mode}) # type: ignore[wrong-arg-types] module = self.clone(kwargs=model_kwargs) kwargs["model_mode"] = model_mode return nnx_wrappers.ToLinen.init(module, *args, **kwargs)
[docs] def apply(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): """Applies the model.""" model_kwargs = self.kwargs.copy({"model_mode": model_mode}) # type: ignore[wrong-arg-types] module = self.clone(kwargs=model_kwargs) kwargs["model_mode"] = model_mode return nnx_wrappers.ToLinen.apply(module, *args, **kwargs)
[docs] class Transformer(nnx.Module): """An autoregressive transformer model.""" # Make new attributes required, so that all Transformer dependencies (train, decode, # compile, etc) will error instead of silently use defaults. # pylint: disable=attribute-defined-outside-init def __init__( self, config: Config, mesh: Mesh, quant: Quant, *, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs, ): """Initialize shared_embedding & decoder layers.""" self.config = config self.mesh = mesh self.quant = quant self.model_mode = model_mode cfg = self.config mesh = self.mesh self.token_embedder = Embed( mesh=self.mesh, num_embeddings=cfg.vocab_size, num_features=cfg.emb_dim, dtype=cfg.dtype, attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability embedding_init=nn.initializers.normal(stddev=1.0), config=cfg, rngs=rngs, ) self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None self.audio_encoder = AudioEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_audio else None if cfg.pure_nnx_decoder: self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) else: decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) dummy_decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) dummy_decoder_positions = jnp.ones((batch_size, seq_len), dtype=jnp.int32) if self.config.attention == "vllm_rpa": try: # pylint: disable=import-outside-toplevel from tpu_inference.layers.common.attention_metadata import AttentionMetadata # pytype: disable=import-error except ImportError as e: raise ImportError( "vLLM RPA attention requires the vllm-tpu package. Please install it with `pip install vllm-tpu`." ) from e dummy_attention_metadata = AttentionMetadata( input_positions=jnp.ones((batch_size * seq_len,), dtype=jnp.int32), block_tables=jnp.ones((seq_len,), dtype=jnp.int32), seq_lens=jnp.ones((1), dtype=jnp.int32), query_start_loc=jnp.ones((2), dtype=jnp.int32), request_distribution=jnp.ones((3), dtype=jnp.int32), ) else: dummy_attention_metadata = None if not cfg.pure_nnx_decoder: self.decoder.lazy_init( shared_embedding=self.token_embedder, decoder_input_tokens=dummy_decoder_input_tokens, decoder_positions=dummy_decoder_positions, attention_metadata=dummy_attention_metadata, ) # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: # Get the list of layer blueprints for the current model. layer_types = self.decoder.get_decoder_layers() # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. # By convention, this is the last layer in the list. mtp_layer = layer_types[-1] self.mtp_block = MultiTokenPredictionBlock( config=self.config, mesh=self.mesh, transformer_layer_module=mtp_layer, decoder=self.decoder, rngs=rngs, )
[docs] def no_op(self, *args, **kwargs): """A no-op method to allow the model to be used in a lazy context.""" return
[docs] def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): """Initializes the KV cache for the Transformer. Args: cache_size: The maximum size of the KV cache. batch_size: The batch size for which the cache is initialized. dtype: Data type for the cache. Defaults to `jnp.float32`. Returns: True if the cache is successfully initialized. """ return True
def __call__( self, decoder_input_tokens: jnp.ndarray, decoder_positions: jnp.ndarray, decoder_segment_ids=None, cache=None, encoder_images: jax.Array | None = None, encoder_image_masks: jax.Array | None = None, encoder_audios: jax.Array | None = None, enable_dropout=True, model_mode=MODEL_MODE_TRAIN, previous_chunk=None, true_length: int | None = None, slot: int | None = None, page_state: page_manager.PageState | None = None, decoder_target_tokens: jax.Array | None = None, decoder_target_mask: jax.Array | None = None, kv_caches: list[jax.Array] | None = None, attention_metadata: dict[str, Any] | None = None, ): """Applies the Zero-1 FSDP wrapped Transformer model. This method handles the all-gather operation for model weights before applying the underlying Transformer model, and then releases them. Args: decoder_input_tokens: Input tokens for the decoder. decoder_positions: Positional encodings for the decoder inputs. decoder_segment_ids: Segment IDs for the decoder inputs (optional). encoder_images: Encoder images for multimodal models (optional). enable_dropout: Whether to enable dropout. Defaults to True. previous_chunk: Previous chunk for incremental decoding (optional). true_length: True length of the prompt before padding (optional). slot: An integer representing the decode batch index selected for this request (optional). page_state: Page state for paged attention (optional). partition_spec: Partition specification for FSDP all-gather. decoder_target_tokens: Target tokens for the decoder (optional, used in MTP). decoder_target_mask: Target mask for the decoder (optional, used in MTP). nnx_method: Method to call on the NNX module (optional). kv_caches: List of KV caches for each attention layer, used when invoking from vLLM (optional). attention_metadata: Mapping to store attention metadata, used when invoking from vLLM (optional). Returns: Logits from the Transformer model. Logits, hidden_state, kv_caches if called by vLLM. """ if decoder_segment_ids is not None and model_mode == MODEL_MODE_AUTOREGRESSIVE: raise ValueError( f"During autoregressive decoding we assume the tokens are in the active sequence" f" which is always {DECODING_ACTIVE_SEQUENCE_INDICATOR}." ) bidirectional_mask = None image_embeddings = None deepstack_visual_embeds = None if self.config.use_multimodal and encoder_images is not None: image_embeddings, deepstack_visual_embeds = self.vision_encoder( input_images=encoder_images, deterministic=not enable_dropout ) bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens) audio_embeddings = None if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None: audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout) # Create audio mask for placeholder tokens (qwen3-omni models) audio_masks = None if audio_embeddings is not None: audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens) multimodal_input = None if image_embeddings is not None or audio_embeddings is not None: multimodal_input = MultimodalInput( image_embeddings=image_embeddings, image_masks=encoder_image_masks, audio_embeddings=audio_embeddings, audio_masks=audio_masks, bidirectional_mask=bidirectional_mask, ) mutable_collections = [] if self.config.record_internal_nn_metrics: mutable_collections.append("intermediates") if self.config.distill_beta > 0.0 and "intermediates" not in mutable_collections: mutable_collections.append("intermediates") if self.config.load_balance_loss_weight > 0.0 and "intermediates" not in mutable_collections: mutable_collections.append("intermediates") if self.config.pure_nnx_decoder: logits, hidden_state, kv_caches = self.decoder( shared_embedding=self.token_embedder, decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=not enable_dropout, model_mode=model_mode, previous_chunk=previous_chunk, slot=slot, page_state=page_state, multimodal_input=multimodal_input, kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, ) # pytype: disable=wrong-keyword-args else: logits, hidden_state, kv_caches = self.decoder( shared_embedding=self.token_embedder, decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=not enable_dropout, model_mode=model_mode, previous_chunk=previous_chunk, slot=slot, page_state=page_state, multimodal_input=multimodal_input, kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, mutable=mutable_collections, ) # pytype: disable=wrong-keyword-args # Materialize hidden state when vocab tiling is enabled if self.config.num_vocab_tiling > 1: self.hidden_states = hidden_state # If we are initializing the model AND MTP is enabled, we must create # dummy target tensors. This allows Flax to trace the MTPBlock and create # all its necessary parameters, without requiring the main training pipeline # to be aware of this initialization detail. # if self.is_initializing() and self.config.mtp_num_layers > 0: # if decoder_target_tokens is None: # dummy_shape = decoder_input_tokens.shape # decoder_target_tokens = jnp.ones(dummy_shape, dtype=jnp.int32) # decoder_target_mask = jnp.ones(dummy_shape, dtype=jnp.int32) # decoder_segment_ids = jnp.ones(dummy_shape, dtype=jnp.int32) # The Multi-Token Prediction (MTP) block functions as a "side-car" to the main # model, active only during training. It computes an auxiliary loss based on # predicting multiple future tokens, as described in the DeepSeek-V3 paper. # To ensure architectural consistency, it uses two key components from the parent Transformer: # 1. The same `DecoderLayer` blueprint for its internal transformer blocks. # 2. The `shared_embedding` for both embedding future tokens and for its final # logit projection. # Its only effect is to "sow" these losses; it does not alter the primary logits output. if self.config.mtp_num_layers > 0: self.mtp_block( shared_embedding=self.token_embedder, main_hidden_state=hidden_state, input_ids=decoder_input_tokens, target_ids=decoder_target_tokens, target_mask=decoder_target_mask, position_ids=decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=not enable_dropout, model_mode=model_mode, ) if self.config.attention == "vllm_rpa": # In vLLM, logits are computed separately after updating the KV cache. return hidden_state, kv_caches return logits