maxtext.layers.attention_op module#

Attentions Ops Layers.

maxtext.layers.attention_op.validate_compute_axis_order(s)[source]#
Parameters:

s (tuple[int, ...])

Return type:

None

maxtext.layers.attention_op.apply_mask_to_logits(logits, mask)[source]#

Applies a floating-point mask to a set of logits.

The mask is represented as a tensor with some dtype where 0 represents true and values below a large negative number (here set to get_large_negative_number(logits.dtype) / 2) represent false. Applying the mask leaves the logits alone in the true case and replaces them by get_large_negative_number(logits.dtype) in the false case. Previously, this was done by adding the logits to the mask; however, this leads to a bad fusion decision in the compiler that saves the values in memory rather than just the predicate. This implementation avoids that problem.

from google/praxis

Parameters:
  • logits (Array) – A JTensor of logit values.

  • mask (Array) – A JTensor of mask values with the encoding described in the function documentation.

Returns:

Masked logits.

maxtext.layers.attention_op.validate_gpu_flash_attention(sinks, record_max_logits)[source]#

Helper function to check for unsupported features with flash attention on GPU.

Parameters:
  • sinks (Array | None)

  • record_max_logits (bool)

Return type:

None

class maxtext.layers.attention_op.ChunkedCausalMask(shape, chunk_size, shard_count=1)[source]#

Bases: _ComputableMask

Lazy chunked causal mask.

Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), … tokens attend to each other but not across chunks. Llama4 models use interleaved chunk attention along with global attention.

This mask class inherits from splash_attention_mask._ComputableMask and is designed to be used with Splash Attention. It allows the mask logic to be computed on-the-fly or fused into the attention kernel, avoiding the memory cost of materializing the full (sequence_length, sequence_length) boolean mask array, which can be prohibitive for long sequences.

Parameters:
  • shape (tuple[int, int])

  • chunk_size (int)

  • shard_count (int)

chunk_size: int#

The size of each attention chunk.

maxtext.layers.attention_op.attention_op_as_linen(*, config, mesh, attention_kernel, max_target_length, num_query_heads, num_kv_heads, float32_qk_product=False, max_prefill_predict_length=-1, float32_logits=False, flash_axis_names_q=('activation_batch_attn', 'activation_heads', 'activation_length', 'activation_kv'), flash_axis_names_kv=('activation_batch_attn', 'activation_heads', 'activation_kv_length', 'activation_kv'), flash_axis_names_splash_kernel=('activation_heads', 'activation_length'), prefill_cache_logical_axis_names=('cache_batch_prefill', 'cache_sequence', 'cache_heads', 'cache_kv'), cache_logical_axis_names=('cache_batch', 'cache_sequence', 'cache_heads', 'cache_kv'), cache_scale_logical_axis_names=('cache_scale_batch', 'cache_scale_sequence', 'cache_scale_heads', 'cache_scale_kv'), ragged_qkv_axis_names=('cache_batch', 'cache_heads', 'cache_sequence', 'cache_kv'), ragged_lengths_names=('cache_batch', ), compute_axis_order=(0, 1, 2, 3), key_axis_order=(2, 0, 1, 3), reshape_q=False, dropout_rate=0.0, dtype=<class 'jax.numpy.float32'>, quant=None, kv_quant=None, attention_type=AttentionType.GLOBAL, attn_logits_soft_cap=None, sliding_window_size=None, chunk_attn_window_size=None, use_ragged_attention=False, ragged_block_size=256)[source]#

A factory function to create an AttentionOp as a Linen module.

This function serves as a bridge to use the NNX-based AttentionOp within a Linen model.

