maxtext.layers.embeddings module

Contents

maxtext.layers.embeddings module#

Embedding Layers.

maxtext.layers.embeddings.embed_as_linen(*, num_embeddings, num_features, config, mesh, cast_input_dtype=None, dtype=<class 'jax.numpy.float32'>, attend_dtype=None, embedding_init=<function variance_scaling.<locals>.init>, name=None)[source]#

Initializes the Embed NNX module and returns it as a Linen module.

This function serves as a bridge to use the NNX-based Embed module within a Linen model. It wraps the Embed module using nnx.bridge.to_linen, making it compatible with the Linen API.

Parameters:
  • num_embeddings (int) – The number of embeddings.

  • num_features (int) – The number of feature dimensions for each embedding.

  • config (Any) – The model configuration.

  • cast_input_dtype (None | dtype) – The dtype to cast the input to, if any.

  • dtype (dtype) – The dtype of the embedding vectors.

  • attend_dtype (None | dtype) – The dtype for the attend method.

  • embedding_init (Callable[[Array, Sequence[int], dtype], Array]) – The initializer for the embedding matrix.

  • name (str | None) – The name of the Linen module.

  • mesh (Mesh)

Returns:

A Linen module that wraps the NNX Embed module.

class maxtext.layers.embeddings.Embed(*args, **kwargs)[source]#

Bases: Module

A parameterized function from integers [0, n) to d-dimensional vectors.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

attend(query, out_sharding=None)[source]#

Attend over the embedding using a query array.

Parameters:
  • query (Array) – array with last dimension equal the feature depth num_features of the embedding.

  • out_sharding (NamedSharding | None) – NamedSharding object indicating how the output gets sharded

Returns:

An array with final dim num_embeddings corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.

Return type:

Array

maxtext.layers.embeddings.attend_on_embedding(query, embedding_table, attend_dtype, config, out_sharding=None)[source]#

Attend over an embedding table using a query array.

TODO: Remove this method when Embed bridge to Linen is no longer needed

Parameters:
  • query (Array) – An array with a last dimension equal to the feature depth of the embedding.

  • embedding_table (Array) – The embedding table to attend over.

  • attend_dtype (dtype) – The data type for the attention computation.

  • config (Any) – The model configuration, used to check for parameter offloading.

  • out_sharding (NamedSharding | None) – NamedSharding object indicating the output sharding

Returns:

An array with a final dimension equal to num_embeddings, corresponding to the batched inner-product of the query vectors against each embedding.

Return type:

Array

maxtext.layers.embeddings.rotary_embedding_as_linen(*, min_timescale, max_timescale, embedding_dims=0, cast_as_fprop_dtype=True, fprop_dtype=<class 'jax.numpy.bfloat16'>, name=None)[source]#

Initializes the RotaryEmbedding module and returns it as a Linen module.

Parameters:
  • min_timescale (int) – Start of the geometric index. Determines the periodicity of the added signal.

  • max_timescale (int) – End of the geometric index. Determines the frequency of the added signal.

  • embedding_dims (int) – Dimension of the embedding to be generated.

  • cast_as_fprop_dtype (bool) – Whether to cast the output to the fprop dtype.

  • fprop_dtype (dtype) – The dtype of the output.

  • name (str | None) – Name of the Linen module.

class maxtext.layers.embeddings.RotaryEmbedding(*args, **kwargs)[source]#

Bases: Module

Rotary Position Embedding.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

property timescale#

Returns the timescale for the rotary embedding.

apply_rotary(inputs, cos, sin)[source]#

Applies the rotary transformation logic.

Parameters:
  • inputs (Array)

  • cos (Array)

  • sin (Array)

Return type:

Array

maxtext.layers.embeddings.llama_rotary_embedding_as_linen(*, min_timescale, max_timescale, embedding_dims=0, cast_as_fprop_dtype=True, fprop_dtype=<class 'jax.numpy.bfloat16'>, use_scale=True, name=None)[source]#

Initializes the LLaMARotaryEmbedding module and returns it as a Linen module.

Parameters:
  • min_timescale (int) – Start of the geometric index. Determines the periodicity of the added signal.

  • max_timescale (int) – End of the geometric index. Determines the frequency of the added signal.

  • embedding_dims (int) – Dimension of the embedding to be generated.

  • cast_as_fprop_dtype (bool) – Whether to cast the output to the fprop dtype.

  • fprop_dtype (dtype) – The dtype of the output.

  • use_scale (bool) – Whether to apply LLaMA3.1 scaling factor.

  • name (str | None) – Name of the Linen module.

