maxtext.models.qwen3 module#
Qwen3 family of model decoder layers.
- maxtext.models.qwen3.naive_jax_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, use_qk_norm_in_gdn=False)[source]#
Naive implementation of the Gated Delta Rule in jax.
- maxtext.models.qwen3.jax_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, use_qk_norm_in_gdn=False, compute_dtype=<class 'jax.numpy.bfloat16'>)[source]#
Optimized JAX implementation of Gated Delta Rule.
- Parameters:
query (Array)
key (Array)
value (Array)
g (Array)
beta (Array)
chunk_size (int)
initial_state (None | Array)
use_qk_norm_in_gdn (bool)
compute_dtype (dtype)
- Return type:
tuple[Array, None | Array]
- class maxtext.models.qwen3.Qwen3NextGatedDeltaNet(*args, **kwargs)[source]#
Bases:
ModuleThis module implements the full end-to-end logic of a Gated Delta Network layer.
End-to-End Equations Implemented: Let x be the input hidden_states.
Step A: Input Projections 1. (q_raw, k_raw, v_raw, z) = Linear_qkvz(x) 2. (b, a) = Linear_ba(x)
Step B: 1D Convolution 1. qkv_conv = silu(Conv1D(concatenate(q_raw, k_raw, v_raw))) 2. (q, k, v) = split(qkv_conv)
Step C: Gated Delta Rule (Recurrent Core) 1. Gates: β=sigmoid(b), g = -exp(A_log) * softplus(a + dt_bias) 2. Core Calculation: core_attn_out = jax_chunk_gated_delta_rule(q, k, v, g, β)
Step D: Final Output Stage 1. y = RMSNorm(core_attn_out) * silu(z) 2. output = Linear_out(y)
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.qwen3.Qwen3NextFullAttention(*args, **kwargs)[source]#
Bases:
ModuleQwen3-Next Full Attention Layer.
This module implements the full self-attention mechanism as used in Qwen3-Next models for layers that do not use the Gated Delta Network. It wraps the main attentions.Attention class, which handles the core attention operation, including the query, key, value, and output projections.
- Qwen3 Next Attention differs from standard attention by the following features:
Query and Gate splitting from a single q projection.
Application of a sigmoid gate to the attention output.
Usage of Qwen3NextRMSNorm for query and key normalization.
Usage of PartialRotaryEmbedding for partial rotary position embeddings. - Partial ROPE is applied to the first 25% of head dimensions
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
MaxText configuration object.
- mesh#
The device mesh for sharding.
- model_mode#
The operational mode (e.g., ‘train’, ‘prefill’).
- layer_idx#
The index of the current layer.
- quant#
Optional quantization configuration.
- attention#
An instance of attentions.Attention which contains the learnable parameters for query, key, value, and output projections (e.g., attention.query, attention.key, etc.), and performs the attention calculation.
- class maxtext.models.qwen3.Qwen3NextSparseMoeBlock(*args, **kwargs)[source]#
Bases:
ModuleThis module encapsulates the unique MoE structure of Qwen3-Next, which includes: 1. A set of routed experts, where each token is sent to a subset of experts. 2. A single shared expert, which all tokens pass through. 3. A learnable gate that determines the contribution of the shared expert.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
The model configuration object.
- mesh#
The device mesh for sharding.
- quant#
Optional quantization configuration.
- class maxtext.models.qwen3.Qwen3NextScannableBlock(*args, **kwargs)[source]#
Bases:
ModuleA scannable block of Qwen3-Next decoder layers.
This module contains a fixed number of heterogeneous decoder layers that form a repeating pattern, as defined by config.inhomogeneous_layer_cycle_interval. It is intended to be the body of an nn.scan transformation to construct the full decoder stack efficiently.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
The model configuration object.
- mesh#
The device mesh for sharding.
- model_mode#
The operational mode (e.g., ‘train’, ‘prefill’).
- quant#
Optional quantization configuration.
- class maxtext.models.qwen3.Qwen3NextDecoderLayer(*args, **kwargs)[source]#
Bases:
ModuleThis layer is a hybrid, capable of functioning as either: 1. A standard attention + MoE layer. 2. A linear attention + MoE layer.
NOTE: This implementation assumes every layer contains a MoE block, which is true for models like Qwen3-Next-80B-A3B where decoder_sparse_step=1. For models that interleave dense and sparse MLP layers, conditional logic would be needed here.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
The model configuration object.
- mesh#
The device mesh for sharding.
- model_mode#
The operational mode (e.g., ‘train’, ‘prefill’).
- layer_idx#
The index of the current layer in the transformer stack.
- quant#
Optional quantization configuration.
- class maxtext.models.qwen3.AttentionWithNorm(*args, **kwargs)[source]#
Bases:
ModuleBase class with shared common components: self-attention block with normalization.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- apply_attention_with_norm(inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, kv_cache=None, attention_metadata=None)[source]#
Applies self-attention with pre and post-layer normalization.
- Parameters:
inputs (Array)
decoder_segment_ids (None | Array)
decoder_positions (None | Array)
deterministic (bool)
model_mode (str)
kv_cache (None | Array)
attention_metadata (None | dict[str, Any])
- class maxtext.models.qwen3.Qwen3DecoderLayer(*args, **kwargs)[source]#
Bases:
AttentionWithNormQwen3 Transformer decoder layer (dense).
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.qwen3.Qwen3MoeDecoderLayer(*args, **kwargs)[source]#
Bases:
AttentionWithNormQwen3 Transformer decoder layer (MoE).
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.qwen3.Qwen3OmniMoeVisionPatchMerger(*args, **kwargs)[source]#
Bases:
ModuleVision patch merger that spatially merges patches using an MLP.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
Hidden dimension after spatial merging
- use_postshuffle_norm#
Whether to apply normalization after spatial shuffle
- dtype#
Data type for computation
- weight_dtype#
Data type for weights
- kernel_init#
Initializer for kernel weights
- rngs#
RNG state for initialization
- ln_q#
LayerNorm before MLP
- mlp_0#
First MLP layer
- mlp_2#
Second MLP layer
- class maxtext.models.qwen3.Qwen3OmniMoeVisionMLP(*args, **kwargs)[source]#
Bases:
ModuleVision MLP block with GELU activation.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
Hidden dimension size
- intermediate_size#
Intermediate dimension size
- dtype#
Data type for computation
- weight_dtype#
Data type for weights
- kernel_init#
Initializer for kernel weights
- rngs#
RNG state for initialization
- linear_fc1#
First linear layer
- linear_fc2#
Second linear layer
- class maxtext.models.qwen3.Qwen3OmniMoeVisionPatchEmbed(*args, **kwargs)[source]#
Bases:
Module3D convolution-based patch embedding for vision inputs.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- patch_size#
Spatial patch size
- temporal_patch_size#
Temporal patch size
- in_channels#
Number of input channels
- embed_dim#
Embedding dimension
- dtype#
Data type for computation
- weight_dtype#
Data type for weights
- rngs#
RNG state for initialization
- proj#
Convolution projection layer
- class maxtext.models.qwen3.Qwen3OmniMoeVisionAttention(*args, **kwargs)[source]#
Bases:
ModuleVision attention layer wrapper.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- attn#
Underlying attention module
- class maxtext.models.qwen3.Qwen3OmniMoeVisionBlock(*args, **kwargs)[source]#
Bases:
ModuleVision transformer block with attention and MLP.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- ln1#
LayerNorm before attention
- ln2#
LayerNorm before MLP
- attn#
Attention module
- mlp#
First MLP layer
- mlp_out#
Second MLP layer
- class maxtext.models.qwen3.Qwen3OmniMoeVisionEncoder(*args, **kwargs)[source]#
Bases:
ModuleVision encoder with patch embedding, positional embedding, and transformer blocks.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- patch_embed#
Patch embedding module
- pos_embed_interpolate#
Position embedding interpolation module
- blocks#
List of transformer blocks
- merger_list#
List of patch mergers for deep supervision
- spatial_merge_size#
Size of spatial merging
- deep_idx#
Indices of layers to extract deep features from
- class maxtext.models.qwen3.Qwen3OmniMoeVisionProjector(*args, **kwargs)[source]#
Bases:
ModuleProjection layer that converts vision encoder output to model embedding space.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- merger#
Patch merger for spatial reduction
- maxtext.models.qwen3.qwen3omni_visionencoder_as_linen(config, mesh)[source]#
Convert Qwen3OmniMoeVisionEncoder to Linen module.
- Parameters:
config (Any)
mesh (Mesh)
- Return type:
Module
- maxtext.models.qwen3.qwen3omni_visionprojector_as_linen(config, mesh)[source]#
Convert Qwen3OmniMoeVisionProjector to Linen module.
- Parameters:
config (Any)
mesh (Mesh)
- Return type:
Module
- class maxtext.models.qwen3.Qwen3OmniAudioEncoderLayer(*args, **kwargs)[source]#
Bases:
ModuleTransformer encoder layer for audio model.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.qwen3.Qwen3OmniAudioEncoder(*args, **kwargs)[source]#
Bases:
ModuleFull audio encoder with convs, positional embeddings, and transformer layers.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- mesh#
Mesh, JAX device mesh (used for sharding)
- class maxtext.models.qwen3.Qwen3OmniAudioProjector(*args, **kwargs)[source]#
Bases:
ModuleProjection layer that converts audio encoder output to model embedding space.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any