maxtext.kernels.attention.splash_attention_kernel module#

class maxtext.kernels.attention.splash_attention_kernel.SegmentIds(q, kv)[source]#

Bases: NamedTuple

SegmentIds for Q and KV sequences.

SegmentIds are a mechanism to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other.

The static mask (e.g. causal) is “and-ed” with the segment id mask to form the actual attention mask. It is important that the latter does not have any all-zero rows (along dimension kv). Otherwise it would result in a invalid softmax (the denominator would be 0). This condition holds for causal self-attention because in this case segment ids form a block diagonal matrix so at least one element in each row is set. It is easy to break this condition with non-self-attention configurations.

Parameters:
  • q (Array)

  • kv (Array)

q: Array#

Alias for field number 0

kv: Array#

Alias for field number 1

maxtext.kernels.attention.splash_attention_kernel.get_kernel_name(block_metadata, is_mqa, save_residuals, is_segmented, phase)[source]#

Returns a unique name for all SplashAttention kernel variants.

Parameters:
  • block_metadata (Mapping[str, Any])

  • is_mqa (bool)

  • save_residuals (bool)

  • is_segmented (bool)

  • phase (str)

Return type:

str

maxtext.kernels.attention.splash_attention_kernel.attention_reference(mask, q, k, v, segment_ids, *, mask_value=-2.381976426469702e+38, save_residuals=False, custom_type='flash', attn_logits_soft_cap=None)[source]#

Reference attention implementation.

Parameters:
  • mask (Array)

  • q (Array)

  • k (Array)

  • v (Array)

  • segment_ids (SegmentIds | None)

  • mask_value (float)

  • save_residuals (bool)

  • custom_type (str)

  • attn_logits_soft_cap (float | None)

Return type:

Array | tuple[Array, tuple[Array]]

maxtext.kernels.attention.splash_attention_kernel.attention_reference_custom(mask, q, k, v, segment_ids, *, mask_value=-2.381976426469702e+38, save_residuals=False, custom_type='flash', attn_logits_soft_cap=None)[source]#

Reference attention custom implementation.

Parameters:
  • mask (Array)

  • q (Array)

  • k (Array)

  • v (Array)

  • segment_ids (SegmentIds | None)

  • mask_value (float)

  • save_residuals (bool)

  • custom_type (str)

  • attn_logits_soft_cap (float | None)

maxtext.kernels.attention.splash_attention_kernel.make_attention_reference(mask, is_mqa, backward_impl='vanilla', **params)[source]#

Returns a function that computes reference attention.

Parameters:
  • mask (Mask | ndarray)

  • is_mqa (bool)

  • backward_impl (str)

  • params (Any)

Return type:

Callable

maxtext.kernels.attention.splash_attention_kernel.make_masked_mha_reference(mask, *, is_mqa=False, backward_impl='vanilla', **params)#

Returns a function that computes reference attention.

Parameters:
  • mask (mask_lib.Mask | np.ndarray)

  • is_mqa (bool)

  • backward_impl (str)

  • params (Any)

Return type:

Callable

maxtext.kernels.attention.splash_attention_kernel.make_masked_mqa_reference(mask, *, is_mqa=True, backward_impl='vanilla', **params)#

Returns a function that computes reference attention.

Parameters:
  • mask (mask_lib.Mask | np.ndarray)

  • is_mqa (bool)

  • backward_impl (str)

  • params (Any)

Return type:

Callable

class maxtext.kernels.attention.splash_attention_kernel.QKVLayout(*values)[source]#

Bases: IntEnum

HEAD_DIM_MINOR = 1#
SEQ_MINOR = 2#
maxtext.kernels.attention.splash_attention_kernel.from_head_minor(vals, layout)[source]#
Parameters:
class maxtext.kernels.attention.splash_attention_kernel.BlockSizes(block_q, block_kv, block_kv_compute=None, block_q_dkv=None, block_kv_dkv=None, block_kv_dkv_compute=None, block_q_dq=None, block_kv_dq=None, use_fused_bwd_kernel=False, q_layout=QKVLayout.HEAD_DIM_MINOR, k_layout=QKVLayout.HEAD_DIM_MINOR, v_layout=QKVLayout.HEAD_DIM_MINOR)[source]#

Bases: object

Tile sizes parameterizing SplashAttention kernels.

Those parameters have negligible effect on numerics, but affect performance greatly.

Note that changing the layouts only influences the physical layout that the kernel will enforce. The logical interface to splash attention always takes the head dimension as the minormost one.

Parameters:
  • block_q (int)

  • block_kv (int)

  • block_kv_compute (int | None)

  • block_q_dkv (int | None)

  • block_kv_dkv (int | None)

  • block_kv_dkv_compute (int | None)

  • block_q_dq (int | None)

  • block_kv_dq (int | None)

  • use_fused_bwd_kernel (bool)

  • q_layout (QKVLayout)

  • k_layout (QKVLayout)

  • v_layout (QKVLayout)