maxtext.layers.embeddings.partial_rotary_embedding_as_linen(*, min_timescale, max_timescale, mesh, embedding_dims=0, partial_rotary_factor=0.25, cast_as_fprop_dtype=True, fprop_dtype=<class 'jax.numpy.bfloat16'>, shard_mode=ShardMode.AUTO, name=None)[source]#

Initializes the PartialRotaryEmbedding module and returns it as a Linen module.

Parameters:
  • min_timescale (int) – Start of the geometric index. Determines the periodicity of the added signal.

  • max_timescale (int) – End of the geometric index. Determines the frequency of the added signal.

  • embedding_dims (int) – Dimension of the embedding to be generated.

  • partial_rotary_factor (float) – Ratio of dimensions to apply ROPE to.

  • cast_as_fprop_dtype (bool) – Whether to cast the output to the fprop dtype.

  • fprop_dtype (dtype) – The dtype of the output.

  • name (str | None) – Name of the Linen module.

  • mesh (Mesh)

  • shard_mode (ShardMode)

class maxtext.layers.embeddings.PartialRotaryEmbedding(*args, **kwargs)[source]#

Bases: RotaryEmbedding

Rotary Position Embedding applied to a partial fraction of dimensions.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.layers.embeddings.Gemma4PartialRotaryEmbedding(*args, **kwargs)[source]#

Bases: RotaryEmbedding

Gemma 4 Rotary Position Embedding applied to a partial fraction of dimensions.

Unlike standard PartialRotaryEmbedding which physically splits and concatenates features (resulting in a [Rotated, Unrotated] layout), Gemma 4 computes frequencies using the full embedding dimension denominator and pads the unrotated timescales with infinity.

Because x / inf = 0, applying RoPE mathematically acts as an identity function on those unrotated dimensions. Because the base Rotary class splits the full tensor in half, this creates an interleaved feature layout in memory: [Rotated Half 1, Unrotated Half 1, Rotated Half 2, Unrotated Half 2].

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

property timescale: Array#

The inf-padded timescale for Gemma 4 rotary embedding.

class maxtext.layers.embeddings.LLaMARotaryEmbedding(*args, **kwargs)[source]#

Bases: RotaryEmbedding

LLaMA variant of ROPE.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

property timescale#

Returns the timescale for the rotary embedding.

maxtext.layers.embeddings.yarn_rotary_embedding_as_linen(*, embedding_dims, mesh, max_position_embeddings=16384, original_max_position_embeddings=4096, beta_fast=32, beta_slow=1, rope_theta=10000.0, rope_factor=40, cast_as_fprop_dtype=True, fprop_dtype=<class 'jax.numpy.bfloat16'>, name=None, interleave=True, truncate=True, attention_scaling=False, shard_mode=ShardMode.AUTO)[source]#

Initializes the YarnRotaryEmbedding module and returns it as a Linen module.

Parameters:
  • embedding_dims (int) – The dimension of the embeddings.

  • max_position_embeddings (int) – The maximum number of positions.

  • original_max_position_embeddings (int) – The original maximum number of positions.

  • beta_fast (float) – The fast beta parameter for YaRN.

  • beta_slow (float) – The slow beta parameter for YaRN.

  • rope_theta (float) – The base for the rotary frequencies.

  • rope_factor (float) – The scaling factor for RoPE.

  • cast_as_fprop_dtype (bool) – Whether to cast the output to fprop_dtype.

  • fprop_dtype (dtype) – The forward pass dtype.

  • name (str | None) – The name of the module.

  • mesh (Mesh)

  • interleave (bool)

  • truncate (bool)

  • attention_scaling (bool)

  • shard_mode (ShardMode)

class maxtext.layers.embeddings.YarnRotaryEmbedding(*args, **kwargs)[source]#

Bases: Module

Yarn rotary embedding.

Based on https://arxiv.org/abs/2309.00071 This implementation uses DeepSeek-v3 PyTorch as reference deepseek-ai/DeepSeek-V3

Implementation Notes: - YaRN vs. Standard RoPE:

  1. Frequency Initialization: YaRN modifies how frequencies are computed.

  2. Attention Scaling: YaRN typically scales embeddings by 0.1 * ln(rope_factor) + 1.0 when rope_factor > 1. This scaling can be applied within this layer (if attention_scaling=True) or externally.

  • RoPE Implementation Details (General): - Arithmetic: Uses complex number arithmetic. Real number arithmetic is not implemented here,

    though the resulting embeddings would be equivalent.

    • Input Layout: Supports both interleaved (interleave=True, e.g., [real1, img1, real2, img2]) and concatenated (interleave=False, e.g., [real1, real2, img1, img2]) formats.

    • Output Layout: Always returns concatenated format ([real, imag]). Interleaved output is not implemented: While the embedding is different, attention scores are invariant, as long as we apply the same output layout for Q and K.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

embedding_dims#

