# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2 family of model decoder layers."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module
from typing import Any
from jax.ad_checkpoint import checkpoint_name
from jax.sharding import Mesh
import jax.numpy as jnp
from flax import linen as nn
from flax import nnx
from maxtext.common.common_types import Config
from maxtext.layers import initializers as max_initializers
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.layers.attentions import Attention
from maxtext.layers.linears import MlpBlock
from maxtext.inference import page_manager
from maxtext.utils import max_utils
# -----------------------------------------
# The Base Decoder Layer for Qwen2
# -----------------------------------------
[docs]
class AttentionWithNorm(nnx.Module):
"""Base class with shared common components: self-attention block with normalization."""
def __init__(
self,
config: Config,
mesh: Mesh,
model_mode: str,
quant: None | Quant,
rngs: nnx.Rngs,
):
self.config = config
self.mesh = mesh
self.quant = quant
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)
self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
# Corresponds to Qwen2's `input_layernorm`
self.pre_self_attention_layer_norm = RMSNorm(
num_features=config.emb_dim,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
kernel_axes=("norm",),
epsilon=config.normalization_layer_epsilon,
rngs=rngs,
)
# Self-attention block
query_pre_attn_scalar = config.head_dim**-0.5 # Qwen2 specific scaling
self.self_attention = Attention(
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
inputs_q_shape=dummy_inputs_shape,
inputs_kv_shape=dummy_inputs_shape,
mesh=mesh,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
dropout_rate=config.dropout_rate,
float32_qk_product=config.float32_qk_product,
float32_logits=config.float32_logits,
quant=quant,
kv_quant=quantizations.configure_kv_quant(config),
use_ragged_attention=config.use_ragged_attention,
ragged_block_size=config.ragged_block_size,
use_qk_norm=config.use_qk_norm,
use_bias_in_projections=config.attention_bias,
query_pre_attn_scalar=query_pre_attn_scalar,
model_mode=model_mode,
use_mrope=config.use_mrope,
mrope_section=config.mrope_section,
rngs=rngs,
)
# Post Attention LayerNorm (corresponds to Qwen2's `post_attention_layernorm`)
self.post_self_attention_layer_norm = RMSNorm(
num_features=config.emb_dim,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
kernel_axes=("norm",),
epsilon=config.normalization_layer_epsilon,
rngs=rngs,
)
[docs]
def apply_attention_with_norm(
self,
inputs: jnp.ndarray,
decoder_segment_ids: None | jnp.ndarray,
decoder_positions: None | jnp.ndarray,
deterministic: bool,
model_mode: str,
kv_cache: None | jnp.ndarray = None,
attention_metadata: None | dict[str, Any] = None,
):
"""Applies self-attention with pre and post-layer normalization."""
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
# Pre attention norm
lnx = self.pre_self_attention_layer_norm(inputs)
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
# Self attention
attention_lnx, kv_cache = self.self_attention(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names)
# Residual connection after attention
intermediate_inputs = inputs + attention_lnx
# Post attention norm
hidden_states = self.post_self_attention_layer_norm(intermediate_inputs)
hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names)
return hidden_states, intermediate_inputs, kv_cache
# -----------------------------------------
# The Dense Decoder Layer for Qwen2
# -----------------------------------------
[docs]
class Qwen2DecoderLayer(AttentionWithNorm):
"""Qwen2 Transformer decoder layer (dense)."""
def __init__(
self,
config: Config,
mesh: Mesh,
model_mode: str,
quant: None | Quant,
rngs: nnx.Rngs,
):
super().__init__(config, mesh, model_mode, quant, rngs)
self.mlp = MlpBlock(
in_features=config.emb_dim,
intermediate_dim=config.mlp_dim,
activations=config.mlp_activations,
intermediate_dropout_rate=config.dropout_rate,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
config=config,
mesh=mesh,
quant=quant,
model_mode=model_mode,
rngs=rngs,
)
def __call__(
self,
inputs: jnp.ndarray,
decoder_segment_ids: None | jnp.ndarray,
decoder_positions: None | jnp.ndarray,
deterministic: bool,
model_mode: str,
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache: None | jnp.ndarray = None,
attention_metadata: None | dict[str, Any] = None,
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]
hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm(
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
mlp_lnx = self.mlp(hidden_states, deterministic=deterministic)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names)
layer_output = intermediate_inputs + mlp_lnx
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
if self.config.scan_layers:
return layer_output, None
else:
return layer_output, kv_cache
Qwen2DecoderLayerToLinen = nnx_wrappers.to_linen_class(
Qwen2DecoderLayer,
base_metadata_fn=max_initializers.variable_to_logically_partitioned,
)