maxtext.kernels.attention.jax_flash_attention module

maxtext.kernels.attention.jax_flash_attention module#

JAX implementation without using Pallas for Flash Attention.

maxtext.kernels.attention.jax_flash_attention.flash_attention_block_masked(q, k, v, segment_ids, block_kv, block_q, mask, mask_value, cap=None, save_residuals=False)[source]#

Computes masked flash attention using block-sparse masking.

Parameters:
  • q (Array) – Query tensor with shape (batch_size, num_kv_heads, num_q_heads_per_kv_head, q_seq_len, head_dim).

  • k (Array) – Key tensor with shape (batch_size, num_kv_heads, kv_seq_len, head_dim).

  • v (Array) – Value tensor with shape (batch_size, num_kv_heads, kv_seq_len, v_head_dim).

  • segment_ids (SegmentIds | None) – SegmentIds are a mechanism to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other. It stores the segment ids of the query and key/value sequences.

  • block_kv (int) – Block size for the key/value sequence dimension.

  • block_q (int) – Block size for the query sequence dimension.

  • mask (Array) – The full attention mask with shape of (q_seq_len, kv_seq_len). This mask will be used for all batches.

  • mask_value (float) – The value to use for masked-out attention scores.

  • cap (float | None) – Optional cap for attention logits. This helps to prevent extremely large logits: capped_logits = jnp.tanh(logits / attn_logits_soft_cap) * attn_logits_soft_cap

  • save_residuals (bool) – Whether to save residuals. If True, returns a tuple of (output, dict=(logsumexp, max_logits)). Both logsumexp and max_logits are of shape (batch_size, num_kv_heads, num_q_heads // num_kv_heads, q_seq_len).

Returns:

  • The output of the attention computation. - A dict of (logsumexp, max_logits)

Otherwise, returns the output of the attention computation.

Return type:

If save_residuals is True, returns a tuple containing

maxtext.kernels.attention.jax_flash_attention.mask_blocker(mask, block_q, block_kv)[source]#

Creates a blocked mask from a full mask.

Parameters:
  • mask (Array) – The attention mask with shape of (batch_size, q_seq_len, kv_seq_len).

  • block_q (int) – Block size for the query sequence dimension.

  • block_kv (int) – Block size for the key/value sequence dimension.

Returns:

A blocked mask where each element indicates the number of non-zero elements in the corresponding block of the original mask.

Return type:

Array