# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for decoder layers"""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module
import functools
from typing import Any
import warnings
from flax import linen as nn
from flax import nnx
from flax.linen.partitioning import ScanIn
import jax
from jax.ad_checkpoint import checkpoint_name
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Config, DecoderBlockType, ShardMode
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN
from maxtext.inference import page_manager
from maxtext.layers import linears
from maxtext.layers import mhc
from maxtext.layers import normalizations
from maxtext.layers import pipeline
from maxtext.layers import quantizations
from maxtext.layers.attentions import attention_as_linen
from maxtext.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen
from maxtext.layers.normalizations import rms_norm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.models import (
deepseek,
deepseek_batchsplit,
deepseek_batchsplit_fp8,
gemma,
gemma2,
gemma3,
gemma4,
gpt3,
gpt_oss,
llama2,
llama4,
mistral,
mixtral,
olmo3,
qwen2,
qwen3,
qwen3_custom,
qwen3_5,
simple_layer,
)
from maxtext.multimodal import utils as mm_utils
from maxtext.utils.sharding import create_sharding
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import sharding
# ------------------------------------------------------------------------------
# The network: Decoder Definitions
# ------------------------------------------------------------------------------
[docs]
class DecoderLayer(nn.Module):
"""
Transformer decoder layer that attends to the encoder.
This is the core, reusable building block for both the main model's
decoder stack and the auxiliary MTP layers.
"""
config: Config
mesh: Mesh
model_mode: str
quant: None | Quant = None
@nn.compact
def __call__(
self,
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
kv_cache: jax.Array | None = None,
attention_metadata: dict[str, Any] | None = None,
):
cfg = self.config
mesh = self.mesh
_maybe_shard_with_logical = functools.partial(
sharding.maybe_shard_with_logical,
mesh=mesh,
shard_mode=cfg.shard_mode,
debug_sharding=cfg.debug_sharding,
)
if self.model_mode == MODEL_MODE_PREFILL:
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
else:
logical_axis_names = ("activation_batch", "activation_length", "activation_embed")
if model_mode == MODEL_MODE_PREFILL:
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
else:
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = rms_norm(
num_features=inputs.shape[-1],
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_norm",
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("norm",),
)(inputs)
if model_mode == MODEL_MODE_PREFILL:
lnx = _maybe_shard_with_logical(lnx, logical_axis_names)
else:
lnx = _maybe_shard_with_logical(lnx, logical_axis_names)
attention_layer = attention_as_linen(
config=self.config,
num_query_heads=cfg.num_query_heads,
num_kv_heads=cfg.num_kv_heads,
head_dim=cfg.head_dim,
max_target_length=cfg.max_target_length,
max_prefill_predict_length=cfg.max_prefill_predict_length,
attention_kernel=cfg.attention,
inputs_q_shape=lnx.shape,
inputs_kv_shape=lnx.shape,
mesh=mesh,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))),
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
reshape_q=cfg.reshape_q,
use_mrope=cfg.use_mrope,
mrope_section=cfg.mrope_section,
share_kv_projections=cfg.share_kv_projections,
model_mode=model_mode,
)
attention_lnx, kv_cache = attention_layer(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
if model_mode == MODEL_MODE_PREFILL:
attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names)
else:
attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names)
# MLP block.
mlp_lnx = linears.mlp_block(
in_features=lnx.shape[-1],
intermediate_dim=cfg.mlp_dim,
activations=cfg.mlp_activations,
intermediate_dropout_rate=cfg.dropout_rate,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="mlp",
model_mode=model_mode,
config=cfg,
quant=self.quant,
mesh=self.mesh,
)(lnx, deterministic=deterministic)
if model_mode == MODEL_MODE_PREFILL:
mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names)
else:
mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names)
next_layer_addition = mlp_lnx + attention_lnx
next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(
next_layer_addition, deterministic=deterministic
)
layer_output = next_layer_addition_dropped_out + inputs
if model_mode == MODEL_MODE_PREFILL:
layer_output = _maybe_shard_with_logical(
layer_output,
logical_axis_names,
)
else:
layer_output = _maybe_shard_with_logical(
layer_output,
logical_axis_names,
)
if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
if cfg.scan_layers:
return layer_output, None
else:
return layer_output, kv_cache
[docs]
class SequentialBlockDecoderLayers(nn.Module):
"""Sequential unscanned series of decoder layers."""
decoder_layer: Any
num_decoder_layers: int
config: Config
mesh: Mesh
quant: Quant
model_mode: str
@nn.compact
def __call__(
self,
inputs: jnp.ndarray,
decoder_segment_ids,
decoder_positions,
deterministic: bool,
model_mode,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
) -> jnp.ndarray:
for lyr in range(self.num_decoder_layers):
inputs = self.decoder_layer(
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode
)(
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
slot=slot,
page_state=page_state,
)
if self.config.scan_layers:
inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None).
if self.config.scan_layers:
return inputs, None # pytype: disable=bad-return-type
else:
return inputs
[docs]
def deepstack_process(hidden_states, bidirectional_mask, visual_embeds):
"""Process deepstack visual embeddings by adding them to hidden states at visual token positions.
Args:
hidden_states: [batch, seq_len, hidden_dim] decoder hidden states
bidirectional_mask: [batch, seq_len] boolean mask marking visual token positions
visual_embeds: [batch, num_visual_tokens, hidden_dim] visual features from encoder layer
Returns:
Updated hidden_states with visual features added at visual positions
"""
# Expand mask to [batch, seq_len, 1] for broadcasting
mask_expanded = bidirectional_mask[:, :, jnp.newaxis]
# Use cumsum to map each True position in mask to its index in visual_embeds
visual_token_idx = jnp.cumsum(bidirectional_mask, axis=1) - 1 # [batch, seq_len], 0-indexed
# Gather visual tokens: for each position, get the corresponding visual token
batch_idx = jnp.arange(hidden_states.shape[0])[:, jnp.newaxis] # [batch, 1]
visual_embeds_scattered = visual_embeds[batch_idx, visual_token_idx, :] # [batch, seq_len, hidden]
# Only add where mask is True: hidden_states += visual_embeds * mask
hidden_states = hidden_states + visual_embeds_scattered * mask_expanded
return hidden_states
[docs]
class Decoder(nn.Module):
"""A stack of decoder layers as a part of an encoder-decoder architecture."""
config: Config
mesh: Mesh
quant: None | Quant = None
model_mode: str = MODEL_MODE_TRAIN
[docs]
def setup(self):
"""Initialize decoder layer."""
self.decoder_layer = self.get_decoder_layers()
self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim)
if self.config.using_pipeline_parallelism:
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
remat_policy = self.get_remat_policy()
self.pipeline_module = pipeline.create_pipeline(
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
)
[docs]
def minimal_policy(self, with_context=False, with_quantization=False):
"""Helper for creating minimal checkpoint policies."""
names = [
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwi_0",
"mlpwi_1",
"mlpwi",
"mlpwo",
]
if with_context:
names.append("context")
if with_quantization:
names.append("quantization")
return jax.checkpoint_policies.save_only_these_names(*names)
[docs]
def get_remat_policy(self):
"""Get remat policy"""
policy = None
cfg = self.config
if cfg.remat_policy != "none":
if cfg.remat_policy in ("minimal_with_context", "minimal_flash"):
# save all
if cfg.remat_policy == "minimal_flash":
max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.")
max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.")
policy = self.minimal_policy(with_context=True)
elif cfg.remat_policy == "minimal":
# save all except context
policy = self.minimal_policy()
elif cfg.remat_policy == "minimal_with_quantization":
if cfg.scan_layers:
warnings.warn(
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
"beneficial for performance."
)
policy = self.minimal_policy(with_context=False, with_quantization=True)
elif cfg.remat_policy == "minimal_with_context_and_quantization":
if cfg.scan_layers:
warnings.warn(
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
"beneficial for performance."
)
policy = self.minimal_policy(with_context=True, with_quantization=True)
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"context",
"out_proj",
)
elif cfg.remat_policy == "save_dot_except_mlpwi":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwo",
)
elif cfg.remat_policy == "save_dot_except_mlp":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
)
elif cfg.remat_policy == "save_qkv_proj":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
)
elif cfg.remat_policy == "qkv_proj_offloaded":
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"],
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "minimal_offloaded":
# offload all except context
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=[
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwi_0",
"mlpwi_1",
"mlpwi",
"mlpwo",
],
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "custom":
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=cfg.tensors_on_device,
names_which_can_be_offloaded=cfg.tensors_to_offload,
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "save_out_proj":
policy = jax.checkpoint_policies.save_only_these_names(
"out_proj",
)
else:
assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies"
policy = None
return policy
[docs]
def get_decoder_layers(self):
"""Retrieves a list of decoder layer classes based on the `decoder_block` config.
Returns:
A list containing one or more `nn.Module` classes for the decoder.
"""
match self.config.decoder_block:
case DecoderBlockType.DEFAULT:
return [DecoderLayer]
case DecoderBlockType.LLAMA2:
return [llama2.LlamaDecoderLayerToLinen]
case DecoderBlockType.LLAMA2LTI:
return [llama2.LlamaLTIDecoderLayerToLinen]
case DecoderBlockType.MISTRAL:
# TODO(ranran): update to Mistral with sliding window attention
return [mistral.MistralDecoderLayerToLinen]
case DecoderBlockType.MIXTRAL:
return [mixtral.MixtralDecoderLayerToLinen]
case DecoderBlockType.DEEPSEEK:
return [
deepseek.DeepSeekDenseLayerToLinen,
deepseek.DeepSeekMoELayerToLinen,
]
case DecoderBlockType.GEMMA:
return [gemma.GemmaDecoderLayerToLinen]
case DecoderBlockType.GEMMA2:
return [gemma2.Gemma2DecoderLayerToLinen]
case DecoderBlockType.GEMMA3:
return [gemma3.Gemma3DecoderLayerToLinen]
case DecoderBlockType.GEMMA4:
return [gemma4.Gemma4ScannableBlockToLinen] if self.config.scan_layers else [gemma4.Gemma4DecoderLayerToLinen]
case DecoderBlockType.GPT3:
return [gpt3.Gpt3DecoderLayerToLinen]
case DecoderBlockType.GPT_OSS:
return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen]
case DecoderBlockType.QWEN2:
return [qwen2.Qwen2DecoderLayerToLinen]
case DecoderBlockType.QWEN3:
return [qwen3.Qwen3DecoderLayerToLinen]
case DecoderBlockType.QWEN3_MOE:
return [qwen3.Qwen3MoeDecoderLayerToLinen]
case DecoderBlockType.QWEN3_CUSTOM_MOE:
return [qwen3_custom.Qwen3CustomMoeDecoderLayerToLinen]
case DecoderBlockType.QWEN3_NEXT:
return [qwen3.Qwen3NextScannableBlockToLinen] if self.config.scan_layers else [qwen3.Qwen3NextDecoderLayerToLinen]
case DecoderBlockType.QWEN3_5:
return [qwen3_5.Qwen3_5ScannableBlockToLinen] if self.config.scan_layers else [qwen3_5.Qwen3_5DecoderLayerToLinen]
case DecoderBlockType.SIMPLE:
return [simple_layer.SimpleDecoderLayerToLinen]
case DecoderBlockType.SIMPLE_MLP:
return [simple_layer.SimpleMlpDecoderLayerToLinen]
case DecoderBlockType.LLAMA4:
return [llama4.Llama4ScannableBlockToLinen] if self.config.scan_layers else [llama4.Llama4DecoderLayerToLinen]
case DecoderBlockType.OLMO3:
return [olmo3.Olmo3ScannableBlockToLinen] if self.config.scan_layers else [olmo3.Olmo3DecoderLayerToLinen]
case _:
# Default case to handle any unknown decoder block types.
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
[docs]
def set_remat_policy(self, block_layers, policy):
"""Set remat policy"""
RemattedBlockLayers = []
for block_layer in block_layers:
if self.config.parameter_memory_host_offload:
# Define parameter movement with mesh-based sharding
def move_to_device(variables):
"""Move parameters to device with proper sharding."""
def map_fn(path, value):
max_logging.log(f"models.py: Moving parameter {path} to device")
return jax.device_put(value, max_utils.device_space())
return jax.tree_util.tree_map_with_path(map_fn, variables)
# Transform layer class before remat
block_layer = nn.map_variables(block_layer, ["params"], move_to_device, mutable=True)
# Apply remat policy to layer
layer = nn.remat(
block_layer,
prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config),
policy=policy,
static_argnums=(4, 5), # Deterministic and model mode are static arguments.
)
RemattedBlockLayers.append(layer)
return RemattedBlockLayers
[docs]
def get_norm_layer(self, num_features: int):
"""get normalization layer (return type inherits from nn.Module)"""
if self.config.decoder_block in (
DecoderBlockType.DEFAULT,
DecoderBlockType.LLAMA2,
DecoderBlockType.MISTRAL,
DecoderBlockType.MIXTRAL,
DecoderBlockType.DEEPSEEK,
DecoderBlockType.GEMMA,
DecoderBlockType.GEMMA2,
DecoderBlockType.GEMMA3,
DecoderBlockType.GEMMA4,
DecoderBlockType.QWEN2,
DecoderBlockType.QWEN3,
DecoderBlockType.QWEN3_MOE,
DecoderBlockType.QWEN3_CUSTOM_MOE,
DecoderBlockType.GPT_OSS,
DecoderBlockType.SIMPLE,
DecoderBlockType.SIMPLE_MLP,
DecoderBlockType.LLAMA4,
DecoderBlockType.OLMO3,
DecoderBlockType.LLAMA2LTI,
):
return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode)
elif self.config.decoder_block == DecoderBlockType.GPT3:
return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True)
elif self.config.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
return functools.partial(
normalizations.Qwen3NextRMSNormLinen, num_features=num_features, shard_mode=self.config.shard_mode
)
else:
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
[docs]
def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs):
"""scan decoder layers, calls `flax.linen.transforms.scan`"""
initializing = self.is_mutable_collection("params")
params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis)
cache_spec = 0
scan_fn = nn.scan(
decoder_layer,
variable_axes={
"params": params_spec,
"cache": cache_spec,
"intermediates": 0,
"aqt": 0,
"batch_stats": 0,
"_overwrite_with_gradient": 0,
},
split_rngs={
"params": True,
"dropout": cfg.enable_dropout,
},
in_axes=in_axes_tuple,
length=length,
metadata_params={nn.PARTITION_NAME: metadata_axis_name},
)
return scan_fn(
config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args
)
[docs]
def get_pipeline_stage_module(self, decoder_blocks):
"""get pipeline stage module"""
def get_layer_to_pipeline(blocks, cfg):
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
return blocks[1] # return the sparse block
else:
return blocks[0]
cfg = self.config
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
if cfg.set_remat_policy_on_layers_per_stage:
policy = self.get_remat_policy()
base_stage = self.set_remat_policy([base_stage], policy)[0]
if cfg.num_layers_per_pipeline_stage == 1:
stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode)
elif cfg.scan_layers_per_stage:
stage_module = self.scan_decoder_layers(
cfg,
base_stage,
cfg.num_layers_per_pipeline_stage,
"layers_per_stage",
self.mesh,
in_axes_tuple=(nn.broadcast,) * 4,
)
else:
stage_module = SequentialBlockDecoderLayers(
decoder_layer=base_stage,
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
config=cfg,
mesh=self.mesh,
quant=self.quant,
model_mode=self.model_mode,
)
return stage_module
@nn.compact
def _apply_embedding(
self,
shared_embedding: nn.Module | nnx.Module,
decoder_input_tokens,
decoder_positions,
deterministic,
model_mode,
multimodal_input=None,
):
"""Applies token and positional embeddings to the input tokens."""
cfg = self.config
y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode)
# Merge the image embeddings with the text embeddings for multimodal models
if multimodal_input is not None:
image_embeddings = multimodal_input.image_embeddings
bidirectional_mask = multimodal_input.bidirectional_mask
image_masks = multimodal_input.image_masks
audio_embeddings = multimodal_input.audio_embeddings
audio_masks = multimodal_input.audio_masks
if image_embeddings is not None and cfg.use_multimodal:
if cfg.model_name in [
"gemma3-4b",
"gemma3-12b",
"gemma3-27b",
"gemma4-26b",
"gemma4-31b",
"llama4-17b-16e",
"llama4-17b-128e",
"qwen3-omni-30b-a3b",
]:
y = mm_utils.merge_mm_embeddings(
text_embeddings=y,
multimodal_embeddings=image_embeddings,
mask=bidirectional_mask,
token_masks=image_masks,
)
# TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed
else:
raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}")
if audio_embeddings is not None and cfg.use_audio:
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
y = mm_utils.merge_mm_embeddings(
text_embeddings=y,
multimodal_embeddings=audio_embeddings,
mask=audio_masks,
token_masks=None,
)
else:
raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}")
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)
y = y.astype(cfg.dtype)
if cfg.use_untrainable_positional_embedding:
y += positional_embedding_as_linen(embedding_dims=cfg.base_emb_dim)(y.shape[1], decoder_positions)
if cfg.trainable_position_size > 0:
y += embed_as_linen(
num_embeddings=cfg.trainable_position_size,
num_features=cfg.emb_dim,
dtype=cfg.dtype,
embedding_init=nn.initializers.normal(stddev=1.0),
name="position_embedder",
config=cfg,
mesh=self.mesh,
)(decoder_positions.astype("int32"), model_mode=model_mode)
return y
[docs]
@nn.compact
def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode):
"""Applies final normalization and projects hidden states to logits."""
cfg = self.config
if cfg.shard_mode == ShardMode.EXPLICIT:
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length", "activation_embed"))
else:
norm_out_sharding = None
y = self.get_norm_layer(num_features=y.shape[-1])(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="decoder_norm",
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("norm",),
parameter_memory_host_offload=cfg.parameter_memory_host_offload,
)(y, out_sharding=norm_out_sharding)
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab"))
else:
out_sharding = create_sharding(
self.mesh, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab")
)
# [batch, length, emb_dim] -> [batch, length, vocab_size]
if cfg.logits_via_embedding:
# Use the transpose of embedding matrix for logit transform.
if isinstance(shared_embedding, nnx.Module):
embedding_table = shared_embedding.embedding.value
else:
embedding_table = shared_embedding.variables["params"]["embedding"]
if isinstance(embedding_table, nn.spmd.LogicallyPartitioned):
embedding_table = embedding_table.unbox()
attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype
logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding)
if self.config.normalize_embedding_logits:
# Correctly normalize pre-softmax logits for this shared case.
logits = logits / jnp.sqrt(y.shape[-1])
if cfg.final_logits_soft_cap:
logits = logits / cfg.final_logits_soft_cap
logits = jnp.tanh(logits) * cfg.final_logits_soft_cap
else:
logits = linears.dense_general(
inputs_shape=y.shape,
out_features_shape=cfg.vocab_size,
weight_dtype=cfg.weight_dtype,
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
kernel_axes=("embed_vocab", "vocab"),
shard_mode=cfg.shard_mode,
name="logits_dense",
matmul_precision=self.config.matmul_precision,
parameter_memory_host_offload=cfg.parameter_memory_host_offload,
)(
y,
out_sharding=out_sharding,
) # We do not quantize the logits matmul.
if self.config.cast_logits_to_fp32:
logits = logits.astype(jnp.float32)
return logits
@nn.compact
def __call__(
self,
shared_embedding: nn.Module | nnx.Module,
decoder_input_tokens,
decoder_positions,
decoder_segment_ids=None,
deterministic=False,
model_mode=MODEL_MODE_TRAIN,
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
multimodal_input=None,
kv_caches: list[jax.Array] | None = None,
attention_metadata=None,
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
):
cfg = self.config
mesh = self.mesh
assert decoder_input_tokens.ndim == 2 # [batch, len]
# [batch, length] -> [batch, length, emb_dim]
y = self._apply_embedding(
shared_embedding,
decoder_input_tokens,
decoder_positions,
deterministic,
model_mode,
multimodal_input=multimodal_input,
)
mhc_expand, mhc_reduce = mhc.get_functions(cfg.mhc_expansion_rate)
if cfg.mhc_expansion_rate > 1:
# (batch, length, emb_dim) --> (batch, length, mhc_expansion_rate, emb_dim)
y = mhc_expand(y)
policy = self.get_remat_policy()
RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy)
# scan does not support kwargs in layer call, passing broadcast_args as positional arg
broadcast_args = (
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)
if cfg.using_pipeline_parallelism:
logical_partition_spec = (
self.pipeline_module.get_weight_sharding(y, decoder_segment_ids, decoder_positions, deterministic, model_mode)
if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat
else None
)
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
dense_layer = RemattedBlockLayers[0]
moe_layer = RemattedBlockLayers[1]
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
# We chose not to pipeline the dense layers, only sparse for SPMD.
with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp):
y, _ = self.scan_decoder_layers(
cfg,
dense_layer,
cfg.first_num_dense_layers,
"dense_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
if num_moe_layers_outside_pp > 0:
y, _ = self.scan_decoder_layers(
cfg,
moe_layer,
num_moe_layers_outside_pp,
"moe_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
else: # Not DeepSeek
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
if remaining_layers > 0:
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp):
y, _ = self.scan_decoder_layers(
cfg,
RemattedBlockLayers[0],
remaining_layers,
"layers_outside_pipeline",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
else:
if cfg.scan_layers:
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
layer_call_kwargs = {
"page_state": page_state,
"previous_chunk": previous_chunk,
"slot": slot,
}
dense_layer = RemattedBlockLayers[0]
moe_layer = RemattedBlockLayers[1]
if cfg.engram_layers:
original_dense_call = dense_layer.__call__
original_moe_call = moe_layer.__call__
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
common_kwargs = {
"dense_layer": dense_layer,
"moe_layer": moe_layer,
"original_dense_call": original_dense_call,
"original_moe_call": original_moe_call,
"layer_call_kwargs": layer_call_kwargs,
"decoder_segment_ids": decoder_segment_ids,
"decoder_positions": decoder_positions,
"deterministic": deterministic,
"model_mode": model_mode,
"decoder_input_tokens": decoder_input_tokens,
"broadcast_args": broadcast_args,
}
# Apply Dense Layers
y = self._apply_interleaved_scanned_layers(
y,
layer_type="dense",
start_idx=0,
end_idx=cfg.first_num_dense_layers,
engram_indices=cfg.engram_layers,
**common_kwargs,
)
# Apply MoE Layers
y = self._apply_interleaved_scanned_layers(
y,
layer_type="moe",
start_idx=cfg.first_num_dense_layers,
end_idx=cfg.num_decoder_layers,
engram_indices=cfg.engram_layers,
**common_kwargs,
)
else:
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
y, _ = self.scan_decoder_layers(
cfg,
dense_layer,
cfg.first_num_dense_layers,
"dense_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
# If batch-split schedule is used and initialization is complete,
# as detected by immutable params, use deepseek_batchsplit custom
# scan with initialized parameters.
if cfg.use_batch_split_schedule and not self.is_mutable_collection("params"):
# old version of batch-split that fully uses qwix quantization.
if cfg.use_qwix_quantization and not cfg.use_manual_quantization:
y = deepseek_batchsplit_fp8.scan_batch_split_layers(
y,
self.variables["params"]["moe_layers"],
decoder_positions,
decoder_segment_ids,
model_mode=model_mode,
mesh=mesh,
quant=self.quant,
cfg=cfg,
policy=policy,
)
else:
# bf16 and fp8 code path for pure-JAX batch-split.
# fp8 code path supports both manual quantization and qwix
# quantization.
y = deepseek_batchsplit.scan_batch_split_layers(
y,
self.variables["params"]["moe_layers"],
decoder_positions,
mesh=mesh,
cfg=cfg,
num_layers=num_moe_layers,
)
else:
y, _ = self.scan_decoder_layers(
cfg,
moe_layer,
num_moe_layers,
"moe_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
elif cfg.decoder_block == DecoderBlockType.GEMMA3:
bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None
y = self._apply_gemma3_scanned_blocks(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask_value,
previous_chunk,
page_state,
slot,
)
elif cfg.decoder_block == DecoderBlockType.GEMMA4:
bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None
y = self._apply_gemma4_scanned_blocks(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask_value,
previous_chunk,
page_state,
slot,
)
else:
RemattedBlockLayer = RemattedBlockLayers[0]
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
layer_kwargs = {}
if cfg.decoder_block == DecoderBlockType.LLAMA4:
layer_kwargs = {
"nope_layer_interval": self.config.nope_layer_interval,
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
}
# Update broadcast_args and in_axes_tuple for vLLM RPA
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
current_broadcast_args = list(broadcast_args)
current_in_axes_tuple = list(in_axes_tuple)
if kv_caches is not None:
# Stack kv_caches for scan: [num_layers, ...]
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
# We pass (y, stacked_kv_cache, 0) as the carry
carry = (y, stacked_kv_cache, 0)
# We don't pass kv_cache as a scanned argument anymore
# Pass None for previous_chunk, slot, page_state, kv_cache to align with __call__ signature
current_broadcast_args.extend([None, None, None, None, attention_metadata])
current_in_axes_tuple.extend([nn.broadcast] * 5)
max_logging.info(f"DEBUG: len(current_broadcast_args)={len(current_broadcast_args)}")
max_logging.info(f"DEBUG: current_broadcast_args={[type(a) for a in current_broadcast_args]}")
final_carry, _ = self.scan_decoder_layers(
cfg,
RemattedBlockLayer,
scan_length,
"layers",
mesh,
in_axes_tuple=tuple(current_in_axes_tuple),
model_mode=model_mode,
**layer_kwargs,
)(carry, *current_broadcast_args)
y, returned_kv_cache, _ = final_carry
# Update the list of KV caches from the scanned results
for i in range(cfg.num_decoder_layers):
kv_caches[i] = returned_kv_cache[i]
else:
# Fallback to old behavior if kv_caches is None (not vLLM RPA)
current_broadcast_args.append(None)
current_in_axes_tuple.append(nn.broadcast)
y, _ = self.scan_decoder_layers(
cfg,
RemattedBlockLayer,
scan_length,
"layers",
mesh,
in_axes_tuple=tuple(current_in_axes_tuple),
model_mode=model_mode,
**layer_kwargs,
)(y, *current_broadcast_args)
else:
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
dense_layer = RemattedBlockLayers[0]
moe_layer = RemattedBlockLayers[1]
layers = [dense_layer, moe_layer]
layer_prefixes = ["dense_layers", "moe_layers"]
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
num_layers_list = [cfg.first_num_dense_layers, num_moe_layers]
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
global_layer_idx_offset = 0
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
for index in range(num_layers):
global_layer_idx = global_layer_idx_offset + index
kv_cache = kv_caches[index] if kv_caches is not None else None
input_tokens = decoder_input_tokens if cfg.engram_layers else None
y, kv_cache = layer(
config=cfg,
mesh=mesh,
name=f"{layer_prefix}_{index}",
quant=self.quant,
model_mode=self.model_mode,
layer_idx=global_layer_idx,
)(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
decoder_input_tokens=input_tokens,
)
if kv_caches is not None and kv_cache is not None:
kv_caches[index] = kv_cache
global_layer_idx_offset += num_layers
else:
for lyr in range(cfg.num_decoder_layers):
RemattedBlockLayer = RemattedBlockLayers[0]
layer_kwargs = {}
layer_call_kwargs = {}
if cfg.decoder_block == DecoderBlockType.GEMMA3:
# Gemma3 uses both global and sliding window attention depending on the layer index.
bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None
layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)}
layer_call_kwargs = {"bidirectional_mask": bidirectional_mask_value}
if cfg.decoder_block == DecoderBlockType.GEMMA4:
# Gemma4 uses both global and sliding window attention depending on the layer index.
bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None
layer_kwargs = {"attention_type": gemma4.get_attention_type(layer_id=lyr)}
layer_call_kwargs = {"bidirectional_mask": bidirectional_mask_value}
if cfg.decoder_block == DecoderBlockType.LLAMA4:
layer_kwargs = {
"is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval),
"is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step),
}
if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
layer_kwargs = {"layer_idx": lyr}
kv_cache = None
if kv_caches is not None and cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
kv_cache = kv_caches[lyr]
elif kv_caches is not None and cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
# For Qwen3Next & Qwen3.5, kv_caches is a dictionary of lists of caches.
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr])
if cfg.decoder_block == DecoderBlockType.GPT_OSS:
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
if cfg.decoder_block == DecoderBlockType.OLMO3:
layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)}
layer = RemattedBlockLayer(
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
)
y, returned_cache = layer(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
**layer_call_kwargs,
)
if kv_caches is not None and returned_cache is not None:
if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
kv_caches[lyr] = returned_cache
elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_caches["key_cache"][lyr] = returned_cache[0]
kv_caches["value_cache"][lyr] = returned_cache[1]
if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
visual_embeds = deepstack_visual_embeds[lyr]
# Use bidirectional_mask to identify visual token positions
bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None
if bidirectional_mask_value is not None and visual_embeds is not None:
y = deepstack_process(y, bidirectional_mask_value, visual_embeds)
assert isinstance(y, jax.Array)
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
if cfg.mhc_expansion_rate > 1:
# (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim)
hidden_state = mhc_reduce(y)
else:
hidden_state = y
# When initializing with vLLM RPA attention, we need to run the output head to
# initialize any parameters associated with it.
if self.is_initializing() and cfg.attention == "vllm_rpa":
_ = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
# When invoking from vLLM with RPA attention, logit computation is deferred to a later stage.
if cfg.attention == "vllm_rpa":
logits = None
# When in the Indexer Dense Warm-up stage, skip the expensive output head projection
# for efficiency, as the main model is frozen and the LM loss is not needed.
# TODO(b/501446870): Investigate model_mode as train at beginning for decoding stage
elif (
cfg.use_indexer and cfg.indexer_loss_scaling_factor > 0.0 and not cfg.indexer_sparse_training
) and model_mode == MODEL_MODE_TRAIN:
logits = None
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
elif cfg.num_vocab_tiling > 1 and model_mode == MODEL_MODE_TRAIN:
logits = None
self.sow("intermediates", "hidden_states", hidden_state)
else:
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
# The API of the Decoder is now a tuple, providing both the main output
# and the raw hidden state needed for auxiliary tasks.
return logits, hidden_state, kv_caches
def _apply_gemma3_scanned_blocks(
self,
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask,
previous_chunk,
page_state,
slot,
):
"""Applies Gemma3 scanned decoder blocks, handling main scan and remainders."""
cfg = self.config
mesh = self.mesh
# Define the repeating pattern length and calculate how many full blocks to scan
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
scan_length = cfg.num_decoder_layers // attention_pattern_length
policy = self.get_remat_policy()
RemattedGemma3Block = self.set_remat_policy([gemma3.Gemma3ScannableBlockToLinen], policy)[0]
layer_call_kwargs = {"bidirectional_mask": bidirectional_mask}
layer_kwargs = {"num_of_layers": attention_pattern_length}
# Apply the main scan over the full blocks
if scan_length > 0:
broadcast_args = (
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)
y, _ = self.scan_decoder_layers(
cfg,
RemattedGemma3Block,
scan_length,
"layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=self.model_mode,
**layer_kwargs,
)(y, *broadcast_args, **layer_call_kwargs)
# Apply any remaining layers that did not fit into a full scanned block
num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length
if num_remaining_layers > 0:
# We name the remainder block with a 'remainder' suffix to avoid parameter name collisions
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
layer = RemattedGemma3Block(
config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs
) # pytype: disable=wrong-keyword-args
y, _ = layer(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
**layer_call_kwargs,
)
return y
def _apply_gemma4_scanned_blocks(
self,
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
bidirectional_mask,
previous_chunk,
page_state,
slot,
):
"""Applies Gemma4 scanned decoder blocks, handling main scan and remainders."""
cfg = self.config
mesh = self.mesh
# Define the repeating pattern length and calculate how many full blocks to scan
block_pattern_len = len(gemma4.GEMMA4_ATTENTION_PATTERN)
num_full_blocks = cfg.num_decoder_layers // block_pattern_len
remainder_layers = cfg.num_decoder_layers % block_pattern_len
broadcast_args = (
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
slot,
page_state,
previous_chunk,
bidirectional_mask,
)
if num_full_blocks > 0:
ScannableBlockToLinen = gemma4.Gemma4ScannableBlockToLinen
policy = self.get_remat_policy()
RemattedGemma4Block = self.set_remat_policy([ScannableBlockToLinen], policy)[0]
# For a fully scanned block, apply it inside a nn.scan over the calculated number of full blocks
y, _ = nn.scan(
RemattedGemma4Block,
variable_axes={
"params": cfg.param_scan_axis,
"cache": 0,
"intermediates": 0,
"aqt": 0,
"_overwrite_with_gradient": 0,
},
split_rngs={"params": True, "dropout": cfg.enable_dropout},
in_axes=(nn.broadcast,) * len(broadcast_args),
length=num_full_blocks,
metadata_params={
nn.PARTITION_NAME: "layers",
"abstract_init": False,
},
)(
config=cfg,
mesh=mesh,
quant=self.quant,
model_mode=model_mode,
num_of_layers=block_pattern_len,
name="scanned_blocks",
)(
y, *broadcast_args
)
# Process any remaining layers that don't fit into a full scanned block
for layer_id in range(cfg.num_decoder_layers - remainder_layers, cfg.num_decoder_layers):
attention_type = gemma4.get_attention_type(layer_id)
layer = gemma4.Gemma4DecoderLayerToLinen(
config=cfg,
mesh=mesh,
model_mode=model_mode,
quant=self.quant,
attention_type=attention_type,
layer_idx=layer_id,
)
y = layer(y, *broadcast_args)
if cfg.scan_layers:
y = y[0]
return y
# TODO(b/490118813): Relocate the following functions to their designated directories
# once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer()
# _apply_scanned_chunk() and _apply_interleaved_scanned_layers().
def _find_next_boundary(self, current_idx, end_idx, engram_indices):
"""Finds the next index boundary, either the next Engram layer index or the overall end index."""
next_engrams = [l for l in engram_indices if l > current_idx]
if next_engrams:
return min(end_idx, *next_engrams)
return end_idx
def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
"""Applies a single, unscanned Engram layer."""
layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"]
layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
original_call = kwargs["original_dense_call"] if layer_type == "dense" else kwargs["original_moe_call"]
layer_call_kwargs = kwargs["layer_call_kwargs"]
layer.__call__ = original_call
y, _ = layer(
config=self.config,
mesh=self.mesh,
name=f"{layer_prefix}_engram_{current_idx}",
quant=self.quant,
model_mode=self.model_mode,
layer_idx=current_idx,
)(
y,
kwargs["decoder_segment_ids"],
kwargs["decoder_positions"],
kwargs["deterministic"],
kwargs["model_mode"],
decoder_input_tokens=kwargs["decoder_input_tokens"],
**layer_call_kwargs,
)
layer.__call__ = functools.partial(original_call, **layer_call_kwargs)
return y
def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_type, **kwargs):
"""Applies a contiguous chunk of layers using the scan operation."""
layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"]
layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
broadcast_args = kwargs["broadcast_args"]
scan_length = next_boundary - current_idx
if scan_length > 0:
y, _ = self.scan_decoder_layers(
self.config,
layer,
scan_length,
f"{layer_prefix}_{current_idx}_{next_boundary - 1}",
self.mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=kwargs["model_mode"],
)(y, *broadcast_args)
return y
def _apply_interleaved_scanned_layers(self, y, layer_type, start_idx, end_idx, engram_indices, **kwargs):
"""Applies a mix of scanned standard layers and unscanned Engram layers."""
current_idx = start_idx
while current_idx < end_idx:
if current_idx in engram_indices:
# Handle individual unscanned Engram layer
y = self._apply_single_engram_layer(y, current_idx, layer_type, **kwargs)
current_idx += 1
else:
# Find next boundary and scan the chunk
next_boundary = self._find_next_boundary(current_idx, end_idx, engram_indices)
y = self._apply_scanned_chunk(y, current_idx, next_boundary, layer_type, **kwargs)
current_idx = next_boundary
return y