# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for decoder layers"""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module
import functools
import inspect
import warnings
from typing import Any
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import nnx
from flax.nnx import wrappers as nnx_wrappers
from jax.ad_checkpoint import checkpoint_name
from jax.sharding import Mesh
from maxtext.common.common_types import (
MODEL_MODE_AUTOREGRESSIVE,
MODEL_MODE_PREFILL,
MODEL_MODE_TRAIN,
Config,
DecoderBlockType,
ShardMode,
)
from maxtext.inference import page_manager
from maxtext.layers import initializers, linears, mhc, normalizations, quantizations
from maxtext.layers.attentions import Attention
from maxtext.layers.embeddings import Embed, PositionalEmbedding, attend_on_embedding
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.models import (
deepseek,
deepseek_batchsplit,
deepseek_batchsplit_fp8,
gemma,
gemma2,
gemma3,
gemma4,
gpt3,
gpt_oss,
llama2,
llama4,
mistral,
mixtral,
olmo3,
qwen3,
qwen3_5,
simple_layer,
)
from maxtext.multimodal import utils as mm_utils
from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding
from maxtext.utils.sharding import create_sharding
# ------------------------------------------------------------------------------
# The network: Decoder Definitions
# ------------------------------------------------------------------------------
[docs]
class NNXDecoderLayer(nnx.Module):
"""
Transformer decoder layer converted to NNX
"""
def __init__(
self,
config: Config,
mesh: Mesh,
model_mode: str,
quant: None | Quant = None,
name: str = "decoder_layer",
*,
rngs: nnx.Rngs,
):
self.config = config
self.mesh = mesh
self.model_mode = model_mode
self.quant = quant
cfg = self.config
self.pre_self_attention_norm = RMSNorm(
num_features=cfg.emb_dim,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("norm",),
rngs=rngs,
)
self.self_attention = Attention(
config=self.config,
num_query_heads=cfg.num_query_heads,
num_kv_heads=cfg.num_kv_heads,
head_dim=cfg.head_dim,
max_target_length=cfg.max_target_length,
max_prefill_predict_length=cfg.max_prefill_predict_length,
attention_kernel=cfg.attention,
inputs_q_shape=(1, 1, cfg.emb_dim),
inputs_kv_shape=(1, 1, cfg.emb_dim),
mesh=mesh,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))),
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
reshape_q=cfg.reshape_q,
use_mrope=cfg.use_mrope,
mrope_section=cfg.mrope_section,
share_kv_projections=cfg.share_kv_projections,
model_mode=model_mode,
rngs=rngs,
)
self.mlp = linears.MlpBlock(
in_features=cfg.emb_dim,
intermediate_dim=cfg.mlp_dim,
activations=cfg.mlp_activations,
intermediate_dropout_rate=cfg.dropout_rate,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
model_mode=model_mode,
config=cfg,
quant=self.quant,
mesh=self.mesh,
rngs=rngs,
)
self.dropout = linears.Dropout(rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,))
def __call__(
self,
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
kv_cache: jax.Array | None = None,
attention_metadata: dict[str, Any] | None = None,
):
cfg = self.config
mesh = self.mesh
_maybe_shard_with_logical = functools.partial(
sharding.maybe_shard_with_logical,
mesh=mesh,
shard_mode=cfg.shard_mode,
debug_sharding=cfg.debug_sharding,
)
if self.model_mode == MODEL_MODE_PREFILL:
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
else:
logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed")
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
lnx = self.pre_self_attention_norm(inputs)
lnx = _maybe_shard_with_logical(lnx, logical_axis_names)
attention_lnx, kv_cache = self.self_attention(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names)
mlp_lnx = self.mlp(lnx, deterministic=deterministic)
mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names)
next_layer_addition = mlp_lnx + attention_lnx
next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic)
layer_output = next_layer_addition_dropped_out + inputs
layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names)
if cfg.record_internal_nn_metrics:
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
if cfg.scan_layers:
return layer_output, None
else:
return layer_output, kv_cache
[docs]
def deepstack_process(hidden_states, bidirectional_mask, visual_embeds):
"""Process deepstack visual embeddings by adding them to hidden states at visual token positions.
Args:
hidden_states: [batch, seq_len, hidden_dim] decoder hidden states
bidirectional_mask: [batch, seq_len] boolean mask marking visual token positions
visual_embeds: [batch, num_visual_tokens, hidden_dim] visual features from encoder layer
Returns:
Updated hidden_states with visual features added at visual positions
"""
# Expand mask to [batch, seq_len, 1] for broadcasting
mask_expanded = bidirectional_mask[:, :, jnp.newaxis]
# Use cumsum to map each True position in mask to its index in visual_embeds
visual_token_idx = jnp.cumsum(bidirectional_mask, axis=1) - 1 # [batch, seq_len], 0-indexed
# Gather visual tokens: for each position, get the corresponding visual token
batch_idx = jnp.arange(hidden_states.shape[0])[:, jnp.newaxis] # [batch, 1]
visual_embeds_scattered = visual_embeds[batch_idx, visual_token_idx, :] # [batch, seq_len, hidden]
# Only add where mask is True: hidden_states += visual_embeds * mask
hidden_states = hidden_states + visual_embeds_scattered * mask_expanded
return hidden_states
[docs]
class NNXDecoder(nnx.Module):
"""A stack of decoder layers as a part of an encoder-decoder architecture, using NNX."""
def __init__(
self,
config: Config,
mesh: Mesh,
quant: None | Quant = None,
model_mode: str = MODEL_MODE_TRAIN,
*,
rngs: nnx.Rngs,
):
self.config = config
self.mesh = mesh
self.quant = quant
self.model_mode = model_mode
self.rngs = rngs
decoder_block_classes = self.get_decoder_layers()
if config.trainable_position_size > 0:
self.position_embedder = Embed(
num_embeddings=config.trainable_position_size,
num_features=config.emb_dim,
dtype=config.dtype,
embedding_init=nn.initializers.normal(stddev=1.0),
config=config,
mesh=self.mesh,
rngs=rngs,
)
self.dropout = linears.Dropout(rate=config.dropout_rate, rngs=rngs, broadcast_dims=(-2,))
self.positional_embedding = PositionalEmbedding(embedding_dims=config.base_emb_dim)
self.decoder_norm = self.get_norm_layer(num_features=config.emb_dim, rngs=rngs)(
dtype=config.dtype,
weight_dtype=config.weight_dtype,
epsilon=config.normalization_layer_epsilon,
kernel_axes=("norm",),
parameter_memory_host_offload=config.parameter_memory_host_offload,
)
if not config.logits_via_embedding:
self.logits_dense = linears.DenseGeneral(
in_features_shape=config.emb_dim,
out_features_shape=config.vocab_size,
weight_dtype=config.weight_dtype,
dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype,
kernel_axes=("embed_vocab", "vocab"),
shard_mode=config.shard_mode,
matmul_precision=self.config.matmul_precision,
parameter_memory_host_offload=config.parameter_memory_host_offload,
rngs=rngs,
)
self.scanned_layers = None
self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK
self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3
self.is_gemma4 = self.config.decoder_block == DecoderBlockType.GEMMA4
if self.config.scan_layers:
if self.is_deepseek:
assert len(decoder_block_classes) == 2
dense_cls, moe_cls = decoder_block_classes
if config.engram_layers:
# 1. Create Dense Chunks (Direct setattr, NO nnx.Dict)
current_idx = 0
while current_idx < config.first_num_dense_layers:
if current_idx in config.engram_layers:
layer_name = f"dense_layers_engram_{current_idx}"
setattr(self, layer_name, self._create_single_layer(dense_cls, rngs, layer_idx=current_idx))
current_idx += 1
else:
next_boundary = self._find_next_boundary(current_idx, config.first_num_dense_layers, config.engram_layers)
chunk_name = f"dense_layers_{current_idx}_{next_boundary - 1}"
setattr(
self,
chunk_name,
self._create_scanned_layers(
dense_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
),
)
current_idx = next_boundary
# 2. Create MoE Chunks (Direct setattr, NO nnx.Dict)
current_idx = config.first_num_dense_layers
while current_idx < config.num_decoder_layers:
if current_idx in config.engram_layers:
layer_name = f"moe_layers_engram_{current_idx}"
setattr(self, layer_name, self._create_single_layer(moe_cls, rngs, layer_idx=current_idx))
current_idx += 1
else:
next_boundary = self._find_next_boundary(current_idx, config.num_decoder_layers, config.engram_layers)
chunk_name = f"moe_layers_{current_idx}_{next_boundary - 1}"
setattr(
self,
chunk_name,
self._create_scanned_layers(
moe_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs
),
)
current_idx = next_boundary
else:
# Standard DeepSeek logic when Engrams are disabled
num_dense = config.first_num_dense_layers
self.dense_layers = self._create_scanned_layers(
dense_cls, length=num_dense, metadata_axis_name="dense_layers", rngs=rngs
)
num_moe = config.num_decoder_layers - config.first_num_dense_layers
self.moe_layers = self._create_scanned_layers(
moe_cls, length=num_moe, metadata_axis_name="moe_layers", rngs=rngs
)
elif self.is_gemma3:
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
scan_length = config.num_decoder_layers // attention_pattern_length
num_remaining_layers = config.num_decoder_layers % attention_pattern_length
layer_kwargs = {"num_of_layers": attention_pattern_length}
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
RemattedGemma3Block = gemma3.Gemma3ScannableBlock
if scan_length > 0:
self.layers = self._create_scanned_layers(
RemattedGemma3Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs
)
self.layers_remainder = RemattedGemma3Block(
config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs
) # pytype: disable=wrong-keyword-args
elif self.is_gemma4:
attention_pattern_length = len(gemma4.GEMMA4_ATTENTION_PATTERN)
scan_length = config.num_decoder_layers // attention_pattern_length
num_remaining_layers = config.num_decoder_layers % attention_pattern_length
layer_kwargs = {"num_of_layers": attention_pattern_length}
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
RemattedGemma4Block = gemma4.Gemma4ScannableBlock
if scan_length > 0:
self.layers = self._create_scanned_layers(
RemattedGemma4Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs
)
self.layers_remainder = RemattedGemma4Block(
config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs
)
else:
layer_cls = decoder_block_classes[0]
num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval)
layer_kwargs = {}
if config.decoder_block == DecoderBlockType.LLAMA4:
layer_kwargs = {
"nope_layer_interval": self.config.nope_layer_interval,
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
}
if num_layers > 0:
self.layers = self._create_scanned_layers(
layer_cls, length=num_layers, metadata_axis_name="layers", rngs=rngs, **layer_kwargs
)
else:
self.layers = nnx.List([])
else:
self.layers = nnx.List([])
if self.is_deepseek:
dense_cls, moe_cls = decoder_block_classes
for i in range(config.first_num_dense_layers):
self._create_and_register_layer(dense_cls, rngs, "dense_layer", i)
for i in range(config.num_decoder_layers - config.first_num_dense_layers):
self._create_and_register_layer(moe_cls, rngs, "moe_layer", i)
else:
layer_cls = decoder_block_classes[0]
for lyr in range(config.num_decoder_layers):
layer_kwargs = {}
if config.decoder_block == DecoderBlockType.GEMMA3:
layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)}
elif config.decoder_block == DecoderBlockType.GEMMA4:
layer_kwargs = {"attention_type": gemma4.get_attention_type(layer_id=lyr)}
elif config.decoder_block == DecoderBlockType.LLAMA4:
layer_kwargs = {
"is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval),
"is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step),
}
elif config.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
layer_kwargs = {"layer_idx": lyr}
elif config.decoder_block == DecoderBlockType.GPT_OSS:
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
elif config.decoder_block == DecoderBlockType.OLMO3:
layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)}
self._create_and_register_layer(layer_cls, rngs, "layers", lyr, **layer_kwargs)
def _create_and_register_layer(self, layer_cls, rngs, base_name, i, **layer_kwargs):
attr_name = f"{base_name}_{i}"
layer = self._create_single_layer(layer_cls, rngs, **layer_kwargs)
setattr(self, attr_name, layer)
self.layers.append(layer)
def _create_single_layer(self, decoder_layer_class, rngs, **kwargs):
"""Helper to create a single layer (Linen or NNX)."""
if issubclass(decoder_layer_class, nnx.Module):
return decoder_layer_class(
config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs, **kwargs
)
else:
layer_linen = decoder_layer_class(
config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, **kwargs
)
return nnx_wrappers.ToNNX(layer_linen, rngs=rngs)
def _create_scanned_layers(
self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs
):
"""Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization."""
if length == 0:
return None
scan_axis = self.config.param_scan_axis
# Fork rngs to get per-layer RNG states for scanning
try:
forked_rngs = rngs.fork(split=length)
except: # pylint: disable=bare-except
pass
rngs_graphdef, rngs_state = nnx.split(forked_rngs)
first_rng_state = jax.tree.map(lambda x: x[0], rngs_state)
ref_rngs = nnx.merge(rngs_graphdef, first_rng_state)
ref_layer = decoder_layer_class(
config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=ref_rngs, **layer_kwargs
)
layer_graphdef, _, _ = nnx.split(ref_layer, nnx.Param, ...)
del ref_layer
def scan_body(carry, rng_state_slice):
layer_rngs = nnx.merge(rngs_graphdef, rng_state_slice)
layer = decoder_layer_class(
config=self.config,
mesh=self.mesh,
quant=self.quant,
model_mode=self.model_mode,
rngs=layer_rngs,
**layer_kwargs,
)
_, params, rest = nnx.split(layer, nnx.Param, ...)
return carry, (params, rest)
_, (stacked_params, stacked_rest) = jax.lax.scan(scan_body, None, rngs_state)
if scan_axis != 0:
stacked_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), stacked_params)
def _add_scan_metadata(state, axis):
def _update_leaf(leaf):
if hasattr(leaf, "replace") and hasattr(leaf, "value"):
replace_kwargs = {}
if hasattr(leaf, "get_metadata"):
replace_kwargs.update(leaf.get_metadata())
replace_kwargs[nnx.PARTITION_NAME] = metadata_axis_name
replace_kwargs["param_scan_axis"] = axis
for key in ["sharding", "out_sharding", "kernel_axes", "sharding_names"]:
val = getattr(leaf, key, None)
if val is None and key in replace_kwargs:
val = replace_kwargs[key]
if val is not None:
if isinstance(val, str):
val = (val,)
if isinstance(val, tuple):
l = list(val)
# Safely insert the scan axis into the logical axes string
if metadata_axis_name not in l:
insert_idx = min(axis, len(l))
l.insert(insert_idx, metadata_axis_name)
replace_kwargs[key] = tuple(l)
return leaf.replace(**replace_kwargs)
return leaf
# We must use a custom is_leaf to catch the VariableState instances
return jax.tree.map(_update_leaf, state, is_leaf=lambda x: hasattr(x, "replace") and hasattr(x, "value"))
stacked_params = _add_scan_metadata(stacked_params, scan_axis)
stacked_rest = _add_scan_metadata(stacked_rest, 0)
return nnx.merge(layer_graphdef, stacked_params, stacked_rest)
def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs):
"""Helper to cleanly apply jax.checkpoint to a single unscanned layer or block."""
graphdef, state = nnx.split(layer)
def pure_layer_fn(state_in, y_in):
merged_layer = nnx.merge(graphdef, state_in)
out = merged_layer(y_in, **kwargs)
return out, nnx.state(merged_layer)
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
out, new_state = checkpointed_fn(state, y)
nnx.update(layer, new_state)
return out
def _apply_layers_sequentially(self, layers, x_in, *args, length: int, kv_caches_stacked=None, **kwargs):
"""Runs the layer stack using nnx.scan.
Args:
layers: The stacked NNX module whose params are scanned over.
x_in: The carry (hidden state) fed into the first layer.
*args: Positional args broadcast to every layer call.
length: Number of scan iterations (= number of layers).
kv_caches_stacked: Optional pytree whose leaves have shape [num_layers, ...].
When provided, the i-th slice is passed as `kv_cache=` to layer i and the
updated caches are returned as a third element of the tuple.
**kwargs: Keyword args forwarded to the layer (filtered by the layer signature).
Returns:
(final_carry, updated_layers) when kv_caches_stacked is None.
(final_carry, updated_layers, returned_kv_stacked) otherwise.
"""
if length == 0:
return x_in, layers, kv_caches_stacked if kv_caches_stacked is not None else None
policy = self.get_remat_policy()
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
graphdef, params, state = nnx.split(layers, nnx.Param, ...)
scan_axis = self.config.param_scan_axis
if scan_axis != 0:
params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params)
layer_cls = layers.__class__
sig = inspect.signature(layer_cls.__call__)
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
def _extract_matching_state(template, full):
if isinstance(template, nnx.State):
return nnx.State({k: _extract_matching_state(v, full[k]) for k, v in template.items()})
elif isinstance(template, dict):
return {k: _extract_matching_state(v, full[k]) for k, v in template.items()}
return full
use_kv = kv_caches_stacked is not None
def layer_fn(carry, scanned_vars):
# Unpack the sliced variables for THIS layer
if use_kv:
current_params, current_state, kv_cache_layer = scanned_vars
else:
current_params, current_state = scanned_vars
kv_cache_layer = None
if self.config.parameter_memory_host_offload:
current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params)
layer = nnx.merge(graphdef, current_params, current_state)
# Build call kwargs, injecting per-layer kv_cache when available
call_kwargs = dict(valid_kwargs)
if kv_cache_layer is not None:
call_kwargs["kv_cache"] = kv_cache_layer
layer_out = layer(carry, *args, **call_kwargs)
if isinstance(layer_out, tuple):
new_carry = layer_out[0]
updated_kv = layer_out[1] if len(layer_out) > 1 else None
else:
new_carry = layer_out
updated_kv = None
# Extract the updated state to return it
new_current_state = nnx.state(layer)
if use_kv:
return new_carry, (new_current_state, updated_kv)
return new_carry, new_current_state
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
if use_kv:
# If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan
# because scanning requires stacking the kv_caches list, which creates a copy
# and breaks the in-place memory updates required by vLLM's PagedAttention.
# Therefore, we must unroll the loop statically when kv_caches is provided.
# kv_caches_stacked is actually the original kv_caches list in this new flow
kv_caches_list = kv_caches_stacked
current_carry = x_in
for i in range(length):
# Statically slice the parameters and state for this layer
current_params = jax.tree.map(lambda x, i=i: x[i], params)
current_state = jax.tree.map(lambda x, i=i: x[i], state)
# Call the layer
current_carry, (_, updated_kv) = layer_fn(current_carry, (current_params, current_state, kv_caches_list[i]))
# Update the list in-place (mutates the list passed by reference)
kv_caches_list[i] = updated_kv
# We don't need to rebuild scanned_state or return it because during
# inference with vLLM, parameters do not change and we don't need intermediates.
return current_carry, layers, None
else:
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
returned_kv_stacked = None
if scan_axis != 0:
new_params, new_rest = scanned_state.split(nnx.Param, ...)
new_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), new_params)
scanned_state = nnx.merge_state(new_params, new_rest)
nnx.update(layers, scanned_state)
return final_carry, layers, returned_kv_stacked if use_kv else None
[docs]
def get_decoder_layers(self):
"""Retrieves decoder layer classes based on config using a dictionary lookup."""
cfg = self.config
def get_scannable(normal_cls, scannable_cls):
return [scannable_cls] if cfg.scan_layers else [normal_cls]
def get_deepseek():
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
layer_map = {
DecoderBlockType.DEFAULT: [NNXDecoderLayer],
DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer],
DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer],
DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer],
DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer],
DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer],
DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer],
DecoderBlockType.GEMMA4: get_scannable(gemma4.Gemma4DecoderLayer, gemma4.Gemma4ScannableBlock),
DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer],
DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer],
DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer],
DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer],
DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer],
DecoderBlockType.DEEPSEEK: get_deepseek(),
DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock),
DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock),
DecoderBlockType.QWEN3_5: get_scannable(qwen3_5.Qwen3_5DecoderLayer, qwen3_5.Qwen3_5ScannableBlock),
DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock),
DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock),
}
if cfg.decoder_block not in layer_map:
raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}")
return layer_map[cfg.decoder_block]
[docs]
def minimal_policy(self, with_context=False, with_quantization=False):
"""Helper for creating minimal checkpoint policies."""
names = [
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwi_0",
"mlpwi_1",
"mlpwi",
"mlpwo",
]
if with_context:
names.append("context")
if with_quantization:
names.append("quantization")
return jax.checkpoint_policies.save_only_these_names(*names)
[docs]
def get_remat_policy(self):
"""Get remat policy for jax.checkpoint."""
policy = None
cfg = self.config
if cfg.remat_policy != "none":
if cfg.remat_policy in ("minimal_with_context", "minimal_flash"):
if cfg.remat_policy == "minimal_flash":
max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.")
policy = self.minimal_policy(with_context=True)
elif cfg.remat_policy == "minimal":
policy = self.minimal_policy()
elif cfg.remat_policy == "minimal_with_quantization":
if cfg.scan_layers:
warnings.warn(
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
"beneficial for performance."
)
policy = self.minimal_policy(with_context=False, with_quantization=True)
elif cfg.remat_policy == "minimal_with_context_and_quantization":
if cfg.scan_layers:
warnings.warn(
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
"beneficial for performance."
)
policy = self.minimal_policy(with_context=True, with_quantization=True)
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"context",
"out_proj",
)
elif cfg.remat_policy == "save_dot_except_mlpwi":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwo",
)
elif cfg.remat_policy == "save_dot_except_mlp":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
)
elif cfg.remat_policy == "save_qkv_proj":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
)
elif cfg.remat_policy == "qkv_proj_offloaded":
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"],
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "minimal_offloaded":
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=[
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwi_0",
"mlpwi_1",
"mlpwi",
"mlpwo",
],
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "custom":
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=cfg.tensors_on_device,
names_which_can_be_offloaded=cfg.tensors_to_offload,
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "save_out_proj":
policy = jax.checkpoint_policies.save_only_these_names("out_proj")
else:
assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies"
policy = None
return policy
[docs]
def get_norm_layer(self, num_features: int, rngs: nnx.Rngs):
"""get normalization layer (return type inherits from nn.Module)"""
if self.config.decoder_block in (
DecoderBlockType.DEFAULT,
DecoderBlockType.LLAMA2,
DecoderBlockType.MISTRAL,
DecoderBlockType.MIXTRAL,
DecoderBlockType.DEEPSEEK,
DecoderBlockType.GEMMA,
DecoderBlockType.GEMMA2,
DecoderBlockType.GEMMA3,
DecoderBlockType.GEMMA4,
DecoderBlockType.QWEN3,
DecoderBlockType.QWEN3_MOE,
DecoderBlockType.GPT_OSS,
DecoderBlockType.SIMPLE,
DecoderBlockType.SIMPLE_MLP,
DecoderBlockType.LLAMA4,
DecoderBlockType.OLMO3,
):
return functools.partial(RMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs)
elif self.config.decoder_block == DecoderBlockType.GPT3:
return functools.partial(
gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=True, rngs=rngs
)
elif self.config.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
return functools.partial(
normalizations.RMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs
)
else:
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
def _apply_embedding(
self,
shared_embedding: nnx.Module,
decoder_input_tokens,
decoder_positions,
deterministic,
model_mode,
multimodal_input=None,
):
"""Applies token and positional embeddings to the input tokens."""
cfg = self.config
y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode)
# Merge the image embeddings with the text embeddings for multimodal models
if multimodal_input is not None:
image_embeddings = multimodal_input.image_embeddings
bidirectional_mask = multimodal_input.bidirectional_mask
image_masks = multimodal_input.image_masks
audio_embeddings = multimodal_input.audio_embeddings
audio_masks = multimodal_input.audio_masks
if image_embeddings is not None and cfg.use_multimodal:
if cfg.model_name in [
"gemma3-4b",
"gemma3-12b",
"gemma3-27b",
"gemma4-26b",
"gemma4-31b",
"llama4-17b-16e",
"llama4-17b-128e",
"qwen3-omni-30b-a3b",
]:
y = mm_utils.merge_mm_embeddings(
text_embeddings=y,
multimodal_embeddings=image_embeddings,
mask=bidirectional_mask,
token_masks=image_masks,
)
else:
raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}")
if audio_embeddings is not None and cfg.use_audio:
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
y = mm_utils.merge_mm_embeddings(
text_embeddings=y,
multimodal_embeddings=audio_embeddings,
mask=audio_masks,
token_masks=None,
)
else:
raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}")
y = self.dropout(y, deterministic=deterministic)
y = y.astype(cfg.dtype)
if cfg.use_untrainable_positional_embedding:
y += self.positional_embedding(y, decoder_positions)
if cfg.trainable_position_size > 0 and self.position_embedder:
y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode)
return y
[docs]
def apply_output_head(self, shared_embedding, y, deterministic, model_mode):
"""Applies final normalization and projects hidden states to logits."""
cfg = self.config
if cfg.shard_mode == ShardMode.EXPLICIT:
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed"))
else:
norm_out_sharding = None
y = self.decoder_norm(y, out_sharding=norm_out_sharding)
y = self.dropout(y, deterministic=deterministic) # NNX call
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab"))
else:
out_sharding = create_sharding(
self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")
)
# [batch, length, emb_dim] -> [batch, length, vocab_size]
if cfg.logits_via_embedding:
# Use the transpose of embedding matrix for logit transform.
if isinstance(shared_embedding, nnx.Module):
embedding_table = shared_embedding.embedding.value
else:
embedding_table = shared_embedding.variables["params"]["embedding"]
if isinstance(embedding_table, nn.spmd.LogicallyPartitioned):
embedding_table = embedding_table.unbox()
attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype
logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding)
if self.config.normalize_embedding_logits:
# Correctly normalize pre-softmax logits for this shared case.
logits = logits / jnp.sqrt(y.shape[-1])
if cfg.final_logits_soft_cap:
logits = logits / cfg.final_logits_soft_cap
logits = jnp.tanh(logits) * cfg.final_logits_soft_cap
else:
logits = self.logits_dense(y, out_sharding=out_sharding)
if self.config.cast_logits_to_fp32:
logits = logits.astype(jnp.float32)
return logits
def _build_linen_params(self, moe_stack: nnx.Module) -> dict:
"""
Bridges NNX to Linen by creating a dictionary that mimics the exact variable
structure expected by `deepseek_batchsplit.fetch_weights`.
"""
state_dict = nnx.state(moe_stack, nnx.Param)
return {
"pre_self_attention_layer_norm": state_dict["pre_self_attention_layer_norm"],
"post_self_attention_layer_norm": state_dict["post_self_attention_layer_norm"],
"self_attention": state_dict["self_attention"],
"DeepSeekMoeBlock_0": state_dict.get("moe_block", state_dict.get("DeepSeekMoeBlock_0")),
}
def _find_next_boundary(self, current_idx, end_idx, engram_indices):
"""Finds the next index boundary, either the next Engram layer index or the overall end index."""
next_engrams = [l for l in engram_indices if l > current_idx]
if next_engrams:
return min(end_idx, *next_engrams)
return end_idx
def _apply_single_engram_layer(self, y, layer_name, *args, **kwargs):
"""Applies a single, unscanned Engram layer."""
layer = getattr(self, layer_name)
decoder_input_tokens = kwargs.get("decoder_input_tokens")
layer_kwargs = kwargs.get("layer_kwargs", {})
out = layer(y, *args, decoder_input_tokens=decoder_input_tokens, **layer_kwargs)
if isinstance(out, tuple):
y = out[0]
else:
y = out
return y
def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args, **kwargs):
"""Applies a contiguous chunk of layers using scan over a state slice."""
scan_length = next_boundary - current_idx
if scan_length > 0:
graphdef, state = nnx.split(layer_stack)
params, rest = state.split(nnx.Param, ...)
scan_axis = self.config.param_scan_axis
# Slice the chunk state along the correct axes
chunk_params = jax.tree.map(
lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis), params
)
chunk_rest = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), rest)
chunk_stack = nnx.merge(graphdef, chunk_params, chunk_rest)
# Apply sequentially
y, chunk_stack, _ = self._apply_layers_sequentially(
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
)
# Update the original stack state
new_state = nnx.state(chunk_stack)
new_params, new_rest = new_state.split(nnx.Param, ...)
updated_params = jax.tree.map(
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis), params, new_params
)
updated_rest = jax.tree.map(
lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), rest, new_rest
)
nnx.update(layer_stack, updated_params, updated_rest)
return y
def _apply_interleaved_scanned_layers(self, y, layer_prefix, start_idx, end_idx, engram_indices, *args, **kwargs):
"""Applies a mix of scanned standard layers and unscanned Engram layers."""
current_idx = start_idx
while current_idx < end_idx:
if current_idx in engram_indices:
layer_name = f"{layer_prefix}_engram_{current_idx}"
y = self._apply_single_engram_layer(y, layer_name, *args, **kwargs)
current_idx += 1
else:
next_boundary = self._find_next_boundary(current_idx, end_idx, engram_indices)
chunk_name = f"{layer_prefix}_{current_idx}_{next_boundary - 1}"
chunk_stack = getattr(self, chunk_name)
scan_length = next_boundary - current_idx
y, chunk_stack, _ = self._apply_layers_sequentially(
chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {})
)
current_idx = next_boundary
return y
def __call__(
self,
shared_embedding: Any,
decoder_input_tokens,
decoder_positions,
decoder_segment_ids=None,
deterministic=False,
model_mode=MODEL_MODE_TRAIN,
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
multimodal_input: None | Any = None,
kv_caches: list[jax.Array] | None = None,
attention_metadata=None,
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
):
cfg = self.config
assert decoder_input_tokens.ndim == 2 # [batch, len]
policy = self.get_remat_policy()
# [batch, length] -> [batch, length, emb_dim]
y = self._apply_embedding(
shared_embedding,
decoder_input_tokens,
decoder_positions,
deterministic,
model_mode,
multimodal_input=multimodal_input,
)
mhc_expand, mhc_reduce = mhc.get_functions(cfg.mhc_expansion_rate)
if cfg.mhc_expansion_rate > 1:
# (batch, length, emb_dim) --> (batch, length, mhc_expansion_rate, emb_dim)
y = mhc_expand(y)
layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode)
layer_kwargs = {}
# Extract the bidirectional mask locally for layer configurations
bidirectional_mask = multimodal_input.bidirectional_mask if multimodal_input is not None else None
if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4):
layer_kwargs["bidirectional_mask"] = bidirectional_mask
if attention_metadata is not None:
layer_kwargs["attention_metadata"] = attention_metadata
if cfg.scan_layers:
if self.is_deepseek:
layer_kwargs = {
"previous_chunk": previous_chunk,
"page_state": page_state,
"slot": slot,
}
if cfg.engram_layers:
common_kwargs = {
"layer_kwargs": layer_kwargs,
"decoder_input_tokens": decoder_input_tokens,
}
y = self._apply_interleaved_scanned_layers(
y, "dense_layers", 0, cfg.first_num_dense_layers, cfg.engram_layers, *layer_args, **common_kwargs
)
y = self._apply_interleaved_scanned_layers(
y,
"moe_layers",
cfg.first_num_dense_layers,
cfg.num_decoder_layers,
cfg.engram_layers,
*layer_args,
**common_kwargs,
)
else:
y, self.dense_layers, _ = self._apply_layers_sequentially(
self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs
)
num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers
if cfg.use_batch_split_schedule:
policy = self.get_remat_policy()
mock_params = self._build_linen_params(self.moe_layers)
if cfg.use_qwix_quantization:
y = deepseek_batchsplit_fp8.scan_batch_split_layers(
y,
mock_params,
decoder_positions,
decoder_segment_ids,
model_mode=model_mode,
mesh=self.mesh,
quant=self.quant,
cfg=cfg,
policy=policy,
)
else:
# bf16 code path
y = deepseek_batchsplit.scan_batch_split_layers(
y,
mock_params,
decoder_positions,
mesh=self.mesh,
cfg=cfg,
num_layers=num_moe,
)
else:
y, self.moe_layers, _ = self._apply_layers_sequentially(
self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs
)
elif self.is_gemma3:
y = self._apply_gemma3_scanned_blocks(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask,
previous_chunk,
page_state,
slot,
)
elif self.is_gemma4:
y = self._apply_gemma4_scanned_blocks(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask,
previous_chunk,
page_state,
slot,
)
else:
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
if kv_caches is not None:
# Pass the kv_caches list directly to avoid copying in jnp.stack,
# which breaks vLLM PagedAttention in-place memory updates.
# The _apply_layers_sequentially function will handle it by statically unrolling.
y, self.layers, _ = self._apply_layers_sequentially(
self.layers, y, *layer_args, length=scan_length, kv_caches_stacked=kv_caches, **layer_kwargs
)
# kv_caches list is updated in-place inside _apply_layers_sequentially
else:
y, self.layers, _ = self._apply_layers_sequentially(
self.layers, y, *layer_args, length=scan_length, **layer_kwargs
)
else:
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg)
# Hoisted function to preserve XLA cache ID
def pure_layer_fn(graphdef, state_in, y_in, kv_in):
if cfg.parameter_memory_host_offload:
state_in = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), state_in)
merged_layer = nnx.merge(graphdef, state_in)
out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs)
return out_y, out_kv, nnx.state(merged_layer)
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
for lyr, layer in enumerate(self.layers):
graphdef, state = nnx.split(layer)
if kv_caches is not None:
if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr])
else:
kv_cache = None
else:
kv_cache = kv_caches[lyr]
else:
kv_cache = None
input_tokens = decoder_input_tokens if cfg.engram_layers else None
if input_tokens is not None:
layer_kwargs["decoder_input_tokens"] = input_tokens
y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache)
nnx.update(layer, new_state)
if kv_caches is not None and kv_cache is not None:
if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_caches["key_cache"][lyr] = kv_cache[0]
kv_caches["value_cache"][lyr] = kv_cache[1]
else:
kv_caches[lyr] = kv_cache
if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
visual_embeds = deepstack_visual_embeds[lyr]
if bidirectional_mask is not None and visual_embeds is not None:
y = deepstack_process(y, bidirectional_mask, visual_embeds)
assert isinstance(y, jax.Array)
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
if cfg.mhc_expansion_rate > 1:
# (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim)
hidden_state = mhc_reduce(y)
else:
hidden_state = y
# When invoking from vLLM with RPA attention, logit computation is deferred to a later stage.
if cfg.attention == "vllm_rpa":
logits = None
# When in the Indexer Dense Warm-up stage, skip the expensive output head projection
# for efficiency, as the main model is frozen and the LM loss is not needed.
elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN:
logits = None
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
logits = None
self.sow(nnx.Intermediate, "hidden_states", hidden_state)
else:
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
return logits, hidden_state, kv_caches
def _apply_gemma3_scanned_blocks(
self,
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask,
previous_chunk,
page_state,
slot,
):
"""Applies Gemma3 scanned decoder blocks, handling main scan and remainders."""
cfg = self.config
# Define the repeating pattern length and calculate how many full blocks to scan
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
scan_length = cfg.num_decoder_layers // attention_pattern_length
layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode)
layer_kwargs = {"bidirectional_mask": bidirectional_mask}
# Apply the main scan over the full blocks
if scan_length > 0:
y, self.layers, _ = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
# Apply any remaining layers that did not fit into a full scanned block
num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length
if num_remaining_layers > 0:
policy = self.get_remat_policy()
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg)
def pure_gemma_fn(graphdef, state_in, y_in):
merged_layer = nnx.merge(graphdef, state_in)
out_y, _ = merged_layer(
y_in, *layer_args, previous_chunk=previous_chunk, page_state=page_state, slot=slot, **layer_kwargs
)
return out_y, nnx.state(merged_layer)
checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse)
graphdef, state = nnx.split(self.layers_remainder)
y, new_state = checkpointed_gemma_fn(graphdef, state, y)
nnx.update(self.layers_remainder, new_state)
return y
def _apply_gemma4_scanned_blocks(
self,
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask,
previous_chunk,
page_state,
slot,
):
"""Applies Gemma4 scanned decoder blocks, handling main scan and remainders."""
cfg = self.config
# Define the repeating pattern length and calculate how many full blocks to scan
attention_pattern_length = len(gemma4.GEMMA4_ATTENTION_PATTERN)
scan_length = cfg.num_decoder_layers // attention_pattern_length
layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode)
layer_kwargs = {"bidirectional_mask": bidirectional_mask}
# Apply the main scan over the full blocks
if scan_length > 0:
y, self.layers, _ = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
# Apply any remaining layers that did not fit into a full scanned block
num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length
if num_remaining_layers > 0:
policy = self.get_remat_policy()
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg)
def pure_gemma_fn(graphdef, state_in, y_in):
merged_layer = nnx.merge(graphdef, state_in)
out_y, _ = merged_layer(
y_in, *layer_args, previous_chunk=previous_chunk, page_state=page_state, slot=slot, **layer_kwargs
)
return out_y, nnx.state(merged_layer)
checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse)
graphdef, state = nnx.split(self.layers_remainder)
y, new_state = checkpointed_gemma_fn(graphdef, state, y)
nnx.update(self.layers_remainder, new_state)
return y
[docs]
def decoder_as_linen(
config: Config,
mesh: Mesh,
rngs: nnx.Rngs,
model_mode: str,
quant: None | Quant = None,
):
"""Creates a Decoder module"""
module = nnx_wrappers.to_linen(
NNXDecoder,
config=config,
mesh=mesh,
model_mode=model_mode,
rngs=rngs,
quant=quant,
name="decoder",
abstract_init=False,
metadata_fn=initializers.variable_to_logically_partitioned,
)
return module