maxtext.layers.attentions module#
Attentions Layers.
- class maxtext.layers.attentions.L2Norm(*args, **kwargs)[source]#
Bases:
ModuleImplementation of L2Norm in JAX.
- Parameters:
eps – float, epsilon used for numerical stability (default value should be ok for most cases).
args (Any)
kwargs (Any)
- Return type:
Any
- eps: float = 1e-06#
- rngs: Rngs = None#
- maxtext.layers.attentions.l2_norm_as_linen(self, eps=1e-06)[source]#
Initializes the L2Norm module and returns it as a Linen module.
- Parameters:
eps (float) – float, epsilon used for numerical stability (default value should be ok for most cases).
- maxtext.layers.attentions.attention_as_linen(*, config, num_query_heads, num_kv_heads, head_dim, max_target_length, mesh, attention_kernel, inputs_q_shape, inputs_kv_shape, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, max_prefill_predict_length=-1, dropout_rate=0.0, kernel_init=<function nd_dense_init.<locals>.init_fn>, float32_qk_product=False, float32_logits=False, quant=None, kv_quant=None, attention_type=AttentionType.GLOBAL, attn_logits_soft_cap=None, sliding_window_size=None, use_ragged_attention=False, ragged_block_size=256, use_qk_norm=False, query_pre_attn_scalar=None, use_bias_in_projections=False, share_kv_projections=False, temperature_tuning=False, temperature_tuning_scale=0.1, temperature_tuning_floor_scale=8192.0, prefill_query_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_key_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), prefill_value_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_kv_heads', 'activation_kv_head_dim'), query_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), key_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), value_axis_names=('activation_kv_batch', 'activation_length_attn', 'activation_kv_heads', 'activation_kv_head_dim'), input_axis_names=('activation_batch_attn', 'activation_length_attn', 'activation_embed_attn'), out_axis_names=('activation_batch_attn', 'activation_length_attn', 'activation_heads', 'activation_kv'), prefill_input_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_embed_attn'), decode_input_axis_names=('decode_batch', 'decode_length', 'activation_embed_attn'), prefill_out_axis_names=('activation_prefill_kv_batch', 'prefill_activation_length', 'activation_heads', 'activation_kv'), decode_out_axis_names=('decode_batch', 'decode_length', 'activation_heads', 'activation_kv'), prefill_cache_axis_order=(1, 2, 0, 3), ar_cache_axis_order=(1, 2, 0, 3), compute_axis_order=(0, 1, 2, 3), reshape_q=False, is_nope_layer=False, is_vision=False, model_mode='train', use_mrope=False, mrope_section=None, name=None, rope_type=None)[source]#
A factory function to create an Attention as a Linen module.
This function serves as a bridge to use the NNX-based Attention within a Linen model.
- Parameters:
config (Any)
num_query_heads (int)
num_kv_heads (int)
head_dim (int)
max_target_length (int)
mesh (Mesh)
attention_kernel (str)
inputs_q_shape (Tuple)
inputs_kv_shape (Tuple)
dtype (dtype)
weight_dtype (dtype)
max_prefill_predict_length (int)
dropout_rate (float)
kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array])
float32_qk_product (bool)
float32_logits (bool)
quant (AqtQuantization | None)
kv_quant (KVQuant | None)
attention_type (AttentionType)
attn_logits_soft_cap (float | None)
sliding_window_size (int | None)
use_ragged_attention (bool)
ragged_block_size (int)
use_qk_norm (bool)
query_pre_attn_scalar (float | None)
use_bias_in_projections (bool)
share_kv_projections (bool)
temperature_tuning (bool)
temperature_tuning_scale (float)
temperature_tuning_floor_scale (float)
prefill_query_axis_names (tuple[str, ...])
prefill_key_axis_names (tuple[str, ...])
prefill_value_axis_names (tuple[str, ...])
query_axis_names (tuple[str, ...])
key_axis_names (tuple[str, ...])
value_axis_names (tuple[str, ...])
input_axis_names (tuple[str, ...])
out_axis_names (tuple[str, ...])
prefill_input_axis_names (tuple[str, ...])
decode_input_axis_names (tuple[str, ...])
prefill_out_axis_names (tuple[str, ...])
decode_out_axis_names (tuple[str, ...])
prefill_cache_axis_order (tuple[int, ...])
ar_cache_axis_order (tuple[int, ...])
compute_axis_order (tuple[int, ...])
reshape_q (bool)
is_nope_layer (bool)
is_vision (bool)
model_mode (str)
use_mrope (bool)
mrope_section (tuple[int, int, int] | None)
name (str | None)
rope_type (str | None)
- class maxtext.layers.attentions.Attention(*args, **kwargs)[source]#
Bases:
ModuleAttention Module.
This module implements multi-headed attention as described in the original Transformer paper. It projects the inputs into query, key, and value vectors, applies the attention mechanism, and projects the results to an output vector.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
The model configuration.
- num_query_heads#
Number of query attention heads.
- num_kv_heads#
Number of key-value attention heads.
- head_dim#
The dimension of each attention head.
- max_target_length#
Maximum sequence length.
- mesh#
The device mesh.
- attention_kernel#
The attention kernel to use (e.g., ‘dot_product’, ‘flash’).
- inputs_q_shape#
Query inputs shape for initialization, required by NNX.
- inputs_kv_shape#
Key/value inputs shape for initialization, required by NNX.
- dtype#
The data type for computation.
- weight_dtype#
The data type for weights.
- max_prefill_predict_length#
Maximum length for prefill.
- dropout_rate#
The dropout rate.
- kernel_init#
Initializer for the kernel of the dense layers.
- float32_qk_product#
If True, compute query-key product in float32.
- float32_logits#
If True, cast logits to float32 before softmax.
- quant#
Quantization configuration.
- kv_quant#
KV cache quantization configuration.
- attention_type#
The type of attention (e.g., ‘global’, ‘local_sliding’).
- attn_logits_soft_cap#
Soft cap for attention logits.
- ... and other configuration parameters.
- init_query_w(inputs_q_shape)[source]#
Query projection initialization.
- Parameters:
inputs_q_shape (Tuple)
- Return type:
Module
- query_projection(inputs_q, out_sharding=None)[source]#
Query projection.
- Parameters:
inputs_q (Array)
out_sharding (NamedSharding | None)
- Return type:
Array
- init_kv_w(inputs_kv_shape)[source]#
Initializes the key or value projection.
- Parameters:
inputs_kv_shape (Tuple) – Key/value inputs shape for initialization.
- Returns:
A DenseGeneral module that performs the key or value projection.
- Return type:
Module
- kv_projection(inputs_kv, proj_name, out_sharding=None)[source]#
Applies the key or value projection.
- Parameters:
inputs_kv (Array) – The input tensor to project.
proj_name (str) – The name of the projection (“key” or “value”).
out_sharding (NamedSharding | None)
- Returns:
The projected key or value tensor.
- Raises:
ValueError – If proj_name is not one of the supported values (“key”, “value”).
- Return type:
Module
- qkv_projection(inputs, proj_name, out_sharding=None)[source]#
Fused QKV projection
- Parameters:
inputs (Array)
proj_name (str)
out_sharding (NamedSharding | None)
- property out_head_dim: int#
- out_projection(out, out_sharding=None)[source]#
out projection
- Parameters:
out (Array)
out_sharding (NamedSharding | None)
- Return type:
Array
- convert_dense_general_inputs_shape(inputs_shape=None, axis=-1)[source]#
- Parameters:
inputs_shape (tuple[int, ...] | None)
axis (Iterable[int] | int)
- Return type:
Iterable[int] | int
- init_rotary_embedding()[source]#
Initializes the rotary embeddings, handling different model types.
- Returns:
The rotary embedding module that will be used in the model.
- apply_rotary_embedding(inputs, inputs_positions=None, rope_kwargs=None)[source]#
Applies rotary embeddings, handling different model types.
- Parameters:
inputs (Array) – The input tensor to apply rotary embeddings to.
inputs_positions (Array | None) – The positions of the inputs.
rope_kwargs (dict | None) – A dictionary of keyword arguments for the rotary embedding.
- Returns:
The input tensor with rotary embeddings applied.
- init_kv_caches(inputs_kv_shape)[source]#
Initializes KVCache.
- Parameters:
inputs_kv_shape (Tuple) – Key/value inputs shape for initialization.
- Returns:
A KVCache module instance.
- update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)[source]#
Updates the KV caches for prefill and autoregressive modes.
This method uses a kvcache module to update and retrieve the key-value caches based on the current operational mode.
- Parameters:
key – The key tensor for the current attention computation.
value – The value tensor for the current attention computation.
decoder_segment_ids – Segment IDs for the decoder, used for masking.
model_mode – The operational mode (‘train’, ‘prefill’, ‘autoregressive’).
previous_chunk – Information about previously processed chunks, used for chunked prefill.
- Returns:
The prefill key-value cache, or None.
The autoregressive key-value cache, or None.
- Return type:
A list containing two elements
- forward_serve_vllm(query, key, value, rpa_kv_cache=None, rpa_metadata=None)[source]#
Forward function for vLLM serving with RPA attention.
- Parameters:
query (Array)
key (Array)
value (Array)
rpa_kv_cache (list[Array] | None)
rpa_metadata (dict[str, Any] | None)
- Return type:
tuple[Array, list[Array]]