maxtext.layers.attentions module#

Attentions Layers.

class maxtext.layers.attentions.L2Norm(*args, **kwargs)[source]#

Bases: Module

Implementation of L2Norm in JAX.

Parameters:
  • eps – float, epsilon used for numerical stability (default value should be ok for most cases).

  • args (Any)

  • kwargs (Any)

Return type:

Any

eps: float = 1e-06#
rngs: Rngs = None#
maxtext.layers.attentions.l2_norm_as_linen(self, eps=1e-06)[source]#

Initializes the L2Norm module and returns it as a Linen module.

Parameters:

eps (float) – float, epsilon used for numerical stability (default value should be ok for most cases).

maxtext.layers.attentions.attention_as_linen(*, config, num_query_heads, num_kv_heads, head_dim, max_target_length, mesh, attention_kernel, inputs_q_shape, inputs_kv_shape, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, max_prefill_predict_length=-1, dropout_rate=0.0, kernel_init=<function nd_dense_init.<locals>.init_fn>, float32_qk_product=False, float32_logits=False, quant=None, kv_quant=None, attention_type=AttentionType.GLOBAL, attn_logits_soft_cap=None, sliding_window_size=None, use_ragged_attention=False, ragged_block_size=256, use_qk_norm=False, query_pre_attn_scalar=None, use_bias_in_projections=False, share_kv_projections=False, temperature_tuning=False, temperature_tuning_scale=0.1, temperature_tuning_floor_scale=8192.0, prefill_query_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_key_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_value_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), query_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), key_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), value_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), input_axis_names=('activation_batch_attn', 'activation_length_attn', 'activation_embed_attn'), out_axis_names=('activation_batch_attn', 'activation_length_attn', 'activation_heads', 'activation_kv'), prefill_input_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_embed_attn'), decode_input_axis_names=('decode_batch', 'decode_length', 'activation_embed_attn'), prefill_out_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_heads', 'activation_kv'), decode_out_axis_names=('decode_batch', 'decode_length', 'activation_heads', 'activation_kv'), prefill_cache_axis_order=(1, 2, 0, 3), ar_cache_axis_order=(1, 2, 0, 3), compute_axis_order=(0, 1, 2, 3), reshape_q=False, is_nope_layer=False, is_vision=False, model_mode='train', use_mrope=False, mrope_section=None, name=None, rope_type=None)[source]#

A factory function to create an Attention as a Linen module.

This function serves as a bridge to use the NNX-based Attention within a Linen model.

Parameters:
  • config (Any)

  • num_query_heads (int)

  • num_kv_heads (int)

  • head_dim (int)

  • max_target_length (int)

  • mesh (Mesh)

  • attention_kernel (str)

  • inputs_q_shape (Tuple)

  • inputs_kv_shape (Tuple)

  • dtype (dtype)

  • weight_dtype (dtype)

  • max_prefill_predict_length (int)

  • dropout_rate (float)

  • kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array])

  • float32_qk_product (bool)

  • float32_logits (bool)

  • quant (AqtQuantization | None)

  • kv_quant (KVQuant | None)

  • attention_type (AttentionType)

  • attn_logits_soft_cap (float | None)

  • sliding_window_size (int | None)

  • use_ragged_attention (bool)

  • ragged_block_size (int)

  • use_qk_norm (bool)

  • query_pre_attn_scalar (float | None)

  • use_bias_in_projections (bool)

  • share_kv_projections (bool)

  • temperature_tuning (bool)

  • temperature_tuning_scale (float)

  • temperature_tuning_floor_scale (float)

  • prefill_query_axis_names (tuple[str, ...])

  • prefill_key_axis_names (tuple[str, ...])

  • prefill_value_axis_names (tuple[str, ...])

  • query_axis_names (tuple[str, ...])

  • key_axis_names (tuple[str, ...])

  • value_axis_names (tuple[str, ...])

  • input_axis_names (tuple[str, ...])

  • out_axis_names (tuple[str, ...])

  • prefill_input_axis_names (tuple[str, ...])

  • decode_input_axis_names (tuple[str, ...])

  • prefill_out_axis_names (tuple[str, ...])

  • decode_out_axis_names (tuple[str, ...])

  • prefill_cache_axis_order (tuple[int, ...])

  • ar_cache_axis_order (tuple[int, ...])

  • compute_axis_order (tuple[int, ...])

  • reshape_q (bool)

  • is_nope_layer (bool)

  • is_vision (bool)

  • model_mode (str)

  • use_mrope (bool)

  • mrope_section (tuple[int, int, int] | None)

  • name (str | None)

  • rope_type (str | None)

class maxtext.layers.attentions.Attention(*args, **kwargs)[source]#

Bases: Module

Attention Module.

This module implements multi-headed attention as described in the original Transformer paper. It projects the inputs into query, key, and value vectors, applies the attention mechanism, and projects the results to an output vector.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

