Source code for maxtext.layers.embeddings

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Embedding Layers."""

import dataclasses
import math

import jax
from jax import lax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding

from flax import nnx

from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType
from maxtext.layers import nnx_wrappers
from maxtext.layers.initializers import Initializer, default_embed_init, variable_to_logically_partitioned
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils.sharding import logical_to_mesh_axes, create_sharding

_MAX_WAVELENGTH = 10_000


def _maybe_move_embedding_to_device(embedding_table: Array, config: Config) -> Array:
  """Moves embedding table to device if parameter offloading is enabled."""
  if config.parameter_memory_host_offload:
    max_logging.log("embeddings.py: Moving embedding parameter to device")
    return jax.device_put(embedding_table, max_utils.device_space())
  return embedding_table


[docs] def embed_as_linen( *, num_embeddings: int, num_features: int, config: Config, mesh: Mesh, cast_input_dtype: None | DType = None, dtype: DType = jnp.float32, attend_dtype: None | DType = None, embedding_init: Initializer = default_embed_init, name: str | None = None, ): """Initializes the Embed NNX module and returns it as a Linen module. This function serves as a bridge to use the NNX-based `Embed` module within a Linen model. It wraps the `Embed` module using `nnx.bridge.to_linen`, making it compatible with the Linen API. Args: num_embeddings: The number of embeddings. num_features: The number of feature dimensions for each embedding. config: The model configuration. cast_input_dtype: The dtype to cast the input to, if any. dtype: The dtype of the embedding vectors. attend_dtype: The dtype for the `attend` method. embedding_init: The initializer for the embedding matrix. name: The name of the Linen module. Returns: A Linen module that wraps the NNX `Embed` module. """ return nnx_wrappers.to_linen( Embed, num_embeddings=num_embeddings, num_features=num_features, config=config, mesh=mesh, cast_input_dtype=cast_input_dtype, dtype=dtype, attend_dtype=attend_dtype, embedding_init=embedding_init, metadata_fn=variable_to_logically_partitioned, name=name, )
[docs] class Embed(nnx.Module): """A parameterized function from integers [0, n) to d-dimensional vectors.""" def __init__( self, num_embeddings: int, num_features: int, config: Config, mesh: Mesh, cast_input_dtype: None | DType = None, dtype: DType = jnp.float32, attend_dtype: None | DType = None, embedding_init: Initializer = default_embed_init, *, # Not used in Embed but passed in by nnx.bridge.to_linen. # TODO: Remove when bridge no longer needed rngs: nnx.Rngs, ): """Initializes the Embed module. Args: num_embeddings: The number of embeddings. num_features: The number of feature dimensions for each embedding. config: The model configuration. cast_input_dtype: The dtype to cast the input to, if any. dtype: The dtype of the embedding vectors. attend_dtype: The dtype for the `attend` method. embedding_init: The initializer for the embedding matrix. rngs: The random number generators for initialization. """ self.num_embeddings = num_embeddings self.num_features = num_features self.config = config self.mesh = mesh self.cast_input_dtype = cast_input_dtype self.dtype = dtype self.attend_dtype = attend_dtype self.embedding = nnx.Param( embedding_init( rngs.params(), (self.num_embeddings, self.num_features), self.config.weight_dtype, ), sharding=("vocab", "embed_vocab"), ) def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: """Embeds the inputs along the last dimension. Args: inputs: input data, all dimensions are considered batch dimensions. Returns: Output which is embedded input data. The output shape follows the input, with an additional `num_features` dimension appended. """ cfg = self.config if self.cast_input_dtype: inputs = inputs.astype(self.cast_input_dtype) if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError("Input type must be an integer or unsigned integer.") embedding = jnp.asarray( _maybe_move_embedding_to_device(self.embedding.get_value(), self.config), self.dtype, ) output_axis_names = ( ( "activation_embed_and_logits_batch", "prefill_activation_length", "activation_embed", ) if model_mode == MODEL_MODE_PREFILL else ( "activation_embed_and_logits_batch", "activation_length", "activation_embed", ) ) out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None)) out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None if cfg.use_iota_embed: iota = lax.iota(jnp.int32, self.num_embeddings) one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) output = jnp.dot(one_hot, embedding, out_sharding=out_sharding) else: output = embedding.at[inputs].get(out_sharding=out_sharding) return output
[docs] def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Array: """Attend over the embedding using a query array. Args: query: array with last dimension equal the feature depth `num_features` of the embedding. out_sharding: NamedSharding object indicating how the output gets sharded Returns: An array with final dim `num_embeddings` corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ embedding = self.embedding.get_value() attend_dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype return attend_on_embedding(query, embedding, attend_dtype, self.config, out_sharding)
[docs] def attend_on_embedding( query: Array, embedding_table: Array, attend_dtype: DType, config: Config, out_sharding: NamedSharding | None = None, ) -> Array: """Attend over an embedding table using a query array. TODO: Remove this method when Embed bridge to Linen is no longer needed Args: query: An array with a last dimension equal to the feature depth of the embedding. embedding_table: The embedding table to attend over. attend_dtype: The data type for the attention computation. config: The model configuration, used to check for parameter offloading. out_sharding: NamedSharding object indicating the output sharding Returns: An array with a final dimension equal to `num_embeddings`, corresponding to the batched inner-product of the query vectors against each embedding. """ # out_sharding must be None under auto shard_mode if config.shard_mode != ShardMode.EXPLICIT: out_sharding = None embedding_table = _maybe_move_embedding_to_device(embedding_table, config) return jnp.dot( query, jnp.asarray(embedding_table, jnp.bfloat16).T, preferred_element_type=attend_dtype, out_sharding=out_sharding, )
[docs] def rotary_embedding_as_linen( *, min_timescale: int, max_timescale: int, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, name: str | None = None, ): """Initializes the RotaryEmbedding module and returns it as a Linen module. Args: min_timescale: Start of the geometric index. Determines the periodicity of the added signal. max_timescale: End of the geometric index. Determines the frequency of the added signal. embedding_dims: Dimension of the embedding to be generated. cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. fprop_dtype: The dtype of the output. name: Name of the Linen module. """ return nnx_wrappers.to_linen( RotaryEmbedding, min_timescale=min_timescale, max_timescale=max_timescale, embedding_dims=embedding_dims, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, metadata_fn=variable_to_logically_partitioned, name=name, )
[docs] class RotaryEmbedding(nnx.Module): """Rotary Position Embedding.""" def __init__( self, min_timescale: int, max_timescale: int, mesh: Mesh, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, shard_mode: ShardMode = ShardMode.AUTO, # Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen. # TODO: Remove when bridge no longer needed rope_linear_scaling_factor: float = 1.0, rngs: nnx.Rngs = None, ): """Initializes the RotaryEmbedding module. Args: min_timescale: Start of the geometric index. Determines the periodicity of the added signal. max_timescale: End of the geometric index. Determines the frequency of the added signal. embedding_dims: Dimension of the embedding to be generated. cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. fprop_dtype: The dtype of the output. rngs: rng keys passed in by nnx.bridge.to_linen. """ self.min_timescale = min_timescale self.max_timescale = max_timescale self.mesh = mesh self.embedding_dims = embedding_dims self.cast_as_fprop_dtype = cast_as_fprop_dtype self.fprop_dtype = fprop_dtype self.shard_mode = shard_mode self.rope_linear_scaling_factor = rope_linear_scaling_factor if self.embedding_dims % 2: raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") @property def timescale(self): """Returns the timescale for the rotary embedding.""" half_embedding_dim = self.embedding_dims // 2 fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction if self.rope_linear_scaling_factor != 1.0: timescale = timescale * self.rope_linear_scaling_factor return timescale def _rotate_half(self, x: jax.Array) -> jax.Array: """Rotates half the hidden dims of the input: (x1, x2) -> (-x2, x1).""" x1, x2 = jnp.split(x, 2, axis=-1) return jnp.concatenate((-x2, x1), axis=-1)
[docs] def apply_rotary(self, inputs: jax.Array, cos: jax.Array, sin: jax.Array) -> jax.Array: """Applies the rotary transformation logic.""" return (inputs * cos) + (self._rotate_half(inputs) * sin)
def __call__( self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks inputs: jax.Array, position: None | jax.Array = None, ) -> jax.Array: """Generates a jax.Array of sinusoids with different frequencies. Args: inputs: The input sequence on which to apply the Rotary position embedding. Since rotary position embeddings are applied to query and keys after projection, it is assumed of shape [B, S, N, H]. position: Optional position jax.Array which denotes the position of each token in the sequence. This only needs to be supplied when the sequence is packed. It is of shape [B, S]. Returns: a jax.Array of shape [B, S, N, H] which includes the inputs together with the rotary position embedding incorporated in it. """ assert position is not None if len(inputs.shape) != 4: raise ValueError("Input is assumed to be a rank 4 tensor of shape" "[batch, sequence, heads, dims].") if self.embedding_dims != inputs.shape[3]: raise ValueError( "The embedding dims of the rotary position embedding" "must match the hidden dimension of the inputs." ) position = position[:, :, jnp.newaxis, jnp.newaxis] sinusoid_inp = position / self.timescale sin_half = jnp.sin(sinusoid_inp).astype(inputs.dtype) cos_half = jnp.cos(sinusoid_inp).astype(inputs.dtype) sin = jnp.concatenate([sin_half, sin_half], axis=-1) cos = jnp.concatenate([cos_half, cos_half], axis=-1) x_out = self.apply_rotary(inputs, cos, sin) if self.cast_as_fprop_dtype: x_out = x_out.astype(self.fprop_dtype) return x_out
[docs] def llama_rotary_embedding_as_linen( *, min_timescale: int, max_timescale: int, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, use_scale: bool = True, name: str | None = None, ): """Initializes the LLaMARotaryEmbedding module and returns it as a Linen module. Args: min_timescale: Start of the geometric index. Determines the periodicity of the added signal. max_timescale: End of the geometric index. Determines the frequency of the added signal. embedding_dims: Dimension of the embedding to be generated. cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. fprop_dtype: The dtype of the output. use_scale: Whether to apply LLaMA3.1 scaling factor. name: Name of the Linen module. """ return nnx_wrappers.to_linen( LLaMARotaryEmbedding, min_timescale=min_timescale, max_timescale=max_timescale, embedding_dims=embedding_dims, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, use_scale=use_scale, metadata_fn=variable_to_logically_partitioned, name=name, )
[docs] def partial_rotary_embedding_as_linen( *, min_timescale: int, max_timescale: int, mesh: Mesh, embedding_dims: int = 0, partial_rotary_factor: float = 0.25, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, shard_mode: ShardMode = ShardMode.AUTO, name: str | None = None, ): """Initializes the PartialRotaryEmbedding module and returns it as a Linen module. Args: min_timescale: Start of the geometric index. Determines the periodicity of the added signal. max_timescale: End of the geometric index. Determines the frequency of the added signal. embedding_dims: Dimension of the embedding to be generated. partial_rotary_factor: Ratio of dimensions to apply ROPE to. cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. fprop_dtype: The dtype of the output. name: Name of the Linen module. """ return nnx_wrappers.to_linen( PartialRotaryEmbedding, min_timescale=min_timescale, max_timescale=max_timescale, mesh=mesh, embedding_dims=embedding_dims, partial_rotary_factor=partial_rotary_factor, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, shard_mode=shard_mode, metadata_fn=variable_to_logically_partitioned, name=name, )
[docs] class PartialRotaryEmbedding(RotaryEmbedding): """Rotary Position Embedding applied to a partial fraction of dimensions.""" def __init__( self, min_timescale: int, max_timescale: int, mesh: Mesh, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, partial_rotary_factor: float = 0.25, shard_mode: ShardMode = ShardMode.AUTO, rngs: nnx.Rngs = None, ): """Initializes the PartialRotaryEmbedding module. Args: min_timescale: Start of the geometric index. Determines the periodicity of the added signal. max_timescale: End of the geometric index. Determines the frequency of the added signal. embedding_dims: Dimension of the embedding to be generated. partial_rotary_factor: Ratio of dimensions to apply ROPE to rngs: rng keys passed in by nnx.bridge.to_linen. """ self.head_dim = embedding_dims self.partial_rotary_factor = partial_rotary_factor self.rotary_dim = int(self.head_dim * self.partial_rotary_factor) # Initialize the base class with only the rotary_dim super().__init__( min_timescale=min_timescale, max_timescale=max_timescale, mesh=mesh, embedding_dims=self.rotary_dim, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, shard_mode=shard_mode, rngs=rngs, ) def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.Array: """Applies Partial variant of rotary position embedding. Args: inputs: The input sequence on which to apply the Rotary position embedding. It is assumed of shape [B, S, H, D]. position: Optional position array [B, S]. Only needed when the sequence is packed. Returns: A jax.Array of shape [B, S, H, D - rotary_dim] with rotary position embeddings applied. """ # Split, apply base RoPE to the first fraction, and concatenate inputs_rot, inputs_pass = jnp.split(inputs, [self.rotary_dim], axis=-1) inputs_rot = super().__call__(inputs_rot, position) inputs = jnp.concatenate([inputs_rot, inputs_pass], axis=-1) return inputs
[docs] class Gemma4PartialRotaryEmbedding(RotaryEmbedding): """Gemma 4 Rotary Position Embedding applied to a partial fraction of dimensions. Unlike standard PartialRotaryEmbedding which physically splits and concatenates features (resulting in a [Rotated, Unrotated] layout), Gemma 4 computes frequencies using the full embedding dimension denominator and pads the unrotated timescales with infinity. Because x / inf = 0, applying RoPE mathematically acts as an identity function on those unrotated dimensions. Because the base Rotary class splits the full tensor in half, this creates an interleaved feature layout in memory: [Rotated Half 1, Unrotated Half 1, Rotated Half 2, Unrotated Half 2]. """ def __init__( self, min_timescale: int, max_timescale: int, mesh: Mesh, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, partial_rotary_factor: float = 0.25, shard_mode: ShardMode = ShardMode.AUTO, rngs: nnx.Rngs = None, ): """Initializes the instance.""" self.head_dim = embedding_dims self.partial_rotary_factor = partial_rotary_factor self.rotary_dim = int(self.head_dim * self.partial_rotary_factor) # Pass the full head_dim to the base class so it splits at head_dim / 2, # ensuring the unrotated dimensions get correctly mixed into the center. super().__init__( min_timescale=min_timescale, max_timescale=max_timescale, mesh=mesh, embedding_dims=self.head_dim, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, shard_mode=shard_mode, rngs=rngs, ) @property def timescale(self) -> jax.Array: """The inf-padded timescale for Gemma 4 rotary embedding.""" half_rotary_dim = self.rotary_dim // 2 # Gemma 4 uniquely uses the full head_dim as the denominator fraction = 2 * jnp.arange(0, half_rotary_dim) / self.head_dim timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction if getattr(self, "rope_linear_scaling_factor", 1.0) != 1.0: timescale = timescale * self.rope_linear_scaling_factor # Pad the remaining angles with jnp.inf. # When position is divided by inf, the angle becomes 0. # sin(0)=0 and cos(0)=1, which acts as a passthrough for unrotated dims. nope_angles = (self.head_dim // 2) - half_rotary_dim return jnp.pad( timescale, pad_width=(0, nope_angles), mode="constant", constant_values=(0.0, jnp.inf), )
# Note: No __call__ override is required. The base RotaryEmbedding.__call__ # handles the rotation perfectly using the padded self.timescale.
[docs] class LLaMARotaryEmbedding(RotaryEmbedding): """LLaMA variant of ROPE.""" def __init__( self, min_timescale: int, max_timescale: int, mesh: Mesh, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, use_scale: bool = True, shard_mode: ShardMode = ShardMode.AUTO, # Not used in LLaMARotaryEmbedding but passed in by nnx.bridge.to_linen. # TODO: Remove when bridge no longer needed rngs: nnx.Rngs = None, ): """Initializes the LLaMARotaryEmbedding module. Args: min_timescale: Start of the geometric index. Determines the periodicity of the added signal. max_timescale: End of the geometric index. Determines the frequency of the added signal. embedding_dims: Dimension of the embedding to be generated. cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. fprop_dtype: The dtype of the output. use_scale: Whether to apply LLaMA3.1 scaling factor. rngs: rng keys passed in by nnx.bridge.to_linen. """ super().__init__( min_timescale=min_timescale, max_timescale=max_timescale, mesh=mesh, embedding_dims=embedding_dims, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, shard_mode=shard_mode, rngs=rngs, ) # LLaMA3.1 ROPE scaling, see the original pytorch implementation: # https://github.com/meta-llama/llama-models/blob/301ca3a2b3b10e94ddcd1fdd2c57e52f812e1cac/models/llama3/reference_impl/model.py#L45C5-L45C18 self.use_scale = use_scale @property def timescale(self): half_embedding_dim = self.embedding_dims // 2 fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims fraction = jnp.repeat(fraction, 2) timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction # Apply scaling factor if enabled if self.use_scale: timescale = 1.0 / jax.vmap(self._apply_scaling_factor)(1.0 / timescale) # Expand timescale dimensions for broadcasting return timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] def _apply_scaling_factor(self, freq): """apply scaling factor to rotary position embedding.""" scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4 old_context_len = 8192 # original llama3 length low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor wavelen = 2 * jnp.pi / freq def lower_wavelen(freq): return freq def bigger_or_equal_wavelen(freq): def bigger_wavelen(freq): return freq / scale_factor def equal_wavelen(freq): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) return (1 - smooth) * freq / scale_factor + smooth * freq bigger_wavelen_cond = wavelen > low_freq_wavelen return jax.lax.cond(bigger_wavelen_cond, bigger_wavelen, equal_wavelen, freq) lower_wavelen_cond = wavelen < high_freq_wavelen return jax.lax.cond(lower_wavelen_cond, lower_wavelen, bigger_or_equal_wavelen, freq) def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.Array: """Applies LLaMA variant of rotary position embedding. Args: inputs: The input sequence on which to apply the Rotary position embedding. It is assumed of shape [B, S, N, H]. position: Optional position array [B, S]. Only needed when the sequence is packed. Returns: A jax.Array of shape [B, S, N, H] with rotary position embeddings applied. """ # Ensure input is 4D if len(inputs.shape) != 4: raise ValueError("Input is assumed to be a rank 4 tensor of shape [B, S, N, H].") if self.embedding_dims != inputs.shape[3]: raise ValueError( "The embedding dims of the rotary position embedding must match the hidden dimension of the inputs." ) # Shift the inputs left and right as per LLaMA's specific behavior inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1) inputs_shifted_right = jnp.concatenate([inputs[..., -1:], inputs[..., :-1]], axis=-1) inputs_shifted = jax.lax.select( jnp.tile( jnp.mod(jnp.arange(self.embedding_dims, dtype=jnp.int32), 2), inputs.shape[:-1] + (1,), ), inputs_shifted_right, inputs_shifted_left, ) # Determine positions if not provided if position is None: seq_length = inputs.shape[1] position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] # Calculate sinusoidal input position = position[:, :, jnp.newaxis, jnp.newaxis] sinusoid_inp = position / self.timescale sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) # Apply alternating sign sign = jnp.tile(jnp.array([-1, 1]), self.embedding_dims // 2) # Combine original inputs with sinusoidal information outputs = inputs * cos + inputs_shifted * sin * sign if self.cast_as_fprop_dtype: outputs = outputs.astype(self.fprop_dtype) return outputs
[docs] def yarn_rotary_embedding_as_linen( *, embedding_dims: int, mesh: Mesh, max_position_embeddings: int = 4096 * 4, original_max_position_embeddings: int = 4096, beta_fast: float = 32, beta_slow: float = 1, rope_theta: float = 10000.0, rope_factor: float = 40, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, name: str | None = None, interleave: bool = True, truncate: bool = True, attention_scaling: bool = False, shard_mode: ShardMode = ShardMode.AUTO, ): """Initializes the YarnRotaryEmbedding module and returns it as a Linen module. Args: embedding_dims: The dimension of the embeddings. max_position_embeddings: The maximum number of positions. original_max_position_embeddings: The original maximum number of positions. beta_fast: The fast beta parameter for YaRN. beta_slow: The slow beta parameter for YaRN. rope_theta: The base for the rotary frequencies. rope_factor: The scaling factor for RoPE. cast_as_fprop_dtype: Whether to cast the output to `fprop_dtype`. fprop_dtype: The forward pass dtype. name: The name of the module. """ return nnx_wrappers.to_linen( YarnRotaryEmbedding, embedding_dims=embedding_dims, max_position_embeddings=max_position_embeddings, mesh=mesh, original_max_position_embeddings=original_max_position_embeddings, beta_fast=beta_fast, beta_slow=beta_slow, rope_theta=rope_theta, rope_factor=rope_factor, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, metadata_fn=variable_to_logically_partitioned, name=name, interleave=interleave, truncate=truncate, attention_scaling=attention_scaling, shard_mode=shard_mode, )
[docs] class YarnRotaryEmbedding(nnx.Module): """Yarn rotary embedding. Based on https://arxiv.org/abs/2309.00071 This implementation uses DeepSeek-v3 PyTorch as reference https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L294 Implementation Notes: - YaRN vs. Standard RoPE: 1. Frequency Initialization: YaRN modifies how frequencies are computed. 2. Attention Scaling: YaRN typically scales embeddings by `0.1 * ln(rope_factor) + 1.0` when `rope_factor > 1`. This scaling can be applied within this layer (if `attention_scaling=True`) or externally. - RoPE Implementation Details (General): - Arithmetic: Uses complex number arithmetic. Real number arithmetic is not implemented here, though the resulting embeddings would be equivalent. - Input Layout: Supports both interleaved (`interleave=True`, e.g., [real1, img1, real2, img2]) and concatenated (`interleave=False`, e.g., [real1, real2, img1, img2]) formats. - Output Layout: Always returns concatenated format ([real, imag]). Interleaved output is not implemented: While the embedding is different, attention scores are invariant, as long as we apply the same output layout for Q and K. Attributes: embedding_dims: Dimension of the embedding to be generated. max_position_embeddings: The maximum sequence length that will be encountered. original_max_position_embeddings: The sequence length for which the base frequencies were defined. beta_fast: Lower bound parameter for correction. beta_slow: Upper bound parameter for correction. rope_theta: The base theta value for the frequency computation. rope_factor: Factor applied to adjust the frequencies. cast_as_fprop_dtype: Whether to cast the output to `fprop_dtype`. fprop_dtype: The forward pass dtype. rope_interleave: Whether complex representation is interleaved or concatenated. rope_truncate: Whether or not to floor lower bound and ceil upper bound for correction range. rope_attention_scaling: Whether or not to scale the rotary embedding output. rngs: rng keys passed in by nnx.bridge.to_linen. """ def __init__( self, embedding_dims: int, mesh: Mesh, max_position_embeddings: int = 4096 * 4, original_max_position_embeddings: int = 4096, beta_fast: float = 32, beta_slow: float = 1, rope_theta: float = 10000.0, rope_factor: float = 40, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, shard_mode: ShardMode = ShardMode.AUTO, interleave=True, truncate=True, attention_scaling=False, # Not used in YarnRotaryEmbedding but passed in by nnx.bridge.to_linen. # TODO: Remove when bridge no longer needed rngs: nnx.Rngs = None, ): """Initializes the YarnRotaryEmbedding module.""" self.embedding_dims = embedding_dims self.max_position_embeddings = max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings self.beta_fast = beta_fast self.beta_slow = beta_slow self.rope_theta = rope_theta self.rope_factor = rope_factor self.cast_as_fprop_dtype = cast_as_fprop_dtype self.fprop_dtype = fprop_dtype self.interleave = interleave self.truncate = truncate self.mesh = mesh self.shard_mode = shard_mode self.attention_scaling = attention_scaling self.freqs_sharding = ( create_sharding(mesh, ("activation_batch", "activation_length", "q_heads")) if shard_mode == ShardMode.EXPLICIT else None ) if self.embedding_dims % 2: raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") @property def freqs_cis(self): """Frequencies for rotary embedding.""" half_dim = self.embedding_dims // 2 # Compute base frequencies for each (even-indexed) dimension. # (Note: We use jnp.arange with float32 for precision.) freqs = 1.0 / (self.rope_theta ** (2.0 * jnp.arange(0, half_dim, dtype=jnp.float32) / self.embedding_dims)) low, high = self._find_correction_range( self.beta_fast, self.beta_slow, self.embedding_dims, self.rope_theta, self.original_max_position_embeddings, self.truncate, ) smooth = 1 - self._linear_ramp_factor(low, high, half_dim) # The corrected frequency is a weighted mix of the scaled and base values. freqs = freqs / self.rope_factor * (1 - smooth) + freqs * smooth # Precompute frequencies for all positions by taking the outer product. t = jnp.arange(self.max_position_embeddings, dtype=jnp.float32) # shape [max_position_embeddings] # This gives a [max_position_embeddings, half_dim] tensor with rows as time steps. freqs = jnp.outer(t, freqs) # Compute the complex “cis” values: exp(i * theta). return jnp.exp(1j * freqs) # shape [max_position_embeddings, half_dim] def _find_correction_dim(self, num_rotations: float, dim: int, base: float, max_position_embeddings: int) -> float: """Compute the correction dimension for a given number of rotations.""" return dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) def _find_correction_range( self, low_rot: float, high_rot: float, dim: int, base: float, max_position_embeddings: int, truncate: bool, ): """Computes the range of correction dimensions for rotary positional embeddings. Args: low_rot (float): Lower bound for the number of rotations. high_rot (float): Upper bound for the number of rotations. dim (int): Dimensionality of the embedding space. base (float): Base value for the exponential computation. max_position_embeddings (int): Maximum sequence length. truncate (bool): Whether to floor lower bound and ceil upper bound. Returns: tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. """ low = self._find_correction_dim(low_rot, dim, base, max_position_embeddings) high = self._find_correction_dim(high_rot, dim, base, max_position_embeddings) if truncate: low = math.floor(low) high = math.ceil(high) low = max(low, 0) high = min(high, dim - 1) return low, high def _linear_ramp_factor(self, min_val: float, max_val: float, dim: int) -> Array: """Computes a linear ramp over the dimension. Returns a jax.Array of shape (dim,) with values between 0 and 1. """ if min_val == max_val: max_val += 0.001 # Avoid division by zero. linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_val) / (max_val - min_val) return jnp.clip(linear_func, 0, 1) def __call__(self, inputs: Array, position: None | Array = None) -> Array: """Applies the rotary positional embedding using the precomputed complex frequencies. Args: inputs: jax.Array of shape [B, S, N, H]. (H must equal self.embedding_dims.) position: jax.Array of shape [B, S] with integer positions (indexes into precomputed freqs). Returns: jax.Array of shape [B, S, N, H] with the rotary embedding applied. """ if len(inputs.shape) != 4: raise ValueError("Input is assumed to be a rank 4 tensor of shape [batch, sequence, heads, dims].") if self.embedding_dims != inputs.shape[3]: raise ValueError( "The embedding dims of the rotary position embedding must match the hidden dimension of the inputs." ) # Determine positions if not provided if position is None: seq_length = inputs.shape[1] position = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :] else: position = position.astype(jnp.int32) # Lookup the precomputed frequencies using the position indices. # self.freqs_cis has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0. # After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads. freqs = self.freqs_cis.at[position].get(out_sharding=self.freqs_sharding) # shape: [B, S, half_dim] freqs = freqs[:, :, jnp.newaxis, :] # shape: [B, S, 1, half_dim] if self.interleave: # Inputs with interleaved format [real1, img1, real2, img2, ...] at last dimension # Convert the last dimension into a complex representation. # First reshape so that each pair of numbers represents the real and imaginary parts. B, S, N, H = inputs.shape half_dim = H // 2 inputs_reshaped = inputs.reshape(B, S, N, half_dim, 2) first_half, second_half = inputs_reshaped[..., 0], inputs_reshaped[..., 1] else: # Inputs with concatenated format [real1, real2, ..., img1, img2, ...] at last dimension first_half, second_half = jnp.split(inputs, 2, axis=-1) inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim] # Apply the rotary transformation via complex multiplication. rotated_sharding = ( create_sharding(self.mesh, ("activation_batch", "activation_length", None, None)) if self.shard_mode == ShardMode.EXPLICIT else None ) freqs = jnp.broadcast_to(freqs, inputs_complex.shape, out_sharding=rotated_sharding) rotated = jnp.multiply(inputs_complex, freqs) # shape: [B, S, N, half_dim] # Convert the complex result back to a real tensor. # Split the complex number into its real and imaginary parts. # [real1, real2, ..., img1, img2, ...] output = jnp.concatenate([jnp.real(rotated), jnp.imag(rotated)], axis=-1) if self.attention_scaling: attention_scaling = 1.0 if self.rope_factor <= 1 else (0.1 * math.log(self.rope_factor) + 1.0) output = output * attention_scaling if self.cast_as_fprop_dtype: output = output.astype(self.fprop_dtype) return output
[docs] def positional_embedding_as_linen( *, embedding_dims: int, max_wavelength: int = _MAX_WAVELENGTH, cast_as_fprop_dtype: bool = False, fprop_dtype: DType = jnp.bfloat16, ): """Initializes the PositionalEmbedding module and returns it as a Linen module. Args: embedding_dims: The dimension of the embeddings. max_wavelength: The maximum wavelength for the sinusoidal positional embeddings. cast_as_fprop_dtype: Whether to cast output to fprop_dtype. fprop_dtype: The dtype of the output when cast_as_fprop_dtype is True. """ return nnx_wrappers.to_linen( PositionalEmbedding, embedding_dims=embedding_dims, max_wavelength=max_wavelength, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, metadata_fn=variable_to_logically_partitioned, )
[docs] @dataclasses.dataclass(repr=False) class PositionalEmbedding(nnx.Module): """Sinusoidal positional embeddings supporting both uniform and per-batch positions. This module computes sinusoidal positional embeddings and supports two use cases: 1. Uniform positions across batch: All batch elements share the same position sequence. Pass position as 1D array (seq_len,) or None for sequential [0,1,2,...]. Returns (seq_len, embedding_dims), caller broadcasts to batch. Example: pos_emb = layer(seq_len) # Sequential positions pos_emb = layer(seq_len, position_1d) # Custom 1D positions 2. Per-batch positions (packed sequences): Each batch element has different positions. Pass position as 2D array (batch, seq_len). Returns (batch, seq_len, embedding_dims). Example: pos_emb = layer(seq_len, position_2d) As a side effect, the uniform case is more efficient since sin/cos are computed once and broadcasted, rather than per batch element. """ #: The dimension of the embeddings. embedding_dims: int #: The maximum wavelength for the sinusoidal positional embeddings. max_wavelength: int = _MAX_WAVELENGTH #: Whether to cast output to fprop_dtype. cast_as_fprop_dtype: bool = False #: The dtype of the output when cast_as_fprop_dtype is True. fprop_dtype: DType = jnp.bfloat16 #: RNG state passed in by nnx.bridge.to_linen, not used in this module. rngs: nnx.Rngs = None # Not used in PositionalEmbedding but passed in by nnx.bridge.to_linen def _compute_embeddings(self, position: Array) -> Array: """Compute sinusoidal embeddings for given positions. Args: position: Either (seq_len,) for efficient path or (batch, seq_len) for full path. Returns: Embeddings of shape (seq_len, embedding_dims) or (batch, seq_len, embedding_dims). """ num_timescales = self.embedding_dims // 2 log_timescale_increment = jnp.log(float(self.max_wavelength)) / jnp.maximum( jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1 ) inv_timescales = jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) if position.ndim == 1: # use the same position for the whole batch when position is (seq_len,) scaled_time = position[:, jnp.newaxis] * inv_timescales[jnp.newaxis, :] else: # when position is (batch, seq_len) position = position[:, :, jnp.newaxis] inv_timescales = inv_timescales[jnp.newaxis, jnp.newaxis, :] scaled_time = position * inv_timescales signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1) if self.cast_as_fprop_dtype: return signal.astype(self.fprop_dtype) else: return signal.astype(jnp.float32) def __call__( self, seq_len: int, position: Array | None = None, ) -> Array: """Compute positional embeddings. Args: seq_len: Sequence length for computing embeddings. position: Optional position array. If None, uses sequential [0,1,2,...]. Shape can be (seq_len,) or (batch, seq_len) for packed sequences. Returns: Positional embeddings of shape (seq_len, embedding_dims) or (batch, seq_len, embedding_dims) if position has batch dimension. """ if position is None: position = jnp.arange(seq_len, dtype=jnp.float32) return self._compute_embeddings(position)
[docs] def llama_vision_rotary_embedding_as_linen( *, image_size: int, patch_size: int, hidden_size: int, num_attention_heads: int, rope_theta: float = 10000.0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, name: str | None = None, ): """Initializes the LlamaVisionRotaryEmbedding module and returns it as a Linen module. Args: image_size: The size of the input image. patch_size: The size of the image patches. hidden_size: The size of the hidden dimension. num_attention_heads: The number of attention heads. rope_theta: The base theta value for the frequency computation. cast_as_fprop_dtype: Whether to cast the output to the fprop dtype. fprop_dtype: The dtype of the output. name: The name of the Linen module. """ return nnx_wrappers.to_linen( LlamaVisionRotaryEmbedding, image_size=image_size, patch_size=patch_size, hidden_size=hidden_size, num_attention_heads=num_attention_heads, rope_theta=rope_theta, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, metadata_fn=variable_to_logically_partitioned, name=name, )
[docs] @dataclasses.dataclass(repr=False) class LlamaVisionRotaryEmbedding(nnx.Module): """Rotary position embedding for Llama4 vision encoder. Based on Pytorch Reference https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py This implementation follows the Llama4 vision encoder's rotary embedding approach, which uses 2D coordinates (x, y) to generate rotary position embeddings. """ #: size of the input image image_size: int #: size of the image patches patch_size: int #: size of the hidden dimension hidden_size: int #: number of attention heads num_attention_heads: int #: base theta value for the frequency computation rope_theta: float = 10000.0 #: whether to cast the output to the fprop dtype cast_as_fprop_dtype: bool = True #: the dtype of the output fprop_dtype: DType = jnp.bfloat16 # Not used in LlamaVisionRotaryEmbedding but passed in by nnx.bridge.to_linen. # TODO: Remove when bridge no longer needed #: RNG state passed in by nnx.bridge.to_linen, not used in this module rngs: nnx.Rngs = None @property def freqs_cis(self): """Frequencies for rotary embedding.""" idx = self.image_size // self.patch_size img_idx = jnp.arange(idx**2, dtype=jnp.int32).reshape(idx**2, 1) img_idx = jnp.concatenate([img_idx, img_idx[:1]], axis=0) img_idx = img_idx.at[-1, -1].set(-2) # ID_CLS_TOKEN # Get 2D coordinates frequencies_x = img_idx % idx # x coordinates frequencies_y = img_idx // idx # y coordinates # Compute frequency dimensions freq_dim = self.hidden_size // self.num_attention_heads // 2 rope_freq = 1.0 / (self.rope_theta ** (jnp.arange(0, freq_dim, 2)[: (freq_dim // 2)].astype(jnp.float32) / freq_dim)) # Compute frequencies for x and y coordinates freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :] freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :] # Interleave x and y frequencies freqs_x = jnp.repeat(freqs_x, 2, axis=-1) freqs_y = jnp.repeat(freqs_y, 2, axis=-1) # Combine frequencies freqs = jnp.concatenate([freqs_x, freqs_y], axis=-1).astype(jnp.float32) freqs = freqs[..., ::2] # Mask out invalid positions freqs = jnp.where(img_idx.reshape(-1, 1, 1) < 0, 0, freqs) # Convert to complex representation return jnp.exp(1j * freqs) def __call__(self, inputs: Array, position: None | Array = None) -> Array: """Applies rotary embeddings to the input tensor for Llama4 vision encoder. Args: inputs: Input tensor of shape [batch_size_times_tiles, num_patches_incl_cls, num_heads, head_dim] Returns: Tensor with rotary embeddings applied, maintaining the same shape as input. """ if len(inputs.shape) != 4: raise ValueError( """Input is assumed to be a rank 4 tensor of shape [batch_size_times_tiles, num_patches_incl_cls, num_heads, head_dim].""" ) # Reshape inputs to complex representation B, S, N, H = inputs.shape half_dim = H // 2 # Convert the last dimension into a complex representation. # First reshape so that each pair of numbers represents the real and imaginary parts. inputs_reshaped = inputs.reshape(B, S, N, half_dim, 2) inputs_complex = inputs_reshaped[..., 0] + 1j * inputs_reshaped[..., 1] # Reshape freqs_ci for broadcasting freqs_ci = self.freqs_cis[jnp.newaxis, :, :, :] # Apply rotary transformation rotated = inputs_complex * freqs_ci # Convert the complex result back to a real tensor. # Split the complex number into its real and imaginary parts. rotated_real = jnp.stack([jnp.real(rotated), jnp.imag(rotated)], axis=-1) output = rotated_real.reshape(B, S, N, H) if self.cast_as_fprop_dtype: output = output.astype(self.fprop_dtype) return output
[docs] class Qwen3OmniMoeVisionRotaryEmbedding(nnx.Module): """Rotary position embedding for Qwen3OmniMoe vision encoder. Attributes: hidden_size: Hidden dimension size num_attention_heads: Number of attention heads spatial_merge_size: Spatial merge block size (e.g., 2 for 2x2 blocks) rope_theta: Base theta for frequency computation (default 10000.0) cast_as_fprop_dtype: Whether to cast to fprop dtype fprop_dtype: Output dtype rngs: RNG state passed in by nnx.bridge.to_linen, not used in this module """ def __init__( self, hidden_size: int, num_attention_heads: int, spatial_merge_size: int, rope_theta: float = 10000.0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, rngs: nnx.Rngs = None, ): """Initializes the Qwen3OmniMoe vision rotary embedding. Args: hidden_size: Hidden dimension size num_attention_heads: Number of attention heads spatial_merge_size: Spatial merge block size (e.g., 2 for 2x2 blocks) rope_theta: Base theta for frequency computation (default 10000.0) cast_as_fprop_dtype: Whether to cast to fprop dtype fprop_dtype: Output dtype rngs: RNG state passed in by nnx.bridge.to_linen, not used in this module """ self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.spatial_merge_size = spatial_merge_size self.rope_theta = rope_theta self.cast_as_fprop_dtype = cast_as_fprop_dtype self.fprop_dtype = fprop_dtype self.rngs = rngs self.head_dim = self.hidden_size // self.num_attention_heads def _compute_freq_table(self, max_hw: int) -> Array: """Precompute frequency table for positions up to max_hw. Args: max_hw: Maximum height or width dimension Returns: Array of shape [max_hw, head_dim//4] containing frequencies for each position """ inv_freq = 1.0 / (self.rope_theta ** (jnp.arange(0, self.head_dim // 2, 2, dtype=jnp.float32) / (self.head_dim // 2))) # Compute for all positions [0, max_hw) positions = jnp.arange(max_hw, dtype=jnp.float32) freqs = jnp.outer(positions, inv_freq) # [max_hw, head_dim//4] return freqs def _generate_position_ids_single(self, num_frames: int, height: int, width: int) -> Array: """Generate 2D position IDs for a single image or video. Args: num_frames: Number of temporal frames (1 for images, >1 for videos) height: Height in patches width: Width in patches Returns: Array of shape [num_frames * height * width, 2] with (row_id, col_id) """ merge_size = self.spatial_merge_size merged_h = height // merge_size merged_w = width // merge_size # Block indices block_rows = jnp.arange(merged_h) # [merged_h] block_cols = jnp.arange(merged_w) # [merged_w] # Intra-block offsets intra_row = jnp.arange(merge_size) # [merge_size] intra_col = jnp.arange(merge_size) # [merge_size] # Full resolution positions using broadcasting # Shape: [merged_h, 1, merge_size, 1] row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] # Shape: [1, merged_w, 1, merge_size] col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] # Expand to full grid and flatten row_idx = jnp.broadcast_to(row_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1) col_idx = jnp.broadcast_to(col_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1) coords = jnp.stack([row_idx, col_idx], axis=-1) # [h*w, 2] # Repeat for video frames if num_frames > 1: coords = jnp.tile(coords, (num_frames, 1)) return coords
[docs] def compute_cos_sin(self, num_frames: int, height: int, width: int) -> tuple[Array, Array]: """Compute cos and sin embeddings for given static grid dimensions. Args: num_frames: Number of temporal frames height: Height in patches width: Width in patches Returns: Tuple of (cos_emb, sin_emb) each of shape [num_frames * height * width, head_dim] """ max_hw = max(height, width) freq_table = self._compute_freq_table(max_hw) # [max_hw, head_dim//4] coords = self._generate_position_ids_single(num_frames, height, width) # [T*H*W, 2] row_freqs = freq_table[coords[:, 0]] # [T*H*W, head_dim//4] col_freqs = freq_table[coords[:, 1]] # [T*H*W, head_dim//4] # Concatenate row and column frequencies embeddings = jnp.concatenate([row_freqs, col_freqs], axis=-1) # [T*H*W, head_dim//2] # Double the embeddings to match head_dim embeddings = jnp.concatenate([embeddings, embeddings], axis=-1) # [T*H*W, head_dim] cos_emb = jnp.cos(embeddings) sin_emb = jnp.sin(embeddings) if self.cast_as_fprop_dtype: cos_emb = cos_emb.astype(self.fprop_dtype) sin_emb = sin_emb.astype(self.fprop_dtype) return cos_emb, sin_emb
def _rotate_half(self, x: Array) -> Array: """Rotates half the hidden dims of the input. Args: x: Input tensor of any shape with last dimension divisible by 2 Returns: Rotated tensor where (x1, x2) -> (-x2, x1) """ x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return jnp.concatenate([-x2, x1], axis=-1) def __call__(self, inputs: Array, num_frames: int, height: int, width: int) -> Array: """Apply rotary position embeddings directly to inputs (Q or K tensors). Args: inputs: Input tensor of shape [B, T*H*W, N, head_dim] (batch, sequence, heads, head_dim) where T=num_frames, H=height, W=width (all static) num_frames: Number of temporal frames (static) height: Height in patches (static) width: Width in patches (static) Returns: Rotated inputs with same shape [B, T*H*W, N, head_dim] """ cos_emb, sin_emb = self.compute_cos_sin(num_frames, height, width) if len(inputs.shape) == 4: cos_emb = cos_emb[None, :, None, :] # [1, S, 1, H] sin_emb = sin_emb[None, :, None, :] elif len(inputs.shape) == 3: # For [S, N, H] case cos_emb = cos_emb[:, None, :] # [S, 1, H] sin_emb = sin_emb[:, None, :] rotated = inputs * cos_emb + self._rotate_half(inputs) * sin_emb return rotated
[docs] def qwen3omnimoe_vision_pos_embed_interpolate_as_linen( *, num_position_embeddings: int, hidden_size: int, spatial_merge_size: int, dtype: DType = jnp.float32, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, name: str | None = None, ): """Initializes Qwen3OmniMoe bilinear position embedding interpolation as Linen module. This implements fast bilinear interpolation of learned 2D positional embeddings for dynamic input sizes. The embeddings are learned on a fixed grid and interpolated to match the actual image/video dimensions. Args: num_position_embeddings: Number of position embeddings in the fixed grid (e.g., 1024 for 32x32) hidden_size: Hidden dimension size spatial_merge_size: Size of spatial merging blocks dtype: Data type for embeddings cast_as_fprop_dtype: Whether to cast the output to the fprop dtype fprop_dtype: The dtype of the output name: Module name Returns: A Linen module that wraps the NNX Qwen3OmniMoeVisionPosEmbedInterpolate module. """ return nnx_wrappers.to_linen( Qwen3OmniMoeVisionPosEmbedInterpolate, num_position_embeddings=num_position_embeddings, hidden_size=hidden_size, spatial_merge_size=spatial_merge_size, dtype=dtype, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, metadata_fn=variable_to_logically_partitioned, name=name, )
[docs] class Qwen3OmniMoeVisionPosEmbedInterpolate(nnx.Module): """Bilinear interpolation of learned 2D positional embeddings for Qwen3OmniMoe vision. This module maintains a fixed grid of learned positional embeddings and interpolates them to match dynamic input dimensions using bilinear interpolation. This allows the model to handle images/videos of varying sizes while using a fixed embedding table. Attributes: num_position_embeddings: Number of position embeddings in the fixed grid hidden_size: Hidden dimension size spatial_merge_size: Spatial merge block size dtype: Data type for embeddings cast_as_fprop_dtype: Whether to cast to fprop dtype fprop_dtype: Output dtype rngs: RNG state passed in by nnx.bridge.to_linen """ def __init__( self, num_position_embeddings: int, hidden_size: int, spatial_merge_size: int, dtype: DType = jnp.float32, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, rngs: nnx.Rngs = None, ): """Initializes the Qwen3OmniMoe vision position embedding interpolation module. Args: num_position_embeddings: Number of position embeddings in the fixed grid hidden_size: Hidden dimension size spatial_merge_size: Spatial merge block size dtype: Data type for embeddings cast_as_fprop_dtype: Whether to cast to fprop dtype fprop_dtype: Output dtype rngs: RNG state passed in by nnx.bridge.to_linen """ self.num_position_embeddings = num_position_embeddings self.hidden_size = hidden_size self.spatial_merge_size = spatial_merge_size self.dtype = dtype self.cast_as_fprop_dtype = cast_as_fprop_dtype self.fprop_dtype = fprop_dtype self.rngs = rngs # Initialize the learned position embedding table if self.rngs is not None: # Initialize with normal distribution scaled by hidden_size^(-0.5) init_fn = nnx.initializers.normal(stddev=self.hidden_size**-0.5) self.pos_embed = nnx.Param( init_fn( self.rngs.params(), (self.num_position_embeddings, self.hidden_size), self.dtype, ), ) self.num_grid_per_side = int(self.num_position_embeddings**0.5) def _interpolate_single(self, t: int, h: int, w: int) -> tuple[Array, Array]: """Compute bilinear interpolation indices and weights for a single image/video. Args: t: Number of temporal frames h: Target height in patches w: Target width in patches Returns: Tuple of (indices, weights) where: - indices: [4, h*w] indices into pos_embed for 4 corners - weights: [4, h*w] bilinear weights for 4 corners """ N = self.num_grid_per_side # Create interpolation coordinates h_idxs = jnp.linspace(0, N - 1, h) w_idxs = jnp.linspace(0, N - 1, w) # Floor and ceiling indices h_idxs_floor = jnp.floor(h_idxs).astype(jnp.int32) w_idxs_floor = jnp.floor(w_idxs).astype(jnp.int32) h_idxs_ceil = jnp.minimum(h_idxs_floor + 1, N - 1) w_idxs_ceil = jnp.minimum(w_idxs_floor + 1, N - 1) # Fractional parts for interpolation weights dh = h_idxs - h_idxs_floor dw = w_idxs - w_idxs_floor # Compute flat indices for 2D grid base_h = h_idxs_floor * N base_h_ceil = h_idxs_ceil * N # 4 corner indices: (floor_h, floor_w), (floor_h, ceil_w), (ceil_h, floor_w), (ceil_h, ceil_w) indices = jnp.stack( [ (base_h[:, None] + w_idxs_floor[None, :]).reshape(-1), (base_h[:, None] + w_idxs_ceil[None, :]).reshape(-1), (base_h_ceil[:, None] + w_idxs_floor[None, :]).reshape(-1), (base_h_ceil[:, None] + w_idxs_ceil[None, :]).reshape(-1), ], axis=0, ) # [4, h*w] # Bilinear weights weights = jnp.stack( [ ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1), ((1 - dh)[:, None] * dw[None, :]).reshape(-1), (dh[:, None] * (1 - dw)[None, :]).reshape(-1), (dh[:, None] * dw[None, :]).reshape(-1), ], axis=0, ) # [4, h*w] return indices, weights def __call__(self, num_frames: int, height: int, width: int) -> Array: """Interpolate positional embeddings for given static grid dimensions. Args: num_frames: Number of temporal frames (static) height: Height in patches (static) width: Width in patches (static) Returns: Interpolated positional embeddings of shape [num_frames * height * width, hidden_size] """ # Get interpolation indices and weights indices, weights = self._interpolate_single(num_frames, height, width) # [4, h*w], [4, h*w] # Lookup embeddings for all 4 corners corner_embeds = self.pos_embed.value[indices] # [4, h*w, hidden_size] # Apply bilinear weights and sum weighted_embeds = corner_embeds * weights[:, :, None] # [4, h*w, hidden_size] interpolated = jnp.sum(weighted_embeds, axis=0) # [h*w, hidden_size] # Repeat for temporal frames if num_frames > 1: interpolated = jnp.tile(interpolated, (num_frames, 1)) # [t*h*w, hidden_size] # Apply spatial merge permutation # Reshape to [t, h, w, hidden_size] then permute for block-based processing merge_size = self.spatial_merge_size merged_h = height // merge_size merged_w = width // merge_size # Reshape: [t*h*w, hidden_size] -> [t, h, w, hidden_size] interpolated = interpolated.reshape(num_frames, height, width, self.hidden_size) # Permute for spatial merging: [t, merged_h, merge_size, merged_w, merge_size, hidden_size] interpolated = interpolated.reshape(num_frames, merged_h, merge_size, merged_w, merge_size, self.hidden_size) # -> [t, merged_h, merged_w, merge_size, merge_size, hidden_size] interpolated = jnp.transpose(interpolated, (0, 1, 3, 2, 4, 5)) # Flatten back to [t*merged_h*merged_w*merge_size*merge_size, hidden_size] interpolated = interpolated.reshape(-1, self.hidden_size) if self.cast_as_fprop_dtype: interpolated = interpolated.astype(self.fprop_dtype) return interpolated
[docs] class Qwen3OmniMoeThinkerTextRotaryEmbedding(RotaryEmbedding): """Multi-dimensional Rotary Position Embedding (MRoPE) for Qwen3-Omni Thinker. This implements MRoPE which extends standard RoPE to handle 3D position IDs (temporal, height, width) for multimodal sequences containing text and vision tokens. For text-only sequences, it uses standard 2D position IDs. For sequences with vision tokens, it uses 3D position IDs where: - Dimension 0: Temporal position - Dimension 1: Height position (spatial) - Dimension 2: Width position (spatial) The implementation uses an interleaved pattern that reorganizes frequency components from chunked [TTT...HHH...WWW] to interleaved [THTHWHTHW...]. """ def __init__( self, min_timescale: int, max_timescale: int, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, mrope_section: tuple[int, int, int] | None = None, attention_scaling: float = 1.0, rngs: nnx.Rngs = None, ): """Initializes the Qwen3OmniMoeThinkerTextRotaryEmbedding module. Args: min_timescale: Start of the geometric index (typically 1). max_timescale: End of the geometric index (rope_theta, e.g., 1000000). embedding_dims: Dimension of the embedding (head_dim). cast_as_fprop_dtype: Whether to cast output to fprop dtype. fprop_dtype: The dtype of the output. mrope_section: Tuple of (temporal_dim, height_dim, width_dim) for MRoPE. Defaults to [24, 20, 20] if None. attention_scaling: Scaling factor applied to cos/sin embeddings. Defaults to 1.0. rngs: rng keys passed in by nnx.bridge.to_linen. """ super().__init__( min_timescale=min_timescale, max_timescale=max_timescale, mesh=None, embedding_dims=embedding_dims, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, rngs=rngs, ) self.mrope_section = mrope_section if mrope_section is not None else (24, 20, 20) self.attention_scaling = attention_scaling if self.embedding_dims % 2: raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") def _apply_interleaved_mrope(self, freqs: jax.Array) -> jax.Array: """Apply interleaved MRoPE pattern to 3D rotary embeddings. Reorganizes frequency layout from chunked [TTT...HHH...WWW] to interleaved [THTHWHTHW...], preserving frequency continuity. Args: freqs: Shape (3, batch, seq_len, head_dim // 2) Dimension 0: temporal frequencies Dimension 1: height frequencies Dimension 2: width frequencies Returns: freqs_t: Shape (batch, seq_len, head_dim // 2) with interleaved pattern """ # Start with temporal frequencies (dimension 0) freqs_t = freqs[0] # (batch, seq_len, head_dim // 2) # Create interleaved pattern # For each spatial dimension (H, W), place frequencies at positions: # offset=1 for H, offset=2 for W, with stride=3 for dim_idx, offset in enumerate([1, 2], start=1): # H=1, W=2 section_size = self.mrope_section[dim_idx] * 3 # Total positions for this dimension # Select positions with stride 3, starting at offset # Use slice syntax to match PyTorch behavior idx = slice(offset, section_size, 3) # Replace those positions with the corresponding spatial frequencies freqs_t = freqs_t.at[..., idx].set(freqs[dim_idx, ..., idx]) return freqs_t def __call__( self, inputs: jax.Array, position: jax.Array, ) -> jax.Array: """Generates rotary position embeddings for multimodal sequences. Args: inputs: Input tensor of shape [batch, sequence, heads, head_dim]. position: Position IDs with shape: - [batch, sequence] for text-only (2D) - [3, batch, sequence] for multimodal with vision (3D) where dim 0 = temporal, dim 1 = height, dim 2 = width Returns: Tensor of shape [batch, sequence, heads, head_dim] with RoPE applied. """ if len(inputs.shape) != 4: raise ValueError("Input is assumed to be a rank 4 tensor of shape [batch, sequence, heads, head_dim].") if self.embedding_dims != inputs.shape[3]: raise ValueError( "The embedding dims of the rotary position embedding must match the hidden dimension of the inputs." ) # Handle both 2D (text-only) and 3D (multimodal) position IDs if position.ndim == 2: # Text-only: expand (batch, seq) -> (3, batch, seq) with same positions position = jnp.broadcast_to(position[jnp.newaxis, ...], (3,) + position.shape) elif position.ndim != 3 or position.shape[0] != 3: raise ValueError(f"Position IDs must be 2D (batch, seq) or 3D (3, batch, seq), got shape {position.shape}") # Compute frequencies: (3, batch, seq, 1) @ (head_dim // 2, 1) -> (3, batch, seq, head_dim // 2) inv_freq_expanded = (1.0 / self.timescale)[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] # (1, 1, 1, head_dim//2) position_expanded = position[..., jnp.newaxis] # (3, batch, seq, 1) freqs = position_expanded * inv_freq_expanded # (3, batch, seq, head_dim//2) # Apply interleaved MRoPE pattern for 3D positions freqs = self._apply_interleaved_mrope(freqs) # (batch, seq, head_dim//2) # Compute sin and cos # Concatenate to get full head_dim: (batch, seq, head_dim//2) -> (batch, seq, head_dim) emb = jnp.concatenate([freqs, freqs], axis=-1) # Duplicate for both halves cos_emb = jnp.cos(emb) * self.attention_scaling # (batch, seq, head_dim) sin_emb = jnp.sin(emb) * self.attention_scaling # (batch, seq, head_dim) # Expand for heads dimension: (batch, seq, head_dim) -> (batch, seq, 1, head_dim) cos_emb = cos_emb[:, :, jnp.newaxis, :] sin_emb = sin_emb[:, :, jnp.newaxis, :] x_out = self.apply_rotary(inputs, cos_emb, sin_emb) if self.cast_as_fprop_dtype: x_out = x_out.astype(self.fprop_dtype) return x_out
[docs] def qwen3_omni_mrope_embedding_as_linen( *, min_timescale: int, max_timescale: int, embedding_dims: int = 0, cast_as_fprop_dtype: bool = True, fprop_dtype: DType = jnp.bfloat16, mrope_section: tuple[int, int, int] | None = None, name: str | None = None, ): """Initializes Qwen3OmniMoeThinkerTextRotaryEmbedding and returns it as a Linen module. Args: min_timescale: Start of the geometric index. max_timescale: End of the geometric index (rope_theta). embedding_dims: Dimension of the embedding (head_dim). cast_as_fprop_dtype: Whether to cast output to fprop dtype. fprop_dtype: The dtype of the output. mrope_section: Tuple of (temporal_dim, height_dim, width_dim) for MRoPE. name: Name of the Linen module. """ return nnx_wrappers.to_linen( Qwen3OmniMoeThinkerTextRotaryEmbedding, min_timescale=min_timescale, max_timescale=max_timescale, embedding_dims=embedding_dims, cast_as_fprop_dtype=cast_as_fprop_dtype, fprop_dtype=fprop_dtype, mrope_section=mrope_section, metadata_fn=variable_to_logically_partitioned, name=name, )