Parameters:
  • config (Any)

  • mesh (Mesh)

  • attention_kernel (str)

  • max_target_length (int)

  • num_query_heads (int)

  • num_kv_heads (int)

  • float32_qk_product (bool)

  • max_prefill_predict_length (int)

  • float32_logits (bool)

  • flash_axis_names_q (tuple[str, ...])

  • flash_axis_names_kv (tuple[str, ...])

  • flash_axis_names_splash_kernel (tuple[str, ...])

  • prefill_cache_logical_axis_names (tuple[str, ...])

  • cache_logical_axis_names (tuple[str, ...])

  • cache_scale_logical_axis_names (tuple[str, ...])

  • ragged_qkv_axis_names (tuple[str, ...])

  • ragged_lengths_names (tuple[str, ...])

  • compute_axis_order (tuple[int, ...])

  • key_axis_order (tuple[int, ...])

  • reshape_q (bool)

  • dropout_rate (float)

  • dtype (dtype)

  • quant (AqtQuantization | None)

  • kv_quant (KVQuant | None)

  • attention_type (AttentionType)

  • attn_logits_soft_cap (float | None)

  • sliding_window_size (int | None)

  • chunk_attn_window_size (int | None)

  • use_ragged_attention (bool)

  • ragged_block_size (int)

class maxtext.layers.attention_op.AttentionOp(*args, **kwargs)[source]#

Bases: Module

Attention operation

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

check_attention_inputs(query, key, value)[source]#

Check attention inputs.

Parameters:
  • query (Array)

  • key (Array | QTensor)

  • value (Array | QTensor)

Return type:

None

generate_attention_mask(query, key, decoder_segment_ids, model_mode, previous_chunk=None, bidirectional_mask=None)[source]#

Generates a combined attention mask for Transformer models.

This function constructs an attention mask by potentially combining several types of masks based on the input parameters and model configuration. The generated mask dictates which query-key pairs are allowed to attend to each other.

The masking logic can enforce: 1. Sequence Separation: Using decoder_segment_ids, attention is

confined within distinct sequences in a batch. This is crucial when multiple unrelated sequences are packed together.

  1. Causality: Preventing attention to future positions. This is

standard for autoregressive decoding. For chunked prefill, as described in the SARATHI paper [2], causality is adjusted based on previous_chunk information.

  1. Specialized Attention Patterns: Depending on self.attention_type,

it can apply: * Local Sliding Window Attention: Restricts attention to a

fixed-size window around each query position.

  • Chunk Attention: Divides sequences into chunks and applies

    masking at the chunk level.

  1. Bidirectional Attention for Sub-sequences: If bidirectional_mask

is provided (e.g., for image tokens in a multimodal model), those parts of the sequence can attend bidirectionally, and this mask is OR-ed with other generated masks.

The overall approach and specific masking techniques are influenced by efficient attention mechanisms like those found in the Pallas MHA Flash Attention reference [1].

Parameters:
  • query – The query tensor, typically of shape [batch_size, q_sequence_length, num_heads, head_dim]. Used primarily for deriving sequence length.

  • key – The key tensor, typically of shape [batch_size, kv_sequence_length, num_heads, head_dim]. Used primarily for deriving sequence length.

  • decoder_segment_ids (Array | None) – Optional Array of shape [batch_size, q_sequence_length]. Identifies distinct sequences within the batch. Attention is restricted to elements within the same segment ID. In autoregressive mode, specific values (e.g., common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR) can mark the currently active sequence for decoding.

  • model_mode (str) – A string (e.g., common_types.MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL) indicating the operational mode. This significantly influences mask generation, particularly how causality and segment separation are handled.

  • previous_chunk (Any) – Optional. Information about previously processed key/value chunks, often a tensor representing the previous keys/values. Used to correctly offset causal masks in chunked attention or streaming scenarios. Its shape might be [batch_size, prev_kv_sequence_length, …].

  • bidirectional_mask (Any) – Optional Array of shape [batch_size, kv_sequence_length]. If provided, this boolean mask indicates tokens (e.g., image tokens) that are allowed to attend bidirectionally. The resulting block-wise bidirectional mask is combined with other masks using a logical OR.

Returns:

An Array representing the attention mask, with shape

[batch_size, 1, 1, q_sequence_length, kv_sequence_length].

It is broadcastable to the shape

