Source code for maxtext.models.gemma3

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Specialised layers for Gemma 3."""

import jax
from jax.ad_checkpoint import checkpoint_name
from jax.sharding import Mesh
import jax.numpy as jnp

from flax import linen as nn
from flax import nnx

from maxtext.common.common_types import Config, AttentionType, MODEL_MODE_PREFILL
from maxtext.layers import quantizations
from maxtext.layers import nnx_wrappers
from maxtext.layers import initializers
from maxtext.layers.attentions import Attention
from maxtext.layers.linears import DenseGeneral, MlpBlock, Dropout
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.layers.initializers import variable_to_logically_partitioned
from maxtext.utils import max_utils


GEMMA3_ATTENTION_PATTERN = (
    AttentionType.LOCAL_SLIDING,
    AttentionType.LOCAL_SLIDING,
    AttentionType.LOCAL_SLIDING,
    AttentionType.LOCAL_SLIDING,
    AttentionType.LOCAL_SLIDING,
    AttentionType.GLOBAL,
)


[docs] def get_attention_type(layer_id): layer_id %= len(GEMMA3_ATTENTION_PATTERN) return GEMMA3_ATTENTION_PATTERN[layer_id]
[docs] def get_query_pre_attn_scalar(config) -> float: """Returns the scalar to multiply the query by before attention.""" if config.model_name in ["gemma3-4b", "gemma3-12b"]: return config.head_dim**-0.5 elif config.model_name == "gemma3-27b": return (config.base_emb_dim // config.base_num_query_heads) ** -0.5 else: raise ValueError(f"Unsupported model name: {config.model_name}")
# Gemma3 Decoder Layer
[docs] class Gemma3DecoderLayer(nnx.Module): """Transformer decoder layer for Gemma3.""" def __init__( self, config: Config, mesh: Mesh, model_mode: str, rngs: nnx.Rngs, quant: None | Quant = None, attention_type: AttentionType = AttentionType.LOCAL_SLIDING, ): """Initializes the Gemma3DecoderLayer. Args: config: The Config object with model hyperparameters. mesh: The device mesh for distributed training. model_mode: One of MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, or MODEL_MODE_AUTOREGRESSIVE. rngs: The random number generators for initialization. quant: The quantization configuration. attention_type: The type of attention to use. """ self.config = config self.mesh = mesh self.quant = quant self.rngs = rngs self.attention_type = attention_type batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) self.pre_self_attention_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), rngs=self.rngs, ) query_pre_attn_scalar = get_query_pre_attn_scalar(config) self.self_attention = Attention( config=config, num_query_heads=config.num_query_heads, num_kv_heads=config.num_kv_heads, head_dim=config.head_dim, max_target_length=config.max_target_length, max_prefill_predict_length=config.max_prefill_predict_length, attention_kernel=config.attention, inputs_q_shape=dummy_inputs_shape, inputs_kv_shape=dummy_inputs_shape, mesh=mesh, dtype=config.dtype, weight_dtype=config.weight_dtype, dropout_rate=config.dropout_rate, float32_qk_product=config.float32_qk_product, float32_logits=config.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(config), attention_type=self.attention_type, sliding_window_size=config.sliding_window_size, attn_logits_soft_cap=config.attn_logits_soft_cap, use_qk_norm=True, # Gemma 3 models use query, key normalizations query_pre_attn_scalar=query_pre_attn_scalar, model_mode=model_mode, rngs=self.rngs, ) if self.config.use_post_attn_norm: self.post_self_attention_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), rngs=self.rngs, ) else: self.post_self_attention_norm = None self.pre_ffw_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), rngs=self.rngs, ) self.mlp = MlpBlock( in_features=config.emb_dim, intermediate_dim=config.mlp_dim, activations=config.mlp_activations, intermediate_dropout_rate=config.dropout_rate, dtype=config.dtype, weight_dtype=config.weight_dtype, config=config, quant=self.quant, model_mode=model_mode, mesh=mesh, rngs=self.rngs, ) if self.config.use_post_ffw_norm: self.post_ffw_norm = RMSNorm( num_features=config.emb_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, kernel_axes=("norm",), rngs=self.rngs, ) else: self.post_ffw_norm = None self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) if model_mode == MODEL_MODE_PREFILL: self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") else: self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, page_state=None, slot=None, bidirectional_mask=None, kv_cache=None, attention_metadata=None, ): cfg = self.config # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) is_scan_carry = False if isinstance(inputs, tuple) and len(inputs) == 3: hidden_states, stacked_kv_cache, layer_idx = inputs kv_cache = stacked_kv_cache[layer_idx] inputs = hidden_states is_scan_carry = True elif isinstance(inputs, tuple): inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx = self.pre_self_attention_norm(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) # Self-attention block attention_lnx, kv_cache = self.self_attention( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, bidirectional_mask=bidirectional_mask, kv_cache=kv_cache, attention_metadata=attention_metadata, ) if cfg.use_post_attn_norm: attention_lnx = self.post_self_attention_norm(attention_lnx) attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) attention_lnx += inputs residual = attention_lnx attn_output = self.pre_ffw_norm(attention_lnx) # MLP block. mlp_lnx = self.mlp(attn_output, deterministic=deterministic) if cfg.use_post_ffw_norm: mlp_lnx = self.post_ffw_norm(mlp_lnx) mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) next_layer_addition = mlp_lnx + residual next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic) layer_output = next_layer_addition_dropped_out layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( "intermediates", "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) if is_scan_carry: def update_cache(cache, val): if jnp.size(val) > 0: return cache.at[layer_idx].set(val) return cache stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) return (layer_output, stacked_kv_cache, layer_idx + 1), None elif cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache
Gemma3DecoderLayerToLinen = nnx_wrappers.to_linen_class( Gemma3DecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, )
[docs] class Gemma3ScannableBlock(nnx.Module): """A repeatable block of Gemma3 decoder layers.""" def __init__( self, config: Config, mesh: Mesh, model_mode: str, rngs: nnx.Rngs, quant: None | Quant = None, num_of_layers: int = 1, ): """Initializes the Gemma3ScannableBlock. Args: config: The Config object with model hyperparameters. mesh: The device mesh for distributed training. model_mode: One of MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, or MODEL_MODE_AUTOREGRESSIVE. rngs: The random number generators for initialization. quant: The quantization configuration. num_of_layers: The number of layers in the model. """ self.config = config self.mesh = mesh self.model_mode = model_mode self.quant = quant self.rngs = rngs self.num_of_layers = num_of_layers for layer_id in range(self.num_of_layers): attention_type = get_attention_type(layer_id) layer_name = f"layers_{layer_id}" layer = Gemma3DecoderLayer( config=self.config, mesh=self.mesh, model_mode=self.model_mode, rngs=self.rngs, quant=self.quant, attention_type=attention_type, ) setattr(self, layer_name, layer) def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, slot=None, page_state=None, previous_chunk=None, bidirectional_mask=None, ): cfg = self.config inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") y = inputs for layer_id in range(self.num_of_layers): y = getattr(self, f"layers_{layer_id}")( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=previous_chunk, page_state=page_state, slot=slot, bidirectional_mask=bidirectional_mask, ) if cfg.scan_layers: y = y[0] if cfg.scan_layers: return y, None else: return y
Gemma3ScannableBlockToLinen = nnx_wrappers.to_linen_class( Gemma3ScannableBlock, base_metadata_fn=initializers.variable_to_logically_partitioned, ) def _posemb_sincos_2d( h: int, w: int, *, width: int, temperature: float = 10_000.0, precision: str = "default", dtype: jnp.dtype = jnp.float32, ): """Follows the MoCo v3 logic.""" y, x = jnp.mgrid[:h, :w] # pylint: disable=unpacking-non-sequence assert width % 4 == 0, "Width must be mult of 4 for sincos posemb" omega = jnp.arange(width // 4) / (width // 4 - 1) omega = 1.0 / (temperature**omega) y = jnp.einsum("m,d->md", y.flatten(), omega, precision=precision) x = jnp.einsum("m,d->md", x.flatten(), omega, precision=precision) pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) return jnp.asarray(pe, dtype)[None, :, :]
[docs] class MlpBlockViT(nnx.Module): """NNX version of Transformer MLP / feed-forward block.""" def __init__( self, config: Config, block_id: int, *, rngs: nnx.Rngs, ): self.config = config self.block_id = block_id self.rngs = rngs self.Dense_0 = DenseGeneral( in_features_shape=self.config.hidden_size_for_vit, out_features_shape=self.config.intermediate_size_for_vit, dtype=self.config.dtype_mm, use_bias=True, matmul_precision=self.config.matmul_precision, rngs=self.rngs, ) self.Dropout_0 = Dropout(rate=self.config.dropout_rate, rngs=self.rngs) self.Dense_1 = DenseGeneral( in_features_shape=self.config.intermediate_size_for_vit, out_features_shape=self.config.hidden_size_for_vit, dtype=self.config.dtype_mm, use_bias=True, matmul_precision=self.config.matmul_precision, rngs=self.rngs, ) def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array: """Applies the Transformer MlpBlock module.""" x = self.Dense_0(x) x = nnx.gelu(x) x = self.Dropout_0(x, deterministic=deterministic) x = self.Dense_1(x) return x
[docs] class Encoder1DBlock(nnx.Module): """Single transformer encoder block (MHSA + MLP).""" def __init__( self, config: Config, mesh: Mesh, block_id: int, *, rngs: nnx.Rngs, ): self.block_id = block_id self.config = config self.mesh = mesh self.rngs = rngs self.seq_len = (self.config.image_size_for_vit // self.config.patch_size_for_vit) ** 2 self.LayerNorm_0 = nnx.LayerNorm( num_features=self.config.hidden_size_for_vit, epsilon=self.config.normalization_layer_epsilon, rngs=self.rngs ) self.MultiHeadDotProductAttention_0 = Attention( config=self.config, num_query_heads=self.config.num_attention_heads_for_vit, num_kv_heads=self.config.num_attention_heads_for_vit, head_dim=self.config.hidden_size_for_vit // self.config.num_attention_heads_for_vit, max_target_length=self.seq_len, float32_qk_product=self.config.float32_qk_product, float32_logits=self.config.float32_logits, dtype=self.config.dtype_mm, weight_dtype=self.config.weight_dtype, mesh=self.mesh, attention_kernel="dot_product", inputs_q_shape=(self.config.per_device_batch_size, self.seq_len, self.config.hidden_size_for_vit), inputs_kv_shape=(self.config.per_device_batch_size, self.seq_len, self.config.hidden_size_for_vit), dropout_rate=0, is_nope_layer=True, use_bias_in_projections=True, attention_type=AttentionType.FULL, use_qk_norm=False, query_pre_attn_scalar=1 / (self.config.hidden_size_for_vit // self.config.num_attention_heads_for_vit) ** 0.5, model_mode="train", rngs=self.rngs, ) self.LayerNorm_1 = nnx.LayerNorm( num_features=self.config.hidden_size_for_vit, epsilon=self.config.normalization_layer_epsilon, rngs=self.rngs ) self.MlpBlockViT_0 = MlpBlockViT( block_id=self.block_id, config=self.config, rngs=self.rngs, ) self.Dropout_0 = Dropout(rate=self.config.dropout_rate, rngs=self.rngs) def __call__(self, x: jax.Array, deterministic: bool = False) -> jax.Array: y = self.LayerNorm_0(x) y, _ = self.MultiHeadDotProductAttention_0(inputs_q=y, inputs_kv=y, deterministic=deterministic) y = self.Dropout_0(y, deterministic=deterministic) x = x + y y = self.LayerNorm_1(x) y = self.MlpBlockViT_0(y, deterministic=deterministic) y = self.Dropout_0(y, deterministic=deterministic) x = x + y return x
[docs] class Encoder(nnx.Module): """Transformer Model Encoder for sequence to sequence translation.""" def __init__( self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs, ): self.config = config self.mesh = mesh self.rngs = rngs for lyr in range(self.config.num_hidden_layers_for_vit): layer_name = f"encoderblock_{lyr}" layer = Encoder1DBlock( block_id=lyr, config=self.config, mesh=self.mesh, rngs=self.rngs, ) setattr(self, layer_name, layer) self.encoder_norm = nnx.LayerNorm( num_features=self.config.hidden_size_for_vit, epsilon=self.config.normalization_layer_epsilon, rngs=self.rngs ) def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array: # TODO(aireenmei, hengtaoguo): add if-scan branch to enable scan support for vision encoder for lyr in range(self.config.num_hidden_layers_for_vit): x = getattr(self, f"encoderblock_{lyr}")(x, deterministic=deterministic) x = self.encoder_norm(x) return x
[docs] class Einsum(nnx.Module): """Einsum is a convenience module for parameterized tensor multiplication.""" def __init__( self, shape: tuple[int, ...], initializer: nnx.initializers.Initializer = nnx.initializers.normal(), dtype: jnp.dtype | None = None, precision: str = "default", *, rngs: nnx.Rngs, ): self.precision = precision self.w = nnx.Param(initializer(rngs.params(), shape, dtype)) def __call__(self, eqn: str, x: jax.Array) -> jax.Array: return jnp.einsum(eqn, x, self.w, precision=self.precision)
[docs] class VisionEmbedder(nnx.Module): """Projects image embeddings to the embedding space of the text encoder.""" def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs): self.config = config self.mesh = mesh self.rngs = rngs self.mm_soft_embedding_norm = RMSNorm( num_features=self.config.hidden_size_for_vit, dtype=self.config.dtype_mm, weight_dtype=self.config.weight_dtype, epsilon=self.config.normalization_layer_epsilon, kernel_axes=("norm",), rngs=self.rngs, ) self.mm_input_projection = Einsum( shape=(self.config.hidden_size_for_vit, self.config.emb_dim), precision=self.config.matmul_precision, rngs=self.rngs, ) def __call__(self, x: jax.Array, eqn: str = "...tm,md->...td") -> jax.Array: x = self.mm_soft_embedding_norm(x) x = self.mm_input_projection(eqn, x) return x
[docs] def visionembedder_as_linen( config: Config, mesh: Mesh, ): """Creates a VisionEmbedder module.""" return nnx_wrappers.to_linen( VisionEmbedder, config, mesh=mesh, name="VisionEmbedder_0", abstract_init=False, metadata_fn=variable_to_logically_partitioned, )
[docs] class VisionExit(nnx.Module): """The vision exit layer. Possibly downsample the soft tokens to a required output length. Attributes: output_length: The embed will be spatially avg-pooled to this output length. """ def __init__(self, output_length: int = 256, *, rngs: nnx.Rngs): self.output_length = output_length self.rngs = rngs def __call__(self, x): cur_length = x.shape[1] if cur_length == self.output_length: return x cur_width = int(cur_length**0.5) assert cur_width**2 == cur_length output_width = int(self.output_length**0.5) assert output_width**2 == self.output_length, f"Cannot pool {x.shape=} to {self.output_length}=!" batch_size = x.shape[0] embed_dim = x.shape[-1] x = jnp.reshape(x, (batch_size, cur_width, cur_width, embed_dim)) assert not cur_width % output_width, f"{cur_width=} {output_width=}" window = cur_width // output_width window_shape = (window, window) x = nnx.avg_pool(x, window_shape=window_shape, strides=window_shape) batch_size, height, width, embed_dim = x.shape return jnp.reshape(x, (batch_size, height * width, embed_dim))
[docs] def vision_exit_as_linen(x: jax.Array, output_length: int) -> jax.Array: """A wrapper to use VisionExit as a function.""" return nnx.bridge.to_linen(VisionExit, output_length=output_length)(x)
[docs] class Gemma3VisionEncoderLayer(nnx.Module): """gemma 3 vision encoder layer""" def __init__( self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs, ): self.config = config self.mesh = mesh self.rngs = rngs self.embedding = nnx.Conv( in_features=self.config.num_channels_for_vit, out_features=self.config.hidden_size_for_vit, kernel_size=(self.config.patch_size_for_vit, self.config.patch_size_for_vit), strides=self.config.conv_stride_for_vit, padding="VALID", precision=self.config.matmul_precision, rngs=self.rngs, ) self.pos_embedding = self._get_posemb( self.config.posemb_type_for_vit, seqshape=( self.config.image_size_for_vit // self.config.patch_size_for_vit, self.config.image_size_for_vit // self.config.patch_size_for_vit, ), width=self.config.hidden_size_for_vit, dtype=self.config.dtype_mm, ) self.Dropout_0 = Dropout(rate=self.config.dropout_rate, rngs=self.rngs) self.Transformer = Encoder( config=self.config, mesh=self.mesh, rngs=self.rngs, ) self.VisionExit = VisionExit(output_length=256, rngs=self.rngs) def _get_posemb( self, typ: str, *, seqshape: tuple[int, int], width: int, dtype: jnp.dtype = jnp.float32, ): """Returns the position embedding.""" if typ == "learn": shape = (1, seqshape[0] * seqshape[1], width) initializer = nnx.initializers.normal(stddev=1 / (width**0.5)) return nnx.Param(initializer(self.rngs.params(), shape, dtype)) elif typ == "sincos2d": return _posemb_sincos_2d(*seqshape, width=width, dtype=dtype, precision=self.config.matmul_precision) else: raise ValueError(f"Unknown posemb type: {typ}") def __call__(self, inputs, deterministic, train=False): """ViT model that transforms image inputs to image embeddings. Args: inputs: jnp.array shaped [B, N, H, W, C], e.g. [4, 1, 896, 896, 3] Returns: jnp.array for image embeddings, shaped [B, N, P, D], e.g. [4, 1, 256, 1152] """ # currently only supports N=1, the inputs shape is [B, H, W, C] if len(inputs.shape) == 4: inputs = inputs[:, None, :] b, n, h, w, c = inputs.shape x = jnp.reshape(inputs, [b * n, h, w, c]) # Gemma3 uses conv2d with stride 14 and kernel size 14 to extract patches. x = self.embedding(x) bn, h, w, c = x.shape x = jnp.reshape(x, [bn, h * w, c]) x = self.pos_embedding + x x = self.Dropout_0(x) # Transformer encoder to extract image features. x = self.Transformer(x, deterministic=deterministic) # Gemma3 use a vision exit layer to downsample the soft tokens to a required output length. x = self.VisionExit(x) bn, l, c = x.shape x = jnp.reshape(x, [b, n, l, c]) return x
[docs] def gemma3visionencoder_as_linen( config: Config, mesh: Mesh, ): """Creates a Gemma3VisionEncoder module.""" module = nnx_wrappers.to_linen( Gemma3VisionEncoderLayer, config=config, mesh=mesh, name="Gemma3VisionEncoderLayer_0", abstract_init=False, metadata_fn=variable_to_logically_partitioned, ) return module