Source code for maxtext.models.gpt_oss

"""Copyright 2026 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Decoder layer definition for GPT OSS models."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module


from typing import Optional

from flax import linen as nn
from flax import nnx
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import AttentionType, Config
from maxtext.layers import attentions
from maxtext.layers import initializers
from maxtext.layers import moe
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.attentions import Attention
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.inference import page_manager
from maxtext.utils import max_utils

# -----------------------------------------
# The Decoder Layer for GPT OSS models
# -----------------------------------------

GPT_OSS_ATTENTION_PATTERN = (
    attentions.AttentionType.LOCAL_SLIDING,
    attentions.AttentionType.GLOBAL,
)


[docs] def get_attention_type(layer_id): """Get attention type based on layer ID.""" layer_id %= len(GPT_OSS_ATTENTION_PATTERN) return GPT_OSS_ATTENTION_PATTERN[layer_id]
[docs] class GptOssDecoderLayer(nnx.Module): """Transformer decoder layer that attends to the encoder.""" def __init__( self, config: Config, mesh: Mesh, model_mode: str, attention_type: AttentionType, quant: Optional[Quant] = None, rngs: nnx.Rngs = None, ): self.config = config self.mesh = mesh self.model_mode = model_mode self.attention_type = attention_type self.quant = quant batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) self.pre_self_attention_layer_norm = RMSNorm( num_features=dummy_inputs_shape[-1], dtype=config.dtype, weight_dtype=jnp.float32, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, rngs=rngs, ) self.post_self_attention_layer_norm = RMSNorm( num_features=dummy_inputs_shape[-1], dtype=config.dtype, weight_dtype=jnp.float32, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, rngs=rngs, ) # Self-attention block self.GptOssAttention = Attention( config=config, num_query_heads=config.num_query_heads, num_kv_heads=config.num_kv_heads, head_dim=config.head_dim, max_target_length=config.max_target_length, max_prefill_predict_length=config.max_prefill_predict_length, attention_kernel=config.attention, inputs_q_shape=dummy_inputs_shape, inputs_kv_shape=dummy_inputs_shape, mesh=mesh, dtype=config.dtype, weight_dtype=config.weight_dtype, dropout_rate=config.dropout_rate, quant=self.quant, kv_quant=quantizations.configure_kv_quant(config), use_bias_in_projections=config.attention_bias, attention_type=self.attention_type, sliding_window_size=config.sliding_window_size, query_pre_attn_scalar=(config.head_dim**-0.5), model_mode=model_mode, rngs=rngs, ) self.GptOssMlp = moe.RoutedMoE( config=config, num_experts=config.num_experts, num_experts_per_tok=config.num_experts_per_tok, mesh=mesh, kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), intermediate_dim=config.mlp_dim, dtype=config.dtype, weight_dtype=config.weight_dtype, quant=self.quant, rngs=rngs, ) def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, page_state: None | page_manager.PageState = None, slot=None, kv_cache=None, attention_metadata=None, ): cfg = self.config # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) is_scan_carry = False if isinstance(inputs, tuple) and len(inputs) == 3: hidden_states, stacked_kv_cache, layer_idx = inputs kv_cache = stacked_kv_cache[layer_idx] inputs = hidden_states is_scan_carry = True elif isinstance(inputs, tuple): inputs = inputs[0] inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") lnx = self.pre_self_attention_layer_norm(inputs) lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) attention_lnx, kv_cache = self.GptOssAttention( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, kv_cache=kv_cache, attention_metadata=attention_metadata, ) attention_lnx = nn.with_logical_constraint( attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed") ) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) hidden_states = nn.with_logical_constraint( hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") ) load_balance_loss = None mlp_lnx, load_balance_loss, _ = self.GptOssMlp(hidden_states) mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, ("activation_batch", "activation_norm_length", "activation_embed"), ) if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None: self.sow("intermediates", "moe_lb_loss", load_balance_loss) if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( "intermediates", "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) if is_scan_carry: def update_cache(cache, val): if jnp.size(val) > 0: return cache.at[layer_idx].set(val) return cache stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) return (layer_output, stacked_kv_cache, layer_idx + 1), None elif cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache
GptOssDecoderLayerToLinen = nnx_wrappers.to_linen_class( GptOssDecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, )
[docs] class GptOssScannableBlock(nnx.Module): """A repeatable block of GPT OSS decoder layers. This block applies multiple decoder layers sequentially, using the attention pattern defined by GPT_OSS_ATTENTION_PATTERN. It's designed to be used with `nn.scan` for efficient compilation. Attributes: config: Config, MaxText model config mesh: Mesh, JAX device mesh (used for sharding) num_of_layers: int, number of decoder layers in the block quant: Optional[Quant], quantization config """ def __init__( self, config: Config, mesh: Mesh, model_mode: str, quant: Optional[Quant] = None, rngs: nnx.Rngs = None, ): self.config = config self.mesh = mesh self.model_mode = model_mode self.quant = quant for layer_id in range(config.inhomogeneous_layer_cycle_interval): attention_type = get_attention_type(layer_id) layer_name = f"layers_{layer_id}" layer = GptOssDecoderLayer( config=config, mesh=mesh, model_mode=model_mode, attention_type=attention_type, quant=self.quant, rngs=rngs, ) setattr(self, layer_name, layer) def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, page_state: None | page_manager.PageState = None, slot=None, kv_cache=None, attention_metadata=None, ): cfg = self.config inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") y = inputs for layer_id in range(cfg.inhomogeneous_layer_cycle_interval): layer_name = f"layers_{layer_id}" layer = getattr(self, layer_name) y, kv_cache = layer( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=previous_chunk, page_state=page_state, slot=slot, kv_cache=kv_cache, attention_metadata=attention_metadata, ) return y, kv_cache
GptOssScannableBlockToLinen = nnx_wrappers.to_linen_class( GptOssScannableBlock, base_metadata_fn=initializers.variable_to_logically_partitioned, )