maxtext.models.qwen3 module

Contents

maxtext.models.qwen3 module#

Qwen3 family of model decoder layers.

maxtext.models.qwen3.naive_jax_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, use_qk_norm_in_gdn=False)[source]#

Naive implementation of the Gated Delta Rule in jax.

maxtext.models.qwen3.jax_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, use_qk_norm_in_gdn=False, compute_dtype=<class 'jax.numpy.bfloat16'>)[source]#

Optimized JAX implementation of Gated Delta Rule.

Parameters:
  • query (Array)

  • key (Array)

  • value (Array)

  • g (Array)

  • beta (Array)

  • chunk_size (int)

  • initial_state (None | Array)

  • use_qk_norm_in_gdn (bool)

  • compute_dtype (dtype)

Return type:

tuple[Array, None | Array]

class maxtext.models.qwen3.Qwen3NextGatedDeltaNet(*args, **kwargs)[source]#

Bases: Module

This module implements the full end-to-end logic of a Gated Delta Network layer.

End-to-End Equations Implemented: Let x be the input hidden_states.

Step A: Input Projections 1. (q_raw, k_raw, v_raw, z) = Linear_qkvz(x) 2. (b, a) = Linear_ba(x)

Step B: 1D Convolution 1. qkv_conv = silu(Conv1D(concatenate(q_raw, k_raw, v_raw))) 2. (q, k, v) = split(qkv_conv)

Step C: Gated Delta Rule (Recurrent Core) 1. Gates: β=sigmoid(b), g = -exp(A_log) * softplus(a + dt_bias) 2. Core Calculation: core_attn_out = jax_chunk_gated_delta_rule(q, k, v, g, β)

Step D: Final Output Stage 1. y = RMSNorm(core_attn_out) * silu(z) 2. output = Linear_out(y)

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.qwen3.Qwen3NextFullAttention(*args, **kwargs)[source]#

Bases: Module

Qwen3-Next Full Attention Layer.

This module implements the full self-attention mechanism as used in Qwen3-Next models for layers that do not use the Gated Delta Network. It wraps the main attentions.Attention class, which handles the core attention operation, including the query, key, value, and output projections.

Qwen3 Next Attention differs from standard attention by the following features:
  • Query and Gate splitting from a single q projection.

  • Application of a sigmoid gate to the attention output.

  • Usage of Qwen3NextRMSNorm for query and key normalization.

  • Usage of PartialRotaryEmbedding for partial rotary position embeddings. - Partial ROPE is applied to the first 25% of head dimensions

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

MaxText configuration object.

mesh#

The device mesh for sharding.

model_mode#

The operational mode (e.g., ‘train’, ‘prefill’).

layer_idx#

The index of the current layer.

quant#

Optional quantization configuration.

attention#

An instance of attentions.Attention which contains the learnable parameters for query, key, value, and output projections (e.g., attention.query, attention.key, etc.), and performs the attention calculation.

class maxtext.models.qwen3.Qwen3NextSparseMoeBlock(*args, **kwargs)[source]#

Bases: Module

This module encapsulates the unique MoE structure of Qwen3-Next, which includes: 1. A set of routed experts, where each token is sent to a subset of experts. 2. A single shared expert, which all tokens pass through. 3. A learnable gate that determines the contribution of the shared expert.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

The model configuration object.

mesh#

The device mesh for sharding.

quant#

Optional quantization configuration.

class maxtext.models.qwen3.Qwen3NextScannableBlock(*args, **kwargs)[source]#

Bases: Module

A scannable block of Qwen3-Next decoder layers.

This module contains a fixed number of heterogeneous decoder layers that form a repeating pattern, as defined by config.inhomogeneous_layer_cycle_interval. It is intended to be the body of an nn.scan transformation to construct the full decoder stack efficiently.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

The model configuration object.

mesh#

The device mesh for sharding.

model_mode#

The operational mode (e.g., ‘train’, ‘prefill’).

quant#

Optional quantization configuration.

class maxtext.models.qwen3.Qwen3NextDecoderLayer(*args, **kwargs)[source]#

Bases: Module

This layer is a hybrid, capable of functioning as either: 1. A standard attention + MoE layer. 2. A linear attention + MoE layer.

NOTE: This implementation assumes every layer contains a MoE block, which is true for models like Qwen3-Next-80B-A3B where decoder_sparse_step=1. For models that interleave dense and sparse MLP layers, conditional logic would be needed here.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

The model configuration object.

mesh#

The device mesh for sharding.

model_mode#

The operational mode (e.g., ‘train’, ‘prefill’).

layer_idx#

The index of the current layer in the transformer stack.

quant#

Optional quantization configuration.

class maxtext.models.qwen3.AttentionWithNorm(*args, **kwargs)[source]#

Bases: Module

Base class with shared common components: self-attention block with normalization.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

apply_attention_with_norm(inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, kv_cache=None, attention_metadata=None)[source]#

Applies self-attention with pre and post-layer normalization.

