# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MLA Attention Layer."""
import math
from typing import Any, Optional, Tuple
import copy
import jax
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import layout
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding
Layout = layout.Format
if jax.__version_info__ >= (0, 6, 3):
DLL = layout.Layout
else:
DLL = layout.DeviceLocalLayout # type: ignore
from flax import nnx
from maxtext.common.common_types import (
Array,
AxisIdxes,
AxisNames,
BATCH_ATTN,
CACHE_BATCH,
CACHE_BATCH_PREFILL,
CACHE_SEQUENCE,
CACHE_HEADS_NONE,
CACHE_KV,
Config,
DECODE_BATCH,
DECODE_LENGTH,
D_KV,
DType,
EMBED,
HEAD,
Q_LORA_UP_PROJ,
KV_BATCH,
KV_HEAD,
KV_HEAD_DIM,
KV_LORA_UP_PROJ,
LENGTH,
MODEL_MODE_PREFILL,
MODEL_MODE_TRAIN,
PREFILL_KV_BATCH,
PREFILL_LENGTH,
AttentionType,
DEFAULT_MASK_VALUE,
)
from maxtext.layers import nnx_wrappers
from maxtext.layers.attentions import Attention
from maxtext.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned
from maxtext.layers.linears import DenseGeneral
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.inference import kvcache
from maxtext.inference import page_manager
from maxtext.inference import paged_attention
from maxtext.inference.kvcache import KVQuant
from maxtext.utils.sharding import create_sharding
from maxtext.utils.globals import EPS
PLACEHOLDER_SEQ_LEN = 1
[docs]
class Indexer(nnx.Module):
"""Indexer for DeepSeek Sparse Attention (DSA).
This module implements the sparse attention indexer introduced in DeepSeek
V3.2.
It computes relevance scores to select the top-k most relevant tokens for
attention.
References:
DeepSeek-AI, `DeepSeek-V3.2: Pushing the Frontier of Open Large Language
Models
<https://arxiv.org/pdf/2512.02556>`_, 2026
Implementation:
https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/inference/model.py
"""
def __init__(
self,
config: Any,
rotary_embedding,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
quant: Optional[Quant] = None,
model_mode: str = MODEL_MODE_TRAIN,
rngs: Optional[nnx.Rngs] = None,
):
self.config = config
self.rotary_embedding = rotary_embedding
self.quant = quant
self.kernel_init = kernel_init
self.model_mode = model_mode
self.rngs = rngs
self.dtype = config.dtype
self.weight_dtype = config.weight_dtype
self.max_target_length = config.max_target_length
self.n_heads = config.indexer_n_heads
self.head_dim = config.indexer_head_dim
self.indexer_topk = config.indexer_topk
self.emb_dim = config.emb_dim
self.rope_head_dim = config.qk_rope_head_dim
self.q_lora_rank = config.q_lora_rank
# scale head weights for numerical stability
self.softmax_scale = self.head_dim**-0.5
# Query Projection: Latent Query -> Indexer Query
self.wq_b = DenseGeneral(
in_features_shape=self.q_lora_rank,
out_features_shape=(self.n_heads, self.head_dim),
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("q_lora", "q_heads", "kv"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
# Key Projection: Input -> Shared Indexer Key
self.wk = DenseGeneral(
in_features_shape=self.emb_dim,
out_features_shape=self.head_dim,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "kv"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
# Key Normalization with Bias
self.k_norm = nnx.LayerNorm(num_features=self.head_dim, use_bias=True, dtype=self.weight_dtype, rngs=rngs)
# Projection: Input -> Importance Weights for Heads
# deepseek3.2 enforces FP32 and does not quantize, for precision and stability.
self.weights_proj = DenseGeneral(
in_features_shape=self.emb_dim,
out_features_shape=self.n_heads,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "q_heads"),
dtype=jnp.float32,
weight_dtype=jnp.float32,
quant=None,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
[docs]
def update_indexer_cache(self, kv_cache, k, decoder_segment_ids, model_mode, previous_chunk):
"""Updates Indexer buffers by processing KV cache results."""
k_expanded = k[:, :, jnp.newaxis, :]
p_res, a_res = kv_cache(
key=k_expanded,
value=k_expanded,
decoder_segment_ids=decoder_segment_ids,
model_mode=model_mode,
use_ragged_attention=self.config.use_ragged_attention,
previous_chunk=previous_chunk,
)
# Filter out None values to handle PREFILL vs AR modes uniformly
active_results = [res for res in [p_res, a_res] if res is not None]
if not active_results:
return None, None
# Extract keys (index 0) and segment IDs (index 2)
keys = jnp.concatenate([res[0] for res in active_results], axis=1)
segs = jnp.concatenate([res[2] for res in active_results], axis=1)
# squeeze(2) removes the jnp.newaxis added above
return keys.squeeze(2), segs
[docs]
def apply_partial_rope(
self,
inputs: Array,
inputs_positions: Optional[Array | None] = None,
):
"""Applies partial RoPE to the indexer query or key
The Indexer's RoPE implementation differs from MLA's in two key aspects:
1. Split Order: Indexer splits the head dimension into [rope, nope], whereas MLA uses [nope, rope].
2. Input Layout: Indexer uses concatenated layout (interleave=False), whereas MLA uses interleaved (interleave=True).
Args:
inputs: Input array of shape [batch, seqlen, indexer_n_heads, indexer_head_dim].
positions: Position array of shape [batch, seqlen].
Returns:
Array with partial RoPE applied, with shape [batch, seqlen, indexer_n_heads, indexer_head_dim]
"""
# indexer_head_dim -> [rope_head_dim, indexer_head_dim - rope_head_dim]
x_pe, x_nope = jnp.split(inputs, [self.rope_head_dim], axis=-1)
# x_pe [B, S, H, rope_head_dim], positions [B, S]
x_pe = self.rotary_embedding(x_pe, position=inputs_positions)
x = jnp.concatenate([x_pe, x_nope], axis=-1)
return x
[docs]
def generate_mask(self, topk_indices, s):
"""
Creates a mask for top-k indices.
Args:
topk_indices: [b, t, k] int - The indices to keep.
s: int - The total size to select from.
Returns:
mask: [b, t, s] - `0.0` at topk_indices, `DEFAULT_MASK_VALUE` (large negative) elsewhere.
"""
# 1. Create a range [0, 1, ..., s-1]
# 2. Broadcast compare against [b, t, k] to get [b, t, k, s]
# 3. Use .any() to see if a s-index is present in any of the k slots
is_topk = (jnp.arange(s) == topk_indices[..., None]).any(axis=-2)
# 4. Use where to select between 0.0 and the mask value
# cast values to dtype
val_true = jnp.array(0.0, dtype=self.dtype)
val_false = jnp.array(DEFAULT_MASK_VALUE, dtype=self.dtype)
return jnp.where(is_topk, val_true, val_false)
def __call__(
self,
inputs_q: Array,
low_rank_q: Array,
inputs_kv: Array,
inputs_positions: Optional[Array | None] = None,
attention_mask: Optional[Array | None] = None,
decoder_segment_ids: Optional[Array | None] = None,
previous_chunk: Any = None,
kv_cache: Any = None,
model_mode: str = MODEL_MODE_TRAIN,
):
"""Computes the index score to determine the top-k relevant tokens.
This uses a ReLU-based similarity for QK with MQA-style broadcasting (shared K).
It uses weighted aggregation over heads to produce a single score per token pair.
Steps:
1. Q = RoPE(Wq @ q_lora)
2. K = RoPE(Norm(Wk @ X))
3. Logits = ReLU(Q @ K.T) # Pairwise similarity
4. Head_Weights = (W_proj @ X) * scale # Dynamic head importance, scale for stability
5. Score = Logits @ Head_Weights # Aggregate heads
6. Indices = ArgTopk(Score)
Args:
inputs_q: Input of shape [b, t, embed_dim].
low_rank_q: Low-rank latent query representations of shape [b, t, q_lora_rank].
inputs_kv: Input of shape [b, s, embed_dim], same as inputs_q
inputs_positions: Position indices of shape [b, s].
attention_mask: Optional attention mask of shape [b, t, s].
Positions with `0.0` allow attention, while positions with
`DEFAULT_MASK_VALUE` (a large negative number) prevent it.
Returns `None` if no masking is determined to be necessary based on
the inputs and configuration.
decoder_segment_ids: Segment IDs for decoder masking.
previous_chunk: Previous chunk info for prefill.
kv_cache: Key-value cache used when serving models.
model_mode: "train", "prefill", or "autoregressive".
Returns:
indexer_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens
and large negative values otherwise.
topk_indices: Indices of the top-k selected tokens [b, t, k].
indexer_score: The computed relevance scores [b, t, s].
Notation:
b: Batch size
t: Query Sequence Length (Target), note t = s here
s: Key/Value Sequence Length (Source)
h: Number of Indexer Heads (indexer_n_heads)
d: Indexer Head Dimension (indexer_head_dim)
"""
bsz, seqlen, _ = inputs_q.shape # s = t = seqlen
# ==============================================================================
# Gradient Isolation Strategy: Main Model vs. Indexer
# ==============================================================================
# This creates a barrier to train both components independently, and applies
# for both Dense Warm-up and Sparse Training stages:
#
# Forward Pass:
# - The Indexer receives a detached copy of the inputs (via `stop_gradient`)
# to independently calculate its scores and `indexer_loss`.
#
# Backward Pass (Main Model):
# - The main model optimizes its weights based solely on the LM loss.
# - The `indexer_mask` in the Attention layer prevents gradients from the main
# loss from flowing into the Indexer's weights.
#
# Backward Pass (Indexer):
# - Gradients from the `indexer_loss` flow back to update the Indexer's weights.
# - The `stop_gradient` applied to the inputs acts as a mathematical wall, dropping
# gradients to 0.0 and preventing the Indexer loss from altering the main model's
# earlier layers.
inputs_q = jax.lax.stop_gradient(inputs_q)
low_rank_q = jax.lax.stop_gradient(low_rank_q)
inputs_kv = jax.lax.stop_gradient(inputs_kv)
# Query Processing: Project from Latent low_rank_q
q = self.wq_b(low_rank_q) # [b, t, q_lora_rank] -> [b, t, h * d]
q = q.reshape(bsz, seqlen, self.n_heads, self.head_dim) # [b, t, h, d]
q = self.apply_partial_rope(q, inputs_positions=inputs_positions)
# Key Processing: Project from Input
k = self.wk(inputs_kv) # [b, s, embed_dim] -> [b, s, d]
k = self.k_norm(k)
k = k[:, :, None, :] # [b, s, d] -> [b, s, 1, d]
k = self.apply_partial_rope(k, inputs_positions=inputs_positions)
k = k.squeeze(2) # [b, s, 1, d] -> [b, s, d]
# Update and retrieve from cache if not training
cached_s = None
if model_mode != MODEL_MODE_TRAIN:
k_cached, cached_s = self.update_indexer_cache(kv_cache, k, decoder_segment_ids, model_mode, previous_chunk)
k = k_cached if k_cached is not None else k
# NOTE: If the total available sequence length <= topk, indexer always selects all tokens.
if k.shape[1] <= self.indexer_topk:
return None, None, None
# Compute Index Scores
# QK product: relu(q @ k.T), [b, t, s, h]
# Similar to MQA, each key is shared by h query head
logits = jnp.einsum("bthd, bsd -> btsh", q, k, precision=self.config.matmul_precision)
logits = jax.nn.relu(logits)
# Compute head weights: project from input, [b, t, embed_dim] -> [b, t, h]
weights = self.weights_proj(inputs_q)
# Weights scaling affect indexer_score, but does not affect topk_indices. Keep scaling for numerical stability.
# https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480
weights = weights * (self.n_heads**-0.5) * self.softmax_scale
# Aggregate head-wise logits: logits @ weights
indexer_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]
internal_padding_mask = None
if cached_s is not None:
# cached_s marks valid tokens from the original prefill step and all subsequent AR steps
internal_padding_mask = jnp.where(cached_s > 0, 0.0, DEFAULT_MASK_VALUE)
indexer_score += internal_padding_mask[:, None, :]
# Apply attention mask before TopK
if attention_mask is not None:
indexer_score += attention_mask
# TopK selection based on index score
_, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk) # topk_indices [b, t, k]
# Create Sparse Index Mask: 0 and large negatives
indexer_mask = self.generate_mask(topk_indices, k.shape[1]) # [b, t, s]
# Re-apply attention mask after TopK: in case number of unmasked tokens < TopK
if attention_mask is not None:
indexer_mask += attention_mask
if internal_padding_mask is not None:
indexer_mask += internal_padding_mask[:, None, :]
return indexer_mask, topk_indices, indexer_score
[docs]
def mla_as_linen(
*,
config: Config,
num_query_heads: int,
num_kv_heads: int,
head_dim: int,
max_target_length: int,
mesh: Mesh,
attention_kernel: str,
inputs_q_shape: Tuple,
inputs_kv_shape: Tuple,
dtype: DType = jnp.float32,
weight_dtype: DType = jnp.float32,
max_prefill_predict_length: int = -1,
dropout_rate: float = 0.0,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
float32_qk_product: bool = False, # computes logits in float32 for stability.
float32_logits: bool = False, # cast logits in float32 for stability.
quant: Optional[Quant] = None,
kv_quant: Optional[KVQuant] = None,
attention_type: AttentionType = AttentionType.MLA, # Default to MLA attention
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_ragged_attention: bool = False,
ragged_block_size: int = 256,
use_qk_norm: bool = False,
query_pre_attn_scalar: float | None = None,
use_bias_in_projections: bool = False, # Set to True will enable bias in q, k, v, o projections
# Temperature tuning parameters used for Llama4
temperature_tuning: bool = False,
temperature_tuning_scale: float = 0.1,
temperature_tuning_floor_scale: float = 8192.0,
# Shard the query activation as the same as the key and value.
# TODO: Find a better sharding axis name.
# TODO: Further break down the Training and Inference axes for the q, k, v.
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH_ATTN, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH_ATTN, LENGTH, HEAD, D_KV),
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV),
prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3),
ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3),
compute_axis_order: AxisIdxes = (0, 1, 2, 3),
reshape_q: bool = False,
is_nope_layer: bool = False,
is_vision: bool = False,
model_mode: str = MODEL_MODE_TRAIN,
q_lora_rank: int = 0,
kv_lora_rank: int = 512,
qk_nope_head_dim: int = 128,
qk_rope_head_dim: int = 64,
v_head_dim: int = 128,
max_position_embeddings: int = 4096 * 4,
original_max_position_embeddings: int = 4096,
mscale: float = 1.0, # scaling factor for softmax
rope_factor: float = 40.0, # rotary embedding factor
name: str | None = None,
):
"""A factory function to create an MLA as a Linen module.
This function serves as a bridge to use the NNX-based `MLA` within a
Linen model.
"""
return nnx_wrappers.to_linen(
MLA,
config=config,
num_query_heads=num_query_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_target_length=max_target_length,
mesh=mesh,
attention_kernel=attention_kernel,
inputs_q_shape=inputs_q_shape,
inputs_kv_shape=inputs_kv_shape,
dtype=dtype,
weight_dtype=weight_dtype,
max_prefill_predict_length=max_prefill_predict_length,
dropout_rate=dropout_rate,
kernel_init=kernel_init,
float32_qk_product=float32_qk_product,
float32_logits=float32_logits,
quant=quant,
kv_quant=kv_quant,
attention_type=attention_type,
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
use_ragged_attention=use_ragged_attention,
ragged_block_size=ragged_block_size,
use_qk_norm=use_qk_norm,
query_pre_attn_scalar=query_pre_attn_scalar,
use_bias_in_projections=use_bias_in_projections,
temperature_tuning=temperature_tuning,
temperature_tuning_scale=temperature_tuning_scale,
temperature_tuning_floor_scale=temperature_tuning_floor_scale,
prefill_query_axis_names=prefill_query_axis_names,
prefill_key_axis_names=prefill_key_axis_names,
prefill_value_axis_names=prefill_value_axis_names,
query_axis_names=query_axis_names,
key_axis_names=key_axis_names,
value_axis_names=value_axis_names,
input_axis_names=input_axis_names,
out_axis_names=out_axis_names,
prefill_input_axis_names=prefill_input_axis_names,
decode_input_axis_names=decode_input_axis_names,
prefill_out_axis_names=prefill_out_axis_names,
decode_out_axis_names=decode_out_axis_names,
prefill_cache_axis_order=prefill_cache_axis_order,
ar_cache_axis_order=ar_cache_axis_order,
compute_axis_order=compute_axis_order,
reshape_q=reshape_q,
is_nope_layer=is_nope_layer,
is_vision=is_vision,
model_mode=model_mode,
q_lora_rank=q_lora_rank,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
max_position_embeddings=max_position_embeddings,
original_max_position_embeddings=original_max_position_embeddings,
mscale=mscale,
rope_factor=rope_factor,
name=name,
metadata_fn=variable_to_logically_partitioned,
abstract_init=False,
)
[docs]
class MLA(Attention):
"""Multi-Head Latent Attention (MLA) layer."""
def __init__(
self,
config: Config,
num_query_heads: int,
num_kv_heads: int,
head_dim: int,
max_target_length: int,
mesh: Mesh,
attention_kernel: str,
inputs_q_shape: Tuple,
inputs_kv_shape: Tuple,
dtype: DType = jnp.float32,
weight_dtype: DType = jnp.float32,
max_prefill_predict_length: int = -1,
dropout_rate: float = 0.0,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
float32_qk_product: bool = False, # computes logits in float32 for stability.
float32_logits: bool = False, # cast logits in float32 for stability.
quant: Optional[Quant] = None,
kv_quant: Optional[KVQuant] = None,
attention_type: AttentionType = AttentionType.MLA, # Default to MLA attention
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_ragged_attention: bool = False,
ragged_block_size: int = 256,
use_qk_norm: bool = False,
query_pre_attn_scalar: float | None = None,
use_bias_in_projections: bool = False, # Set to True will enable bias in q, k, v, o projections
# Temperature tuning parameters used for Llama4
temperature_tuning: bool = False,
temperature_tuning_scale: float = 0.1,
temperature_tuning_floor_scale: float = 8192.0,
# Shard the query activation as the same as the key and value.
# TODO: Find a better sharding axis name.
# TODO: Further break down the Training and Inference axes for the q, k, v.
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH_ATTN, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH_ATTN, LENGTH, HEAD, D_KV),
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
decode_out_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV),
prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3),
ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3),
compute_axis_order: AxisIdxes = (0, 1, 2, 3),
reshape_q: bool = False,
is_nope_layer: bool = False,
is_vision: bool = False,
model_mode: str = MODEL_MODE_TRAIN,
q_lora_rank: int = 0,
kv_lora_rank: int = 512,
qk_nope_head_dim: int = 128,
qk_rope_head_dim: int = 64,
v_head_dim: int = 128,
max_position_embeddings: int = 4096 * 4,
original_max_position_embeddings: int = 4096,
mscale: float = 1.0, # scaling factor for softmax
rope_factor: float = 40.0, # rotary embedding factor
name: str | None = None,
rngs: Optional[nnx.Rngs] = None,
):
"""Initializes the MLA module.
Args:
config: The model configuration.
... and other configuration parameters for MLA attention.
rngs: The random number generators for initialization, passed by the nnx.to_linen wrapper.
"""
base_kv_cache = config.attention != "paged" and config.mla_naive_kvcache
# Setting these before call to super because a field is used in super
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.mscale = mscale
self.rope_factor = rope_factor
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
super().__init__(
config=config,
num_query_heads=num_query_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_target_length=max_target_length,
mesh=mesh,
attention_kernel=attention_kernel,
inputs_q_shape=inputs_q_shape,
inputs_kv_shape=inputs_kv_shape,
dtype=dtype,
weight_dtype=weight_dtype,
max_prefill_predict_length=max_prefill_predict_length,
dropout_rate=dropout_rate,
kernel_init=kernel_init,
float32_qk_product=float32_qk_product,
float32_logits=float32_logits,
quant=quant,
kv_quant=kv_quant,
attention_type=attention_type,
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
use_ragged_attention=use_ragged_attention,
ragged_block_size=ragged_block_size,
use_qk_norm=use_qk_norm,
query_pre_attn_scalar=query_pre_attn_scalar,
use_bias_in_projections=use_bias_in_projections,
temperature_tuning=temperature_tuning,
temperature_tuning_scale=temperature_tuning_scale,
temperature_tuning_floor_scale=temperature_tuning_floor_scale,
prefill_query_axis_names=prefill_query_axis_names,
prefill_key_axis_names=prefill_key_axis_names,
prefill_value_axis_names=prefill_value_axis_names,
query_axis_names=query_axis_names,
key_axis_names=key_axis_names,
value_axis_names=value_axis_names,
input_axis_names=input_axis_names,
out_axis_names=out_axis_names,
prefill_input_axis_names=prefill_input_axis_names,
decode_input_axis_names=decode_input_axis_names,
prefill_out_axis_names=prefill_out_axis_names,
decode_out_axis_names=decode_out_axis_names,
prefill_cache_axis_order=prefill_cache_axis_order,
ar_cache_axis_order=ar_cache_axis_order,
compute_axis_order=compute_axis_order,
reshape_q=reshape_q,
is_nope_layer=is_nope_layer,
is_vision=is_vision,
model_mode=model_mode,
base_kv_cache=base_kv_cache,
rngs=rngs,
)
# Initialize Indexer
self.use_indexer = config.use_indexer
if self.use_indexer:
# Need two versions of rope.
# MLA applies yarn with interleave layout.
# Indexer applies yarn with concatenate layout.
indexer_rope = copy.copy(self.rotary_embedding)
indexer_rope.interleave = False
self.indexer = Indexer(
config,
rngs=rngs,
rotary_embedding=indexer_rope,
kernel_init=kernel_init,
quant=quant,
model_mode=model_mode,
)
self.IndexerKVCache_0 = self.init_indexer_cache(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
else:
self.indexer = None
self.IndexerKVCache_0 = None
# Module attribute names must match names previously passed to Linen for checkpointing
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
[docs]
def init_indexer_cache(self, inputs_kv_shape: Tuple):
"""Initializes Indexer Cache."""
batch_size, _, _ = inputs_kv_shape
# Use standard KVCache to store keys. Values are unused but required by KVCache API.
# KVCache expects key_heads and value_heads. Since k is shared (MQA-like for Indexer),
# we use key_heads=1, value_heads=1.
return kvcache.KVCache(
max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=batch_size,
key_seq_len=PLACEHOLDER_SEQ_LEN,
value_seq_len=PLACEHOLDER_SEQ_LEN,
key_heads=1,
value_heads=1,
key_head_size=self.config.indexer_head_dim,
value_head_size=self.config.indexer_head_dim,
dtype=self.dtype,
kv_quant=None, # Quantization is not yet supported by the indexer.
prefill_cache_logical_axis_names=(CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV),
cache_logical_axis_names=(CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV),
prefill_cache_axis_order=(1, 2, 0, 3),
ar_cache_axis_order=(1, 2, 0, 3),
use_chunked_prefill=self.config.use_chunked_prefill,
model_mode=self.model_mode,
rngs=self.rngs,
)
def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
"""Initializes the MLA-specific projections."""
# Assert required configuration parameters for MLA attention.
assert (
self.config.attention_type == AttentionType.MLA.value
), f"MLA requires MLA attention type {AttentionType.MLA.value}"
assert self.kv_lora_rank > 0, "KV LoRA rank must be > 0"
assert self.qk_nope_head_dim > 0, "QK NoPe head dim must be > 0"
assert self.qk_rope_head_dim > 0, "QK RoPE head dim must be > 0"
assert self.v_head_dim > 0, "V head dim must be > 0"
assert self.num_query_heads == self.num_kv_heads, "MLA requires equal number of query and kv heads"
assert not self.config.fused_qkv, "Fused QKV is not supported for MLA"
if self.q_lora_rank == 0:
# Standard Q projection (without LoRA).
self.query = DenseGeneral(
in_features_shape=self.config.emb_dim,
out_features_shape=(self.num_query_heads, self.qk_head_dim),
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "q_heads", "kv"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
else:
# LoRA path for Q.
self.wq_a = DenseGeneral(
in_features_shape=self.config.emb_dim,
out_features_shape=self.q_lora_rank,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "q_lora_up_proj"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
self.q_norm = RMSNorm(
num_features=self.q_lora_rank,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
epsilon=self.config.normalization_layer_epsilon,
kernel_axes=("norm",),
rngs=self.rngs,
)
self.wq_b = DenseGeneral(
in_features_shape=self.q_lora_rank,
out_features_shape=(self.num_query_heads, self.qk_head_dim),
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("q_lora", "q_heads", "kv"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
# KV LoRA path.
self.wkv_a = DenseGeneral(
in_features_shape=self.config.emb_dim,
out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "kv_lora_up_proj"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
self.kv_norm = RMSNorm(
num_features=self.kv_lora_rank,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
epsilon=self.config.normalization_layer_epsilon,
kernel_axes=("norm",),
rngs=self.rngs,
)
self.wkv_b = DenseGeneral(
in_features_shape=self.kv_lora_rank,
out_features_shape=(
self.num_query_heads,
(self.qk_nope_head_dim + self.v_head_dim),
),
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("kv_lora", "kv_heads", "kv_head_dim"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
# Set softmax scaling.
self.softmax_scale = self.qk_head_dim**-0.5
if self.max_position_embeddings > self.original_max_position_embeddings:
mscale = 0.1 * self.mscale * math.log(self.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
self.out = self.init_out_w(output_dim=inputs_q_shape[-1])
# Setup paged attention op
if self.config.attention == "paged":
# Set head_dim to the max of qk_head_dim and v_head_dim. The current paged
# attention kernel requires the head_dim to be the same for q, k, v.
head_dim = max(self.qk_head_dim, self.v_head_dim)
# Align head_dim to the pagedattn_head_dim_alignment if specified.
if self.config.pagedattn_head_dim_alignment > 0:
alignment = self.config.pagedattn_head_dim_alignment
head_dim = (head_dim + alignment - 1) // alignment * alignment
self.ds_paged_attention_op = paged_attention.PagedAttentionOp(
mesh=self.mesh,
num_pages=self.config.pagedattn_num_pages,
tokens_per_page=self.config.pagedattn_tokens_per_page,
max_pages_per_slot=(self.config.max_target_length + self.config.pagedattn_tokens_per_page - 1)
// self.config.pagedattn_tokens_per_page,
max_pages_per_prefill=(self.config.max_prefill_predict_length + self.config.pagedattn_tokens_per_page - 1)
// self.config.pagedattn_tokens_per_page,
pages_per_compute_block=self.config.pagedattn_pages_per_compute_block,
num_kv_heads=self.num_kv_heads,
kv_head_dim_size=head_dim,
dtype=self.dtype,
attn_logits_soft_cap=self.attn_logits_soft_cap,
rngs=self.rngs,
)
@property
def out_head_dim(self) -> int:
return self.v_head_dim
[docs]
def mla_query_projection(
self, inputs_q: Array, inputs_positions: Array, model_mode
) -> tuple[jax.Array, Optional[jax.Array]]:
"""Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0."""
# specify query logical name
if model_mode == MODEL_MODE_PREFILL:
query_logical_name = self.prefill_query_axis_names
wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ)
else:
query_logical_name = self.query_axis_names
wqa_logical_name = (KV_BATCH, LENGTH, Q_LORA_UP_PROJ)
query_sharding = create_sharding(self.mesh, query_logical_name)
wqa_out_sharding = create_sharding(self.mesh, wqa_logical_name)
# Set softmax scaling.
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.softmax_scale = self.qk_head_dim**-0.5
if self.max_position_embeddings > self.original_max_position_embeddings:
mscale = 0.1 * self.mscale * math.log(self.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
# Low-rank latent vector for queries. This is also accessed by indexer.
low_rank_q = None
if self.q_lora_rank == 0:
q = self.query(inputs_q, out_sharding=query_sharding)
else:
# LoRA path
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
low_rank_q = checkpoint_name(low_rank_q, "query_wa_proj")
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
low_rank_q = checkpoint_name(low_rank_q, "mla_q")
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim]
# Partial RoPE: Split into non-positional and rotary parts.
# last dimension: qk_nope_head_dim, qk_rope_head_dim
q_nope, q_pe = jnp.split(q, [self.qk_nope_head_dim], axis=-1)
q_nope = self._maybe_shard_with_logical(q_nope, query_logical_name)
q_pe = self.apply_rotary_embedding(q_pe, inputs_positions=inputs_positions)
q_pe = self._maybe_shard_with_logical(q_pe, query_logical_name)
# Query projection is scaled by self.softmax_scale to be consistent MaxText implementation.
# DeepSeek v3 was doing it in attention score computation.
query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale
query = self._maybe_shard_with_logical(query, query_logical_name)
return query, low_rank_q
[docs]
def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
"""get (key,value) pair from mla"""
if model_mode == MODEL_MODE_PREFILL:
key_logical_name = self.prefill_key_axis_names
value_logical_name = self.prefill_value_axis_names
else:
key_logical_name = self.key_axis_names
value_logical_name = self.value_axis_names
wkva_out_sharding = create_sharding(self.mesh, key_logical_name)
kv_out = self.wkv_b(low_rank_main, out_sharding=wkva_out_sharding)
# Split kv_out into key_nope and value parts.
key_nope, value = jnp.split(kv_out, [self.qk_nope_head_dim], axis=-1)
key_rope = jnp.broadcast_to(key_rope, (key_nope.shape[0], key_nope.shape[1], self.num_query_heads, key_rope.shape[3]))
key_nope = self._maybe_shard_with_logical(key_nope, key_logical_name)
key_rope = self._maybe_shard_with_logical(key_rope, key_logical_name)
key = jnp.concatenate([key_nope, key_rope], axis=-1)
key = self._maybe_shard_with_logical(key, key_logical_name)
value = self._maybe_shard_with_logical(value, value_logical_name)
return key, value
[docs]
def init_mla_kv_caches(self, inputs_kv_shape: Tuple):
"""Initializes MlaKVCache.
Args:
inputs_kv_shape: Key/value inputs shape for initialization.
Returns:
An MlaKVCache module instance.
Raises:
ValueError: If the configuration is invalid.
"""
batch_size, _, _ = inputs_kv_shape
# During initialization, seq_len of inputs_kv is max_target_length,
# which is not always correct for some functions in MlaKVCache.
# However, MlaKVCache internal cache shapes are based on max_prefill_length
# and max_target_length, not the passed seq_len.
# We can use a placeholder value. The correct fix might involve refactoring
# MlaKVCache.
return kvcache.MlaKVCache(
max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=batch_size,
key_seq_len=PLACEHOLDER_SEQ_LEN,
value_seq_len=PLACEHOLDER_SEQ_LEN,
key_head_size=self.kv_lora_rank,
value_head_size=self.qk_rope_head_dim,
dtype=self.dtype,
kv_quant=self.kv_quant,
prefill_cache_axis_order=self.prefill_cache_axis_order,
ar_cache_axis_order=self.ar_cache_axis_order,
model_mode=self.model_mode,
use_chunked_prefill=self.config.use_chunked_prefill,
rngs=self.rngs,
)
[docs]
def update_mla_kv_caches(self, low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk=None):
"""Updates the MLA (Multi-Head Latent Attention) KV caches.
This method is specific to the MLA attention mechanism. It calls the
`mla_kv_cache_as_linen` module to update and retrieve the caches, which
store latent representations (`low_rank_main`) and RoPE-applied keys
(`key_rope`). It then reconstructs the full key and value tensors from
the cached components.
Args:
low_rank_main: The main latent component of the key.
key_rope: The RoPE-applied component of the key.
decoder_segment_ids: Segment IDs for decoder masking.
model_mode: The operational mode ('train', 'prefill', 'autoregressive').
previous_chunk: Information about previously processed chunks, for
chunked prefill.
Returns:
A list containing two elements:
- The prefill key-value cache, reconstructed from the MLA cache, or None.
- The autoregressive key-value cache, reconstructed from the MLA cache, or None.
"""
prefill_mla_cache, ar_mla_cache = self.MlaKVCache_0(
key_latent=low_rank_main,
key_rope=key_rope,
decoder_segment_ids=decoder_segment_ids,
model_mode=model_mode,
use_ragged_attention=self.use_ragged_attention,
previous_chunk=previous_chunk,
)
if prefill_mla_cache:
low_rank_main, key_rope, decoder_segment_ids = prefill_mla_cache
key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
prefill_kv_cache = key, value, decoder_segment_ids
else:
prefill_kv_cache = None
if ar_mla_cache:
low_rank_main, key_rope, decoder_segment_ids, lengths = ar_mla_cache
key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
ar_kv_cache = key, value, decoder_segment_ids, lengths
else:
ar_kv_cache = None
return [prefill_kv_cache, ar_kv_cache]
[docs]
def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segment_ids, model_mode, previous_chunk):
"""MLA key/value projection with integrated rotary embedding."""
if model_mode == MODEL_MODE_PREFILL:
wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ)
else:
wka_logical_name = (KV_BATCH, LENGTH, KV_LORA_UP_PROJ)
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
low_rank = checkpoint_name(low_rank, "kv_wa_proj")
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
low_rank_main = self.kv_norm(low_rank_main)
low_rank_main = checkpoint_name(low_rank_main, "mla_kv")
# Apply rotary embedding to key_rope.
key_rope = jnp.expand_dims(low_rank_rope, axis=2)
key_rope = self.apply_rotary_embedding(key_rope, inputs_positions=inputs_positions)
key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
cached_values = [None, None]
if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN:
if self.config.mla_naive_kvcache:
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
else:
cached_values = self.update_mla_kv_caches(
low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk
)
return key, value, cached_values
[docs]
def calculate_indexer_loss(
self,
indexer_score: Array,
query: Array,
key: Array,
attention_mask: Optional[Array | None],
indexer_mask: Array,
sparse_loss: bool,
scaling_factor: float,
) -> Array:
"""Calculates the indexer KL divergence loss.
This loss trains the indexer to predict which tokens are important by matching
the distribution of true attention scores from the main model.
The target distribution is derived through the following steps:
1. Compute raw attention scores via Q @ K^T.
2. Aggregate scores by summing across all attention heads.
3. Apply L1-normalization across the sequence dimension.
target_distribution = L1_Normalize(Sum_h(Softmax(Q @ K^T)))
Reference:
DeepSeek-V3.2 - https://arxiv.org/pdf/2512.02556
Args:
indexer_score: Scores predicted by indexer [batch, q_len, kv_len].
query: Query tensor from main model [batch, q_len, heads, dim].
key: Key tensor from main model [batch, kv_len, heads, dim].
attention_mask: Attention mask [batch, q_len, kv_len] or None.
indexer_mask: Indexer mask [batch, q_len, kv_len].
sparse_loss: Whether to use sparse loss.
scaling_factor: The scaling factor for the loss.
Returns:
The computed KL divergence loss.
"""
# Detach main model components from the computational graph.
# The indexer should match the main model, but the main model should not be influenced
# by the indexer's learning progress via this loss in sparse training stage.
# We also apply this during the Dense Warm-up stage to save compute and memory.
query = jax.lax.stop_gradient(query)
key = jax.lax.stop_gradient(key)
# Compute attention scores: [b, t, h, d] @ [b, s, h, d] -> [b, h, t, s]
attention_scores = jnp.einsum("bthd, bshd -> bhts", query, key, precision=self.config.matmul_precision)
if sparse_loss:
# indexer_mask is already pre-filtered with the attention_mask if any
attention_scores = attention_scores + indexer_mask[:, None, :, :]
indexer_score = indexer_score + indexer_mask
elif attention_mask is not None:
# indexer_score already applies attention_mask; updating attention_scores only
attention_scores = attention_scores + attention_mask[:, None, :, :]
# Use float32 for softmax numerical stability.
attention_probs = jax.nn.softmax(attention_scores.astype(jnp.float32), axis=-1)
indexer_probs = jax.nn.softmax(indexer_score.astype(jnp.float32), axis=-1)
# Aggregate heads: [b, h, t, s] -> [b, t, s]
attention_probs = jnp.sum(attention_probs, axis=1)
# L1 normalize aggregated target distribution
attention_probs = attention_probs / (jnp.sum(attention_probs, axis=-1, keepdims=True) + EPS)
# KL Divergence: KL(attention || indexer)
log_attention_probs = jnp.log(attention_probs + EPS)
log_indexer_probs = jnp.log(indexer_probs + EPS)
kl_per_token = attention_probs * (log_attention_probs - log_indexer_probs)
indexer_loss = jnp.mean(jnp.sum(kl_per_token, axis=-1))
return indexer_loss * scaling_factor
def __call__(
self,
inputs_q: Array,
inputs_kv: Array,
inputs_positions: Array | None = None,
decoder_segment_ids: Array | None = None,
out_sharding: NamedSharding | None = None,
*,
model_mode: str = MODEL_MODE_TRAIN,
deterministic: bool = False,
previous_chunk: Any = None,
slot: Optional[int] = None,
page_state: Optional[page_manager.PageState] = None,
bidirectional_mask: Optional[Any] = None,
rope_kwargs: dict | None = None,
kv_cache: Optional[Array] = None,
attention_metadata: Optional[dict[str, Any]] = None,
) -> tuple[Array, Optional[Array]]:
"""Forward pass for MLA, reusing `AttentionOp` for the actual attention.
Args:
inputs_q: Query input [batch, q_length, embed_dim].
inputs_kv: KV input [batch, kv_length, embed_dim].
inputs_positions: Positions for rotary embeddings or similar.
decoder_segment_ids: Segment IDs for masking, if any.
model_mode: "train", "prefill", or "autoregressive".
deterministic: Disables dropout if set to True.
previous_chunk: Information about previously processed chunks for chunked prefill.
slot: The batch slot index for paged attention.
page_state: The current state of the paged attention manager.
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
kv_cache: Optional key-value cache used when serving models with vLLM.
attention_metadata: Optional attention-related metadata used when serving models with vLLM.
Returns:
A tensor of shape [batch, length, embed_dim] containing the
MLA-attended outputs.
"""
if model_mode == MODEL_MODE_PREFILL:
inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names)
out_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV)
else:
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
out_logical_name = (BATCH_ATTN, LENGTH, HEAD, D_KV)
if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None:
decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32)
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
if self.config.force_q_layout:
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
key, value, cached_values = self.mla_kv_projection(
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
)
query = checkpoint_name(query, "query_proj")
key = checkpoint_name(key, "key_proj")
value = checkpoint_name(value, "value_proj")
# Indexer Logic
indexer_mask = None
if self.use_indexer:
# generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
attention_mask = self.attention_op.generate_attention_mask(
query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask
)
if attention_mask is not None:
attention_mask = attention_mask.squeeze(axis=(1, 2))
# apply indexer, indexer_mask [b, q_len, kv_len]
indexer_mask, _, indexer_score = self.indexer(
inputs_q=inputs_q,
low_rank_q=low_rank_q,
inputs_kv=inputs_kv,
inputs_positions=inputs_positions,
attention_mask=attention_mask,
decoder_segment_ids=decoder_segment_ids,
previous_chunk=previous_chunk,
kv_cache=self.IndexerKVCache_0,
model_mode=model_mode,
)
if indexer_mask is not None and self.config.indexer_loss_scaling_factor > 0.0:
indexer_loss = self.calculate_indexer_loss(
indexer_score=indexer_score,
query=query,
key=key,
attention_mask=attention_mask,
indexer_mask=indexer_mask,
sparse_loss=self.config.indexer_sparse_training,
scaling_factor=self.config.indexer_loss_scaling_factor,
)
self.indexer_loss = nnx.Intermediate(indexer_loss)
# Check if we need QK Clip stats
use_qk_clip = self.model_mode == MODEL_MODE_TRAIN and self.config.use_qk_clip
if self.config.attention == "paged" and model_mode != MODEL_MODE_TRAIN:
unnormalized_out, _, exp_sum = self.ds_paged_attention_op(
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
)
unnormalized_out = unnormalized_out[..., : self.v_head_dim]
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
else:
out = self.attention_op(
query,
key,
value,
decoder_segment_ids,
inputs_positions,
model_mode,
cached_values,
indexer_mask=indexer_mask,
record_max_logits=use_qk_clip,
)
out = self._maybe_shard_with_logical(out, self.out_axis_names)
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
out_sharding = create_sharding(self.mesh, out_logical_name)
out = self.out_projection(out, out_sharding=out_sharding)
out = checkpoint_name(out, "out_proj")
return out, kv_cache