Dimension of the embedding to be generated.

max_position_embeddings#

The maximum sequence length that will be encountered.

original_max_position_embeddings#

The sequence length for which the base frequencies were defined.

beta_fast#

Lower bound parameter for correction.

beta_slow#

Upper bound parameter for correction.

rope_theta#

The base theta value for the frequency computation.

rope_factor#

Factor applied to adjust the frequencies.

cast_as_fprop_dtype#

Whether to cast the output to fprop_dtype.

fprop_dtype#

The forward pass dtype.

rope_interleave#

Whether complex representation is interleaved or concatenated.

rope_truncate#

Whether or not to floor lower bound and ceil upper bound for correction range.

rope_attention_scaling#

Whether or not to scale the rotary embedding output.

rngs#

rng keys passed in by nnx.bridge.to_linen.

property freqs_cis#

Frequencies for rotary embedding.

maxtext.layers.embeddings.positional_embedding_as_linen(*, embedding_dims, max_wavelength=10000, cast_as_fprop_dtype=False, fprop_dtype=<class 'jax.numpy.bfloat16'>)[source]#

Initializes the PositionalEmbedding module and returns it as a Linen module.

Parameters:
  • embedding_dims (int) – The dimension of the embeddings.

  • max_wavelength (int) – The maximum wavelength for the sinusoidal positional embeddings.

  • cast_as_fprop_dtype (bool) – Whether to cast output to fprop_dtype.

  • fprop_dtype (dtype) – The dtype of the output when cast_as_fprop_dtype is True.

class maxtext.layers.embeddings.PositionalEmbedding(*args, **kwargs)[source]#

Bases: Module

Sinusoidal positional embeddings supporting both uniform and per-batch positions.

This module computes sinusoidal positional embeddings and supports two use cases:

  1. Uniform positions across batch: All batch elements share the same position sequence. Pass position as 1D array (seq_len,) or None for sequential [0,1,2,…]. Returns (seq_len, embedding_dims), caller broadcasts to batch. Example: pos_emb = layer(seq_len) # Sequential positions

    pos_emb = layer(seq_len, position_1d) # Custom 1D positions

  2. Per-batch positions (packed sequences): Each batch element has different positions. Pass position as 2D array (batch, seq_len). Returns (batch, seq_len, embedding_dims). Example: pos_emb = layer(seq_len, position_2d)

As a side effect, the uniform case is more efficient since sin/cos are computed once and broadcasted, rather than per batch element.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

embedding_dims: int#

The dimension of the embeddings.

max_wavelength: int = 10000#

The maximum wavelength for the sinusoidal positional embeddings.

cast_as_fprop_dtype: bool = False#

Whether to cast output to fprop_dtype.

fprop_dtype#

The dtype of the output when cast_as_fprop_dtype is True.

alias of bfloat16

rngs: Rngs = None#

RNG state passed in by nnx.bridge.to_linen, not used in this module.

maxtext.layers.embeddings.llama_vision_rotary_embedding_as_linen(*, image_size, patch_size, hidden_size, num_attention_heads, rope_theta=10000.0, cast_as_fprop_dtype=True, fprop_dtype=<class 'jax.numpy.bfloat16'>, name=None)[source]#

Initializes the LlamaVisionRotaryEmbedding module and returns it as a Linen module.

Parameters:
  • image_size (int) – The size of the input image.

  • patch_size (int) – The size of the image patches.

  • hidden_size (int) – The size of the hidden dimension.

  • num_attention_heads (int) – The number of attention heads.

  • rope_theta (float) – The base theta value for the frequency computation.

  • cast_as_fprop_dtype (bool) – Whether to cast the output to the fprop dtype.

  • fprop_dtype (dtype) – The dtype of the output.

  • name (str | None) – The name of the Linen module.

class maxtext.layers.embeddings.LlamaVisionRotaryEmbedding(*args, **kwargs)[source]#

Bases: Module

Rotary position embedding for Llama4 vision encoder.

Based on Pytorch Reference huggingface/transformers This implementation follows the Llama4 vision encoder’s rotary embedding approach, which uses 2D coordinates (x, y) to generate rotary position embeddings.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

image_size: int#

size of the input image

patch_size: int#

size of the image patches

hidden_size: int#

size of the hidden dimension

num_attention_heads: int#

number of attention heads

rope_theta: float = 10000.0#

base theta value for the frequency computation

cast_as_fprop_dtype: bool = True#

whether to cast the output to the fprop dtype

fprop_dtype#

the dtype of the output

alias of bfloat16

rngs: Rngs = None#

RNG state passed in by nnx.bridge.to_linen, not used in this module

property freqs_cis#

Frequencies for rotary embedding.