Parameters:
  • inputs (Array)

  • decoder_segment_ids (None | Array)

  • decoder_positions (None | Array)

  • deterministic (bool)

  • model_mode (str)

  • kv_cache (None | Array)

  • attention_metadata (None | dict[str, Any])

class maxtext.models.qwen3.Qwen3DecoderLayer(*args, **kwargs)[source]#

Bases: AttentionWithNorm

Qwen3 Transformer decoder layer (dense).

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.qwen3.Qwen3MoeDecoderLayer(*args, **kwargs)[source]#

Bases: AttentionWithNorm

Qwen3 Transformer decoder layer (MoE).

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.qwen3.Qwen3OmniMoeVisionPatchMerger(*args, **kwargs)[source]#

Bases: Module

Vision patch merger that spatially merges patches using an MLP.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

hidden_size#

Hidden dimension after spatial merging

use_postshuffle_norm#

Whether to apply normalization after spatial shuffle

dtype#

Data type for computation

weight_dtype#

Data type for weights

kernel_init#

Initializer for kernel weights

rngs#

RNG state for initialization

ln_q#

LayerNorm before MLP

mlp_0#

First MLP layer

mlp_2#

Second MLP layer

class maxtext.models.qwen3.Qwen3OmniMoeVisionMLP(*args, **kwargs)[source]#

Bases: Module

Vision MLP block with GELU activation.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

hidden_size#

Hidden dimension size

intermediate_size#

Intermediate dimension size

dtype#

Data type for computation

weight_dtype#

Data type for weights

kernel_init#

Initializer for kernel weights

rngs#

RNG state for initialization

linear_fc1#

First linear layer

linear_fc2#

Second linear layer

class maxtext.models.qwen3.Qwen3OmniMoeVisionPatchEmbed(*args, **kwargs)[source]#

Bases: Module

3D convolution-based patch embedding for vision inputs.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

patch_size#

Spatial patch size

temporal_patch_size#

Temporal patch size

in_channels#

Number of input channels

embed_dim#

Embedding dimension

dtype#

Data type for computation

weight_dtype#

Data type for weights

rngs#

RNG state for initialization

proj#

Convolution projection layer

class maxtext.models.qwen3.Qwen3OmniMoeVisionAttention(*args, **kwargs)[source]#

Bases: Module

Vision attention layer wrapper.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

attn#

Underlying attention module

class maxtext.models.qwen3.Qwen3OmniMoeVisionBlock(*args, **kwargs)[source]#

Bases: Module

Vision transformer block with attention and MLP.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

ln1#

LayerNorm before attention

ln2#

LayerNorm before MLP

attn#

Attention module

mlp#

First MLP layer

mlp_out#

Second MLP layer

class maxtext.models.qwen3.Qwen3OmniMoeVisionEncoder(*args, **kwargs)[source]#

Bases: Module

Vision encoder with patch embedding, positional embedding, and transformer blocks.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

patch_embed#

Patch embedding module

pos_embed_interpolate#

Position embedding interpolation module

blocks#

List of transformer blocks

merger_list#

List of patch mergers for deep supervision

spatial_merge_size#

Size of spatial merging

deep_idx#

Indices of layers to extract deep features from

class maxtext.models.qwen3.Qwen3OmniMoeVisionProjector(*args, **kwargs)[source]#

Bases: Module

Projection layer that converts vision encoder output to model embedding space.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

merger#

Patch merger for spatial reduction

maxtext.models.qwen3.qwen3omni_visionencoder_as_linen(config, mesh)[source]#

Convert Qwen3OmniMoeVisionEncoder to Linen module.

Parameters:
  • config (Any)

  • mesh (Mesh)

Return type:

Module

maxtext.models.qwen3.qwen3omni_visionprojector_as_linen(config, mesh)[source]#

Convert Qwen3OmniMoeVisionProjector to Linen module.

Parameters:
  • config (Any)

  • mesh (Mesh)

Return type:

Module

class maxtext.models.qwen3.Qwen3OmniAudioEncoderLayer(*args, **kwargs)[source]#

Bases: Module

Transformer encoder layer for audio model.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.qwen3.Qwen3OmniAudioEncoder(*args, **kwargs)[source]#

Bases: Module

Full audio encoder with convs, positional embeddings, and transformer layers.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

mesh#

Mesh, JAX device mesh (used for sharding)

class maxtext.models.qwen3.Qwen3OmniAudioProjector(*args, **kwargs)[source]#

Bases: Module

Projection layer that converts audio encoder output to model embedding space.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.models.qwen3.qwen3omni_audioencoder_as_linen(config, mesh)[source]#

Convert AudioEncoder (convs + transformer layers, no projector) to Linen module.

Parameters:
  • config (Any)

  • mesh (Mesh)

maxtext.models.qwen3.qwen3omni_audioprojector_as_linen(config, mesh)[source]#

Convert AudioProjector to Linen module.

Parameters:
  • config (Any)

  • mesh (Mesh)