Source code for maxtext.models.llama2

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer model definition."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

import functools
from flax import nnx
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Config
from maxtext.common.common_types import MODEL_MODE_PREFILL
from maxtext.inference import page_manager
from maxtext.layers import initializers
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.attentions import Attention
from maxtext.layers.linears import Dropout, MlpBlock
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.utils import max_utils
from maxtext.utils.sharding import create_sharding, maybe_shard_with_logical
from maxtext.layers.learn_to_init_layer import LearnToInitDecoderLayer

# -----------------------------------------
# The Decoder Layer specific for Llama2
# -----------------------------------------


[docs] class LlamaDecoderLayer(nnx.Module): """Transformer decoder layer that attends to the encoder.""" def __init__( self, config: Config, model_mode: str, mesh: Mesh, rngs: nnx.Rngs, quant: None | Quant = None, ): self.config = config self.mesh = mesh self.quant = quant if model_mode == MODEL_MODE_PREFILL: self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") else: self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) self.pre_self_attention_layer_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, shard_mode=config.shard_mode, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, rngs=rngs, ) self.self_attention = Attention( config=config, num_query_heads=config.num_query_heads, num_kv_heads=config.num_kv_heads, head_dim=config.head_dim, max_target_length=config.max_target_length, max_prefill_predict_length=config.max_prefill_predict_length, attention_kernel=config.attention, inputs_q_shape=dummy_inputs_shape, inputs_kv_shape=dummy_inputs_shape, mesh=mesh, dtype=config.dtype, weight_dtype=config.weight_dtype, dropout_rate=config.dropout_rate, float32_qk_product=config.float32_qk_product, float32_logits=config.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(config), prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), reshape_q=config.reshape_q, use_ragged_attention=config.use_ragged_attention, ragged_block_size=config.ragged_block_size, model_mode=model_mode, attn_logits_soft_cap=config.attn_logits_soft_cap, rngs=rngs, ) self.post_self_attention_layer_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, shard_mode=config.shard_mode, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, rngs=rngs, ) self.mlp = MlpBlock( in_features=config.emb_dim, intermediate_dim=config.mlp_dim, activations=config.mlp_activations, intermediate_dropout_rate=config.dropout_rate, dtype=config.dtype, weight_dtype=config.weight_dtype, config=config, mesh=mesh, quant=self.quant, model_mode=model_mode, rngs=rngs, ) self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) self._maybe_shard_with_logical = functools.partial( maybe_shard_with_logical, mesh=self.mesh, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding, extra_stack_level=1, ) def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, kv_cache=None, attention_metadata=None, ): cfg = self.config # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) is_scan_carry = False if isinstance(inputs, tuple) and len(inputs) == 3: hidden_states, stacked_kv_cache, layer_idx = inputs kv_cache = stacked_kv_cache[layer_idx] inputs = hidden_states is_scan_carry = True elif isinstance(inputs, tuple): inputs = inputs[0] inputs = self._maybe_shard_with_logical(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx_sharding = create_sharding(self.mesh, self.activation_axis_names) lnx = self.pre_self_attention_layer_norm(inputs, out_sharding=lnx_sharding) lnx = self._maybe_shard_with_logical(lnx, self.activation_axis_names) # Self-attention block attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, slot=slot, page_state=page_state, previous_chunk=previous_chunk, out_sharding=lnx_sharding, kv_cache=kv_cache, attention_metadata=attention_metadata, ) attention_lnx = self._maybe_shard_with_logical(attention_lnx, self.activation_axis_names) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = self.post_self_attention_layer_norm(intermediate_inputs, out_sharding=lnx_sharding) hidden_states = self._maybe_shard_with_logical(hidden_states, self.activation_axis_names) # MLP block. mlp_intermediate_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length", "activation_mlp")) mlp_lnx = self.mlp( hidden_states, deterministic=deterministic, intermediate_sharding=mlp_intermediate_sharding, out_sharding=lnx_sharding, ) mlp_lnx = self._maybe_shard_with_logical(mlp_lnx, self.activation_axis_names) layer_output = mlp_lnx + intermediate_inputs layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = self._maybe_shard_with_logical(layer_output, self.activation_axis_names) if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( "intermediates", "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) if is_scan_carry: def update_cache(cache, val): if jnp.size(val) > 0: return cache.at[layer_idx].set(val) return cache stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) return (layer_output, stacked_kv_cache, layer_idx + 1), None elif cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache
[docs] class LlamaLTIDecoderLayer(LearnToInitDecoderLayer): """A Type-bounded version of Llama-specific LearnToInitDecoderLayer. Temporal LTI wrapper before it is generalized for other models. """ def __init__(self, *args, **kwargs): super().__init__(*args, base_layer_cls=LlamaDecoderLayer, **kwargs)
LlamaLTIDecoderLayerToLinen = nnx_wrappers.to_linen_class( LlamaLTIDecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, ) LlamaDecoderLayerToLinen = nnx_wrappers.to_linen_class( LlamaDecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, )