maxtext.layers.attention_mla module#

MLA Attention Layer.

class maxtext.layers.attention_mla.Indexer(*args, **kwargs)[source]#

Bases: Module

Indexer for DeepSeek Sparse Attention (DSA).

This module implements the sparse attention indexer introduced in DeepSeek V3.2. It computes relevance scores to select the top-k most relevant tokens for attention.

References

DeepSeek-AI, `DeepSeek-V3.2: Pushing the Frontier of Open Large Language Models

Implementation: deepseek-ai/DeepSeek-V3.2-Exp

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

update_indexer_cache(kv_cache, k, decoder_segment_ids, model_mode, previous_chunk)[source]#

Updates Indexer buffers by processing KV cache results.

apply_partial_rope(inputs, inputs_positions=None)[source]#

Applies partial RoPE to the indexer query or key

The Indexer’s RoPE implementation differs from MLA’s in two key aspects: 1. Split Order: Indexer splits the head dimension into [rope, nope], whereas MLA uses [nope, rope]. 2. Input Layout: Indexer uses concatenated layout (interleave=False), whereas MLA uses interleaved (interleave=True).

Parameters:
  • inputs (Array) – Input array of shape [batch, seqlen, indexer_n_heads, indexer_head_dim].

  • positions – Position array of shape [batch, seqlen].

  • inputs_positions (Array | None)

Returns:

Array with partial RoPE applied, with shape [batch, seqlen, indexer_n_heads, indexer_head_dim]

generate_mask(topk_indices, s)[source]#

Creates a mask for top-k indices.

Parameters:
  • topk_indices – [b, t, k] int - The indices to keep.

  • s – int - The total size to select from.

Returns:

[b, t, s] - 0.0 at topk_indices, DEFAULT_MASK_VALUE (large negative) elsewhere.

Return type:

mask

maxtext.layers.attention_mla.mla_as_linen(*, config, num_query_heads, num_kv_heads, head_dim, max_target_length, mesh, attention_kernel, inputs_q_shape, inputs_kv_shape, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, max_prefill_predict_length=-1, dropout_rate=0.0, kernel_init=<function nd_dense_init.<locals>.init_fn>, float32_qk_product=False, float32_logits=False, quant=None, kv_quant=None, attention_type=AttentionType.MLA, attn_logits_soft_cap=None, sliding_window_size=None, use_ragged_attention=False, ragged_block_size=256, use_qk_norm=False, query_pre_attn_scalar=None, use_bias_in_projections=False, temperature_tuning=False, temperature_tuning_scale=0.1, temperature_tuning_floor_scale=8192.0, prefill_query_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_key_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_value_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), query_axis_names=('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), key_axis_names=('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), value_axis_names=('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), input_axis_names=('activation_batch_attn', 'activation_length', 'activation_embed'), out_axis_names=('activation_batch_attn', 'activation_length', 'activation_heads', 'activation_kv'), prefill_input_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_embed'), decode_input_axis_names=('decode_batch', 'decode_length', 'activation_embed'), prefill_out_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_heads', 'activation_kv'), decode_out_axis_names=('decode_batch', 'decode_length', 'activation_heads', 'activation_kv'), prefill_cache_axis_order=(1, 2, 0, 3), ar_cache_axis_order=(1, 2, 0, 3), compute_axis_order=(0, 1, 2, 3), reshape_q=False, is_nope_layer=False, is_vision=False, model_mode='train', q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, max_position_embeddings=16384, original_max_position_embeddings=4096, mscale=1.0, rope_factor=40.0, name=None)[source]#

A factory function to create an MLA as a Linen module.

This function serves as a bridge to use the NNX-based MLA within a Linen model.

Parameters:
  • config (Any)

  • num_query_heads (int)

  • num_kv_heads (int)

  • head_dim (int)

  • max_target_length (int)

  • mesh (Mesh)

  • attention_kernel (str)

  • inputs_q_shape (Tuple)

  • inputs_kv_shape (Tuple)

  • dtype (dtype)

  • weight_dtype (dtype)

  • max_prefill_predict_length (int)

  • dropout_rate (float)

  • kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array])

  • float32_qk_product (bool)

  • float32_logits (bool)

  • quant (AqtQuantization | None)

  • kv_quant (KVQuant | None)

  • attention_type (AttentionType)

  • attn_logits_soft_cap (float | None)

  • sliding_window_size (int | None)

  • use_ragged_attention (bool)

  • ragged_block_size (int)

  • use_qk_norm (bool)

  • query_pre_attn_scalar (float | None)

  • use_bias_in_projections (bool)

  • temperature_tuning (bool)

  • temperature_tuning_scale (float)

  • temperature_tuning_floor_scale (float)

  • prefill_query_axis_names (tuple[str, ...])

  • prefill_key_axis_names (tuple[str, ...])

  • prefill_value_axis_names (tuple[str, ...])

  • query_axis_names (tuple[str, ...])

  • key_axis_names (tuple[str, ...])

  • value_axis_names (tuple[str, ...])

  • input_axis_names (tuple[str, ...])

  • out_axis_names (tuple[str, ...])

  • prefill_input_axis_names (tuple[str, ...])

  • decode_input_axis_names (tuple[str, ...])

  • prefill_out_axis_names (tuple[str, ...])

  • decode_out_axis_names (tuple[str, ...])

  • prefill_cache_axis_order (tuple[int, ...])

  • ar_cache_axis_order (tuple[int, ...])

  • compute_axis_order (tuple[int, ...])

  • reshape_q (bool)

  • is_nope_layer (bool)

  • is_vision (bool)

  • model_mode (str)

  • q_lora_rank (int)

  • kv_lora_rank (int)

  • qk_nope_head_dim (int)

  • qk_rope_head_dim (int)

  • v_head_dim (int)

  • max_position_embeddings (int)

  • original_max_position_embeddings (int)

  • mscale (float)

  • rope_factor (float)

  • name (str | None)

class maxtext.layers.attention_mla.MLA(*args, **kwargs)[source]#

Bases: Attention

Multi-Head Latent Attention (MLA) layer.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

init_indexer_cache(inputs_kv_shape)[source]#

Initializes Indexer Cache.

Parameters:

inputs_kv_shape (Tuple)

property out_head_dim: int#
mla_query_projection(inputs_q, inputs_positions, model_mode)[source]#

Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0.

Parameters:
  • inputs_q (Array)

  • inputs_positions (Array)

Return type:

tuple[Array, Array | None]

mla_get_key_value(low_rank_main, key_rope, model_mode)[source]#

get (key,value) pair from mla

init_mla_kv_caches(inputs_kv_shape)[source]#

Initializes MlaKVCache.

Parameters:

inputs_kv_shape (Tuple) – Key/value inputs shape for initialization.

Returns:

An MlaKVCache module instance.

Raises:

ValueError – If the configuration is invalid.

update_mla_kv_caches(low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk=None)[source]#

Updates the MLA (Multi-Head Latent Attention) KV caches.

This method is specific to the MLA attention mechanism. It calls the mla_kv_cache_as_linen module to update and retrieve the caches, which store latent representations (low_rank_main) and RoPE-applied keys (key_rope). It then reconstructs the full key and value tensors from the cached components.

Parameters:
  • low_rank_main – The main latent component of the key.

  • key_rope – The RoPE-applied component of the key.

  • decoder_segment_ids – Segment IDs for decoder masking.

  • model_mode – The operational mode (‘train’, ‘prefill’, ‘autoregressive’).

  • previous_chunk – Information about previously processed chunks, for chunked prefill.

Returns:

  • The prefill key-value cache, reconstructed from the MLA cache, or None.

  • The autoregressive key-value cache, reconstructed from the MLA cache, or None.

Return type:

A list containing two elements

mla_kv_projection(inputs, inputs_positions, decoder_segment_ids, model_mode, previous_chunk)[source]#

MLA key/value projection with integrated rotary embedding.

Parameters:
  • inputs (Array)

  • inputs_positions (Array)

calculate_indexer_loss(indexer_score, query, key, attention_mask, indexer_mask, sparse_loss, scaling_factor)[source]#

Calculates the indexer KL divergence loss.

This loss trains the indexer to predict which tokens are important by matching the distribution of true attention scores from the main model.

The target distribution is derived through the following steps: 1. Compute raw attention scores via Q @ K^T. 2. Aggregate scores by summing across all attention heads. 3. Apply L1-normalization across the sequence dimension.

target_distribution = L1_Normalize(Sum_h(Softmax(Q @ K^T)))

Reference: DeepSeek-V3.2 - https://arxiv.org/pdf/2512.02556

Parameters:
  • indexer_score (Array) – Scores predicted by indexer [batch, q_len, kv_len].

  • query (Array) – Query tensor from main model [batch, q_len, heads, dim].

  • key (Array) – Key tensor from main model [batch, kv_len, heads, dim].

  • attention_mask (Array | None) – Attention mask [batch, q_len, kv_len] or None.

  • indexer_mask (Array) – Indexer mask [batch, q_len, kv_len].

  • sparse_loss (bool) – Whether to use sparse loss.

  • scaling_factor (float) – The scaling factor for the loss.

Returns:

The computed KL divergence loss.

Return type:

Array