block_q: int#
block_kv: int#
block_kv_compute: int | None#
block_q_dkv: int | None#
block_kv_dkv: int | None#
block_kv_dkv_compute: int | None#
block_q_dq: int | None#
block_kv_dq: int | None#
use_fused_bwd_kernel: bool#
q_layout: QKVLayout#
k_layout: QKVLayout#
v_layout: QKVLayout#
property has_backward_blocks: bool#
classmethod get_default()[source]#
maxtext.kernels.attention.splash_attention_kernel.flash_attention_kernel(data_next_ref, block_mask_ref, mask_next_ref, q_ref, k_ref, v_ref, q_segment_ids_ref, kv_segment_ids_ref, mask_ref, q_sequence_ref, m_scratch_ref, l_scratch_ref, o_scratch_ref, o_ref, logsumexp_ref=None, *, mask_value, grid_width, bq, bkv, bkv_compute, head_dim_v, q_layout, k_layout, v_layout, attn_logits_soft_cap, mask_function)[source]#

Flash attention kernel.

Parameters:
  • mask_value (float)

  • grid_width (int)

  • bq (int)

  • bkv (int)

  • bkv_compute (int)

  • head_dim_v (int)

  • q_layout (QKVLayout)

  • k_layout (QKVLayout)

  • v_layout (QKVLayout)

  • attn_logits_soft_cap (float | None)

  • mask_function (Callable[[...], Array] | None)

class maxtext.kernels.attention.splash_attention_kernel.SplashAttentionKernel(fwd_mask_info, dq_mask_info, dkv_mask_info, **kwargs)[source]#

Bases: object

Defines a SplashAttention kernel object.

Parameters:
  • fwd_mask_info (mask_info_lib.MaskInfo)

  • dq_mask_info (mask_info_lib.MaskInfo | None)

  • dkv_mask_info (mask_info_lib.MaskInfo | None)

manual_fwd(*args, **kwargs)[source]#
Return type:

Array | tuple[Array, tuple[Array]]

manual_bwd(*args, **kwargs)[source]#
manual_sharding_spec(sharding)[source]#

Returns a value that can be used as a shard_map partition spec for the kernel.

Parameters:

sharding (NamedSharding)

tree_flatten()[source]#
classmethod tree_unflatten(kwargs, values)[source]#
maxtext.kernels.attention.splash_attention_kernel.make_splash_mha(mask, *, block_sizes=None, is_mqa=False, save_residuals=False, mask_value=-2.381976426469702e+38, attn_logits_soft_cap=None, downcast_smem_data=True, head_shards, q_seq_shards, residual_checkpoint_name=None, interpret=False)#

Creates a SplashAttentionKernel.

Parameters:
  • mask (np.ndarray | jax.Array | mask_lib.MultiHeadMask)

  • block_sizes (BlockSizes | None)

  • is_mqa (bool)

  • save_residuals (bool)

  • mask_value (float)

  • attn_logits_soft_cap (float | None)

  • downcast_smem_data (bool)

  • head_shards (int)

  • q_seq_shards (int)

  • residual_checkpoint_name (str | None)

  • interpret (bool)

maxtext.kernels.attention.splash_attention_kernel.make_splash_mqa(mask, *, block_sizes=None, is_mqa=True, save_residuals=False, mask_value=-2.381976426469702e+38, attn_logits_soft_cap=None, downcast_smem_data=True, head_shards, q_seq_shards, residual_checkpoint_name=None, interpret=False)#

Creates a SplashAttentionKernel.

Parameters:
  • mask (np.ndarray | jax.Array | mask_lib.MultiHeadMask)

  • block_sizes (BlockSizes | None)

  • is_mqa (bool)

  • save_residuals (bool)

  • mask_value (float)

  • attn_logits_soft_cap (float | None)

  • downcast_smem_data (bool)

  • head_shards (int)

  • q_seq_shards (int)

  • residual_checkpoint_name (str | None)

  • interpret (bool)

maxtext.kernels.attention.splash_attention_kernel.make_splash_mha_single_device(mask, *, block_sizes=None, is_mqa=False, save_residuals=False, mask_value=-2.381976426469702e+38, attn_logits_soft_cap=None, downcast_smem_data=True, head_shards=1, q_seq_shards=1, residual_checkpoint_name=None, interpret=False)#

Creates a SplashAttentionKernel.

Parameters:
  • mask (np.ndarray | jax.Array | mask_lib.MultiHeadMask)

  • block_sizes (BlockSizes | None)

  • is_mqa (bool)

  • save_residuals (bool)

  • mask_value (float)

  • attn_logits_soft_cap (float | None)

  • downcast_smem_data (bool)

  • head_shards (int)

  • q_seq_shards (int)

  • residual_checkpoint_name (str | None)

  • interpret (bool)

maxtext.kernels.attention.splash_attention_kernel.make_splash_mqa_single_device(mask, *, block_sizes=None, is_mqa=True, save_residuals=False, mask_value=-2.381976426469702e+38, attn_logits_soft_cap=None, downcast_smem_data=True, head_shards=1, q_seq_shards=1, residual_checkpoint_name=None, interpret=False)#

Creates a SplashAttentionKernel.

Parameters:
  • mask (np.ndarray | jax.Array | mask_lib.MultiHeadMask)

  • block_sizes (BlockSizes | None)

  • is_mqa (bool)

  • save_residuals (bool)

  • mask_value (float)

  • attn_logits_soft_cap (float | None)

  • downcast_smem_data (bool)

  • head_shards (int)

  • q_seq_shards (int)

  • residual_checkpoint_name (str | None)

  • interpret (bool)