The model configuration.

num_query_heads#

Number of query attention heads.

num_kv_heads#

Number of key-value attention heads.

head_dim#

The dimension of each attention head.

max_target_length#

Maximum sequence length.

mesh#

The device mesh.

attention_kernel#

The attention kernel to use (e.g., ‘dot_product’, ‘flash’).

inputs_q_shape#

Query inputs shape for initialization, required by NNX.

inputs_kv_shape#

Key/value inputs shape for initialization, required by NNX.

dtype#

The data type for computation.

weight_dtype#

The data type for weights.

max_prefill_predict_length#

Maximum length for prefill.

dropout_rate#

The dropout rate.

kernel_init#

Initializer for the kernel of the dense layers.

float32_qk_product#

If True, compute query-key product in float32.

float32_logits#

If True, cast logits to float32 before softmax.

quant#

Quantization configuration.

kv_quant#

KV cache quantization configuration.

attention_type#

The type of attention (e.g., ‘global’, ‘local_sliding’).

attn_logits_soft_cap#

Soft cap for attention logits.

... and other configuration parameters.
init_query_w(inputs_q_shape)[source]#

Query projection initialization.

Parameters:

inputs_q_shape (Tuple)

Return type:

Module

query_projection(inputs_q, out_sharding=None)[source]#

Query projection.

Parameters:
  • inputs_q (Array)

  • out_sharding (NamedSharding | None)

Return type:

Array

init_kv_w(inputs_kv_shape)[source]#

Initializes the key or value projection.

Parameters:

inputs_kv_shape (Tuple) – Key/value inputs shape for initialization.

Returns:

A DenseGeneral module that performs the key or value projection.

Return type:

Module

kv_projection(inputs_kv, proj_name, out_sharding=None)[source]#

Applies the key or value projection.

Parameters:
  • inputs_kv (Array) – The input tensor to project.

  • proj_name (str) – The name of the projection (“key” or “value”).

  • out_sharding (NamedSharding | None)

Returns:

The projected key or value tensor.

Raises:

ValueError – If proj_name is not one of the supported values (“key”, “value”).

Return type:

Module

init_qkv_w(inputs_shape)[source]#
Parameters:

inputs_shape (Tuple)

Return type:

Module

qkv_projection(inputs, proj_name, out_sharding=None)[source]#

Fused QKV projection

Parameters:
  • inputs (Array)

  • proj_name (str)

  • out_sharding (NamedSharding | None)

property out_head_dim: int#
init_out_w(output_dim)[source]#

out projection

Parameters:

output_dim (int)

Return type:

Module

out_projection(out, out_sharding=None)[source]#

out projection

Parameters:
  • out (Array)

  • out_sharding (NamedSharding | None)

Return type:

Array

convert_dense_general_inputs_shape(inputs_shape=None, axis=-1)[source]#
Parameters:
  • inputs_shape (tuple[int, ...] | None)

  • axis (Iterable[int] | int)

Return type:

Iterable[int] | int

init_rotary_embedding()[source]#

Initializes the rotary embeddings, handling different model types.

Returns:

The rotary embedding module that will be used in the model.

apply_rotary_embedding(inputs, inputs_positions=None, rope_kwargs=None)[source]#

Applies rotary embeddings, handling different model types.

Parameters:
  • inputs (Array) – The input tensor to apply rotary embeddings to.

  • inputs_positions (Array | None) – The positions of the inputs.

  • rope_kwargs (dict | None) – A dictionary of keyword arguments for the rotary embedding.

Returns:

The input tensor with rotary embeddings applied.

init_kv_caches(inputs_kv_shape)[source]#

Initializes KVCache.

Parameters:

inputs_kv_shape (Tuple) – Key/value inputs shape for initialization.

Returns:

A KVCache module instance.

update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)[source]#

Updates the KV caches for prefill and autoregressive modes.

This method uses a kvcache module to update and retrieve the key-value caches based on the current operational mode.

Parameters:
  • key – The key tensor for the current attention computation.

  • value – The value tensor for the current attention computation.

  • decoder_segment_ids – Segment IDs for the decoder, used for masking.

  • model_mode – The operational mode (‘train’, ‘prefill’, ‘autoregressive’).

  • previous_chunk – Information about previously processed chunks, used for chunked prefill.

Returns:

  • The prefill key-value cache, or None.

  • The autoregressive key-value cache, or None.

Return type:

A list containing two elements

forward_serve_vllm(query, key, value, rpa_kv_cache=None, rpa_metadata=None)[source]#

Forward function for vLLM serving with RPA attention.

Parameters:
  • query (Array)

  • key (Array)

  • value (Array)

  • rpa_kv_cache (list[Array] | None)

  • rpa_metadata (dict[str, Any] | None)

Return type:

tuple[Array, list[Array]]