maxtext.kernels.attention.splash_attention_kernel module#
- class maxtext.kernels.attention.splash_attention_kernel.SegmentIds(q, kv)[source]#
Bases:
NamedTupleSegmentIds for Q and KV sequences.
SegmentIds are a mechanism to ensure that there is no cross-attention between segments (fraction of a sequence) that have been concatenated together into a sequence. Each array is a list of ids (integers). Only tokens with the same id are allowed to attend to each other.
The static mask (e.g. causal) is “and-ed” with the segment id mask to form the actual attention mask. It is important that the latter does not have any all-zero rows (along dimension kv). Otherwise it would result in a invalid softmax (the denominator would be 0). This condition holds for causal self-attention because in this case segment ids form a block diagonal matrix so at least one element in each row is set. It is easy to break this condition with non-self-attention configurations.
- Parameters:
q (Array)
kv (Array)
- q: Array#
Alias for field number 0
- kv: Array#
Alias for field number 1
- maxtext.kernels.attention.splash_attention_kernel.get_kernel_name(block_metadata, is_mqa, save_residuals, is_segmented, phase)[source]#
Returns a unique name for all SplashAttention kernel variants.
- Parameters:
block_metadata (Mapping[str, Any])
is_mqa (bool)
save_residuals (bool)
is_segmented (bool)
phase (str)
- Return type:
str
- maxtext.kernels.attention.splash_attention_kernel.attention_reference(mask, q, k, v, segment_ids, *, mask_value=-2.381976426469702e+38, save_residuals=False, custom_type='flash', attn_logits_soft_cap=None)[source]#
Reference attention implementation.
- Parameters:
mask (Array)
q (Array)
k (Array)
v (Array)
segment_ids (SegmentIds | None)
mask_value (float)
save_residuals (bool)
custom_type (str)
attn_logits_soft_cap (float | None)
- Return type:
Array | tuple[Array, tuple[Array]]
- maxtext.kernels.attention.splash_attention_kernel.attention_reference_custom(mask, q, k, v, segment_ids, *, mask_value=-2.381976426469702e+38, save_residuals=False, custom_type='flash', attn_logits_soft_cap=None)[source]#
Reference attention custom implementation.
- Parameters:
mask (Array)
q (Array)
k (Array)
v (Array)
segment_ids (SegmentIds | None)
mask_value (float)
save_residuals (bool)
custom_type (str)
attn_logits_soft_cap (float | None)
- maxtext.kernels.attention.splash_attention_kernel.make_attention_reference(mask, is_mqa, backward_impl='vanilla', **params)[source]#
Returns a function that computes reference attention.
- Parameters:
mask (Mask | ndarray)
is_mqa (bool)
backward_impl (str)
params (Any)
- Return type:
Callable
- maxtext.kernels.attention.splash_attention_kernel.make_masked_mha_reference(mask, *, is_mqa=False, backward_impl='vanilla', **params)#
Returns a function that computes reference attention.
- Parameters:
mask (mask_lib.Mask | np.ndarray)
is_mqa (bool)
backward_impl (str)
params (Any)
- Return type:
Callable
- maxtext.kernels.attention.splash_attention_kernel.make_masked_mqa_reference(mask, *, is_mqa=True, backward_impl='vanilla', **params)#
Returns a function that computes reference attention.
- Parameters:
mask (mask_lib.Mask | np.ndarray)
is_mqa (bool)
backward_impl (str)
params (Any)
- Return type:
Callable
- class maxtext.kernels.attention.splash_attention_kernel.QKVLayout(*values)[source]#
Bases:
IntEnum- HEAD_DIM_MINOR = 1#
- SEQ_MINOR = 2#
- maxtext.kernels.attention.splash_attention_kernel.from_head_minor(vals, layout)[source]#
- Parameters:
vals (tuple[Any, ...])
layout (QKVLayout)
- class maxtext.kernels.attention.splash_attention_kernel.BlockSizes(block_q, block_kv, block_kv_compute=None, block_q_dkv=None, block_kv_dkv=None, block_kv_dkv_compute=None, block_q_dq=None, block_kv_dq=None, use_fused_bwd_kernel=False, q_layout=QKVLayout.HEAD_DIM_MINOR, k_layout=QKVLayout.HEAD_DIM_MINOR, v_layout=QKVLayout.HEAD_DIM_MINOR)[source]#
Bases:
objectTile sizes parameterizing SplashAttention kernels.
Those parameters have negligible effect on numerics, but affect performance greatly.
Note that changing the layouts only influences the physical layout that the kernel will enforce. The logical interface to splash attention always takes the head dimension as the minormost one.
- Parameters:
- block_q: int#
- block_kv: int#
- block_kv_compute: int | None#
- block_q_dkv: int | None#
- block_kv_dkv: int | None#
- block_kv_dkv_compute: int | None#
- block_q_dq: int | None#
- block_kv_dq: int | None#
- use_fused_bwd_kernel: bool#
- property has_backward_blocks: bool#
- maxtext.kernels.attention.splash_attention_kernel.flash_attention_kernel(data_next_ref, block_mask_ref, mask_next_ref, q_ref, k_ref, v_ref, q_segment_ids_ref, kv_segment_ids_ref, mask_ref, q_sequence_ref, m_scratch_ref, l_scratch_ref, o_scratch_ref, o_ref, logsumexp_ref=None, *, mask_value, grid_width, bq, bkv, bkv_compute, head_dim_v, q_layout, k_layout, v_layout, attn_logits_soft_cap, mask_function)[source]#
Flash attention kernel.
- class maxtext.kernels.attention.splash_attention_kernel.SplashAttentionKernel(fwd_mask_info, dq_mask_info, dkv_mask_info, **kwargs)[source]#
Bases:
objectDefines a SplashAttention kernel object.
- Parameters:
fwd_mask_info (mask_info_lib.MaskInfo)
dq_mask_info (mask_info_lib.MaskInfo | None)
dkv_mask_info (mask_info_lib.MaskInfo | None)
- maxtext.kernels.attention.splash_attention_kernel.make_splash_mha(mask, *, block_sizes=None, is_mqa=False, save_residuals=False, mask_value=-2.381976426469702e+38, attn_logits_soft_cap=None, downcast_smem_data=True, head_shards, q_seq_shards, residual_checkpoint_name=None, interpret=False)#
Creates a SplashAttentionKernel.
- Parameters:
mask (np.ndarray | jax.Array | mask_lib.MultiHeadMask)
block_sizes (BlockSizes | None)
is_mqa (bool)
save_residuals (bool)
mask_value (float)
attn_logits_soft_cap (float | None)
downcast_smem_data (bool)
head_shards (int)
q_seq_shards (int)
residual_checkpoint_name (str | None)
interpret (bool)
- maxtext.kernels.attention.splash_attention_kernel.make_splash_mqa(mask, *, block_sizes=None, is_mqa=True, save_residuals=False, mask_value=-2.381976426469702e+38, attn_logits_soft_cap=None, downcast_smem_data=True, head_shards, q_seq_shards, residual_checkpoint_name=None, interpret=False)#
Creates a SplashAttentionKernel.
- Parameters:
mask (np.ndarray | jax.Array | mask_lib.MultiHeadMask)
block_sizes (BlockSizes | None)
is_mqa (bool)
save_residuals (bool)
mask_value (float)
attn_logits_soft_cap (float | None)
downcast_smem_data (bool)
head_shards (int)
q_seq_shards (int)
residual_checkpoint_name (str | None)
interpret (bool)
- maxtext.kernels.attention.splash_attention_kernel.make_splash_mha_single_device(mask, *, block_sizes=None, is_mqa=False, save_residuals=False, mask_value=-2.381976426469702e+38, attn_logits_soft_cap=None, downcast_smem_data=True, head_shards=1, q_seq_shards=1, residual_checkpoint_name=None, interpret=False)#
Creates a SplashAttentionKernel.
- Parameters:
mask (np.ndarray | jax.Array | mask_lib.MultiHeadMask)
block_sizes (BlockSizes | None)
is_mqa (bool)
save_residuals (bool)
mask_value (float)
attn_logits_soft_cap (float | None)
downcast_smem_data (bool)
head_shards (int)
q_seq_shards (int)
residual_checkpoint_name (str | None)
interpret (bool)
- maxtext.kernels.attention.splash_attention_kernel.make_splash_mqa_single_device(mask, *, block_sizes=None, is_mqa=True, save_residuals=False, mask_value=-2.381976426469702e+38, attn_logits_soft_cap=None, downcast_smem_data=True, head_shards=1, q_seq_shards=1, residual_checkpoint_name=None, interpret=False)#
Creates a SplashAttentionKernel.
- Parameters:
mask (np.ndarray | jax.Array | mask_lib.MultiHeadMask)
block_sizes (BlockSizes | None)
is_mqa (bool)
save_residuals (bool)
mask_value (float)
attn_logits_soft_cap (float | None)
downcast_smem_data (bool)
head_shards (int)
q_seq_shards (int)
residual_checkpoint_name (str | None)
interpret (bool)