maxtext.kernels.attention.ragged_attention module#

Kernels for ragged attention for efficient inference.

maxtext.kernels.attention.ragged_attention.get_mha_cost_estimate(shape_dtype)[source]#

Get cost estimate for MHA based on static shape information.

maxtext.kernels.attention.ragged_attention.reference_mqa(q, k, v, lengths, *, mask_value=-2.381976426469702e+38)[source]#

Multi query attention reference.

Parameters:
  • q (Array) – A [batch_size, num_heads, head_dim] jax.Array.

  • k (Array) – A [batch_size, seq_len, head_dim] jax.Array.

  • v (Array) – A [batch_size, seq_len, head_dim] jax.Array.

  • lengths (Array) – A i32[batch_size] jax.Array.

  • mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.

Returns:

The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads]) and softmax denominator ([batch_size, num_heads]).

Return type:

tuple[Array, Array, Array]

maxtext.kernels.attention.ragged_attention.reference_mha(q, k, v, lengths, *, mask_value=-2.381976426469702e+38)[source]#

Multi head attention reference.

Parameters:
  • q (Array) – A [batch_size, 1, num_heads, head_dim] jax.Array.

  • k (Array) – A [batch_size, seq_len, num_heads, head_dim] jax.Array.

  • v (Array) – A [batch_size, seq_len, num_heads, head_dim] jax.Array.

  • lengths (Array) – A i32[batch_size] jax.Array.

  • mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.

Returns:

The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads]) and softmax denominator ([batch_size, num_heads]).

Return type:

tuple[Array, Array, Array]

maxtext.kernels.attention.ragged_attention.reference_gqa(q, k, v, lengths, mask_value=-2.381976426469702e+38)[source]#

Vanilla attention GQA implementation for reference.

Parameters:
  • q (Array) – A [batch_size, num_q_heads, head_dim] jax.Array.

  • k (Array) – A [batch_size, num_kv_heads, max_seq_len, head_dim] jax.Array.

  • v (Array) – A [batch_size, num_kv_heads, max_seq_len, head_dim] jax.Array.

  • lengths (Array) – A i32[batch_size] jax.Array.

  • mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.

Returns:

The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads]) and softmax denominator ([batch_size, num_heads]).

Return type:

tuple[Array, Array, Array]

maxtext.kernels.attention.ragged_attention.ragged_flash_attention_kernel(lengths_ref, q_ref, k_ref, v_ref, o_ref, m_ref, l_ref, *, block_size, mask_value)[source]#

Pallas kernel for flash attention.

Parameters:
  • block_size (int)

  • mask_value (float)

maxtext.kernels.attention.ragged_attention.ragged_mqa(q, k, v, lengths, *, block_size=256, mask_value=-2.381976426469702e+38, cost_estimate=None)[source]#

Ragged multi query attention.

Parameters:
  • q (Array) – A [batch_size, 1, head_dim] jax.Array.

  • k (Array) – A [batch_size, seq_len, head_dim] jax.Array.

  • v (Array) – A [batch_size, seq_len, head_dim] jax.Array.

  • lengths (Array) – A i32[batch_size] jax.Array.

  • mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.

  • cost_estimate (CostEstimate | None) – A Pallas TPU cost estimate based on a reference implementation

  • block_size (int)

Returns:

The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).

Return type:

tuple[Array, Array, Array]

maxtext.kernels.attention.ragged_attention.ragged_mha(query, key, value, lengths, *, block_size=256, mask_value=-2.381976426469702e+38)[source]#

Ragged multi head attention.

Parameters:
  • q – A [batch_size, 1, num_heads, head_dim] jax.Array.

  • k – A [batch_size, seq_len, num_heads, head_dim] jax.Array.

  • v – A [batch_size, seq_len, num_heads, head_dim] jax.Array.

  • lengths (Array) – A i32[batch_size] jax.Array.

  • block_size (int) – Value defining the Pallas block length in the seq_len dimension

  • mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.

  • query (Array)

  • key (Array)

  • value (Array)

Returns:

The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).

Return type:

tuple[Array, Array, Array]

maxtext.kernels.attention.ragged_attention.ragged_gqa(query, key, value, lengths, *, block_size=256, mask_value=-2.381976426469702e+38)[source]#

Ragged group query attention.

Parameters:
  • q – A [batch_size, num_heads_q, head_dim] jax.Array.

  • k – A [batch_size, seq_len, num_heads_kv, head_dim] jax.Array.

  • v – A [batch_size, seq_len, num_heads_kv, head_dim] jax.Array.

  • lengths (Array) – A i32[batch_size] jax.Array.

  • block_size (int) – Value defining the Pallas block length in the seq_len dimension

  • mask_value (float) – The value used for padding in attention. By default it is a very negative floating point number.

  • query (Array)

  • key (Array)

  • value (Array)

Returns:

The output of attention([batch_size, num_heads, head_dim]), along with the max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size, num_heads, 1]).

Return type:

tuple[Array, Array, Array]