maxtext.layers.normalizations module#

Normalization Layers.

class maxtext.layers.normalizations.RMSNorm(*args, **kwargs)[source]#

Bases: Module

RMS normalization.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.layers.normalizations.GlobalRMSNorm(*args, **kwargs)[source]#

Bases: RMSNorm

Applies RMSNorm over the last two dimensions (Heads * HeadDim). Used for Olmo3 which normalizes across all heads combined.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.layers.normalizations.Qwen3NextRMSNorm(num_features, eps, dtype, weight_dtype, *, rngs)[source]#

Used for input and post attention layernorms in Qwen3NextDecoderLayer.

This normalization layer is specific to Qwen3-Next. Key characteristics: 1. The learnable scale parameter scale is initialized to ZEROS. 2. The scale is applied as (1.0 + self.scale), making the initial scale effectively 1.0.

This matches the PyTorch implementation of Qwen3NextRMSNorm.

Parameters:
  • num_features (int)

  • eps (float)

  • dtype (dtype)

  • weight_dtype (dtype)

  • rngs (Rngs)

class maxtext.layers.normalizations.Qwen3NextRMSNormGated(*args, **kwargs)[source]#

Bases: Module

This applies RMS Normalization and then a gated activation function (SiLU). This is used within the Qwen3NextGatedDeltaNet.

The normalization is performed by an internal RMSNorm instance (self.rms_norm), which has its own learnable scale parameter, initialized to ONES.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

num_features#

The number of features in the input.

eps#

A small epsilon value to prevent division by zero in RMSNorm.

dtype#

The datatype of the computation.

weight_dtype#

The datatype of the internal RMSNorm scale.

maxtext.layers.normalizations.rms_norm(num_features, epsilon=1e-06, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, shard_mode=ShardMode.AUTO, kernel_axes=(), scale_init=<function ones>, name=None, parameter_memory_host_offload=False, with_scale=True)[source]#

Creates a RMSNorm module.

Parameters:
  • num_features (int)

  • epsilon (float)

  • dtype (Any)

  • weight_dtype (Any)

  • shard_mode (ShardMode)

  • kernel_axes (tuple[None | str, ...])

  • scale_init (Callable[[Array, Sequence[int], dtype], Array])

  • name (None | str)

  • parameter_memory_host_offload (bool)

  • with_scale (bool)

maxtext.layers.normalizations.l2norm(x, dim=-1, eps=1e-06)[source]#

L2 normalization function. Normalizes a vector to have a length of 1.

Parameters:
  • x (Array) – Input array.

  • dim (int) – The axis or axes along which to normalize. Defaults to the last axis.

  • eps (float) – Small epsilon to prevent division by zero.

Returns:

L2 normalized array with the same shape as x.

Return type:

Array