maxtext.kernels.attention.jax_flash_attention module#
JAX implementation without using Pallas for Flash Attention.
- maxtext.kernels.attention.jax_flash_attention.flash_attention_block_masked(q, k, v, segment_ids, block_kv, block_q, mask, mask_value, cap=None, save_residuals=False)[source]#
Computes masked flash attention using block-sparse masking.
- Parameters:
q (Array) – Query tensor with shape (batch_size, num_kv_heads, num_q_heads_per_kv_head, q_seq_len, head_dim).
k (Array) – Key tensor with shape (batch_size, num_kv_heads, kv_seq_len, head_dim).
v (Array) – Value tensor with shape (batch_size, num_kv_heads, kv_seq_len, v_head_dim).
segment_ids (SegmentIds | None) – SegmentIds are a mechanism to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other. It stores the segment ids of the query and key/value sequences.
block_kv (int) – Block size for the key/value sequence dimension.
block_q (int) – Block size for the query sequence dimension.
mask (Array) – The full attention mask with shape of (q_seq_len, kv_seq_len). This mask will be used for all batches.
mask_value (float) – The value to use for masked-out attention scores.
cap (float | None) – Optional cap for attention logits. This helps to prevent extremely large logits: capped_logits = jnp.tanh(logits / attn_logits_soft_cap) * attn_logits_soft_cap
save_residuals (bool) – Whether to save residuals. If True, returns a tuple of (output, dict=(logsumexp, max_logits)). Both logsumexp and max_logits are of shape (batch_size, num_kv_heads, num_q_heads // num_kv_heads, q_seq_len).
- Returns:
The output of the attention computation. - A dict of (logsumexp, max_logits)
Otherwise, returns the output of the attention computation.
- Return type:
If save_residuals is True, returns a tuple containing
- maxtext.kernels.attention.jax_flash_attention.mask_blocker(mask, block_q, block_kv)[source]#
Creates a blocked mask from a full mask.
- Parameters:
mask (Array) – The attention mask with shape of (batch_size, q_seq_len, kv_seq_len).
block_q (int) – Block size for the query sequence dimension.
block_kv (int) – Block size for the key/value sequence dimension.
- Returns:
A blocked mask where each element indicates the number of non-zero elements in the corresponding block of the original mask.
- Return type:
Array