Source code for maxtext.layers.attention_op

#  Copyright 2023 Google LLC
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#  pytype: disable=module-attr
"""Attentions Ops Layers."""

import dataclasses
import functools
from functools import partial
import math
from typing import Any, Callable, Optional, Tuple

from flax import linen as nn
from flax import nnx
from flax.linen import partitioning
import jax
from jax import lax
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import pallas as pl
from jax.experimental.pallas.ops.gpu import attention as gpu_pallas_attention
from jax.experimental.pallas.ops.gpu import decode_attention as gpu_pallas_decode_attention
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import (
    Array,
    AttentionType,
    AxisIdxes,
    AxisNames,
    BATCH_ATTN,
    CACHE_BATCH,
    CACHE_BATCH_PREFILL,
    CACHE_HEADS,
    CACHE_KV,
    CACHE_SCALE_BATCH,
    CACHE_SCALE_HEADS,
    CACHE_SCALE_KV,
    CACHE_SCALE_SEQUENCE,
    CACHE_SEQUENCE,
    Config,
    DECODE_BATCH,
    DECODE_LENGTH,
    DECODING_ACTIVE_SEQUENCE_INDICATOR,
    DEFAULT_MASK_VALUE,
    DType,
    D_KV,
    HEAD,
    KV_LENGTH,
    LENGTH,
    MODEL_MODE_AUTOREGRESSIVE,
    MODEL_MODE_PREFILL,
    MODEL_MODE_TRAIN,
    PREFILL_LENGTH,
    Q_LENGTH,
)
from maxtext.inference import page_manager
from maxtext.inference.kvcache import KVQuant, KVTensor
from maxtext.kernels.attention import jax_flash_attention
from maxtext.kernels.attention.ragged_attention import ragged_gqa
from maxtext.kernels.attention.ragged_attention import ragged_mha
from maxtext.layers import nnx_wrappers
from maxtext.layers.initializers import variable_to_logically_partitioned
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.utils import max_utils
from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec
import numpy as np
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask
# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
# pytype: disable=attribute-error

# Used to pass in splash attention block sizes from config.
global_block_q = 0
global_block_kv = 0
global_block_kv_compute = 0
global_block_q_dkv = 0
global_block_kv_dkv = 0
global_block_kv_dkv_compute = 0
global_block_q_dq = 0
global_block_kv_dq = 0
global_use_fused_bwd_kernel = False
global_q_layout = ""
global_k_layout = ""
global_v_layout = ""

dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


