Source code for maxtext.models.mistral

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer model definition."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

from flax import linen as nn
from flax import nnx
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Config
from maxtext.inference import page_manager
from maxtext.layers import initializers, nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.attentions import Attention
from maxtext.layers.linears import Dropout, MlpBlock
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.utils import max_utils

# -----------------------------------------
# The Decoder Layer for Mistral
# -----------------------------------------


[docs] class MistralDecoderLayer(nnx.Module): """Transformer decoder layer that attends to the encoder.""" def __init__( self, config: Config, model_mode: str, mesh: Mesh, *, rngs: nnx.Rngs, quant: None | Quant = None, ): self.config = config self.mesh = mesh self.quant = quant self.rngs = rngs batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) self.pre_self_attention_layer_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, rngs=self.rngs, ) self.self_attention = Attention( config=config, num_query_heads=config.num_query_heads, num_kv_heads=config.num_kv_heads, head_dim=config.head_dim, max_target_length=config.max_target_length, max_prefill_predict_length=config.max_prefill_predict_length, attention_kernel=config.attention, inputs_q_shape=dummy_inputs_shape, inputs_kv_shape=dummy_inputs_shape, mesh=mesh, dtype=config.dtype, weight_dtype=config.weight_dtype, dropout_rate=config.dropout_rate, float32_qk_product=config.float32_qk_product, float32_logits=config.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(config), prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), reshape_q=config.reshape_q, use_ragged_attention=config.use_ragged_attention, ragged_block_size=config.ragged_block_size, model_mode=model_mode, rngs=self.rngs, ) self.post_self_attention_layer_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, rngs=self.rngs, ) self.mlp = MlpBlock( mesh=self.mesh, in_features=config.emb_dim, intermediate_dim=config.mlp_dim, activations=config.mlp_activations, intermediate_dropout_rate=config.dropout_rate, dtype=config.dtype, weight_dtype=config.weight_dtype, config=config, quant=self.quant, model_mode=model_mode, rngs=self.rngs, ) self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, kv_cache=None, attention_metadata=None, ): cfg = self.config # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) is_scan_carry = False if isinstance(inputs, tuple) and len(inputs) == 3: hidden_states, stacked_kv_cache, layer_idx = inputs kv_cache = stacked_kv_cache[layer_idx] inputs = hidden_states is_scan_carry = True elif isinstance(inputs, tuple): inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx = self.pre_self_attention_layer_norm(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, slot=slot, page_state=page_state, previous_chunk=previous_chunk, kv_cache=kv_cache, attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) # MLP block. mlp_lnx = self.mlp(hidden_states, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) layer_output = mlp_lnx + intermediate_inputs layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( "intermediates", "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) if is_scan_carry: def update_cache(cache, val): if jnp.size(val) > 0: return cache.at[layer_idx].set(val) return cache stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) return (layer_output, stacked_kv_cache, layer_idx + 1), None elif cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache
MistralDecoderLayerToLinen = nnx_wrappers.to_linen_class( MistralDecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, )