maxtext.kernels.attention.ragged_attention module#
Kernels for ragged attention for efficient inference.
- maxtext.kernels.attention.ragged_attention.get_mha_cost_estimate(shape_dtype)[source]#
Get cost estimate for MHA based on static shape information.
- maxtext.kernels.attention.ragged_attention.reference_mqa(q, k, v, lengths, *, mask_value=-2.381976426469702e+38)[source]#
Multi query attention reference.
- Parameters:
q (Array) – A [batch_size, num_heads, head_dim] jax.Array.
k (Array) – A [batch_size, seq_len, head_dim] jax.Array.
v (Array) – A [batch_size, seq_len, head_dim] jax.Array.
lengths (Array) – A i32[batch_size] jax.Array.
mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.
- Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads]) and softmax denominator ([batch_size, num_heads]).
- Return type:
tuple[Array, Array, Array]
- maxtext.kernels.attention.ragged_attention.reference_mha(q, k, v, lengths, *, mask_value=-2.381976426469702e+38)[source]#
Multi head attention reference.
- Parameters:
q (Array) – A [batch_size, 1, num_heads, head_dim] jax.Array.
k (Array) – A [batch_size, seq_len, num_heads, head_dim] jax.Array.
v (Array) – A [batch_size, seq_len, num_heads, head_dim] jax.Array.
lengths (Array) – A i32[batch_size] jax.Array.
mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.
- Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads]) and softmax denominator ([batch_size, num_heads]).
- Return type:
tuple[Array, Array, Array]
- maxtext.kernels.attention.ragged_attention.reference_gqa(q, k, v, lengths, mask_value=-2.381976426469702e+38)[source]#
Vanilla attention GQA implementation for reference.
- Parameters:
q (Array) – A [batch_size, num_q_heads, head_dim] jax.Array.
k (Array) – A [batch_size, num_kv_heads, max_seq_len, head_dim] jax.Array.
v (Array) – A [batch_size, num_kv_heads, max_seq_len, head_dim] jax.Array.
lengths (Array) – A i32[batch_size] jax.Array.
mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.
- Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads]) and softmax denominator ([batch_size, num_heads]).
- Return type:
tuple[Array, Array, Array]
- maxtext.kernels.attention.ragged_attention.ragged_flash_attention_kernel(lengths_ref, q_ref, k_ref, v_ref, o_ref, m_ref, l_ref, *, block_size, mask_value)[source]#
Pallas kernel for flash attention.
- Parameters:
block_size (int)
mask_value (float)
- maxtext.kernels.attention.ragged_attention.ragged_mqa(q, k, v, lengths, *, block_size=256, mask_value=-2.381976426469702e+38, cost_estimate=None)[source]#
Ragged multi query attention.
- Parameters:
q (Array) – A [batch_size, 1, head_dim] jax.Array.
k (Array) – A [batch_size, seq_len, head_dim] jax.Array.
v (Array) – A [batch_size, seq_len, head_dim] jax.Array.
lengths (Array) – A i32[batch_size] jax.Array.
mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.
cost_estimate (CostEstimate | None) – A Pallas TPU cost estimate based on a reference implementation
block_size (int)
- Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).
- Return type:
tuple[Array, Array, Array]
- maxtext.kernels.attention.ragged_attention.ragged_mha(query, key, value, lengths, *, block_size=256, mask_value=-2.381976426469702e+38)[source]#
Ragged multi head attention.
- Parameters:
q – A [batch_size, 1, num_heads, head_dim] jax.Array.
k – A [batch_size, seq_len, num_heads, head_dim] jax.Array.
v – A [batch_size, seq_len, num_heads, head_dim] jax.Array.
lengths (Array) – A i32[batch_size] jax.Array.
block_size (int) – Value defining the Pallas block length in the seq_len dimension
mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.
query (Array)
key (Array)
value (Array)
- Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).
- Return type:
tuple[Array, Array, Array]
- maxtext.kernels.attention.ragged_attention.ragged_gqa(query, key, value, lengths, *, block_size=256, mask_value=-2.381976426469702e+38)[source]#
Ragged group query attention.
- Parameters:
q – A [batch_size, num_heads_q, head_dim] jax.Array.
k – A [batch_size, seq_len, num_heads_kv, head_dim] jax.Array.
v – A [batch_size, seq_len, num_heads_kv, head_dim] jax.Array.
lengths (Array) – A i32[batch_size] jax.Array.
block_size (int) – Value defining the Pallas block length in the seq_len dimension
mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.
query (Array)
key (Array)
value (Array)
- Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).
- Return type:
tuple[Array, Array, Array]