[docs] def validate_compute_axis_order(s: AxisIdxes) -> None: valid_compute_axis_order = ((0, 1, 2, 3), (0, 2, 1, 3)) if s not in valid_compute_axis_order: # currently supported compute_axis_order raise ValueError( "Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order, )
[docs] def apply_mask_to_logits(logits: Array, mask: Array): """Applies a floating-point mask to a set of logits. The mask is represented as a tensor with some dtype where 0 represents true and values below a large negative number (here set to get_large_negative_number(logits.dtype) / 2) represent false. Applying the mask leaves the logits alone in the true case and replaces them by get_large_negative_number(logits.dtype) in the false case. Previously, this was done by adding the logits to the mask; however, this leads to a bad fusion decision in the compiler that saves the values in memory rather than just the predicate. This implementation avoids that problem. from https://github.com/google/praxis/blob/4712a6b9ee13e224b86e235ff55f7c6bab9fbab3/praxis/py_utils.py#L706 Args: logits: A JTensor of logit values. mask: A JTensor of mask values with the encoding described in the function documentation. Returns: Masked logits. """ return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE)
[docs] def validate_gpu_flash_attention(sinks: Array | None, record_max_logits: bool) -> None: """Helper function to check for unsupported features with flash attention on GPU.""" if sinks is not None: raise ValueError("The flash attention with sinks is not supported on GPU yet.") if record_max_logits: raise NotImplementedError("record_max_logits (QK-Clip) is not supported for GPU flash attention kernels yet.")
# TODO(agagik): change splash_attention_mask._ComputableMask to be non protected
[docs] class ChunkedCausalMask(splash_attention_mask._ComputableMask): # pylint: disable=protected-access """Lazy chunked causal mask. Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens attend to each other but not across chunks. Llama4 models use interleaved chunk attention along with global attention. This mask class inherits from splash_attention_mask._ComputableMask and is designed to be used with Splash Attention. It allows the mask logic to be computed on-the-fly or fused into the attention kernel, avoiding the memory cost of materializing the full (sequence_length, sequence_length) boolean mask array, which can be prohibitive for long sequences. """ #: The size of each attention chunk. chunk_size: int def __init__( self, shape: tuple[int, int], chunk_size: int, shard_count: int = 1, ): if chunk_size <= 0: raise ValueError("chunk_size must be positive") self.chunk_size = chunk_size # Define the mask function for chunk attention def chunked_causal_mask_function(q_ids, kv_ids): """Computes the mask logic for the given slice indices.""" if q_ids.size == 0 or kv_ids.size == 0: return np.empty((q_ids.shape[0], kv_ids.shape[1]), dtype=np.bool_) # Condition 1: Same chunk q_chunk = q_ids // self.chunk_size kv_chunk = kv_ids // self.chunk_size same_chunk = q_chunk == kv_chunk # Condition 2: Causal causal = q_ids >= kv_ids return same_chunk & causal # Initialize the parent ComputableMask with this function super().__init__( shape=shape, mask_function=chunked_causal_mask_function, shard_count=shard_count, ) # Implement equality and hashing based on relevant attributes def __eq__(self, other: object): if not isinstance(other, type(self)): return NotImplemented # Compare shape, chunk_size, and the underlying q_sequence array return ( self.shape == other.shape and self.chunk_size == other.chunk_size and np.array_equal(self.q_sequence, other.q_sequence) ) def __hash__(self): return hash( ( type(self), self.shape, self.chunk_size, self.q_sequence.tobytes() if self.q_sequence is not None else None, ) )
def _generate_chunk_attention_mask(mask_shape: tuple[int, int], chunk_size: int, q_offset: int = 0) -> jax.Array: """Generates an explicit boolean mask for chunked causal attention. This function computes the full boolean mask array where True indicates attention is allowed based on chunked causal rules (tokens attend only within the same chunk, and causally within that chunk). Args: mask_shape: The desired shape of the mask (q_seq_len, kv_seq_len). chunk_size: The size of the attention chunks. Returns: A boolean mask of shape `mask_shape` where True indicates attention is allowed according to chunked causal rules, and False otherwise. Raises: ValueError: If chunk_window_size is None or not positive. """ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + q_offset col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) if chunk_size <= 0: raise ValueError("chunk_size must be positive") # chunk mask calculation same_chunk = (row_ids // chunk_size) == (col_ids // chunk_size) chunk_mask = same_chunk & (row_ids >= col_ids) return chunk_mask def _make_block_mask_indices(bidirectional_mask): """Creates block mask identifying segments based on a bidirectional mask. Args: bidirectional_mask: boolean mask, e.g. [011110011010]. Returns: block mask for segments, e.g. [011110022030]. """ # Left pad 0. padded_mask = jnp.pad(bidirectional_mask, [(0, 0), (1, 0)], constant_values=0) boundary = padded_mask[..., 1:] > padded_mask[..., :-1] numbered_boundary = jnp.cumsum(boundary, axis=-1) return bidirectional_mask * numbered_boundary def _make_bidirectional_block_mask(bidirectional_mask): """Creates bidirectional block mask from bidirectional_mask, where True corresponds to image tokens. bidirectional_mask shape: [B, L] bidirectional_block_mask shape: [B, L, L] Examples: bidirectional_mask = [[0, 1, 1, 1, 0, 0]] bidirectional_block_mask = [[ [False, False, False, False, False, False], [False, True, True, True, False, False], [False, True, True, True, False, False], [False, True, True, True, False, False], [False, False, False, False, False, False], [False, False, False, False, False, False], ]] """ q_block_indices = _make_block_mask_indices(bidirectional_mask) kv_block_indices = q_block_indices bidirectional_block_mask = (kv_block_indices[:, None, :] == q_block_indices[..., None]) & ( q_block_indices[..., None] > 0 ) return bidirectional_block_mask
[docs] def attention_op_as_linen( *, config: Config, mesh: Mesh, attention_kernel: str, max_target_length: int, num_query_heads: int, num_kv_heads: int, float32_qk_product: bool = False, max_prefill_predict_length: int = -1, float32_logits: bool = False, flash_axis_names_q: AxisNames = (BATCH_ATTN, HEAD, LENGTH, D_KV), flash_axis_names_kv: AxisNames = (BATCH_ATTN, HEAD, KV_LENGTH, D_KV), flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH), prefill_cache_logical_axis_names: AxisNames = ( CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, ), cache_logical_axis_names: AxisNames = ( CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, ), cache_scale_logical_axis_names: AxisNames = ( CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV, ), ragged_qkv_axis_names: AxisNames = ( CACHE_BATCH, CACHE_HEADS, CACHE_SEQUENCE, CACHE_KV, ), ragged_lengths_names: AxisNames = (CACHE_BATCH,), compute_axis_order: AxisIdxes = (0, 1, 2, 3), key_axis_order: AxisIdxes = (2, 0, 1, 3), reshape_q: bool = False, dropout_rate: float = 0.0, dtype: DType = jnp.float32, quant: Optional[Quant] = None, kv_quant: Optional[KVQuant] = None, attention_type: AttentionType = AttentionType.GLOBAL, # Default to global attention attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, chunk_attn_window_size: int | None = None, use_ragged_attention: bool = False, ragged_block_size: int = 256, ): """A factory function to create an AttentionOp as a Linen module. This function serves as a bridge to use the NNX-based `AttentionOp` within a Linen model. """ return nnx_wrappers.to_linen( AttentionOp, config=config, mesh=mesh, attention_kernel=attention_kernel, max_target_length=max_target_length, num_query_heads=num_query_heads, num_kv_heads=num_kv_heads, float32_qk_product=float32_qk_product, max_prefill_predict_length=max_prefill_predict_length, float32_logits=float32_logits, flash_axis_names_q=flash_axis_names_q, flash_axis_names_kv=flash_axis_names_kv, flash_axis_names_splash_kernel=flash_axis_names_splash_kernel, prefill_cache_logical_axis_names=prefill_cache_logical_axis_names, cache_logical_axis_names=cache_logical_axis_names, cache_scale_logical_axis_names=cache_scale_logical_axis_names, ragged_qkv_axis_names=ragged_qkv_axis_names, ragged_lengths_names=ragged_lengths_names, compute_axis_order=compute_axis_order, key_axis_order=key_axis_order, reshape_q=reshape_q, dropout_rate=dropout_rate, dtype=dtype, quant=quant, kv_quant=kv_quant, attention_type=attention_type, attn_logits_soft_cap=attn_logits_soft_cap, sliding_window_size=sliding_window_size, chunk_attn_window_size=chunk_attn_window_size, use_ragged_attention=use_ragged_attention, ragged_block_size=ragged_block_size, metadata_fn=variable_to_logically_partitioned, )
[docs] class AttentionOp(nnx.Module): """Attention operation""" def __init__( self, config: Config, mesh: Mesh, attention_kernel: str, max_target_length: int, num_query_heads: int, num_kv_heads: int, float32_qk_product: bool = False, max_prefill_predict_length: int = -1, float32_logits: bool = False, flash_axis_names_q: AxisNames = (BATCH_ATTN, HEAD, LENGTH, D_KV), flash_axis_names_kv: AxisNames = (BATCH_ATTN, HEAD, KV_LENGTH, D_KV), flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH), prefill_cache_logical_axis_names: AxisNames = ( CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, ), cache_logical_axis_names: AxisNames = ( CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, ), cache_scale_logical_axis_names: AxisNames = ( CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV, ), ragged_qkv_axis_names: AxisNames = ( CACHE_BATCH, CACHE_HEADS, CACHE_SEQUENCE, CACHE_KV, ), ragged_lengths_names: AxisNames = (CACHE_BATCH,), compute_axis_order: AxisIdxes = (0, 1, 2, 3), key_axis_order: AxisIdxes = (2, 0, 1, 3), reshape_q: bool = False, dropout_rate: float = 0.0, dtype: DType = jnp.float32, quant: Optional[Quant] = None, kv_quant: Optional[KVQuant] = None, attention_type: AttentionType = AttentionType.GLOBAL, # Default to global attention attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, chunk_attn_window_size: int | None = None, use_ragged_attention: bool = False, ragged_block_size: int = 256, rngs: nnx.Rngs | None = None, ): """Initializes the AttentionOp module. Args: config: The configuration for the model. mesh: The device mesh. attention_kernel: The attention kernel to use. max_target_length: The maximum target length. num_query_heads: The number of query heads. num_kv_heads: The number of key/value heads. float32_qk_product: Whether to compute qk_product in float32. max_prefill_predict_length: The maximum prefill predict length. float32_logits: Whether to compute logits in float32. flash_axis_names_kv: The logical axis names for the KV cache in flash attention. flash_axis_names_q: The logical axis names for the query in flash attention. flash_axis_names_splash_kernel: The logical axis names for the splash attention kernel. prefill_cache_logical_axis_names: The logical axis names for the prefill cache. cache_logical_axis_names: The logical axis names for the cache. cache_scale_logical_axis_names: The logical axis names for the cache scale. ragged_qkv_axis_names: The logical axis names for ragged QKV tensors. ragged_lengths_names: The logical axis names for ragged lengths. compute_axis_order: The order of axes for computation. key_axis_order: The order of axes for the key. ... and other configuration parameters. rngs: The random number generators for initialization, passed by the nnx.to_linen wrapper. """ self.config = config self.mesh = mesh self.attention_kernel = attention_kernel self.max_target_length = max_target_length self.num_query_heads = num_query_heads self.num_kv_heads = num_kv_heads self.float32_qk_product = float32_qk_product self.max_prefill_predict_length = max_prefill_predict_length self.float32_logits = float32_logits self.flash_axis_names_q = flash_axis_names_q self.flash_axis_names_kv = flash_axis_names_kv self.flash_axis_names_splash_kernel = flash_axis_names_splash_kernel self.prefill_cache_logical_axis_names = prefill_cache_logical_axis_names self.cache_logical_axis_names = cache_logical_axis_names self.cache_scale_logical_axis_names = cache_scale_logical_axis_names self.ragged_qkv_axis_names = ragged_qkv_axis_names self.ragged_lengths_names = ragged_lengths_names self.compute_axis_order = compute_axis_order self.key_axis_order = key_axis_order self.reshape_q = reshape_q self.dropout_rate = dropout_rate self.dtype = dtype self.quant = quant self.kv_quant = kv_quant self.attention_type = attention_type self.attn_logits_soft_cap = attn_logits_soft_cap self.sliding_window_size = sliding_window_size self.chunk_attn_window_size = chunk_attn_window_size self.use_ragged_attention = use_ragged_attention self.ragged_block_size = ragged_block_size self.rngs = rngs def maybe_create_nnx(einsum, *args): if isinstance(einsum, nn.Module): return nnx_wrappers.ToNNX(einsum, rngs=rngs).lazy_init(*args) return einsum # qk_product if self.kv_quant: # Dummy inputs for lazy initialization b = 1 t_prefill = self.max_prefill_predict_length t_ar = 1 # Autoregressive mode has a query length of 1 n = self.num_query_heads n_kv = self.num_kv_heads d = self.config.head_dim g = n // n_kv s_prefill = self.max_prefill_predict_length s_ar = self.max_target_length # Dummy query/key/value shapes as before... dummy_query_prefill = jnp.zeros((b, t_prefill, n_kv, g, d), dtype=self.dtype) dummy_key_prefill = jnp.zeros((b, s_prefill, n_kv, d), dtype=self.dtype) dummy_query_ar = jnp.zeros((b, t_ar, n_kv, g, d), dtype=self.dtype) dummy_key_ar = jnp.zeros((b, s_ar, n_kv, d), dtype=self.dtype) dummy_attn_weights_prefill = jnp.zeros((b, n_kv, g, t_prefill, s_prefill), dtype=jnp.float32) dummy_value_prefill = jnp.zeros((b, s_prefill, n_kv, d), dtype=self.dtype) dummy_attn_weights_ar = jnp.zeros((b, n_kv, g, t_ar, s_ar), dtype=jnp.float32) dummy_value_ar = jnp.zeros((b, s_ar, n_kv, d), dtype=self.dtype) # Prefill AqtEinsum instances self.AqtEinsum_0 = maybe_create_nnx( self.kv_quant.einsum_fn_with_rhs_qtensor(), "btkgd,bskd->bkgts", dummy_query_prefill, dummy_key_prefill, ) self.AqtEinsum_1 = maybe_create_nnx( self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(), "bkgts,bskd->btkgd", dummy_attn_weights_prefill, dummy_value_prefill, ) # Autoregressive AqtEinsum instances self.AqtEinsum_2 = maybe_create_nnx( self.kv_quant.einsum_fn_with_rhs_qtensor(), "btkgd,bskd->bkgts", dummy_query_ar, dummy_key_ar, ) self.AqtEinsum_3 = maybe_create_nnx( self.kv_quant.einsum_fn_with_rhs_qtensor_and_dequant(), "bkgts,bskd->btkgd", dummy_attn_weights_ar, dummy_value_ar, ) else: self.AqtEinsum_0 = jnp.einsum self.AqtEinsum_1 = jnp.einsum self.AqtEinsum_2 = jnp.einsum self.AqtEinsum_3 = jnp.einsum def _logical_to_mesh_axes(self, logical_name): logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
[docs] def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None: """Check attention inputs.""" assert key.ndim == value.ndim, f"k (dim {key.ndim}), v (dim {value.ndim}) must have same rank." assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( f"{query.shape[:-3]=}, {key.shape[:-3]=}, {value.shape[:-3]=} batch" " dims must match." ) assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." assert key.shape[-3] == value.shape[-3], "k, v lengths must match." assert query.shape[-1] == key.shape[-1], "q, k depths must match."
def _maybe_shard_with_pspec(self, inputs, pspec: jax.sharding.PartitionSpec | None): return maybe_shard_with_pspec( inputs, pspec, mesh=self.mesh, shard_mode=self.config.shard_mode, debug_sharding=self.config.debug_sharding, extra_stack_level=1, )
[docs] def generate_attention_mask( self, query, key, decoder_segment_ids: Array | None, model_mode: str, previous_chunk: Any = None, bidirectional_mask: Any = None, ) -> Array | None: """Generates a combined attention mask for Transformer models. This function constructs an attention mask by potentially combining several types of masks based on the input parameters and model configuration. The generated mask dictates which query-key pairs are allowed to attend to each other. The masking logic can enforce: 1. **Sequence Separation:** Using `decoder_segment_ids`, attention is confined within distinct sequences in a batch. This is crucial when multiple unrelated sequences are packed together. 2. **Causality:** Preventing attention to future positions. This is standard for autoregressive decoding. For chunked prefill, as described in the SARATHI paper [2], causality is adjusted based on `previous_chunk` information. 3. **Specialized Attention Patterns:** Depending on `self.attention_type`, it can apply: * Local Sliding Window Attention: Restricts attention to a fixed-size window around each query position. * Chunk Attention: Divides sequences into chunks and applies masking at the chunk level. 4. **Bidirectional Attention for Sub-sequences:** If `bidirectional_mask` is provided (e.g., for image tokens in a multimodal model), those parts of the sequence can attend bidirectionally, and this mask is OR-ed with other generated masks. The overall approach and specific masking techniques are influenced by efficient attention mechanisms like those found in the Pallas MHA Flash Attention reference [1]. Args: query: The query tensor, typically of shape `[batch_size, q_sequence_length, num_heads, head_dim]`. Used primarily for deriving sequence length. key: The key tensor, typically of shape `[batch_size, kv_sequence_length, num_heads, head_dim]`. Used primarily for deriving sequence length. decoder_segment_ids: Optional `Array` of shape `[batch_size, q_sequence_length]`. Identifies distinct sequences within the batch. Attention is restricted to elements within the same segment ID. In autoregressive mode, specific values (e.g., `common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR`) can mark the currently active sequence for decoding. model_mode: A string (e.g., `common_types.MODEL_MODE_AUTOREGRESSIVE`, `MODEL_MODE_PREFILL`) indicating the operational mode. This significantly influences mask generation, particularly how causality and segment separation are handled. previous_chunk: Optional. Information about previously processed key/value chunks, often a tensor representing the previous keys/values. Used to correctly offset causal masks in chunked attention or streaming scenarios. Its shape might be `[batch_size, prev_kv_sequence_length, ...]`. bidirectional_mask: Optional `Array` of shape `[batch_size, kv_sequence_length]`. If provided, this boolean mask indicates tokens (e.g., image tokens) that are allowed to attend bidirectionally. The resulting block-wise bidirectional mask is combined with other masks using a logical OR. Returns: An `Array` representing the attention mask, with shape `[batch_size, 1, 1, q_sequence_length, kv_sequence_length]`. It is broadcastable to the shape `[batch_size, num_kv_heads, group_size=n_q // n_kv, q_sequence_length, kv_sequence_length]`. Positions with `0.0` allow attention, while positions with `DEFAULT_MASK_VALUE` (a large negative number) prevent it. Returns `None` if no masking is determined to be necessary based on the inputs and configuration. References: [1] JAX Pallas MHA Flash Attention: https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py [2] SARATHI: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills - ArXiv:2308.16369 (https://arxiv.org/abs/2308.16369) """ mask = None if model_mode == MODEL_MODE_AUTOREGRESSIVE and decoder_segment_ids is not None: mask = decoder_segment_ids[:, None, None, None, :] == DECODING_ACTIVE_SEQUENCE_INDICATOR elif decoder_segment_ids is not None: mask = decoder_segment_ids[:, :, None] == decoder_segment_ids[:, None, :] mask = mask[:, None, None, :, :] _, q_seq_len, _, _ = query.shape _, kv_seq_len, _, _ = key.shape next_pos = 0 if previous_chunk is not None: next_pos = previous_chunk.shape[1] if mask is not None: mask = mask[:, :, :, next_pos : next_pos + q_seq_len, :] elif model_mode == MODEL_MODE_AUTOREGRESSIVE and q_seq_len == 1: # In autoregression, the query position is the last position in the KV sequence. next_pos = kv_seq_len - 1 causal_mask = None # We enforce causality except for AUTOREGRESSION if model_mode != MODEL_MODE_AUTOREGRESSIVE and self.attention_type != AttentionType.FULL: mask_shape = (q_seq_len, kv_seq_len) # row_ids indicates the position of query # col_ids indicates the position of kv row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) # Attention mask for chunked prefill is generated in the same way # as mentioned in SARATHI - https://arxiv.org/abs/2308.16369 causal_mask = (col_ids <= row_ids + next_pos)[None, None, None, :, :] output_mask = None if (mask is not None) and (causal_mask is not None): output_mask = jnp.logical_and(mask, causal_mask) elif mask is not None: output_mask = mask elif causal_mask is not None: output_mask = causal_mask if self.attention_type == AttentionType.LOCAL_SLIDING and output_mask is not None: if self.sliding_window_size is None: raise ValueError("Sliding_window_size must be set if Local Sliding attention type") row_ids_sliding = jax.lax.broadcasted_iota(jnp.int32, (q_seq_len, 1), 0) + next_pos col_ids_sliding = jax.lax.broadcasted_iota(jnp.int32, (1, kv_seq_len), 1) sliding_mask = (col_ids_sliding > (row_ids_sliding - self.sliding_window_size)) & ( col_ids_sliding <= row_ids_sliding ) output_mask = sliding_mask * output_mask elif self.attention_type == AttentionType.CHUNK and output_mask is not None: mask_shape = (q_seq_len, kv_seq_len) chunk_mask = _generate_chunk_attention_mask( mask_shape=(q_seq_len, kv_seq_len), chunk_size=self.chunk_attn_window_size, q_offset=next_pos, ) output_mask = chunk_mask * output_mask if bidirectional_mask is not None: image_mask = _make_bidirectional_block_mask(bidirectional_mask) output_mask = output_mask | image_mask[:, None, None, ...] return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None
[docs] def calculate_moba_gate_logic(self, q_item, k_item, q_pos_item): """Computes the block-level MoBA gating intermediates for one batch item. Args: q_item: Query tensor shaped `[q_len, n_q_heads, head_dim]`. k_item: Key tensor shaped `[kv_len, n_kv_heads, head_dim]`. q_pos_item: Absolute query positions shaped `[q_len]`, used to derive the chunk index for each query. For example, during prefill after 128 tokens have been processed `q_pos_item` is `jnp.arange(128, 128 + q_len)`, while in autoregressive decode with a single query token it is `jnp.array([kv_len - 1])`. Returns: `need_attend`, a boolean mask of shape `[n_kv_heads, g, q_len, num_block]` indicating which key blocks each query should attend to. The additional values in the returned tuple are debug intermediates used for logging and diagnostics when inspecting the gating behaviour. """ q_len, n_q_heads, head_dim = q_item.shape kv_len, n_kv_heads, _ = k_item.shape g = n_q_heads // n_kv_heads q_item_f32 = q_item.astype(jnp.float32).reshape(q_len, n_kv_heads, g, head_dim) # grouped-query attention (GQA) moba_chunk_size = self.config.moba_chunk_size moba_topk = self.config.moba_topk num_block = math.ceil(kv_len / moba_chunk_size) block_ids = jnp.arange(kv_len, dtype=jnp.int32) // moba_chunk_size # chunk index for each key position # Sum key vectors per chunk so we can later average within each block. key_gate_weight_sum = jax.ops.segment_sum( k_item.astype(jnp.float32), block_ids, num_segments=num_block ) # [num_block, n_kv_heads, head_dim] # Count how many tokens end up in each chunk so we can take the mean. block_counts = jax.ops.segment_sum( jnp.ones((kv_len,), dtype=jnp.float32), block_ids, num_segments=num_block, ) # [num_block] # Mean Pooling, Avoid division by zero for empty blocks. key_gate_weight = key_gate_weight_sum / jnp.maximum( block_counts[:, None, None], 1 ) # [num_block, n_kv_heads, head_dim] # Take the dot product between each query and every key chunk to get a score. gate = jnp.einsum("skgd,Nkd->kgsN", q_item_f32, key_gate_weight) # [n_kv_heads, g, q_len, num_block] gate_before_masking = gate q_block_idx = q_pos_item // moba_chunk_size # chunk id for each query block_indices = jnp.arange(num_block) # list every key chunk index q_block_idx_b = jnp.expand_dims(q_block_idx, axis=-1) # [q_len, 1] block_indices_b = jnp.expand_dims(block_indices, axis=0) # [1, num_block] # Block-causal masking: a query can't attend to future key blocks, # and must attend to its own key block. mask_future = q_block_idx_b > block_indices_b gate = jnp.where(mask_future, gate, -float("inf")) mask_diag = q_block_idx_b == block_indices_b gate = jnp.where(mask_diag, float("inf"), gate) gate_after_masking = gate k_for_topk = min(moba_topk, num_block) gate_top_k_val, gate_top_k_idx = jax.lax.top_k(gate, k=k_for_topk) # [n_kv_heads, g, q_len, k_for_topk] gate_top_k_val_min = jnp.min(gate_top_k_val, axis=-1, keepdims=True) # [n_kv_heads, g, q_len, 1] need_attend_threshold_mask = gate >= gate_top_k_val_min # [n_kv_heads, g, q_len, num_block] # Tie-breaking: if multiple blocks have the same gate value as the k-th # block, we only select the ones that appear in the top-k indices. gate_idx_mask = jnp.sum( jax.nn.one_hot(gate_top_k_idx, num_block, dtype=jnp.bool_), axis=-2 ) # [n_kv_heads, g, q_len, num_block] need_attend = jnp.logical_and(need_attend_threshold_mask, gate_idx_mask) # [n_kv_heads, g, q_len, num_block] return ( key_gate_weight, gate_before_masking, gate_after_masking, gate_top_k_val, gate_top_k_idx, gate_top_k_val_min, need_attend_threshold_mask, gate_idx_mask, need_attend, # [n_kv_heads, g, q_len, num_block] )
[docs] def generate_moba_mask_single_item(self, q_item, k_item, q_positions): """Generates the token-level MoBA additive mask for a single batch item.""" q_len, _, _ = q_item.shape kv_len, _, _ = k_item.shape moba_chunk_size = self.config.moba_chunk_size # Run the gating logic to find which key blocks this query cares about. *_, need_attend = self.calculate_moba_gate_logic(q_item, k_item, q_positions) # Expand the block-level `need_attend` mask to a token-level mask. k_block_indices = jnp.arange(kv_len, dtype=jnp.int32) // moba_chunk_size token_level_need_attend = need_attend[..., k_block_indices] # Convert the boolean mask to float mask values. gate = jnp.where(token_level_need_attend, 0.0, -float("inf")) # Apply a final per-token causal mask to ensure causality within chunks. k_indices = jax.lax.broadcasted_iota(jnp.int32, (q_len, kv_len), 1) q_indices = q_positions[:, None] causal_mask = q_indices >= k_indices gate = jnp.where(causal_mask, gate, -float("inf")) # Return the additive mask for this batch item. return gate
def _generate_moba_mask(self, query: Array, key: Array, q_positions: Array) -> Array: """Builds the token-level MoBA additive mask for the whole batch. Args: query: Query tensor shaped `[batch, q_len, n_q_heads, head_dim]`. key: Key tensor shaped `[batch, kv_len, n_kv_heads, head_dim]`. q_positions: Absolute query positions shaped `[q_len]`, shared across the batch, identifying the starting offset of each query token. For example, in prefill after 128 tokens we pass `jnp.arange(128, 128 + q_len)`, while in autoregressive decode with a single new token the vector is `[kv_len - 1]` for each batch element. Returns: Additive attention mask with shape `[batch, n_kv_heads, n_q_heads // n_kv_heads, q_len, kv_len]` containing `0.` for permitted positions and `-inf` for masked ones. """ # vmap over the batch dimension of query and key. q_positions is constant across the batch. moba_mask = jax.vmap(self.generate_moba_mask_single_item, in_axes=(0, 0, None))(query, key, q_positions) return moba_mask
[docs] def apply_attention( self, query: Array, key: Array | KVTensor, value: Array | KVTensor, decoder_segment_ids: Array | None, segment_positions: Array | None, lengths: Array | None, model_mode: str, use_ragged_attention: bool = False, previous_chunk: Any = None, bidirectional_mask: Any = None, sinks: Array | None = None, indexer_mask: Array | None = None, record_max_logits: bool = False, *, qk_product_einsum: Callable[..., Array], wv_product_einsum: Callable[..., Array], ): """Apply attention""" self.check_attention_inputs(query, key, value) length = query.shape[-3] target_hardware = self.mesh.devices[(0,) * self.mesh.devices.ndim].platform if use_ragged_attention and model_mode == MODEL_MODE_AUTOREGRESSIVE: if lengths is None: lengths = jnp.sum(decoder_segment_ids, axis=-1) if target_hardware == "tpu": impl = self.tpu_ragged_attention elif target_hardware == "gpu": impl = self.gpu_ragged_attention else: raise NotImplementedError(target_hardware) local_out, local_max, local_sum = impl(query, key, value, lengths, self.ragged_block_size) if record_max_logits: self.max_logits = nnx.Intermediate(local_max) return local_out, local_max, local_sum # 'vllm_rpa' uses the same dot-attention wrapper but routes to the vLLM # ragged paged attention kernel in `Attention.__call__`. elif ( self.attention_kernel == "dot_product" or (self.attention_kernel == "autoselected" and model_mode == MODEL_MODE_AUTOREGRESSIVE) or (self.attention_kernel == "autoselected" and length < 128) or (self.attention_kernel == "paged") or (self.attention_kernel == "vllm_rpa") ): return self.apply_attention_dot( query, key, value, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask=bidirectional_mask, sinks=sinks, indexer_mask=indexer_mask, record_max_logits=record_max_logits, qk_product_einsum=qk_product_einsum, wv_product_einsum=wv_product_einsum, ) elif self.attention_kernel in ("flash", "autoselected"): if target_hardware == "tpu": if isinstance(key, KVTensor): key = key.dequant() if isinstance(value, KVTensor): value = value.dequant() if model_mode == MODEL_MODE_AUTOREGRESSIVE: raise ValueError( """Decode not supported with flash attention. Use `dot_product` instead.""" ) out, max_logits = self.tpu_flash_attention( query, key, value, decoder_segment_ids, self.attn_logits_soft_cap, sinks, record_max_logits=record_max_logits, ) if max_logits is not None: self.max_logits = nnx.Intermediate(max_logits) return out, None, None else: if model_mode == MODEL_MODE_AUTOREGRESSIVE: # fallback to dot_product as pallas gpu flash attention doesn't support decode stage return self.apply_attention_dot( query, key, value, decoder_segment_ids, model_mode, bidirectional_mask=bidirectional_mask, record_max_logits=record_max_logits, qk_product_einsum=qk_product_einsum, wv_product_einsum=wv_product_einsum, ) else: validate_gpu_flash_attention(sinks, record_max_logits) head_axis = -2 num_query_heads = query.shape[head_axis] num_kv_heads = key.shape[head_axis] if num_query_heads != num_kv_heads: # Handle cases where the number of query heads is different from the number of key/value heads. if num_query_heads % num_kv_heads != 0: raise ValueError( f"Number of query heads ({num_query_heads}) must be divisible" f" by number of key/value heads ({num_kv_heads})." ) # TODO Investigate if the KV copy can be eliminated. It's likely redundant. q_heads_per_kv_head = num_query_heads // num_kv_heads key = jnp.repeat( key, q_heads_per_kv_head, axis=head_axis ) # key shape [batch_size, kv_seq_len, num_kv_heads, head_dim] value = jnp.repeat( value, q_heads_per_kv_head, axis=head_axis ) # value shape [batch_size, kv_seq_len, num_kv_heads, head_dim] out = gpu_pallas_attention.mha(query, key, value, decoder_segment_ids, sm_scale=1.0, causal=True) return out, None, None elif self.attention_kernel == "cudnn_flash_te": validate_gpu_flash_attention(sinks, record_max_logits) if isinstance(key, KVTensor): key = key.dequant() if isinstance(value, KVTensor): value = value.dequant() if model_mode == MODEL_MODE_AUTOREGRESSIVE: raise ValueError( """Decode not supported with flash attention. Use `dot_product` instead.""" ) return ( self.cudnn_flash_attention(query, key, value, decoder_segment_ids, segment_positions, model_mode), None, None, ) elif self.attention_kernel == "cudnn_flash_jax": validate_gpu_flash_attention(sinks, record_max_logits) if isinstance(key, KVTensor): key = key.dequant() if isinstance(value, KVTensor): value = value.dequant() return ( *self.cudnn_jax_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, ) else: raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.")
[docs] def gpu_ragged_attention( self, q: Array, k: Array | KVTensor, v: Array | KVTensor, lengths: Array, block_size: int, ): """gpu ragged attention""" batch_size, q_length, q_heads, head_dim = q.shape # Reshape q to match gqa's expected shape q_for_gqa = q.squeeze(axis=1) # Define logical axis names - clearer and avoids repeated calls. b = self._logical_to_mesh_axes(self.ragged_lengths_names) bsnd = self._logical_to_mesh_axes(self.cache_logical_axis_names) bnd = self._logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS, CACHE_KV)) bn = self._logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS)) @functools.partial( jax.shard_map, mesh=self.mesh, in_specs=(bnd, bsnd, bsnd, b, None), out_specs=(bnd, bn, bn), check_vma=False, ) def wrap_ragged_attention( q: Array, k: Array, v: Array, lengths: Array, block_size: int ) -> Tuple[Array, Array, Array]: # Use the original gqa function to get the attention output """Wraps the GQA function with appropriate sharding. Args: q: Query tensor. k: Key tensor. v: Value tensor. lengths: Sequence lengths. block_size: Block size for attention. Returns: A tuple containing the output, max, and sum tensors. """ # Use the original gqa function to get the attention output local_out, (local_sum, local_max) = gpu_pallas_decode_attention.gqa( q=q, k=k, v=v, kv_seq_len=lengths, block_k=block_size, sm_scale=1.0, return_residuals=True, normalize_output=False, ) return local_out, local_max, local_sum local_out, local_max, local_sum = wrap_ragged_attention(q_for_gqa, k, v, lengths, block_size) # Reshape local_out, local_max and local_sum to match Maxtext requirements local_out = local_out.reshape(batch_size, q_length, q_heads, head_dim) local_max = local_max.reshape(batch_size, q_length, q_heads, 1) local_sum = local_sum.reshape(batch_size, q_length, q_heads, 1) return local_out, local_max, local_sum
[docs] def tpu_ragged_attention( self, query: Array, key: Array | KVTensor, value: Array | KVTensor, lengths: Array, block_size: int, ) -> tuple[Array, Array, Array]: """Ragged Attention.""" if isinstance(query, KVTensor): raise TypeError("Ragged attention does not currently support quantized tensors.") b = self._logical_to_mesh_axes(self.ragged_lengths_names) bsnd = self._logical_to_mesh_axes(self.cache_logical_axis_names) @functools.partial( jax.shard_map, mesh=self.mesh, in_specs=( bsnd, bsnd, bsnd, b, None, ), out_specs=bsnd, check_vma=False, ) def wrap_ragged_attention(query, key, value, lengths, block_size): if query.shape[-2] == key.shape[-2]: return ragged_mha(query, key, value, lengths, block_size=block_size) else: return ragged_gqa(query, key, value, lengths, block_size=block_size) return wrap_ragged_attention(query, key, value, lengths, block_size)
[docs] def tpu_flash_attention( self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, attn_logits_soft_cap: float | None = None, sinks: Array | None = None, indexer_mask: Array | None = None, record_max_logits: bool = False, ) -> tuple[Array, Array]: """TPU Flash Attention.""" cp_size = self.config.context_parallel_size load_balanced_context_parallel = self.config.context_parallel_load_balance # Transpose to ('batch', 'heads', 'length', 'kv') query = jnp.transpose(query, axes=(0, 2, 1, 3)) key = jnp.transpose(key, axes=(0, 2, 1, 3)) value = jnp.transpose(value, axes=(0, 2, 1, 3)) segment_axis_names_q = None segment_axis_names_kv = None sink_axis_names = self._logical_to_mesh_axes((HEAD,)) if decoder_segment_ids is not None: segment_axis_names_q = self._logical_to_mesh_axes((BATCH_ATTN, Q_LENGTH)) segment_axis_names_kv = self._logical_to_mesh_axes((BATCH_ATTN, KV_LENGTH)) axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_ATTN, Q_LENGTH, KV_LENGTH)) global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel global global_q_layout, global_k_layout, global_v_layout global_block_q = self.config.sa_block_q global_block_kv = self.config.sa_block_kv global_block_kv_compute = self.config.sa_block_kv_compute global_block_q_dkv = self.config.sa_block_q_dkv global_block_kv_dkv = self.config.sa_block_kv_dkv global_block_kv_dkv_compute = self.config.sa_block_kv_dkv_compute global_block_q_dq = self.config.sa_block_q_dq global_block_kv_dq = self.config.sa_block_kv_dq global_use_fused_bwd_kernel = self.config.sa_use_fused_bwd_kernel global_q_layout = self.config.sa_q_layout global_k_layout = self.config.sa_k_layout global_v_layout = self.config.sa_v_layout devices_in_data_fsdp = self.mesh.shape.get("data", 1) * self.mesh.shape.get("fsdp", 1) assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( "Batch dimension should be shardable among the devices in data and fsdp" " axis" f" got {query.shape[0]=}/{devices_in_data_fsdp=}" ) # create_splash_attention config def create_sa_config(config, query, key, attn_logits_soft_cap): if config.use_tokamax_splash: sa_config = tokamax_splash_kernel.SplashConfig( block_q=min(global_block_q, query.shape[2]), block_kv=min(global_block_kv, key.shape[2]), block_kv_compute=min(global_block_kv_compute, key.shape[2]), block_q_dkv=min(global_block_q_dkv, query.shape[2]), block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), use_fused_bwd_kernel=True, # tokamax only supports fused bwd kernel q_layout=tokamax_splash_kernel.QKVLayout[global_q_layout], k_layout=tokamax_splash_kernel.QKVLayout[global_k_layout], v_layout=tokamax_splash_kernel.QKVLayout[global_v_layout], attn_logits_soft_cap=attn_logits_soft_cap, residual_checkpoint_name="context", fwd_cost_estimate=pl.CostEstimate( flops=config.cost_estimate_flops_fwd, transcendentals=0, bytes_accessed=0, ) if config.cost_estimate_flops_fwd >= 0 else None, bwd_cost_estimate=pl.CostEstimate( flops=config.cost_estimate_flops_bwd, transcendentals=0, bytes_accessed=0, ) if config.cost_estimate_flops_bwd >= 0 else None, dq_reduction_steps=config.dq_reduction_steps if config.dq_reduction_steps > 0 else None, use_experimental_scheduler=config.use_splash_scheduler, ) else: sa_config = splash_attention_kernel.BlockSizes( block_q=min(global_block_q, query.shape[2]), block_kv=min(global_block_kv, key.shape[2]), block_kv_compute=min(global_block_kv_compute, key.shape[2]), block_q_dkv=min(global_block_q_dkv, query.shape[2]), block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]), block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]), use_fused_bwd_kernel=global_use_fused_bwd_kernel, q_layout=splash_attention_kernel.QKVLayout[global_q_layout], k_layout=splash_attention_kernel.QKVLayout[global_k_layout], v_layout=splash_attention_kernel.QKVLayout[global_v_layout], ) return sa_config sa_config = create_sa_config(self.config, query, key, attn_logits_soft_cap) mask_shape = (query.shape[2], key.shape[2]) # (q_seq_len, kv_seq_len) mask_module = tokamax_splash_mask if self.config.use_tokamax_splash else splash_attention_mask if self.attention_type == AttentionType.FULL: mask = mask_module.FullMask(mask_shape) else: mask = mask_module.CausalMask(shape=mask_shape) # Create LoadBalancedCausalMask if cp and load_balancing if cp_size > 1 and load_balanced_context_parallel: mask = LoadBalancedCausalMask(shape=mask_shape, cp_size=cp_size) # TODO: figure out local_sliding attention + load_balancing, default is global # Apply local masking if local sliding attention is enabled. if self.attention_type == AttentionType.LOCAL_SLIDING: if self.sliding_window_size is None: raise ValueError("Sliding_window_size must be set if Local Sliding attention type") mask &= mask_module.LocalMask( shape=(query.shape[2], key.shape[2]), window_size=(self.sliding_window_size, self.sliding_window_size), offset=0, ) elif self.attention_type == AttentionType.CHUNK: if self.chunk_attn_window_size is None: raise ValueError("chunk_attn_window_size must be set for chunk attention type") mask &= ChunkedCausalMask( shape=(query.shape[2], key.shape[2]), chunk_size=self.chunk_attn_window_size, ) max_logit_value = None if self.config.use_tokamax_splash: # Create mask single_head_mask = mask # tokamax now just uses a single mask and assumes broadcast to all heads if self.config.use_max_logit_estimate > 0: sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate) # Create the splash attention kernel object separately, jit it for performance @partial( jax.jit, static_argnames=[ "single_head_mask", ], ) def wrap_splash_kernel(single_head_mask): splash_kernel = tokamax_splash_kernel.make_splash_mha( mask=single_head_mask, config=sa_config, q_seq_shards=cp_size, # axis for sequence sharding, ) return splash_kernel splash_kernel = wrap_splash_kernel(single_head_mask) segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,)) splash_kernel = self._maybe_shard_with_pspec(splash_kernel, segment_axis_names_splash_kernel) elif self.config.use_jax_splash: if self.config.use_max_logit_estimate > 0: sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate) segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,)) else: # Create multi-head mask multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) # Create the splash attention kernel object separately, jit it for performance @partial( jax.jit, static_argnames=[ "multi_head_mask", "shard_head_size", ], ) def wrap_splash_kernel(multi_head_mask, shard_head_size=1): splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, head_shards=shard_head_size, # the size of the axis if sharding over heads q_seq_shards=cp_size, # axis for sequence sharding block_sizes=sa_config, attn_logits_soft_cap=attn_logits_soft_cap, residual_checkpoint_name="context", ) return splash_kernel head_physical_axes = logical_to_mesh_axes((HEAD,), self.mesh)[0] head_physical_axes = (head_physical_axes,) if isinstance(head_physical_axes, str) else (head_physical_axes or ()) shard_head_size = math.prod(self.mesh.shape.get(ax, 1) for ax in head_physical_axes) splash_kernel = wrap_splash_kernel(multi_head_mask, shard_head_size) named_sharding = jax.sharding.NamedSharding(self.mesh, axis_names_splash_kernel) segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) splash_kernel = jax.tree.map( lambda arr, spec: None if arr is None else self._maybe_shard_with_pspec(arr, spec), splash_kernel, segment_axis_names_splash_kernel, is_leaf=lambda x: x is None, ) # Now call the function wrap_flash_attention which does the actual computation. # The splash kernel is passed as a parameter to the function. Since we have the shard map # decorating the wrap_flash_attention function, the data will be correctly sharded # meaning q will be sharded over sequence aka context length but K and V will be duplicated # The shardings are specified in the in_specs and out_specs of the shard_map decorator: # 'segment_axis_names_q' maps to ['activation_q_length', ['context']] meaning that q is sharded over the context axis # 'segment_axis_names_kv' maps to ['activation_kv_length', []] meaning that K and V are not sharded # splash_kernel is sharded over (HEAD, LENGTH) if record_max_logits: # max_logits will share similar sharding as query but last dim is unrelated to model # Using None for the last dimension of max_logits sharding if isinstance(axis_names_q, jax.sharding.PartitionSpec): # max_logits is Rank 3 (batch, heads, seq), so drop the last dimension (d_kv) max_logits_spec = jax.sharding.PartitionSpec(*axis_names_q[:-1]) else: # Fallback if axis_names_q is not a PartitionSpec (unlikely in this context) max_logits_spec = axis_names_q[:-1] out_specs = (axis_names_q, max_logits_spec) else: # Safely pad the specs with None to match the (attention_output, None) return type out_specs = (axis_names_q, None) @functools.partial( jax.shard_map, mesh=self.mesh, in_specs=( axis_names_q, axis_names_kv, axis_names_kv, segment_axis_names_q, segment_axis_names_kv, None, # no sharding for config segment_axis_names_splash_kernel, None, # no sharding for cp_size None, # no sharding for load_balanced_context_parallel sink_axis_names, # sharding align with query heads indexer_mask_axis_names, ), out_specs=out_specs, check_vma=False, ) def wrap_flash_attention( query, key, value, decoder_segment_ids_q, decoder_segment_ids_kv, sa_config, splash_kernel, cp_size, load_balanced_context_parallel, sinks, indexer_mask, ): # If load_balanced_context_parallel is enabled, reorder the key and value tensors # to ensure that they are contiguous in memory. # This is necessary for the splash attention kernel to work correctly because it expects # the K and V to be contiguous. Note that K and V are not sharded over the sequence aka context axis # This was we get the unsharded unpermuted key and value tensors if cp_size > 1 and load_balanced_context_parallel: key = max_utils.reorder_sequence(tensor=key, cp_size=cp_size, seq_dim=2, to_contiguous=True) value = max_utils.reorder_sequence(tensor=value, cp_size=cp_size, seq_dim=2, to_contiguous=True) decoder_segment_ids_unpermuted = max_utils.reorder_sequence( tensor=decoder_segment_ids_kv, cp_size=cp_size, seq_dim=1, to_contiguous=True, ) if decoder_segment_ids_q is not None: if cp_size > 1 and load_balanced_context_parallel: decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds( decoder_segment_ids_q, decoder_segment_ids_unpermuted ) else: # if cp=1, decoder_segment_ids_q is the same as decoder_segment_ids_kv decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(decoder_segment_ids_q, decoder_segment_ids_kv) else: decoder_segment_ids_tuple = None if self.config.use_tokamax_splash: if self.config.use_indexer and indexer_mask is not None: # Construct the splash kernel call with dynamic mask def dynamic_mask_splash_kernel(q, k, v, segment, sinks, indexer_mask): splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha( mask=indexer_mask, config=sa_config, ) kernel = partial(splash_kernel, max_logit_value=max_logit_value) if record_max_logits: out, stats = kernel(q, k, v, segment, sinks=sinks, save_residuals=True) return out, stats["max_logits"] else: return kernel(q, k, v, segment, sinks=sinks), None # Iterate over batch dimension for (query, key, value, segment, sinks, mask) attn_fn = jax.vmap(dynamic_mask_splash_kernel, (0, 0, 0, 0, None, 0)) indexer_mask = jnp.isclose(indexer_mask, 0.0) if record_max_logits: attention_output, max_logits = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, indexer_mask) return attention_output, max_logits else: attention_output, _ = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, indexer_mask) return attention_output, None else: kernel = partial(splash_kernel, max_logit_value=max_logit_value) if record_max_logits: def kernel_fn(q, k, v, d, s): # Pass save_residuals=True to force stats generation out, stats = kernel(q, k, v, d, sinks=s, save_residuals=True) return out, stats["max_logits"] attention_output, max_logits = jax.vmap(kernel_fn, in_axes=(0, 0, 0, 0, None))( query, key, value, decoder_segment_ids_tuple, sinks ) return attention_output, max_logits else: attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))( query, key, value, decoder_segment_ids_tuple, sinks ) return attention_output, None elif self.config.use_jax_splash: materialized_mask = jnp.asarray(mask[:, :]) attention_output = jax_flash_attention.flash_attention_block_masked( query, key, value, decoder_segment_ids_tuple, block_kv=self.config.sa_block_kv, block_q=self.config.sa_block_q, mask=materialized_mask, mask_value=DEFAULT_MASK_VALUE, ) if record_max_logits: # The native JAX splash attention implementation does not currently expose the softmax statistics # (e.g., max_logits) required for QK-Clip. Use tokamax splash attention if max logit recording is needed. raise NotImplementedError("record_max_logits not supported for jax_splash") else: attention_output = jax.vmap(splash_kernel, in_axes=(0, 0, 0, 0, None))( query, key, value, decoder_segment_ids_tuple, sinks ) if record_max_logits: raise NotImplementedError("record_max_logits not supported for legacy splash") return attention_output, None query = self._maybe_shard_with_pspec(query, axis_names_q) key = self._maybe_shard_with_pspec(key, axis_names_kv) value = self._maybe_shard_with_pspec(value, axis_names_kv) decoder_segment_ids_q = self._maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q) decoder_segment_ids_kv = self._maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv) sinks = self._maybe_shard_with_pspec(sinks, sink_axis_names) indexer_mask = self._maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names) ret = wrap_flash_attention( query, key, value, decoder_segment_ids_q, decoder_segment_ids_kv, sa_config, None if self.config.use_jax_splash else splash_kernel, cp_size, load_balanced_context_parallel, sinks, indexer_mask, ) x, max_logits = ret x = jnp.transpose(x, axes=(0, 2, 1, 3)) if record_max_logits: # Max over sequence length (dim 2 of max_logits) # max_logits from kernel is (batch, heads, q_len) # output needs to be (batch, heads) # Note: q_len is sharded. We first reduce locally. max_logits_local = jnp.max(max_logits, axis=2) return x, max_logits_local return x, None
[docs] def cudnn_flash_attention( self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, segment_positions: Array | None, model_mode: str = MODEL_MODE_TRAIN, ) -> Array: """CUDNN Flash Attention with Transformer Engine. 1. Stable API, supports MHA, GQA, SWA, Packing and Context Parallelism 2. Context Parallelism currently only supports causal masking 3. Only Ring attention has packing support with striped load balancing (context_parallel_strategy="ring" and context_parallel_load_balance=true) 4. Breaks with TE 2.12 and 2.13 (known bug); works with TE stable release <=2.11 or >=2.14. """ # These imports are only meant to work in a GPU build. # pylint: disable=import-outside-toplevel from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error from transformer_engine.jax.attention import SequenceDescriptor # pytype: disable=import-error _, _, _, head_dim = query.shape # pylint: disable=unused-variable using_context_parallelism = self.mesh.shape[self.config.context_sharding] > 1 using_load_balanced_ring_cp = ( using_context_parallelism and self.config.context_parallel_strategy == "ring" and self.config.context_parallel_load_balance ) # Initialize default attention configuration sliding_window_size = None mask_type = "padding_causal" qkv_layout = "BSHD_BSHD_BSHD" # Non-packed format: 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' max_segments_per_seq = 1 # max number of segments per sequence; for non-packed its 1 # Handle local sliding window attention if configured if self.attention_type == AttentionType.LOCAL_SLIDING: sliding_window_size = [self.sliding_window_size, 0] # Handle packing configurations if self.config.packing and self.config.dataset_type != "synthetic": if using_context_parallelism and not using_load_balanced_ring_cp: raise ValueError("Packing is only supported for load balanced ring attention with context parallelism.") qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' if decoder_segment_ids is None: decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) attn_mask = SequenceDescriptor.from_segment_ids_and_pos( segment_ids=decoder_segment_ids, segment_pos=segment_positions ) # Create dummy SequenceDescriptor for lazy_init dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos( segment_ids=dummy_segment_ids, segment_pos=segment_positions ) max_segments_per_seq = self.config.max_segments_per_seq elif using_context_parallelism: if self.attention_type == AttentionType.LOCAL_SLIDING: raise AssertionError( "Sliding window attention requires context parallelism with load-balanced ring strategy " "and packing enabled." ) # Context parallelism without packing: only supports causal masking, but not sliding window attention attn_mask = None dummy_attn_mask = None mask_type = "causal" elif model_mode == MODEL_MODE_PREFILL and self.config.attention_kernel == "cudnn": # Prefill with CUDNN attention does not support packing or context parallelism. attn_mask = None dummy_attn_mask = None mask_type = "causal" else: # Default case: no packing, no context parallelism dummy_attn_mask = jnp.zeros( (1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8, ) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) dpa_layer = DotProductAttention( head_dim=head_dim, num_attention_heads=self.num_query_heads, num_gqa_groups=self.num_kv_heads, attn_mask_type=mask_type, # 'no_mask', 'padding', 'causal', or 'padding_causal' attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' attention_dropout=self.dropout_rate, dropout_rng_name="aqt", dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, scale_factor=1.0, transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis=self.config.context_sharding, context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, ) dpa_layer = nnx_wrappers.ToNNX(dpa_layer, rngs=self.rngs) dummy_query_prefill = jnp.zeros( (1, self.max_target_length, self.num_query_heads, self.config.head_dim), dtype=self.dtype, ) dummy_key_prefill = jnp.zeros( (1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype, ) dummy_value_prefill = jnp.zeros( (1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype, ) dpa_layer.lazy_init( dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, sequence_descriptor=dummy_attn_mask, ) return dpa_layer(query, key, value, sequence_descriptor=attn_mask)
[docs] def cudnn_jax_flash_attention( self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, model_mode: str = MODEL_MODE_TRAIN, ) -> tuple[Array, Array]: """CUDNN Flash Attention with JAX SDPA API.""" # These imports are only meant to work in a GPU build. # pylint: disable=import-outside-toplevel from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention, MaskType, ) _, _, _, head_dim = query.shape # pylint: disable=unused-variable if model_mode == MODEL_MODE_AUTOREGRESSIVE: lengths = jnp.sum(decoder_segment_ids, axis=-1) output, lse = dot_product_attention( query, key, value, q_seqlen=lengths, kv_seqlen=lengths, mask_type=MaskType.PADDING, scale=1.0, dropout_rate=self.dropout_rate, qkv_layout="BTNH", return_residual=True, ) else: output, lse = dot_product_attention( query, key, value, mask_type=MaskType.CAUSAL, scale=1.0, dropout_rate=self.dropout_rate, qkv_layout="BTNH", return_residual=True, ) output = checkpoint_name(output, "context") lse = checkpoint_name(lse, "context") return output, lse
[docs] def compute_local_attention( self, attn_weights: Array, value: Array | KVTensor, q_seq_len: int, model_mode: str, wv_product_einsum: Callable[..., Array], sinks: Array | None = None, ) -> tuple[Array, Array, Array]: """Computes the attention of a local subset of the kv cache. Local attention results will need to be combined with any other local attentions and normalized Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py Args: attn_weights (Array): Product of query and key value (Array): Current value aqt_rng (PRNGKey | None): Optional rng Returns: (local_out, local_max,): where local_out is local unnormalized output local_max is the local max of exponentials local_sum is the sum of exponentials for this chunk, divided by exp(local_max). """ b, n_kv, g, t, s = attn_weights.shape n_q = n_kv * g logits = jnp.reshape(attn_weights, (b, n_q, t, s)) if sinks is not None: # broadcast sinks to match the attn weights dimension and combine sinks_param = sinks.astype(attn_weights.dtype) # (n_q,) sinks_logits = sinks_param[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] # (1, n_q, 1, 1) sinks_logits = jnp.broadcast_to(sinks_logits, (b, n_q, t, 1)) logits = jnp.concatenate([logits, sinks_logits], axis=-1) # softmax local_max = jnp.max(logits, axis=-1, keepdims=True) local_exps_combined = jnp.exp(logits - local_max) local_sum = jnp.sum(local_exps_combined, axis=-1, keepdims=True) # reshape and transpose local_exps = local_exps_combined[..., :s] local_exps = jnp.reshape(local_exps, (b, n_kv, g, t, s)) local_max = jnp.transpose(local_max, (0, 2, 1, 3)) # (b, t, n_q, 1) local_sum = jnp.transpose(local_sum, (0, 2, 1, 3)) # (b, t, n_q, 1) local_out = self.wv_product(local_exps, value, model_mode, wv_product_einsum) if model_mode == MODEL_MODE_AUTOREGRESSIVE and self.is_partition_in_decode(q_seq_len): local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) elif model_mode == MODEL_MODE_PREFILL: local_out = partitioning.with_sharding_constraint(local_out, (BATCH_ATTN, KV_LENGTH, HEAD, D_KV)) if self.reshape_q and q_seq_len == 1: local_max = local_max[:, 0:1, :, :] local_sum = local_sum[:, 0:1, :, :] local_out = local_out[:, 0:1, :, :] if model_mode == MODEL_MODE_AUTOREGRESSIVE and self.is_partition_in_decode(q_seq_len): local_max = partitioning.with_sharding_constraint(local_max, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) local_sum = partitioning.with_sharding_constraint(local_sum, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) return local_out, local_max, local_sum
[docs] def is_partition_in_decode(self, seq_len): return self.config.ici_context_autoregressive_parallelism > 0 and seq_len == 1
[docs] def apply_attention_dot( self, query: Array, key: Array | KVTensor, value: Array | KVTensor, decoder_segment_ids: Array | None, model_mode: str = MODEL_MODE_TRAIN, previous_chunk: Any = None, bidirectional_mask: Any = None, sinks: Array | None = None, indexer_mask: Array | None = None, record_max_logits: bool = False, *, qk_product_einsum: Callable[..., Array], wv_product_einsum: Callable[..., Array], ): """Apply Attention.""" validate_compute_axis_order(self.compute_axis_order) # Casting qk_product and softmaxt computation for float32 for model stability. if self.float32_qk_product: if isinstance(key, KVTensor): key = key.dequant() query = query.astype(jnp.float32) key = key.astype(jnp.float32) # special sharding for decode q_seq_len = query.shape[1] prefill_qkv_sharding = (BATCH_ATTN, PREFILL_LENGTH, HEAD, D_KV) decode_qkv_sharding = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV) if self.is_partition_in_decode(q_seq_len): query = partitioning.with_sharding_constraint(query, decode_qkv_sharding) # avoid sharding scale tensor when using kv cache quantization if self.kv_quant and isinstance(key, KVTensor) and isinstance(value, KVTensor): key.qvalue = partitioning.with_sharding_constraint(key.qvalue, decode_qkv_sharding) value.qvalue = partitioning.with_sharding_constraint(value.qvalue, decode_qkv_sharding) else: key = partitioning.with_sharding_constraint(key, decode_qkv_sharding) value = partitioning.with_sharding_constraint(value, decode_qkv_sharding) elif model_mode == MODEL_MODE_PREFILL: query = partitioning.with_sharding_constraint(query, prefill_qkv_sharding) # avoid sharding scale tensor when using kv cache quantization if self.kv_quant and isinstance(key, KVTensor) and isinstance(value, KVTensor): key.qvalue = partitioning.with_sharding_constraint(key.qvalue, prefill_qkv_sharding) value.qvalue = partitioning.with_sharding_constraint(value.qvalue, prefill_qkv_sharding) else: key = partitioning.with_sharding_constraint(key, prefill_qkv_sharding) value = partitioning.with_sharding_constraint(value, prefill_qkv_sharding) attn_weights = self.qk_product(query, key, q_seq_len, model_mode, qk_product_einsum) if self.is_partition_in_decode(q_seq_len): attn_weights = partitioning.with_sharding_constraint(attn_weights, (KV_LENGTH, HEAD, None, None, None)) elif model_mode == MODEL_MODE_PREFILL: attn_weights = partitioning.with_sharding_constraint( attn_weights, (BATCH_ATTN, HEAD, None, PREFILL_LENGTH, KV_LENGTH) ) if self.attn_logits_soft_cap: attn_weights = jnp.tanh(attn_weights / self.attn_logits_soft_cap) attn_weights = attn_weights * self.attn_logits_soft_cap # Casting softmaxt computation for float32 for model stability. if self.float32_logits: attn_weights = attn_weights.astype(jnp.float32) attn_mask = self.generate_attention_mask( query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask, ) if self.config.moba: kv_seq_len = key.shape[1] # This logic for `next_pos` is duplicated from `generate_attention_mask`. # It determines the starting position of the query sequence. next_pos = 0 if previous_chunk is not None: next_pos = previous_chunk.shape[1] elif model_mode == MODEL_MODE_AUTOREGRESSIVE and q_seq_len == 1: next_pos = kv_seq_len - 1 q_positions = jnp.arange(next_pos, next_pos + q_seq_len) # The gate calculation in MoBA uses the unscaled query. # With scaled query, the gate values are scaled, but since the top-k selection # is scale-invariant, we can use the scaled query directly. moba_mask = self._generate_moba_mask(query, key, q_positions) attn_weights += moba_mask # Apply index mask, deepseek sparse attention # index mask contains 0.0 for kept tokens and large negative for masked tokens. if indexer_mask is not None: # indexer_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len] indexer_mask = indexer_mask[:, None, None, :, :] # attn_weights: [b, n_kv, n_q // n_kv, q_len, kv_len] attn_weights = apply_mask_to_logits(attn_weights, indexer_mask) if self.is_partition_in_decode(q_seq_len): attn_mask = partitioning.with_sharding_constraint(attn_mask, (KV_LENGTH, HEAD, None, None, None)) elif model_mode == MODEL_MODE_PREFILL: attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH_ATTN, HEAD, None, PREFILL_LENGTH, KV_LENGTH)) if attn_mask is not None: attn_weights = apply_mask_to_logits(attn_weights, attn_mask) # We record max logits AFTER soft-capping and masking to match Flash/Splash attention behavior. if record_max_logits: # attn_weights shape: [b, n_kv, g, t, s] # Max over t (query len) and s (key len) # Result shape: [b, n_kv, g] -> reshape to [b, n_heads] # Note: Masked values are large negatives (DEFAULT_MASK_VALUE), so max() correctly ignores them. max_logits_per_group = jnp.max(attn_weights, axis=(-2, -1)) b, n_kv, g = max_logits_per_group.shape max_logits = max_logits_per_group.reshape(b, n_kv * g) self.max_logits = nnx.Intermediate(max_logits) return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode, wv_product_einsum, sinks)
[docs] def qk_product( self, query: Array, key: Array | KVTensor, q_seq_len: int, model_mode: str, einsum: Callable[..., Array], ) -> Array: """Query-Key product. Args: query: Query projection, in shape of [b, t, n, d] key: Key projection in shape of [b, s, n_kv, d] Returns: results in shape [b, n_kv, n // n_kv, t, s]. Annotations: b: batch size t: query length s: key / value length d: head / kv dimension n: number of query heads n_kv: number of kv heads, sometimes annotated as k n // n_kv: number of group for query, sometimes annotated with g """ b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads precision_kwargs = {"precision": self.config.matmul_precision} if einsum is jnp.einsum else {} if model_mode == MODEL_MODE_TRAIN or self.compute_axis_order == ( 0, 1, 2, 3, ): query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) if self.reshape_q and q_seq_len == 1: query = jnp.broadcast_to(query, (b, 2, n_kv, n // n_kv, d)) result = einsum("btkgd,bskd->bkgts", query, key, **precision_kwargs) elif self.compute_axis_order == (0, 2, 1, 3): query = jnp.transpose(query, axes=self.compute_axis_order) key = jax.tree.map(lambda x: jnp.transpose(x, axes=self.compute_axis_order), key) query = jnp.reshape(query, (b, n_kv, n // n_kv, t, d)) if self.reshape_q and q_seq_len == 1: query = jnp.broadcast_to(query, (b, n_kv, n // n_kv, 2, d)) result = einsum("bkgtd,bksd->bkgts", query, key, **precision_kwargs) else: raise NotImplementedError(self.compute_axis_order) return result
[docs] def wv_product( self, attn_weights: Array, value: Array | KVTensor, model_mode: str, einsum: Callable[..., Array], ) -> Array: """weighted value product. Args: attn_weights: Computed results of qk_einsum, in shape [b, n_kv, n // n_kv, t, s] value: Value projection, in shape of [b, s, n_kv, d] Returns: result in shape [b, t, n, d] Annotations: b: batch size t: query length s: key / value length d: head / kv dimension n: number of query heads n_kv: number of kv heads, sometimes annotated as k n // n_kv: number of group for query, sometimes annotated with g """ precision_kwargs = {"precision": self.config.matmul_precision} if einsum is jnp.einsum else {} if self.kv_quant: # manually cast to bf16 to avoid the fp32 XLA ops for speedup if isinstance(value, KVTensor) and self.kv_quant.dtype == jnp.float8_e4m3fn: value.qvalue = value.qvalue.astype(jnp.bfloat16) if model_mode == MODEL_MODE_TRAIN or self.compute_axis_order == ( 0, 1, 2, 3, ): out = einsum("bkgts,bskd->btkgd", attn_weights, value, **precision_kwargs) b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) elif self.compute_axis_order == (0, 2, 1, 3): value = jax.tree.map(lambda x: jnp.transpose(x, axes=self.compute_axis_order), value) out = einsum("bkgts,bksd->bkgtd", attn_weights, value, **precision_kwargs) b, n_kv, g, t, d = out.shape result = jnp.reshape(out, (b, n_kv * g, t, d)) result = self.reverse_transepose(result, self.compute_axis_order) return result
[docs] def reverse_transepose(self, transposed_array, transpose_axis_order): return jax.numpy.moveaxis(transposed_array, (0, 1, 2, 3), transpose_axis_order)
[docs] def normalize_cudnn_attention(self, local_outs, local_stats): """Normalize across two cuDNN attentions Args: local_outs (list): List of outputs entries for each cudnn attention in shape [b, t, n, d]. local_stats (list): List of logsumexp entries for each cudnn attention in shape [b, n, t]. Returns: Array: Combined attention that has been normalized in shape [b, t, n, d]. """ # reshape stat to have shape [b, n, t, 1] stat0 = local_stats[0].reshape((*local_stats[0].shape, 1)) stat1 = local_stats[1].reshape((*local_stats[1].shape, 1)) global_stat = jnp.log(jnp.exp(stat0) + jnp.exp(stat1)) # # transpose stat to have shape [b, t, n, 1] for elemenwise multiplication attn_out = local_outs[0].astype(jnp.float32) * jnp.exp(stat0 - global_stat).transpose((0, 2, 1, 3)) + local_outs[ 1 ].astype(jnp.float32) * jnp.exp(stat1 - global_stat).transpose((0, 2, 1, 3)) return attn_out.astype(local_stats[0].dtype)
[docs] def normalize_attention(self, local_outs, local_maxes, local_sums): """Normalize across multiple localized attentions Args: local_outs (list): List of unnormalized outputs entries for each local attention local_maxes (list): List of max exponentials entries for each local attention local_sums (list): List of exponential sum entries for each local attention Returns: Array: Combined attention that has been normalized """ # Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py global_max = functools.reduce(jnp.maximum, local_maxes) global_sum = sum( (jnp.exp(local_max - global_max) * local_sum for (local_sum, local_max) in zip(local_sums, local_maxes)) ) attn_out = 0 for local_max, local_out in zip(local_maxes, local_outs): local_normalizer = jnp.exp(local_max - global_max) / global_sum attn_out += local_normalizer * local_out return attn_out
def __call__( self, query, key, value, decoder_segment_ids, inputs_positions, model_mode, cached_values=None, previous_chunk=None, bidirectional_mask=None, sinks=None, indexer_mask: Optional[Array] = None, slot: Optional[int] = None, page_state: Optional[page_manager.PageState] = None, record_max_logits: bool = False, ): if cached_values is None: prefill_kv_cache, ar_kv_cache = None, None else: prefill_kv_cache, ar_kv_cache = cached_values[0], cached_values[1] if model_mode != MODEL_MODE_TRAIN: assert prefill_kv_cache key, value, decoder_segment_ids = prefill_kv_cache indexer_mask_prefill = None indexer_mask_ar = None if indexer_mask is not None: prefill_len = key.shape[1] indexer_mask_prefill = indexer_mask[:, :, :prefill_len] if ar_kv_cache is not None: indexer_mask_ar = indexer_mask[:, :, prefill_len:] prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention( query=query, key=key, value=value, decoder_segment_ids=decoder_segment_ids, segment_positions=inputs_positions, lengths=None, model_mode=model_mode, use_ragged_attention=self.use_ragged_attention, previous_chunk=previous_chunk, bidirectional_mask=bidirectional_mask, sinks=sinks, indexer_mask=indexer_mask_prefill, record_max_logits=record_max_logits, qk_product_einsum=self.AqtEinsum_0, wv_product_einsum=self.AqtEinsum_1, ) # Return the "prefill" cache if it actually the combined prefill+ar kv cache if ar_kv_cache is None: if prefill_exponentials_sum is not None: return prefill_unnormalized_output / prefill_exponentials_sum return prefill_unnormalized_output key, value, decoder_segment_ids, lengths = ar_kv_cache ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention( query=query, key=key, value=value, decoder_segment_ids=decoder_segment_ids, segment_positions=inputs_positions, lengths=lengths, model_mode=model_mode, use_ragged_attention=self.use_ragged_attention, bidirectional_mask=bidirectional_mask, indexer_mask=indexer_mask_ar, qk_product_einsum=self.AqtEinsum_2, wv_product_einsum=self.AqtEinsum_3, ) if ar_unnormalized_output is not None: unnormalized_outputs = [ prefill_unnormalized_output, ar_unnormalized_output, ] exponentials_maxes = [prefill_exponentials_max, ar_exponentials_max] exponentials_sums = [prefill_exponentials_sum, ar_exponentials_sum] if prefill_exponentials_max is not None and prefill_exponentials_sum is None: prefill_stat = prefill_exponentials_max ar_stat = ar_exponentials_max stats = [prefill_stat, ar_stat] return self.normalize_cudnn_attention(unnormalized_outputs, stats) else: return self.normalize_attention(unnormalized_outputs, exponentials_maxes, exponentials_sums) else: return prefill_unnormalized_output / prefill_exponentials_sum
# pylint: disable=protected-access
[docs] class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): """Lazy causal mask, prevents the model from attending to future tokens. Attributes: offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. """ offset: int shape: tuple[int, int] cp_size: int def __init__( self, shape: tuple[int, int], offset: int = 0, shard_count: int = 1, cp_size: int = 4, ): self.offset = offset def causal_mask_function(q_ids, kv_ids): if self.offset == 0: return q_ids >= kv_ids else: return q_ids + self.offset >= kv_ids arr = np.arange(shape[0]) # we reorder the mask to be load balanced following the same approach as # used to reorder the input tokens out = max_utils.reorder_mask_load_balancing(arr[None, :, None, None], cp_size, 1) q_sequence = out[0, :, 0, 0] mask_function = causal_mask_function super().__init__( shape=shape, mask_function=mask_function, shard_count=shard_count, ) self.q_sequence = q_sequence def __eq__(self, other: object): if not isinstance(other, type(self)): return NotImplemented return self.shape == other.shape and self.offset == other.offset and np.array_equal(self.q_sequence, other.q_sequence) def __hash__(self): return hash( ( type(self), self.shape, self.offset, self.q_sequence.tobytes() if self.q_sequence is not None else None, ) )