Source code for maxtext.kernels.attention.jax_flash_attention

#  Copyright 2026 Google LLC
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""JAX implementation without using Pallas for Flash Attention."""

from typing import Optional, Tuple, Union

import jax
import jax.numpy as jnp
from maxtext.kernels.attention import splash_attention_kernel

SegmentIds = splash_attention_kernel.SegmentIds


# This function computes masked flash attention using a block-sparse approach.
# This implementation keeps the full batch and number of heads dimensions
# throughout the attention computation while iterating through blocks of the
# key/value sequence and, within each, iterates through blocks of the query
# sequence. The `mask_blocked` is used to skip computations for blocks where all
# attention scores are masked out, improving efficiency for sparse masks.
[docs] def flash_attention_block_masked( q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, segment_ids: SegmentIds | None, block_kv: int, block_q: int, mask: jnp.ndarray, mask_value: float, cap: Optional[float] = None, save_residuals: bool = False, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]]: """Computes masked flash attention using block-sparse masking. Args: q: Query tensor with shape (batch_size, num_kv_heads, num_q_heads_per_kv_head, q_seq_len, head_dim). k: Key tensor with shape (batch_size, num_kv_heads, kv_seq_len, head_dim). v: Value tensor with shape (batch_size, num_kv_heads, kv_seq_len, v_head_dim). segment_ids: SegmentIds are a mechanism to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other. It stores the segment ids of the query and key/value sequences. block_kv: Block size for the key/value sequence dimension. block_q: Block size for the query sequence dimension. mask: The full attention mask with shape of (q_seq_len, kv_seq_len). This mask will be used for all batches. mask_value: The value to use for masked-out attention scores. cap: Optional cap for attention logits. This helps to prevent extremely large logits: capped_logits = jnp.tanh(logits / attn_logits_soft_cap) * attn_logits_soft_cap save_residuals: Whether to save residuals. If True, returns a tuple of (output, dict=(logsumexp, max_logits)). Both `logsumexp` and `max_logits` are of shape (batch_size, num_kv_heads, num_q_heads // num_kv_heads, q_seq_len). Returns: If save_residuals is True, returns a tuple containing: - The output of the attention computation. - A dict of (logsumexp, max_logits) Otherwise, returns the output of the attention computation. """ batch_size, num_q_heads, q_seq_len, qk_head_dim_size = q.shape _, num_kv_heads, kv_seq_len, _ = k.shape v_head_dim_size = v.shape[-1] data_type = q.dtype q_groups = num_q_heads // num_kv_heads q = q.reshape(( batch_size, num_kv_heads, q_groups, q_seq_len, qk_head_dim_size, )) # Calculate the number of key/value and query blocks. num_kv_blocks = kv_seq_len // block_kv num_q_blocks = q_seq_len // block_q # Before applying the segment mask, we need to broadcast the mask in batch # dimension since we have same logic for all batches. mask_full = jnp.broadcast_to( mask[None, :, :], (batch_size, q_seq_len, kv_seq_len) ) if segment_ids is not None: segment_ids_q = segment_ids.q[:, :, None] segment_ids_kv = segment_ids.kv[:, None, :] mask_full = jnp.logical_and(mask_full, segment_ids_q == segment_ids_kv) mask_blocked = jax.jit(mask_blocker, static_argnums=[1, 2])( mask_full, block_q, block_kv ) # Initialize `l` (logsumexp) and `m` (max_logits) for the online softmax. # `l` is initialized to 0 since no blocks have been processed yet and the sum # is 0. l = jnp.zeros( (batch_size, num_kv_heads, q_groups, q_seq_len), dtype=data_type ) # `m` is initialized to the mask_value so that the first block's maximum logit # correctly becomes the running maximum. m = jnp.full( (batch_size, num_kv_heads, q_groups, q_seq_len), mask_value, dtype=data_type, ) output = jnp.zeros( ( batch_size, num_kv_heads, q_groups, q_seq_len, v_head_dim_size, ), dtype=data_type, ) # Outer loop over the key/value blocks. def outer_loop_body(j, carried): output, l, m = carried k_j_slice = jax.lax.dynamic_slice_in_dim(k, j * block_kv, block_kv, axis=-2) v_j_slice = jax.lax.dynamic_slice_in_dim(v, j * block_kv, block_kv, axis=-2) # Inner loop over the query blocks. def inner_loop_body(i, carried_inner): output, l, m = carried_inner # let's get the slice of Q in N dimension q_slice = jax.lax.dynamic_slice_in_dim(q, i * block_q, block_q, axis=-2) # Calculates the attention computation (Q@K.T)@V with online softmax for # the current query and key/value blocks. def compute_attention_block(output, l, m): output_i_slice = jax.lax.dynamic_slice_in_dim( output, i * block_q, block_q, axis=-2 ) l_i_slice = jax.lax.dynamic_slice_in_dim( l, i * block_q, block_q, axis=-1 ) m_i_slice = jax.lax.dynamic_slice_in_dim( m, i * block_q, block_q, axis=-1 ) s_i_j = jnp.einsum( "bxhqc,bxkc->bxhqk", q_slice, k_j_slice, preferred_element_type=data_type, ) full_mask_i_j_slice = jax.lax.dynamic_slice( mask_full, (0, i * block_q, j * block_kv), (batch_size, block_q, block_kv), ) broadcasted_mask = jnp.broadcast_to( full_mask_i_j_slice[:, None, None, :, :], (batch_size, num_kv_heads, q_groups, block_q, block_kv), ) s_i_j = jnp.where(broadcasted_mask, s_i_j, mask_value) if cap is not None: s_i_j = jnp.tanh(s_i_j / cap) s_i_j = s_i_j * cap m_i_j = s_i_j.max(axis=-1) p_i_j = jnp.exp(s_i_j - m_i_j[..., None]) l_i_j = p_i_j.sum(axis=-1) assert m_i_j.shape == m_i_slice.shape m_i_new = jnp.maximum(m_i_slice, m_i_j) m_i_difference = jnp.exp(m_i_slice - m_i_new) m_i_j_difference = jnp.exp(m_i_j - m_i_new) l_i_new = m_i_difference * l_i_slice + m_i_j_difference * l_i_j divider = l_i_new[..., None] numerator = l_i_slice[..., None] * m_i_difference[ ..., None ] * output_i_slice + m_i_j_difference[..., None] * jnp.einsum( "bxhqk,bxkc->bxhqc", p_i_j, v_j_slice, preferred_element_type=data_type, ) output_i_slice_new = numerator / divider output = jax.lax.dynamic_update_index_in_dim( output, output_i_slice_new, i * block_q, axis=-2 ) l = jax.lax.dynamic_update_index_in_dim( l, l_i_new, i * block_q, axis=-1 ) m = jax.lax.dynamic_update_index_in_dim( m, m_i_new, i * block_q, axis=-1 ) return output, l, m def identity(output, l, m): """A no-op identity function.""" return output, l, m batch_size = mask_blocked.shape[0] mask_i_j_slice = jax.lax.dynamic_slice( mask_blocked, (0, i, j), (batch_size, 1, 1) ) # The compute_attention_block should be executed if at least one element # in the slice is non-zero, meaning at least one batch requires work for # this block. output, l, m = jax.lax.cond( jnp.any(jnp.not_equal(mask_i_j_slice, 0)), compute_attention_block, identity, output, l, m, ) return output, l, m output, l, m = jax.lax.fori_loop( 0, num_q_blocks, inner_loop_body, (output, l, m), unroll=True ) return (output, l, m) output, l, m = jax.lax.fori_loop( 0, num_kv_blocks, outer_loop_body, (output, l, m), unroll=True ) # Reshape the output to drop the size one dimension at index 2, # which corresponds to `num_q_heads // num_kv_heads` when # num_q_heads == num_kv_heads. output = output.squeeze(axis=2) if not save_residuals: # To avoid remat of the output, we can use context=hbm remat policy as in # maxtext/configs/types.py return output l = l.squeeze(axis=2) m = m.squeeze(axis=2) stats = {"logsumexp": m + jnp.log(l), "max_logits": m} stats = jax.tree.map(jax.lax.stop_gradient, stats) return output, stats
[docs] def mask_blocker(mask: jnp.ndarray, block_q: int, block_kv: int) -> jnp.ndarray: """Creates a blocked mask from a full mask. Args: mask: The attention mask with shape of (batch_size, q_seq_len, kv_seq_len). block_q: Block size for the query sequence dimension. block_kv: Block size for the key/value sequence dimension. Returns: A blocked mask where each element indicates the number of non-zero elements in the corresponding block of the original mask. """ batch_size, q_seq_len, kv_seq_len = mask.shape if q_seq_len % block_q != 0: raise ValueError( f"q_seq_len {q_seq_len} must be divisible by block_q {block_q}" ) if kv_seq_len % block_kv != 0: raise ValueError( f"kv_seq_len {kv_seq_len} must be divisible by block_kv {block_kv}" ) q_blocks = q_seq_len // block_q kv_blocks = kv_seq_len // block_kv blocked_mask = mask.reshape( batch_size, q_blocks, block_q, kv_blocks, block_kv ) return jnp.count_nonzero(blocked_mask, axis=(2, 4)).astype(jnp.int32)