[batch_size, num_kv_heads, group_size=n_q // n_kv, q_sequence_length, kv_sequence_length].

Positions with 0.0 allow attention, while positions with

DEFAULT_MASK_VALUE (a large negative number) prevent it.

Returns None if no masking is determined to be necessary based on

the inputs and configuration.

Return type:

Array | None

References

[1] JAX Pallas MHA Flash Attention:

jax-ml/jax

[2] SARATHI: Efficient LLM Inference by Piggybacking Decodes with

Chunked Prefills - ArXiv:2308.16369 (https://arxiv.org/abs/2308.16369)

calculate_moba_gate_logic(q_item, k_item, q_pos_item)[source]#

Computes the block-level MoBA gating intermediates for one batch item.

Parameters:
  • q_item – Query tensor shaped [q_len, n_q_heads, head_dim].

  • k_item – Key tensor shaped [kv_len, n_kv_heads, head_dim].

  • q_pos_item – Absolute query positions shaped [q_len], used to derive the chunk index for each query. For example, during prefill after 128 tokens have been processed q_pos_item is jnp.arange(128, 128 + q_len), while in autoregressive decode with a single query token it is jnp.array([kv_len - 1]).

Returns:

need_attend, a boolean mask of shape [n_kv_heads, g, q_len, num_block] indicating which key blocks each query should attend to. The additional values in the returned tuple are debug intermediates used for logging and diagnostics when inspecting the gating behaviour.

generate_moba_mask_single_item(q_item, k_item, q_positions)[source]#

Generates the token-level MoBA additive mask for a single batch item.

apply_attention(query, key, value, decoder_segment_ids, segment_positions, lengths, model_mode, use_ragged_attention=False, previous_chunk=None, bidirectional_mask=None, sinks=None, indexer_mask=None, record_max_logits=False, *, qk_product_einsum, wv_product_einsum)[source]#

Apply attention

Parameters:
  • query (Array)

  • key (Array | QTensor)

  • value (Array | QTensor)

  • decoder_segment_ids (Array | None)

  • segment_positions (Array | None)

  • lengths (Array | None)

  • model_mode (str)

  • use_ragged_attention (bool)

  • previous_chunk (Any)

  • bidirectional_mask (Any)

  • sinks (Array | None)

  • indexer_mask (Array | None)

  • record_max_logits (bool)

  • qk_product_einsum (Callable[[...], Array])

  • wv_product_einsum (Callable[[...], Array])

gpu_ragged_attention(q, k, v, lengths, block_size)[source]#

gpu ragged attention

Parameters:
  • q (Array)

  • k (Array | QTensor)

  • v (Array | QTensor)

  • lengths (Array)

  • block_size (int)

tpu_ragged_attention(query, key, value, lengths, block_size)[source]#

Ragged Attention.

Parameters:
  • query (Array)

  • key (Array | QTensor)

  • value (Array | QTensor)

  • lengths (Array)

  • block_size (int)

Return type:

tuple[Array, Array, Array]

tpu_flash_attention(query, key, value, decoder_segment_ids, attn_logits_soft_cap=None, sinks=None, indexer_mask=None, record_max_logits=False)[source]#

TPU Flash Attention.

Parameters:
  • query (Array)

  • key (Array)

  • value (Array)

  • decoder_segment_ids (Array | None)

  • attn_logits_soft_cap (float | None)

  • sinks (Array | None)

  • indexer_mask (Array | None)

  • record_max_logits (bool)

Return type:

tuple[Array, Array]

cudnn_flash_attention(query, key, value, decoder_segment_ids, segment_positions, model_mode='train')[source]#

CUDNN Flash Attention with Transformer Engine. 1. Stable API, supports MHA, GQA, SWA, Packing and Context Parallelism 2. Context Parallelism currently only supports causal masking 3. Only Ring attention has packing support with striped load balancing

(context_parallel_strategy=”ring” and context_parallel_load_balance=true)

  1. Breaks with TE 2.12 and 2.13 (known bug); works with TE stable release <=2.11 or >=2.14.

Parameters:
  • query (Array)

  • key (Array)

  • value (Array)

  • decoder_segment_ids (Array | None)

  • segment_positions (Array | None)

  • model_mode (str)

Return type:

Array

cudnn_jax_flash_attention(query, key, value, decoder_segment_ids, model_mode='train')[source]#

CUDNN Flash Attention with JAX SDPA API.

Parameters:
  • query (Array)

  • key (Array)

  • value (Array)

  • decoder_segment_ids (Array | None)

  • model_mode (str)

Return type:

tuple[Array, Array]

compute_local_attention(attn_weights, value, q_seq_len, model_mode, wv_product_einsum, sinks=None)[source]#

Computes the attention of a local subset of the kv cache.

Local attention results will need to be combined with any other local attentions and normalized Based on google-research/google-research

Parameters:
  • attn_weights (Array) – Product of query and key

  • value (Array) – Current value

  • aqt_rng (PRNGKey | None) – Optional rng

  • q_seq_len (int)

  • model_mode (str)

  • wv_product_einsum (Callable[[...], Array])

  • sinks (Array | None)

Returns:

where

local_out is local unnormalized output local_max is the local max of exponentials local_sum is the sum of exponentials for this chunk, divided by exp(local_max).

Return type:

(local_out, local_max,)

is_partition_in_decode(seq_len)[source]#
apply_attention_dot(query, key, value, decoder_segment_ids, model_mode='train', previous_chunk=None, bidirectional_mask=None, sinks=None, indexer_mask=None, record_max_logits=False, *, qk_product_einsum, wv_product_einsum)[source]#

Apply Attention.

Parameters:
  • query (Array)

  • key (Array | QTensor)

  • value (Array | QTensor)

  • decoder_segment_ids (Array | None)

  • model_mode (str)

  • previous_chunk (Any)

  • bidirectional_mask (Any)

  • sinks (Array | None)

  • indexer_mask (Array | None)

  • record_max_logits (bool)

  • qk_product_einsum (Callable[[...], Array])

  • wv_product_einsum (Callable[[...], Array])

qk_product(query, key, q_seq_len, model_mode, einsum)[source]#

Query-Key product.

Parameters:
  • query (Array) – Query projection, in shape of [b, t, n, d]

  • key (Array | QTensor) – Key projection in shape of [b, s, n_kv, d]

  • q_seq_len (int)

  • model_mode (str)

  • einsum (Callable[[...], Array])

Returns:

results in shape [b, n_kv, n // n_kv, t, s].

Return type:

Array

Annotations:

b: batch size t: query length s: key / value length d: head / kv dimension n: number of query heads n_kv: number of kv heads, sometimes annotated as k n // n_kv: number of group for query, sometimes annotated with g

wv_product(attn_weights, value, model_mode, einsum)[source]#

weighted value product.

Parameters:
  • attn_weights (Array) – Computed results of qk_einsum, in shape [b, n_kv, n // n_kv, t, s]

  • value (Array | QTensor) – Value projection, in shape of [b, s, n_kv, d]

  • model_mode (str)

  • einsum (Callable[[...], Array])

Returns:

result in shape [b, t, n, d]

Return type:

Array

Annotations:

b: batch size t: query length s: key / value length d: head / kv dimension n: number of query heads n_kv: number of kv heads, sometimes annotated as k n // n_kv: number of group for query, sometimes annotated with g

reverse_transepose(transposed_array, transpose_axis_order)[source]#
normalize_cudnn_attention(local_outs, local_stats)[source]#

Normalize across two cuDNN attentions

Parameters:
  • local_outs (list) – List of outputs entries for each cudnn attention in shape [b, t, n, d].

  • local_stats (list) – List of logsumexp entries for each cudnn attention in shape [b, n, t].

Returns:

Combined attention that has been normalized in shape [b, t, n, d].

Return type:

Array

normalize_attention(local_outs, local_maxes, local_sums)[source]#

Normalize across multiple localized attentions

Parameters:
  • local_outs (list) – List of unnormalized outputs entries for each local attention

  • local_maxes (list) – List of max exponentials entries for each local attention

  • local_sums (list) – List of exponential sum entries for each local attention

Returns:

Combined attention that has been normalized

Return type:

Array

class maxtext.layers.attention_op.LoadBalancedCausalMask(shape, offset=0, shard_count=1, cp_size=4)[source]#

Bases: _ComputableMask

Lazy causal mask, prevents the model from attending to future tokens.

Parameters:
  • shape (tuple[int, int])

  • offset (int)

  • shard_count (int)

  • cp_size (int)

offset#

Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first ‘offset’ rows of the attention matrix all 0s which leads to undefined softmax.

Type:

int

cp_size: int#
offset: int#