maxtext.layers.attention_mla module#
MLA Attention Layer.
- class maxtext.layers.attention_mla.Indexer(*args, **kwargs)[source]#
Bases:
ModuleIndexer for DeepSeek Sparse Attention (DSA).
This module implements the sparse attention indexer introduced in DeepSeek V3.2. It computes relevance scores to select the top-k most relevant tokens for attention.
References
DeepSeek-AI, `DeepSeek-V3.2: Pushing the Frontier of Open Large Language Models
<https://arxiv.org/pdf/2512.02556>`_, 2026
Implementation: deepseek-ai/DeepSeek-V3.2-Exp
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- update_indexer_cache(kv_cache, k, decoder_segment_ids, model_mode, previous_chunk)[source]#
Updates Indexer buffers by processing KV cache results.
- apply_partial_rope(inputs, inputs_positions=None)[source]#
Applies partial RoPE to the indexer query or key
The Indexer’s RoPE implementation differs from MLA’s in two key aspects: 1. Split Order: Indexer splits the head dimension into [rope, nope], whereas MLA uses [nope, rope]. 2. Input Layout: Indexer uses concatenated layout (interleave=False), whereas MLA uses interleaved (interleave=True).
- Parameters:
inputs (Array) – Input array of shape [batch, seqlen, indexer_n_heads, indexer_head_dim].
positions – Position array of shape [batch, seqlen].
inputs_positions (Array | None)
- Returns:
Array with partial RoPE applied, with shape [batch, seqlen, indexer_n_heads, indexer_head_dim]
- maxtext.layers.attention_mla.mla_as_linen(*, config, num_query_heads, num_kv_heads, head_dim, max_target_length, mesh, attention_kernel, inputs_q_shape, inputs_kv_shape, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, max_prefill_predict_length=-1, dropout_rate=0.0, kernel_init=<function nd_dense_init.<locals>.init_fn>, float32_qk_product=False, float32_logits=False, quant=None, kv_quant=None, attention_type=AttentionType.MLA, attn_logits_soft_cap=None, sliding_window_size=None, use_ragged_attention=False, ragged_block_size=256, use_qk_norm=False, query_pre_attn_scalar=None, use_bias_in_projections=False, temperature_tuning=False, temperature_tuning_scale=0.1, temperature_tuning_floor_scale=8192.0, prefill_query_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_key_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_value_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), query_axis_names=('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), key_axis_names=('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), value_axis_names=('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), input_axis_names=('activation_batch_attn', 'activation_length', 'activation_embed'), out_axis_names=('activation_batch_attn', 'activation_length', 'activation_heads', 'activation_kv'), prefill_input_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_embed'), decode_input_axis_names=('decode_batch', 'decode_length', 'activation_embed'), prefill_out_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_heads', 'activation_kv'), decode_out_axis_names=('decode_batch', 'decode_length', 'activation_heads', 'activation_kv'), prefill_cache_axis_order=(1, 2, 0, 3), ar_cache_axis_order=(1, 2, 0, 3), compute_axis_order=(0, 1, 2, 3), reshape_q=False, is_nope_layer=False, is_vision=False, model_mode='train', q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, max_position_embeddings=16384, original_max_position_embeddings=4096, mscale=1.0, rope_factor=40.0, name=None)[source]#
A factory function to create an MLA as a Linen module.
This function serves as a bridge to use the NNX-based MLA within a Linen model.
- Parameters:
config (Any)
num_query_heads (int)
num_kv_heads (int)
head_dim (int)
max_target_length (int)
mesh (Mesh)
attention_kernel (str)
inputs_q_shape (Tuple)
inputs_kv_shape (Tuple)
dtype (dtype)
weight_dtype (dtype)
max_prefill_predict_length (int)
dropout_rate (float)
kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array])
float32_qk_product (bool)
float32_logits (bool)
quant (AqtQuantization | None)
kv_quant (KVQuant | None)
attention_type (AttentionType)
attn_logits_soft_cap (float | None)
sliding_window_size (int | None)
use_ragged_attention (bool)
ragged_block_size (int)
use_qk_norm (bool)
query_pre_attn_scalar (float | None)
use_bias_in_projections (bool)
temperature_tuning (bool)
temperature_tuning_scale (float)
temperature_tuning_floor_scale (float)
prefill_query_axis_names (tuple[str, ...])
prefill_key_axis_names (tuple[str, ...])
prefill_value_axis_names (tuple[str, ...])
query_axis_names (tuple[str, ...])
key_axis_names (tuple[str, ...])
value_axis_names (tuple[str, ...])
input_axis_names (tuple[str, ...])
out_axis_names (tuple[str, ...])
prefill_input_axis_names (tuple[str, ...])
decode_input_axis_names (tuple[str, ...])
prefill_out_axis_names (tuple[str, ...])
decode_out_axis_names (tuple[str, ...])
prefill_cache_axis_order (tuple[int, ...])
ar_cache_axis_order (tuple[int, ...])
compute_axis_order (tuple[int, ...])
reshape_q (bool)
is_nope_layer (bool)
is_vision (bool)
model_mode (str)
q_lora_rank (int)
kv_lora_rank (int)
qk_nope_head_dim (int)
qk_rope_head_dim (int)
v_head_dim (int)
max_position_embeddings (int)
original_max_position_embeddings (int)
mscale (float)
rope_factor (float)
name (str | None)
- class maxtext.layers.attention_mla.MLA(*args, **kwargs)[source]#
Bases:
AttentionMulti-Head Latent Attention (MLA) layer.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- init_indexer_cache(inputs_kv_shape)[source]#
Initializes Indexer Cache.
- Parameters:
inputs_kv_shape (Tuple)
- property out_head_dim: int#
- mla_query_projection(inputs_q, inputs_positions, model_mode)[source]#
Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0.
- Parameters:
inputs_q (Array)
inputs_positions (Array)
- Return type:
tuple[Array, Array | None]
- init_mla_kv_caches(inputs_kv_shape)[source]#
Initializes MlaKVCache.
- Parameters:
inputs_kv_shape (Tuple) – Key/value inputs shape for initialization.
- Returns:
An MlaKVCache module instance.
- Raises:
ValueError – If the configuration is invalid.
- update_mla_kv_caches(low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk=None)[source]#
Updates the MLA (Multi-Head Latent Attention) KV caches.
This method is specific to the MLA attention mechanism. It calls the mla_kv_cache_as_linen module to update and retrieve the caches, which store latent representations (low_rank_main) and RoPE-applied keys (key_rope). It then reconstructs the full key and value tensors from the cached components.
- Parameters:
low_rank_main – The main latent component of the key.
key_rope – The RoPE-applied component of the key.
decoder_segment_ids – Segment IDs for decoder masking.
model_mode – The operational mode (‘train’, ‘prefill’, ‘autoregressive’).
previous_chunk – Information about previously processed chunks, for chunked prefill.
- Returns:
The prefill key-value cache, reconstructed from the MLA cache, or None.
The autoregressive key-value cache, reconstructed from the MLA cache, or None.
- Return type:
A list containing two elements
- mla_kv_projection(inputs, inputs_positions, decoder_segment_ids, model_mode, previous_chunk)[source]#
MLA key/value projection with integrated rotary embedding.
- Parameters:
inputs (Array)
inputs_positions (Array)
- calculate_indexer_loss(indexer_score, query, key, attention_mask, indexer_mask, sparse_loss, scaling_factor)[source]#
Calculates the indexer KL divergence loss.
This loss trains the indexer to predict which tokens are important by matching the distribution of true attention scores from the main model.
The target distribution is derived through the following steps: 1. Compute raw attention scores via Q @ K^T. 2. Aggregate scores by summing across all attention heads. 3. Apply L1-normalization across the sequence dimension.
target_distribution = L1_Normalize(Sum_h(Softmax(Q @ K^T)))
Reference: DeepSeek-V3.2 - https://arxiv.org/pdf/2512.02556
- Parameters:
indexer_score (Array) – Scores predicted by indexer [batch, q_len, kv_len].
query (Array) – Query tensor from main model [batch, q_len, heads, dim].
key (Array) – Key tensor from main model [batch, kv_len, heads, dim].
attention_mask (Array | None) – Attention mask [batch, q_len, kv_len] or None.
indexer_mask (Array) – Indexer mask [batch, q_len, kv_len].
sparse_loss (bool) – Whether to use sparse loss.
scaling_factor (float) – The scaling factor for the loss.
- Returns:
The computed KL divergence loss.
- Return type:
Array