Source code for maxtext.models.gemma

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Specialised layers for Gemma."""

from typing import Optional

from flax import linen as nn
from flax import nnx
from jax.ad_checkpoint import checkpoint_name
from jax.sharding import Mesh
import jax.numpy as jnp

from maxtext.common.common_types import Config
from maxtext.layers import initializers
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.attentions import Attention
from maxtext.layers.linears import Dropout, MlpBlock
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.inference import page_manager
from maxtext.utils import max_utils


# Decoder and Model definitions
[docs] class GemmaDecoderLayer(nnx.Module): """Transformer decoder layer that attends to the encoder.""" config: Config mesh: Mesh model_mode: str quant: None | Quant = None def __init__( self, config: Config, mesh: Mesh, model_mode: str, quant: Optional[Quant] = None, *, rngs: nnx.Rngs, ): self.config = config self.mesh = mesh self.model_mode = model_mode self.quant = quant self.rngs = rngs batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) self.pre_self_attention_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), rngs=self.rngs, ) self.self_attention = Attention( config=config, num_query_heads=config.num_query_heads, num_kv_heads=config.num_kv_heads, head_dim=config.head_dim, max_target_length=config.max_target_length, max_prefill_predict_length=config.max_prefill_predict_length, attention_kernel=config.attention, inputs_q_shape=dummy_inputs_shape, inputs_kv_shape=dummy_inputs_shape, mesh=self.mesh, dtype=config.dtype, weight_dtype=config.weight_dtype, dropout_rate=config.dropout_rate, float32_qk_product=config.float32_qk_product, float32_logits=config.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(config), use_ragged_attention=config.use_ragged_attention, ragged_block_size=config.ragged_block_size, model_mode=self.model_mode, rngs=self.rngs, ) self.pre_ffw_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), rngs=self.rngs, ) self.mlp = MlpBlock( config=config, mesh=self.mesh, in_features=config.emb_dim, intermediate_dim=config.mlp_dim, activations=config.mlp_activations, intermediate_dropout_rate=config.dropout_rate, dtype=config.dtype, weight_dtype=config.weight_dtype, quant=self.quant, model_mode=self.model_mode, rngs=self.rngs, ) self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, page_state: None | page_manager.PageState = None, slot=None, kv_cache=None, attention_metadata=None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] lnx = self.pre_self_attention_norm(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, kv_cache=kv_cache, attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) attention_lnx += inputs residual = attention_lnx attn_output = self.pre_ffw_norm(attention_lnx) mlp_lnx = self.mlp(attn_output, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) next_layer_addition = mlp_lnx + residual next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic) layer_output = next_layer_addition_dropped_out layer_output = nn.with_logical_constraint( layer_output, self.activation_axis_names, ) if self.config.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( "intermediates", "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) if self.config.scan_layers: return layer_output, None else: return layer_output, kv_cache
GemmaDecoderLayerToLinen = nnx_wrappers.to_linen_class( GemmaDecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, )