Source code for maxtext.layers.decoders

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for decoder layers"""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

import functools
from typing import Any
import warnings

from flax import linen as nn
from flax import nnx
from flax.linen.partitioning import ScanIn
import jax
from jax.ad_checkpoint import checkpoint_name
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Config, DecoderBlockType, ShardMode
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN
from maxtext.inference import page_manager
from maxtext.layers import linears
from maxtext.layers import mhc
from maxtext.layers import normalizations
from maxtext.layers import pipeline
from maxtext.layers import quantizations
from maxtext.layers.attentions import attention_as_linen
from maxtext.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen
from maxtext.layers.normalizations import rms_norm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.models import (
    deepseek,
    deepseek_batchsplit,
    deepseek_batchsplit_fp8,
    gemma,
    gemma2,
    gemma3,
    gemma4,
    gpt3,
    gpt_oss,
    llama2,
    llama4,
    mistral,
    mixtral,
    olmo3,
    qwen2,
    qwen3,
    qwen3_custom,
    qwen3_5,
    simple_layer,
)
from maxtext.multimodal import utils as mm_utils
from maxtext.utils.sharding import create_sharding
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import sharding

# ------------------------------------------------------------------------------
# The network: Decoder Definitions
# ------------------------------------------------------------------------------


[docs] class DecoderLayer(nn.Module): """ Transformer decoder layer that attends to the encoder. This is the core, reusable building block for both the main model's decoder stack and the auxiliary MTP layers. """ config: Config mesh: Mesh model_mode: str quant: None | Quant = None @nn.compact def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, kv_cache: jax.Array | None = None, attention_metadata: dict[str, Any] | None = None, ): cfg = self.config mesh = self.mesh _maybe_shard_with_logical = functools.partial( sharding.maybe_shard_with_logical, mesh=mesh, shard_mode=cfg.shard_mode, debug_sharding=cfg.debug_sharding, ) if self.model_mode == MODEL_MODE_PREFILL: logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") else: logical_axis_names = ("activation_batch", "activation_length", "activation_embed") if model_mode == MODEL_MODE_PREFILL: inputs = _maybe_shard_with_logical(inputs, logical_axis_names) else: inputs = _maybe_shard_with_logical(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] lnx = rms_norm( num_features=inputs.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), )(inputs) if model_mode == MODEL_MODE_PREFILL: lnx = _maybe_shard_with_logical(lnx, logical_axis_names) else: lnx = _maybe_shard_with_logical(lnx, logical_axis_names) attention_layer = attention_as_linen( config=self.config, num_query_heads=cfg.num_query_heads, num_kv_heads=cfg.num_kv_heads, head_dim=cfg.head_dim, max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, inputs_q_shape=lnx.shape, inputs_kv_shape=lnx.shape, mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, dropout_rate=cfg.dropout_rate, name="self_attention", float32_qk_product=cfg.float32_qk_product, float32_logits=cfg.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), reshape_q=cfg.reshape_q, use_mrope=cfg.use_mrope, mrope_section=cfg.mrope_section, share_kv_projections=cfg.share_kv_projections, model_mode=model_mode, ) attention_lnx, kv_cache = attention_layer( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, kv_cache=kv_cache, attention_metadata=attention_metadata, ) if model_mode == MODEL_MODE_PREFILL: attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) else: attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) # MLP block. mlp_lnx = linears.mlp_block( in_features=lnx.shape[-1], intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="mlp", model_mode=model_mode, config=cfg, quant=self.quant, mesh=self.mesh, )(lnx, deterministic=deterministic) if model_mode == MODEL_MODE_PREFILL: mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) else: mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) next_layer_addition = mlp_lnx + attention_lnx next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( next_layer_addition, deterministic=deterministic ) layer_output = next_layer_addition_dropped_out + inputs if model_mode == MODEL_MODE_PREFILL: layer_output = _maybe_shard_with_logical( layer_output, logical_axis_names, ) else: layer_output = _maybe_shard_with_logical( layer_output, logical_axis_names, ) if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( "intermediates", "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) if cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache
[docs] class SequentialBlockDecoderLayers(nn.Module): """Sequential unscanned series of decoder layers.""" decoder_layer: Any num_decoder_layers: int config: Config mesh: Mesh quant: Quant model_mode: str @nn.compact def __call__( self, inputs: jnp.ndarray, decoder_segment_ids, decoder_positions, deterministic: bool, model_mode, slot: None | int = None, page_state: None | page_manager.PageState = None, ) -> jnp.ndarray: for lyr in range(self.num_decoder_layers): inputs = self.decoder_layer( config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode )( inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, slot=slot, page_state=page_state, ) if self.config.scan_layers: inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None). if self.config.scan_layers: return inputs, None # pytype: disable=bad-return-type else: return inputs
[docs] def deepstack_process(hidden_states, bidirectional_mask, visual_embeds): """Process deepstack visual embeddings by adding them to hidden states at visual token positions. Args: hidden_states: [batch, seq_len, hidden_dim] decoder hidden states bidirectional_mask: [batch, seq_len] boolean mask marking visual token positions visual_embeds: [batch, num_visual_tokens, hidden_dim] visual features from encoder layer Returns: Updated hidden_states with visual features added at visual positions """ # Expand mask to [batch, seq_len, 1] for broadcasting mask_expanded = bidirectional_mask[:, :, jnp.newaxis] # Use cumsum to map each True position in mask to its index in visual_embeds visual_token_idx = jnp.cumsum(bidirectional_mask, axis=1) - 1 # [batch, seq_len], 0-indexed # Gather visual tokens: for each position, get the corresponding visual token batch_idx = jnp.arange(hidden_states.shape[0])[:, jnp.newaxis] # [batch, 1] visual_embeds_scattered = visual_embeds[batch_idx, visual_token_idx, :] # [batch, seq_len, hidden] # Only add where mask is True: hidden_states += visual_embeds * mask hidden_states = hidden_states + visual_embeds_scattered * mask_expanded return hidden_states
[docs] class Decoder(nn.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" config: Config mesh: Mesh quant: None | Quant = None model_mode: str = MODEL_MODE_TRAIN
[docs] def setup(self): """Initialize decoder layer.""" self.decoder_layer = self.get_decoder_layers() self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim) if self.config.using_pipeline_parallelism: pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) remat_policy = self.get_remat_policy() self.pipeline_module = pipeline.create_pipeline( config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy )
[docs] def minimal_policy(self, with_context=False, with_quantization=False): """Helper for creating minimal checkpoint policies.""" names = [ "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj", "mlpwi_0", "mlpwi_1", "mlpwi", "mlpwo", ] if with_context: names.append("context") if with_quantization: names.append("quantization") return jax.checkpoint_policies.save_only_these_names(*names)
[docs] def get_remat_policy(self): """Get remat policy""" policy = None cfg = self.config if cfg.remat_policy != "none": if cfg.remat_policy in ("minimal_with_context", "minimal_flash"): # save all if cfg.remat_policy == "minimal_flash": max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") policy = self.minimal_policy(with_context=True) elif cfg.remat_policy == "minimal": # save all except context policy = self.minimal_policy() elif cfg.remat_policy == "minimal_with_quantization": if cfg.scan_layers: warnings.warn( "Scan layers can introduce overhead to checkpointed values that in some configurations is slower" "than not checkpointing at all. If you are using scan layers, benchmark with and without quantization " "checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is " "beneficial for performance." ) policy = self.minimal_policy(with_context=False, with_quantization=True) elif cfg.remat_policy == "minimal_with_context_and_quantization": if cfg.scan_layers: warnings.warn( "Scan layers can introduce overhead to checkpointed values that in some configurations is slower" "than not checkpointing at all. If you are using scan layers, benchmark with and without quantization " "checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is " "beneficial for performance." ) policy = self.minimal_policy(with_context=True, with_quantization=True) elif cfg.remat_policy == "save_dot_with_context_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( "query_proj", "value_proj", "key_proj", "qkv_proj", "context", "out_proj", ) elif cfg.remat_policy == "save_dot_except_mlpwi": policy = jax.checkpoint_policies.save_only_these_names( "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj", "mlpwo", ) elif cfg.remat_policy == "save_dot_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj", ) elif cfg.remat_policy == "save_qkv_proj": policy = jax.checkpoint_policies.save_only_these_names( "query_proj", "value_proj", "key_proj", "qkv_proj", ) elif cfg.remat_policy == "qkv_proj_offloaded": policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], offload_src="device", offload_dst="pinned_host", ) elif cfg.remat_policy == "minimal_offloaded": # offload all except context policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], names_which_can_be_offloaded=[ "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj", "mlpwi_0", "mlpwi_1", "mlpwi", "mlpwo", ], offload_src="device", offload_dst="pinned_host", ) elif cfg.remat_policy == "custom": policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=cfg.tensors_on_device, names_which_can_be_offloaded=cfg.tensors_to_offload, offload_src="device", offload_dst="pinned_host", ) elif cfg.remat_policy == "save_out_proj": policy = jax.checkpoint_policies.save_only_these_names( "out_proj", ) else: assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" policy = None return policy
[docs] def get_decoder_layers(self): """Retrieves a list of decoder layer classes based on the `decoder_block` config. Returns: A list containing one or more `nn.Module` classes for the decoder. """ match self.config.decoder_block: case DecoderBlockType.DEFAULT: return [DecoderLayer] case DecoderBlockType.LLAMA2: return [llama2.LlamaDecoderLayerToLinen] case DecoderBlockType.LLAMA2LTI: return [llama2.LlamaLTIDecoderLayerToLinen] case DecoderBlockType.MISTRAL: # TODO(ranran): update to Mistral with sliding window attention return [mistral.MistralDecoderLayerToLinen] case DecoderBlockType.MIXTRAL: return [mixtral.MixtralDecoderLayerToLinen] case DecoderBlockType.DEEPSEEK: return [ deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen, ] case DecoderBlockType.GEMMA: return [gemma.GemmaDecoderLayerToLinen] case DecoderBlockType.GEMMA2: return [gemma2.Gemma2DecoderLayerToLinen] case DecoderBlockType.GEMMA3: return [gemma3.Gemma3DecoderLayerToLinen] case DecoderBlockType.GEMMA4: return [gemma4.Gemma4ScannableBlockToLinen] if self.config.scan_layers else [gemma4.Gemma4DecoderLayerToLinen] case DecoderBlockType.GPT3: return [gpt3.Gpt3DecoderLayerToLinen] case DecoderBlockType.GPT_OSS: return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen] case DecoderBlockType.QWEN2: return [qwen2.Qwen2DecoderLayerToLinen] case DecoderBlockType.QWEN3: return [qwen3.Qwen3DecoderLayerToLinen] case DecoderBlockType.QWEN3_MOE: return [qwen3.Qwen3MoeDecoderLayerToLinen] case DecoderBlockType.QWEN3_CUSTOM_MOE: return [qwen3_custom.Qwen3CustomMoeDecoderLayerToLinen] case DecoderBlockType.QWEN3_NEXT: return [qwen3.Qwen3NextScannableBlockToLinen] if self.config.scan_layers else [qwen3.Qwen3NextDecoderLayerToLinen] case DecoderBlockType.QWEN3_5: return [qwen3_5.Qwen3_5ScannableBlockToLinen] if self.config.scan_layers else [qwen3_5.Qwen3_5DecoderLayerToLinen] case DecoderBlockType.SIMPLE: return [simple_layer.SimpleDecoderLayerToLinen] case DecoderBlockType.SIMPLE_MLP: return [simple_layer.SimpleMlpDecoderLayerToLinen] case DecoderBlockType.LLAMA4: return [llama4.Llama4ScannableBlockToLinen] if self.config.scan_layers else [llama4.Llama4DecoderLayerToLinen] case DecoderBlockType.OLMO3: return [olmo3.Olmo3ScannableBlockToLinen] if self.config.scan_layers else [olmo3.Olmo3DecoderLayerToLinen] case _: # Default case to handle any unknown decoder block types. raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
[docs] def set_remat_policy(self, block_layers, policy): """Set remat policy""" RemattedBlockLayers = [] for block_layer in block_layers: if self.config.parameter_memory_host_offload: # Define parameter movement with mesh-based sharding def move_to_device(variables): """Move parameters to device with proper sharding.""" def map_fn(path, value): max_logging.log(f"models.py: Moving parameter {path} to device") return jax.device_put(value, max_utils.device_space()) return jax.tree_util.tree_map_with_path(map_fn, variables) # Transform layer class before remat block_layer = nn.map_variables(block_layer, ["params"], move_to_device, mutable=True) # Apply remat policy to layer layer = nn.remat( block_layer, prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config), policy=policy, static_argnums=(4, 5), # Deterministic and model mode are static arguments. ) RemattedBlockLayers.append(layer) return RemattedBlockLayers
[docs] def get_norm_layer(self, num_features: int): """get normalization layer (return type inherits from nn.Module)""" if self.config.decoder_block in ( DecoderBlockType.DEFAULT, DecoderBlockType.LLAMA2, DecoderBlockType.MISTRAL, DecoderBlockType.MIXTRAL, DecoderBlockType.DEEPSEEK, DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4, DecoderBlockType.QWEN2, DecoderBlockType.QWEN3, DecoderBlockType.QWEN3_MOE, DecoderBlockType.QWEN3_CUSTOM_MOE, DecoderBlockType.GPT_OSS, DecoderBlockType.SIMPLE, DecoderBlockType.SIMPLE_MLP, DecoderBlockType.LLAMA4, DecoderBlockType.OLMO3, DecoderBlockType.LLAMA2LTI, ): return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) elif self.config.decoder_block == DecoderBlockType.GPT3: return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True) elif self.config.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): return functools.partial( normalizations.Qwen3NextRMSNormLinen, num_features=num_features, shard_mode=self.config.shard_mode ) else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
[docs] def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): """scan decoder layers, calls `flax.linen.transforms.scan`""" initializing = self.is_mutable_collection("params") params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) cache_spec = 0 scan_fn = nn.scan( decoder_layer, variable_axes={ "params": params_spec, "cache": cache_spec, "intermediates": 0, "aqt": 0, "batch_stats": 0, "_overwrite_with_gradient": 0, }, split_rngs={ "params": True, "dropout": cfg.enable_dropout, }, in_axes=in_axes_tuple, length=length, metadata_params={nn.PARTITION_NAME: metadata_axis_name}, ) return scan_fn( config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args )
[docs] def get_pipeline_stage_module(self, decoder_blocks): """get pipeline stage module""" def get_layer_to_pipeline(blocks, cfg): if cfg.decoder_block == DecoderBlockType.DEEPSEEK: return blocks[1] # return the sparse block else: return blocks[0] cfg = self.config base_stage = get_layer_to_pipeline(decoder_blocks, cfg) if cfg.set_remat_policy_on_layers_per_stage: policy = self.get_remat_policy() base_stage = self.set_remat_policy([base_stage], policy)[0] if cfg.num_layers_per_pipeline_stage == 1: stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode) elif cfg.scan_layers_per_stage: stage_module = self.scan_decoder_layers( cfg, base_stage, cfg.num_layers_per_pipeline_stage, "layers_per_stage", self.mesh, in_axes_tuple=(nn.broadcast,) * 4, ) else: stage_module = SequentialBlockDecoderLayers( decoder_layer=base_stage, num_decoder_layers=cfg.num_layers_per_pipeline_stage, config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, ) return stage_module
@nn.compact def _apply_embedding( self, shared_embedding: nn.Module | nnx.Module, decoder_input_tokens, decoder_positions, deterministic, model_mode, multimodal_input=None, ): """Applies token and positional embeddings to the input tokens.""" cfg = self.config y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) # Merge the image embeddings with the text embeddings for multimodal models if multimodal_input is not None: image_embeddings = multimodal_input.image_embeddings bidirectional_mask = multimodal_input.bidirectional_mask image_masks = multimodal_input.image_masks audio_embeddings = multimodal_input.audio_embeddings audio_masks = multimodal_input.audio_masks if image_embeddings is not None and cfg.use_multimodal: if cfg.model_name in [ "gemma3-4b", "gemma3-12b", "gemma3-27b", "gemma4-26b", "gemma4-31b", "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", ]: y = mm_utils.merge_mm_embeddings( text_embeddings=y, multimodal_embeddings=image_embeddings, mask=bidirectional_mask, token_masks=image_masks, ) # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed else: raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") if audio_embeddings is not None and cfg.use_audio: if cfg.model_name in ["qwen3-omni-30b-a3b"]: y = mm_utils.merge_mm_embeddings( text_embeddings=y, multimodal_embeddings=audio_embeddings, mask=audio_masks, token_masks=None, ) else: raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}") y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) y = y.astype(cfg.dtype) if cfg.use_untrainable_positional_embedding: y += positional_embedding_as_linen(embedding_dims=cfg.base_emb_dim)(y.shape[1], decoder_positions) if cfg.trainable_position_size > 0: y += embed_as_linen( num_embeddings=cfg.trainable_position_size, num_features=cfg.emb_dim, dtype=cfg.dtype, embedding_init=nn.initializers.normal(stddev=1.0), name="position_embedder", config=cfg, mesh=self.mesh, )(decoder_positions.astype("int32"), model_mode=model_mode) return y
[docs] @nn.compact def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode): """Applies final normalization and projects hidden states to logits.""" cfg = self.config if cfg.shard_mode == ShardMode.EXPLICIT: norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length", "activation_embed")) else: norm_out_sharding = None y = self.get_norm_layer(num_features=y.shape[-1])( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="decoder_norm", epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), parameter_memory_host_offload=cfg.parameter_memory_host_offload, )(y, out_sharding=norm_out_sharding) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) else: out_sharding = create_sharding( self.mesh, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") ) # [batch, length, emb_dim] -> [batch, length, vocab_size] if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. if isinstance(shared_embedding, nnx.Module): embedding_table = shared_embedding.embedding.value else: embedding_table = shared_embedding.variables["params"]["embedding"] if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): embedding_table = embedding_table.unbox() attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) if self.config.normalize_embedding_logits: # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) if cfg.final_logits_soft_cap: logits = logits / cfg.final_logits_soft_cap logits = jnp.tanh(logits) * cfg.final_logits_soft_cap else: logits = linears.dense_general( inputs_shape=y.shape, out_features_shape=cfg.vocab_size, weight_dtype=cfg.weight_dtype, dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability kernel_axes=("embed_vocab", "vocab"), shard_mode=cfg.shard_mode, name="logits_dense", matmul_precision=self.config.matmul_precision, parameter_memory_host_offload=cfg.parameter_memory_host_offload, )( y, out_sharding=out_sharding, ) # We do not quantize the logits matmul. if self.config.cast_logits_to_fp32: logits = logits.astype(jnp.float32) return logits
@nn.compact def __call__( self, shared_embedding: nn.Module | nnx.Module, decoder_input_tokens, decoder_positions, decoder_segment_ids=None, deterministic=False, model_mode=MODEL_MODE_TRAIN, previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, multimodal_input=None, kv_caches: list[jax.Array] | None = None, attention_metadata=None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, ): cfg = self.config mesh = self.mesh assert decoder_input_tokens.ndim == 2 # [batch, len] # [batch, length] -> [batch, length, emb_dim] y = self._apply_embedding( shared_embedding, decoder_input_tokens, decoder_positions, deterministic, model_mode, multimodal_input=multimodal_input, ) mhc_expand, mhc_reduce = mhc.get_functions(cfg.mhc_expansion_rate) if cfg.mhc_expansion_rate > 1: # (batch, length, emb_dim) --> (batch, length, mhc_expansion_rate, emb_dim) y = mhc_expand(y) policy = self.get_remat_policy() RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) # scan does not support kwargs in layer call, passing broadcast_args as positional arg broadcast_args = ( decoder_segment_ids, decoder_positions, deterministic, model_mode, ) if cfg.using_pipeline_parallelism: logical_partition_spec = ( self.pipeline_module.get_weight_sharding(y, decoder_segment_ids, decoder_positions, deterministic, model_mode) if cfg.pipeline_fsdp_ag_once or cfg.pipeline_fsdp_ag_per_repeat else None ) if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." dense_layer = RemattedBlockLayers[0] moe_layer = RemattedBlockLayers[1] num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) # We chose not to pipeline the dense layers, only sparse for SPMD. with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): y, _ = self.scan_decoder_layers( cfg, dense_layer, cfg.first_num_dense_layers, "dense_layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=model_mode, )(y, *broadcast_args) if num_moe_layers_outside_pp > 0: y, _ = self.scan_decoder_layers( cfg, moe_layer, num_moe_layers_outside_pp, "moe_layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=model_mode, )(y, *broadcast_args) y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec) else: # Not DeepSeek y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec) remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers if remaining_layers > 0: logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): y, _ = self.scan_decoder_layers( cfg, RemattedBlockLayers[0], remaining_layers, "layers_outside_pipeline", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=model_mode, )(y, *broadcast_args) else: if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." layer_call_kwargs = { "page_state": page_state, "previous_chunk": previous_chunk, "slot": slot, } dense_layer = RemattedBlockLayers[0] moe_layer = RemattedBlockLayers[1] if cfg.engram_layers: original_dense_call = dense_layer.__call__ original_moe_call = moe_layer.__call__ dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) common_kwargs = { "dense_layer": dense_layer, "moe_layer": moe_layer, "original_dense_call": original_dense_call, "original_moe_call": original_moe_call, "layer_call_kwargs": layer_call_kwargs, "decoder_segment_ids": decoder_segment_ids, "decoder_positions": decoder_positions, "deterministic": deterministic, "model_mode": model_mode, "decoder_input_tokens": decoder_input_tokens, "broadcast_args": broadcast_args, } # Apply Dense Layers y = self._apply_interleaved_scanned_layers( y, layer_type="dense", start_idx=0, end_idx=cfg.first_num_dense_layers, engram_indices=cfg.engram_layers, **common_kwargs, ) # Apply MoE Layers y = self._apply_interleaved_scanned_layers( y, layer_type="moe", start_idx=cfg.first_num_dense_layers, end_idx=cfg.num_decoder_layers, engram_indices=cfg.engram_layers, **common_kwargs, ) else: dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) y, _ = self.scan_decoder_layers( cfg, dense_layer, cfg.first_num_dense_layers, "dense_layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=model_mode, )(y, *broadcast_args) moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers # If batch-split schedule is used and initialization is complete, # as detected by immutable params, use deepseek_batchsplit custom # scan with initialized parameters. if cfg.use_batch_split_schedule and not self.is_mutable_collection("params"): # old version of batch-split that fully uses qwix quantization. if cfg.use_qwix_quantization and not cfg.use_manual_quantization: y = deepseek_batchsplit_fp8.scan_batch_split_layers( y, self.variables["params"]["moe_layers"], decoder_positions, decoder_segment_ids, model_mode=model_mode, mesh=mesh, quant=self.quant, cfg=cfg, policy=policy, ) else: # bf16 and fp8 code path for pure-JAX batch-split. # fp8 code path supports both manual quantization and qwix # quantization. y = deepseek_batchsplit.scan_batch_split_layers( y, self.variables["params"]["moe_layers"], decoder_positions, mesh=mesh, cfg=cfg, num_layers=num_moe_layers, ) else: y, _ = self.scan_decoder_layers( cfg, moe_layer, num_moe_layers, "moe_layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=model_mode, )(y, *broadcast_args) elif cfg.decoder_block == DecoderBlockType.GEMMA3: bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None y = self._apply_gemma3_scanned_blocks( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, bidirectional_mask_value, previous_chunk, page_state, slot, ) elif cfg.decoder_block == DecoderBlockType.GEMMA4: bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None y = self._apply_gemma4_scanned_blocks( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, bidirectional_mask_value, previous_chunk, page_state, slot, ) else: RemattedBlockLayer = RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) layer_kwargs = {} if cfg.decoder_block == DecoderBlockType.LLAMA4: layer_kwargs = { "nope_layer_interval": self.config.nope_layer_interval, "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } # Update broadcast_args and in_axes_tuple for vLLM RPA in_axes_tuple = (nn.broadcast,) * len(broadcast_args) current_broadcast_args = list(broadcast_args) current_in_axes_tuple = list(in_axes_tuple) if kv_caches is not None: # Stack kv_caches for scan: [num_layers, ...] stacked_kv_cache = jnp.stack(kv_caches, axis=0) # We pass (y, stacked_kv_cache, 0) as the carry carry = (y, stacked_kv_cache, 0) # We don't pass kv_cache as a scanned argument anymore # Pass None for previous_chunk, slot, page_state, kv_cache to align with __call__ signature current_broadcast_args.extend([None, None, None, None, attention_metadata]) current_in_axes_tuple.extend([nn.broadcast] * 5) max_logging.info(f"DEBUG: len(current_broadcast_args)={len(current_broadcast_args)}") max_logging.info(f"DEBUG: current_broadcast_args={[type(a) for a in current_broadcast_args]}") final_carry, _ = self.scan_decoder_layers( cfg, RemattedBlockLayer, scan_length, "layers", mesh, in_axes_tuple=tuple(current_in_axes_tuple), model_mode=model_mode, **layer_kwargs, )(carry, *current_broadcast_args) y, returned_kv_cache, _ = final_carry # Update the list of KV caches from the scanned results for i in range(cfg.num_decoder_layers): kv_caches[i] = returned_kv_cache[i] else: # Fallback to old behavior if kv_caches is None (not vLLM RPA) current_broadcast_args.append(None) current_in_axes_tuple.append(nn.broadcast) y, _ = self.scan_decoder_layers( cfg, RemattedBlockLayer, scan_length, "layers", mesh, in_axes_tuple=tuple(current_in_axes_tuple), model_mode=model_mode, **layer_kwargs, )(y, *current_broadcast_args) else: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." dense_layer = RemattedBlockLayers[0] moe_layer = RemattedBlockLayers[1] layers = [dense_layer, moe_layer] layer_prefixes = ["dense_layers", "moe_layers"] num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers num_layers_list = [cfg.first_num_dense_layers, num_moe_layers] # Iterate over the two layer groups (dense and MoE) and apply layer transformation global_layer_idx_offset = 0 for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes): for index in range(num_layers): global_layer_idx = global_layer_idx_offset + index kv_cache = kv_caches[index] if kv_caches is not None else None input_tokens = decoder_input_tokens if cfg.engram_layers else None y, kv_cache = layer( config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode, layer_idx=global_layer_idx, )( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=previous_chunk, page_state=page_state, slot=slot, kv_cache=kv_cache, attention_metadata=attention_metadata, decoder_input_tokens=input_tokens, ) if kv_caches is not None and kv_cache is not None: kv_caches[index] = kv_cache global_layer_idx_offset += num_layers else: for lyr in range(cfg.num_decoder_layers): RemattedBlockLayer = RemattedBlockLayers[0] layer_kwargs = {} layer_call_kwargs = {} if cfg.decoder_block == DecoderBlockType.GEMMA3: # Gemma3 uses both global and sliding window attention depending on the layer index. bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} layer_call_kwargs = {"bidirectional_mask": bidirectional_mask_value} if cfg.decoder_block == DecoderBlockType.GEMMA4: # Gemma4 uses both global and sliding window attention depending on the layer index. bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None layer_kwargs = {"attention_type": gemma4.get_attention_type(layer_id=lyr)} layer_call_kwargs = {"bidirectional_mask": bidirectional_mask_value} if cfg.decoder_block == DecoderBlockType.LLAMA4: layer_kwargs = { "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), } if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): layer_kwargs = {"layer_idx": lyr} kv_cache = None if kv_caches is not None and cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): kv_cache = kv_caches[lyr] elif kv_caches is not None and cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): # For Qwen3Next & Qwen3.5, kv_caches is a dictionary of lists of caches. if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr]) if cfg.decoder_block == DecoderBlockType.GPT_OSS: layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} if cfg.decoder_block == DecoderBlockType.OLMO3: layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)} layer = RemattedBlockLayer( config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs ) y, returned_cache = layer( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=previous_chunk, page_state=page_state, slot=slot, kv_cache=kv_cache, attention_metadata=attention_metadata, **layer_call_kwargs, ) if kv_caches is not None and returned_cache is not None: if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): kv_caches[lyr] = returned_cache elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: kv_caches["key_cache"][lyr] = returned_cache[0] kv_caches["value_cache"][lyr] = returned_cache[1] if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): visual_embeds = deepstack_visual_embeds[lyr] # Use bidirectional_mask to identify visual token positions bidirectional_mask_value = multimodal_input.bidirectional_mask if multimodal_input is not None else None if bidirectional_mask_value is not None and visual_embeds is not None: y = deepstack_process(y, bidirectional_mask_value, visual_embeds) assert isinstance(y, jax.Array) # After the final transformer layer, `y` holds the raw, un-normalized hidden state. if cfg.mhc_expansion_rate > 1: # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) hidden_state = mhc_reduce(y) else: hidden_state = y # When initializing with vLLM RPA attention, we need to run the output head to # initialize any parameters associated with it. if self.is_initializing() and cfg.attention == "vllm_rpa": _ = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) # When invoking from vLLM with RPA attention, logit computation is deferred to a later stage. if cfg.attention == "vllm_rpa": logits = None # When in the Indexer Dense Warm-up stage, skip the expensive output head projection # for efficiency, as the main model is frozen and the LM loss is not needed. # TODO(b/501446870): Investigate model_mode as train at beginning for decoding stage elif ( cfg.use_indexer and cfg.indexer_loss_scaling_factor > 0.0 and not cfg.indexer_sparse_training ) and model_mode == MODEL_MODE_TRAIN: logits = None # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits elif cfg.num_vocab_tiling > 1 and model_mode == MODEL_MODE_TRAIN: logits = None self.sow("intermediates", "hidden_states", hidden_state) else: logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) # The API of the Decoder is now a tuple, providing both the main output # and the raw hidden state needed for auxiliary tasks. return logits, hidden_state, kv_caches def _apply_gemma3_scanned_blocks( self, y, decoder_segment_ids, decoder_positions, deterministic, model_mode, bidirectional_mask, previous_chunk, page_state, slot, ): """Applies Gemma3 scanned decoder blocks, handling main scan and remainders.""" cfg = self.config mesh = self.mesh # Define the repeating pattern length and calculate how many full blocks to scan attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) scan_length = cfg.num_decoder_layers // attention_pattern_length policy = self.get_remat_policy() RemattedGemma3Block = self.set_remat_policy([gemma3.Gemma3ScannableBlockToLinen], policy)[0] layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} layer_kwargs = {"num_of_layers": attention_pattern_length} # Apply the main scan over the full blocks if scan_length > 0: broadcast_args = ( decoder_segment_ids, decoder_positions, deterministic, model_mode, ) y, _ = self.scan_decoder_layers( cfg, RemattedGemma3Block, scan_length, "layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=self.model_mode, **layer_kwargs, )(y, *broadcast_args, **layer_call_kwargs) # Apply any remaining layers that did not fit into a full scanned block num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length if num_remaining_layers > 0: # We name the remainder block with a 'remainder' suffix to avoid parameter name collisions rem_layer_kwargs = {"num_of_layers": num_remaining_layers} layer = RemattedGemma3Block( config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs ) # pytype: disable=wrong-keyword-args y, _ = layer( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=previous_chunk, page_state=page_state, slot=slot, **layer_call_kwargs, ) return y def _apply_gemma4_scanned_blocks( self, y, decoder_segment_ids, decoder_positions, deterministic, model_mode, bidirectional_mask, previous_chunk, page_state, slot, ): """Applies Gemma4 scanned decoder blocks, handling main scan and remainders.""" cfg = self.config mesh = self.mesh # Define the repeating pattern length and calculate how many full blocks to scan block_pattern_len = len(gemma4.GEMMA4_ATTENTION_PATTERN) num_full_blocks = cfg.num_decoder_layers // block_pattern_len remainder_layers = cfg.num_decoder_layers % block_pattern_len broadcast_args = ( decoder_segment_ids, decoder_positions, deterministic, model_mode, slot, page_state, previous_chunk, bidirectional_mask, ) if num_full_blocks > 0: ScannableBlockToLinen = gemma4.Gemma4ScannableBlockToLinen policy = self.get_remat_policy() RemattedGemma4Block = self.set_remat_policy([ScannableBlockToLinen], policy)[0] # For a fully scanned block, apply it inside a nn.scan over the calculated number of full blocks y, _ = nn.scan( RemattedGemma4Block, variable_axes={ "params": cfg.param_scan_axis, "cache": 0, "intermediates": 0, "aqt": 0, "_overwrite_with_gradient": 0, }, split_rngs={"params": True, "dropout": cfg.enable_dropout}, in_axes=(nn.broadcast,) * len(broadcast_args), length=num_full_blocks, metadata_params={ nn.PARTITION_NAME: "layers", "abstract_init": False, }, )( config=cfg, mesh=mesh, quant=self.quant, model_mode=model_mode, num_of_layers=block_pattern_len, name="scanned_blocks", )( y, *broadcast_args ) # Process any remaining layers that don't fit into a full scanned block for layer_id in range(cfg.num_decoder_layers - remainder_layers, cfg.num_decoder_layers): attention_type = gemma4.get_attention_type(layer_id) layer = gemma4.Gemma4DecoderLayerToLinen( config=cfg, mesh=mesh, model_mode=model_mode, quant=self.quant, attention_type=attention_type, layer_idx=layer_id, ) y = layer(y, *broadcast_args) if cfg.scan_layers: y = y[0] return y # TODO(b/490118813): Relocate the following functions to their designated directories # once the plug-in strategy is implemented: _find_next_boundary(), _apply_single_engram_layer() # _apply_scanned_chunk() and _apply_interleaved_scanned_layers(). def _find_next_boundary(self, current_idx, end_idx, engram_indices): """Finds the next index boundary, either the next Engram layer index or the overall end index.""" next_engrams = [l for l in engram_indices if l > current_idx] if next_engrams: return min(end_idx, *next_engrams) return end_idx def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs): """Applies a single, unscanned Engram layer.""" layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"] layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers" original_call = kwargs["original_dense_call"] if layer_type == "dense" else kwargs["original_moe_call"] layer_call_kwargs = kwargs["layer_call_kwargs"] layer.__call__ = original_call y, _ = layer( config=self.config, mesh=self.mesh, name=f"{layer_prefix}_engram_{current_idx}", quant=self.quant, model_mode=self.model_mode, layer_idx=current_idx, )( y, kwargs["decoder_segment_ids"], kwargs["decoder_positions"], kwargs["deterministic"], kwargs["model_mode"], decoder_input_tokens=kwargs["decoder_input_tokens"], **layer_call_kwargs, ) layer.__call__ = functools.partial(original_call, **layer_call_kwargs) return y def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_type, **kwargs): """Applies a contiguous chunk of layers using the scan operation.""" layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"] layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers" broadcast_args = kwargs["broadcast_args"] scan_length = next_boundary - current_idx if scan_length > 0: y, _ = self.scan_decoder_layers( self.config, layer, scan_length, f"{layer_prefix}_{current_idx}_{next_boundary - 1}", self.mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=kwargs["model_mode"], )(y, *broadcast_args) return y def _apply_interleaved_scanned_layers(self, y, layer_type, start_idx, end_idx, engram_indices, **kwargs): """Applies a mix of scanned standard layers and unscanned Engram layers.""" current_idx = start_idx while current_idx < end_idx: if current_idx in engram_indices: # Handle individual unscanned Engram layer y = self._apply_single_engram_layer(y, current_idx, layer_type, **kwargs) current_idx += 1 else: # Find next boundary and scan the chunk next_boundary = self._find_next_boundary(current_idx, end_idx, engram_indices) y = self._apply_scanned_chunk(y, current_idx, next_boundary, layer_type, **kwargs) current_idx = next_boundary return y