"""Copyright 2026 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Decoder layer definition for GPT OSS models."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module
from typing import Optional
from flax import linen as nn
from flax import nnx
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import AttentionType, Config
from maxtext.layers import attentions
from maxtext.layers import initializers
from maxtext.layers import moe
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.attentions import Attention
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.inference import page_manager
from maxtext.utils import max_utils
# -----------------------------------------
# The Decoder Layer for GPT OSS models
# -----------------------------------------
GPT_OSS_ATTENTION_PATTERN = (
attentions.AttentionType.LOCAL_SLIDING,
attentions.AttentionType.GLOBAL,
)
[docs]
def get_attention_type(layer_id):
"""Get attention type based on layer ID."""
layer_id %= len(GPT_OSS_ATTENTION_PATTERN)
return GPT_OSS_ATTENTION_PATTERN[layer_id]
[docs]
class GptOssDecoderLayer(nnx.Module):
"""Transformer decoder layer that attends to the encoder."""
def __init__(
self,
config: Config,
mesh: Mesh,
model_mode: str,
attention_type: AttentionType,
quant: Optional[Quant] = None,
rngs: nnx.Rngs = None,
):
self.config = config
self.mesh = mesh
self.model_mode = model_mode
self.attention_type = attention_type
self.quant = quant
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)
self.pre_self_attention_layer_norm = RMSNorm(
num_features=dummy_inputs_shape[-1],
dtype=config.dtype,
weight_dtype=jnp.float32,
kernel_axes=("norm",),
epsilon=config.normalization_layer_epsilon,
rngs=rngs,
)
self.post_self_attention_layer_norm = RMSNorm(
num_features=dummy_inputs_shape[-1],
dtype=config.dtype,
weight_dtype=jnp.float32,
kernel_axes=("norm",),
epsilon=config.normalization_layer_epsilon,
rngs=rngs,
)
# Self-attention block
self.GptOssAttention = Attention(
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
inputs_q_shape=dummy_inputs_shape,
inputs_kv_shape=dummy_inputs_shape,
mesh=mesh,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
dropout_rate=config.dropout_rate,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(config),
use_bias_in_projections=config.attention_bias,
attention_type=self.attention_type,
sliding_window_size=config.sliding_window_size,
query_pre_attn_scalar=(config.head_dim**-0.5),
model_mode=model_mode,
rngs=rngs,
)
self.GptOssMlp = moe.RoutedMoE(
config=config,
num_experts=config.num_experts,
num_experts_per_tok=config.num_experts_per_tok,
mesh=mesh,
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
intermediate_dim=config.mlp_dim,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
quant=self.quant,
rngs=rngs,
)
def __call__(
self,
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot=None,
kv_cache=None,
attention_metadata=None,
):
cfg = self.config
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
is_scan_carry = False
if isinstance(inputs, tuple) and len(inputs) == 3:
hidden_states, stacked_kv_cache, layer_idx = inputs
kv_cache = stacked_kv_cache[layer_idx]
inputs = hidden_states
is_scan_carry = True
elif isinstance(inputs, tuple):
inputs = inputs[0]
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
lnx = self.pre_self_attention_layer_norm(inputs)
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
attention_lnx, kv_cache = self.GptOssAttention(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
intermediate_inputs = inputs + attention_lnx
# Fully Connected
hidden_states = self.post_self_attention_layer_norm(intermediate_inputs)
hidden_states = nn.with_logical_constraint(
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
)
load_balance_loss = None
mlp_lnx, load_balance_loss, _ = self.GptOssMlp(hidden_states)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
layer_output = mlp_lnx + intermediate_inputs
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_norm_length", "activation_embed"),
)
if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
if is_scan_carry:
def update_cache(cache, val):
if jnp.size(val) > 0:
return cache.at[layer_idx].set(val)
return cache
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
return (layer_output, stacked_kv_cache, layer_idx + 1), None
elif cfg.scan_layers:
return layer_output, None
else:
return layer_output, kv_cache
GptOssDecoderLayerToLinen = nnx_wrappers.to_linen_class(
GptOssDecoderLayer,
base_metadata_fn=initializers.variable_to_logically_partitioned,
)
[docs]
class GptOssScannableBlock(nnx.Module):
"""A repeatable block of GPT OSS decoder layers.
This block applies multiple decoder layers sequentially, using the attention
pattern defined by GPT_OSS_ATTENTION_PATTERN. It's designed to be
used with `nn.scan` for efficient compilation.
Attributes:
config: Config, MaxText model config
mesh: Mesh, JAX device mesh (used for sharding)
num_of_layers: int, number of decoder layers in the block
quant: Optional[Quant], quantization config
"""
def __init__(
self,
config: Config,
mesh: Mesh,
model_mode: str,
quant: Optional[Quant] = None,
rngs: nnx.Rngs = None,
):
self.config = config
self.mesh = mesh
self.model_mode = model_mode
self.quant = quant
for layer_id in range(config.inhomogeneous_layer_cycle_interval):
attention_type = get_attention_type(layer_id)
layer_name = f"layers_{layer_id}"
layer = GptOssDecoderLayer(
config=config,
mesh=mesh,
model_mode=model_mode,
attention_type=attention_type,
quant=self.quant,
rngs=rngs,
)
setattr(self, layer_name, layer)
def __call__(
self,
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot=None,
kv_cache=None,
attention_metadata=None,
):
cfg = self.config
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
y = inputs
for layer_id in range(cfg.inhomogeneous_layer_cycle_interval):
layer_name = f"layers_{layer_id}"
layer = getattr(self, layer_name)
y, kv_cache = layer(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
return y, kv_cache
GptOssScannableBlockToLinen = nnx_wrappers.to_linen_class(
GptOssScannableBlock,
base_metadata_fn=initializers.variable_to_logically_partitioned,
)