maxtext.layers.normalizations module#
Normalization Layers.
- class maxtext.layers.normalizations.RMSNorm(*args, **kwargs)[source]#
Bases:
ModuleRMS normalization.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.layers.normalizations.GlobalRMSNorm(*args, **kwargs)[source]#
Bases:
RMSNormApplies RMSNorm over the last two dimensions (Heads * HeadDim). Used for Olmo3 which normalizes across all heads combined.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.layers.normalizations.Qwen3NextRMSNorm(num_features, eps, dtype, weight_dtype, *, rngs)[source]#
Used for input and post attention layernorms in Qwen3NextDecoderLayer.
This normalization layer is specific to Qwen3-Next. Key characteristics: 1. The learnable scale parameter scale is initialized to ZEROS. 2. The scale is applied as (1.0 + self.scale), making the initial scale effectively 1.0.
This matches the PyTorch implementation of Qwen3NextRMSNorm.
- Parameters:
num_features (int)
eps (float)
dtype (dtype)
weight_dtype (dtype)
rngs (Rngs)
- class maxtext.layers.normalizations.Qwen3NextRMSNormGated(*args, **kwargs)[source]#
Bases:
ModuleThis applies RMS Normalization and then a gated activation function (SiLU). This is used within the Qwen3NextGatedDeltaNet.
The normalization is performed by an internal RMSNorm instance (self.rms_norm), which has its own learnable scale parameter, initialized to ONES.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- num_features#
The number of features in the input.
- eps#
A small epsilon value to prevent division by zero in RMSNorm.
- dtype#
The datatype of the computation.
- weight_dtype#
The datatype of the internal RMSNorm scale.
- maxtext.layers.normalizations.rms_norm(num_features, epsilon=1e-06, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, shard_mode=ShardMode.AUTO, kernel_axes=(), scale_init=<function ones>, name=None, parameter_memory_host_offload=False, with_scale=True)[source]#
Creates a RMSNorm module.
- Parameters:
num_features (int)
epsilon (float)
dtype (Any)
weight_dtype (Any)
shard_mode (ShardMode)
kernel_axes (tuple[None | str, ...])
scale_init (Callable[[Array, Sequence[int], dtype], Array])
name (None | str)
parameter_memory_host_offload (bool)
with_scale (bool)
- maxtext.layers.normalizations.l2norm(x, dim=-1, eps=1e-06)[source]#
L2 normalization function. Normalizes a vector to have a length of 1.
- Parameters:
x (Array) – Input array.
dim (int) – The axis or axes along which to normalize. Defaults to the last axis.
eps (float) – Small epsilon to prevent division by zero.
- Returns:
L2 normalized array with the same shape as x.
- Return type:
Array