"""
Copyright 2023-2026 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Decoder layer definition for Olmo 3 models."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module
from typing import Optional
from flax import linen as nn
from flax import nnx
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import AttentionType, Config
from maxtext.layers import attentions
from maxtext.layers import initializers
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.attentions import Attention
from maxtext.layers.linears import MlpBlock
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.utils import max_utils
# -----------------------------------------
# The Decoder Layer for Olmo3 models
# -----------------------------------------
OLMO3_ATTENTION_PATTERN = (
attentions.AttentionType.LOCAL_SLIDING,
attentions.AttentionType.LOCAL_SLIDING,
attentions.AttentionType.LOCAL_SLIDING,
attentions.AttentionType.GLOBAL,
)
[docs]
def get_attention_type(layer_id):
"""Get attention type based on layer ID."""
layer_id %= len(OLMO3_ATTENTION_PATTERN)
return OLMO3_ATTENTION_PATTERN[layer_id]
[docs]
class Olmo3DecoderLayer(nnx.Module):
"""Transformer decoder layer that attends to the encoder."""
def __init__(
self,
config: Config,
mesh: Mesh,
model_mode: str,
attention_type: AttentionType,
quant: Optional[Quant] = None,
rngs: nnx.Rngs = None,
):
self.config = config
self.mesh = mesh
self.model_mode = model_mode
self.attention_type = attention_type
self.quant = quant
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)
self.post_self_attention_layer_norm = RMSNorm(
num_features=dummy_inputs_shape[-1],
dtype=config.dtype,
weight_dtype=jnp.float32,
kernel_axes=("norm",),
epsilon=config.normalization_layer_epsilon,
rngs=rngs,
)
self.post_mlp_layer_norm = RMSNorm(
num_features=dummy_inputs_shape[-1],
dtype=config.dtype,
weight_dtype=jnp.float32,
kernel_axes=("norm",),
epsilon=config.normalization_layer_epsilon,
rngs=rngs,
)
# Match HF runtime: a single rotary (with rope_scaling/YaRN) is applied to every layer,
# including sliding-window. The "RoPE scaling is not applied to sliding window attention
# layers" comment in HF's modular_olmo3.py is unimplemented design intent — its code
# passes the same (cos, sin) to all layers.
current_rope_type = config.rope_type.lower()
# Self-attention block
self.attention = Attention(
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
inputs_q_shape=dummy_inputs_shape,
inputs_kv_shape=dummy_inputs_shape,
mesh=mesh,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
dropout_rate=config.dropout_rate,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(config),
use_bias_in_projections=config.attention_bias,
attention_type=self.attention_type,
sliding_window_size=config.sliding_window_size,
query_pre_attn_scalar=(config.head_dim**-0.5),
model_mode=model_mode,
use_qk_norm=config.use_qk_norm,
rope_type=current_rope_type,
rngs=rngs,
)
self.mlp = MlpBlock(
in_features=config.emb_dim,
intermediate_dim=config.mlp_dim,
activations=config.mlp_activations,
intermediate_dropout_rate=config.dropout_rate,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
config=config,
mesh=mesh,
quant=quant,
model_mode=model_mode,
rngs=rngs,
)
def __call__(
self,
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
page_state=None,
slot=None,
kv_cache=None,
attention_metadata=None,
):
cfg = self.config
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
is_scan_carry = False
if isinstance(inputs, tuple) and len(inputs) == 3:
hidden_states, stacked_kv_cache, layer_idx = inputs
kv_cache = stacked_kv_cache[layer_idx]
inputs = hidden_states
is_scan_carry = True
elif isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
attention_lnx, kv_cache = self.attention(
inputs,
inputs,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
# Normalize stream before addition
attention_lnx = self.post_self_attention_layer_norm(attention_lnx)
attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
intermediate_inputs = inputs + attention_lnx
# Fully Connected
mlp_lnx = self.mlp(intermediate_inputs)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
# Normalize stream before addition
mlp_lnx = self.post_mlp_layer_norm(mlp_lnx)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
layer_output = mlp_lnx + intermediate_inputs
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_norm_length", "activation_embed"),
)
if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
if is_scan_carry:
def update_cache(cache, val):
if jnp.size(val) > 0:
return cache.at[layer_idx].set(val)
return cache
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
return (layer_output, stacked_kv_cache, layer_idx + 1), None
elif cfg.scan_layers:
return layer_output, None
else:
return layer_output, kv_cache
Olmo3DecoderLayerToLinen = nnx_wrappers.to_linen_class(
Olmo3DecoderLayer,
base_metadata_fn=initializers.variable_to_logically_partitioned,
)
[docs]
class Olmo3ScannableBlock(nnx.Module):
"""A repeatable block of Olmo 3 decoder layers.
This block applies multiple decoder layers sequentially, using the attention
pattern defined by OLMO3_ATTENTION_PATTERN. It's designed to be
used with `nn.scan` for efficient compilation.
Attributes:
config: Config, MaxText model config
mesh: Mesh, JAX device mesh (used for sharding)
num_of_layers: int, number of decoder layers in the block
quant: Optional[Quant], quantization config
"""
def __init__(
self,
config: Config,
mesh: Mesh,
model_mode: str,
quant: Optional[Quant] = None,
rngs: nnx.Rngs = None,
):
self.config = config
self.mesh = mesh
self.model_mode = model_mode
self.quant = quant
for layer_id in range(config.inhomogeneous_layer_cycle_interval):
attention_type = get_attention_type(layer_id)
layer_name = f"layers_{layer_id}"
layer = Olmo3DecoderLayer(
config=config,
mesh=mesh,
model_mode=model_mode,
attention_type=attention_type,
quant=self.quant,
rngs=rngs,
)
setattr(self, layer_name, layer)
def __call__(
self,
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
page_state=None,
slot=None,
kv_cache=None,
attention_metadata=None,
):
cfg = self.config
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
y = inputs
for layer_id in range(cfg.inhomogeneous_layer_cycle_interval):
layer_name = f"layers_{layer_id}"
layer = getattr(self, layer_name)
y = layer(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
if cfg.scan_layers:
y = y[0]
if cfg.scan_layers:
return y, None
else:
return y
Olmo3ScannableBlockToLinen = nnx_wrappers.to_linen_class(
Olmo3ScannableBlock,
base_metadata_fn=initializers.variable_to_logically_partitioned,
)