class maxtext.layers.embeddings.Qwen3OmniMoeVisionRotaryEmbedding(*args, **kwargs)[source]#

Bases: Module

Rotary position embedding for Qwen3OmniMoe vision encoder.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

hidden_size#

Hidden dimension size

num_attention_heads#

Number of attention heads

spatial_merge_size#

Spatial merge block size (e.g., 2 for 2x2 blocks)

rope_theta#

Base theta for frequency computation (default 10000.0)

cast_as_fprop_dtype#

Whether to cast to fprop dtype

fprop_dtype#

Output dtype

rngs#

RNG state passed in by nnx.bridge.to_linen, not used in this module

compute_cos_sin(num_frames, height, width)[source]#

Compute cos and sin embeddings for given static grid dimensions.

Parameters:
  • num_frames (int) – Number of temporal frames

  • height (int) – Height in patches

  • width (int) – Width in patches

Returns:

Tuple of (cos_emb, sin_emb) each of shape [num_frames * height * width, head_dim]

Return type:

tuple[Array, Array]

maxtext.layers.embeddings.qwen3omnimoe_vision_pos_embed_interpolate_as_linen(*, num_position_embeddings, hidden_size, spatial_merge_size, dtype=<class 'jax.numpy.float32'>, cast_as_fprop_dtype=True, fprop_dtype=<class 'jax.numpy.bfloat16'>, name=None)[source]#

Initializes Qwen3OmniMoe bilinear position embedding interpolation as Linen module.

This implements fast bilinear interpolation of learned 2D positional embeddings for dynamic input sizes. The embeddings are learned on a fixed grid and interpolated to match the actual image/video dimensions.

Parameters:
  • num_position_embeddings (int) – Number of position embeddings in the fixed grid (e.g., 1024 for 32x32)

  • hidden_size (int) – Hidden dimension size

  • spatial_merge_size (int) – Size of spatial merging blocks

  • dtype (dtype) – Data type for embeddings

  • cast_as_fprop_dtype (bool) – Whether to cast the output to the fprop dtype

  • fprop_dtype (dtype) – The dtype of the output

  • name (str | None) – Module name

Returns:

A Linen module that wraps the NNX Qwen3OmniMoeVisionPosEmbedInterpolate module.

class maxtext.layers.embeddings.Qwen3OmniMoeVisionPosEmbedInterpolate(*args, **kwargs)[source]#

Bases: Module

Bilinear interpolation of learned 2D positional embeddings for Qwen3OmniMoe vision.

This module maintains a fixed grid of learned positional embeddings and interpolates them to match dynamic input dimensions using bilinear interpolation. This allows the model to handle images/videos of varying sizes while using a fixed embedding table.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

num_position_embeddings#

Number of position embeddings in the fixed grid

hidden_size#

Hidden dimension size

spatial_merge_size#

Spatial merge block size

dtype#

Data type for embeddings

cast_as_fprop_dtype#

Whether to cast to fprop dtype

fprop_dtype#

Output dtype

rngs#

RNG state passed in by nnx.bridge.to_linen

class maxtext.layers.embeddings.Qwen3OmniMoeThinkerTextRotaryEmbedding(*args, **kwargs)[source]#

Bases: RotaryEmbedding

Multi-dimensional Rotary Position Embedding (MRoPE) for Qwen3-Omni Thinker.

This implements MRoPE which extends standard RoPE to handle 3D position IDs (temporal, height, width) for multimodal sequences containing text and vision tokens.

For text-only sequences, it uses standard 2D position IDs. For sequences with vision tokens, it uses 3D position IDs where:

  • Dimension 0: Temporal position

  • Dimension 1: Height position (spatial)

  • Dimension 2: Width position (spatial)

The implementation uses an interleaved pattern that reorganizes frequency components from chunked [TTT…HHH…WWW] to interleaved [THTHWHTHW…].

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.layers.embeddings.qwen3_omni_mrope_embedding_as_linen(*, min_timescale, max_timescale, embedding_dims=0, cast_as_fprop_dtype=True, fprop_dtype=<class 'jax.numpy.bfloat16'>, mrope_section=None, name=None)[source]#

Initializes Qwen3OmniMoeThinkerTextRotaryEmbedding and returns it as a Linen module.

Parameters:
  • min_timescale (int) – Start of the geometric index.

  • max_timescale (int) – End of the geometric index (rope_theta).

  • embedding_dims (int) – Dimension of the embedding (head_dim).

  • cast_as_fprop_dtype (bool) – Whether to cast output to fprop dtype.

  • fprop_dtype (dtype) – The dtype of the output.

  • mrope_section (tuple[int, int, int] | None) – Tuple of (temporal_dim, height_dim, width_dim) for MRoPE.

  • name (str | None) – Name of the Linen module.