# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization Layers."""
from typing import Any
from flax import linen as nn
from flax import nnx
from flax.linen import initializers as linen_initializers
import jax
from jax import lax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from maxtext.common.common_types import Array, DType, ShardMode
from maxtext.layers import nnx_wrappers
from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned
from maxtext.utils import max_logging
from maxtext.utils import max_utils
[docs]
class RMSNorm(nnx.Module):
"""RMS normalization."""
def __init__(
self,
num_features: int,
epsilon: float = 1e-6,
dtype: Any = jnp.float32,
weight_dtype: Any = jnp.float32,
shard_mode: ShardMode = ShardMode.AUTO,
kernel_axes: tuple[None | str, ...] = (),
scale_init: Initializer = nn.initializers.ones,
parameter_memory_host_offload: bool = False,
scale_offset: float = 0.0,
with_scale: bool = True,
*,
rngs: nnx.Rngs,
):
self.num_features = num_features
self.epsilon = epsilon
self.dtype = dtype
self.weight_dtype = weight_dtype
self.shard_mode = shard_mode
self.kernel_axes = kernel_axes
self.scale_init = scale_init
self.parameter_memory_host_offload = parameter_memory_host_offload
self.scale_offset = scale_offset
self.with_scale = with_scale
if self.with_scale:
self.scale = nnx.Param(
scale_init(rngs.params(), (num_features,), weight_dtype),
sharding=kernel_axes,
)
else:
self.scale = None
def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> jnp.ndarray:
"""Applies layer normalization on the input."""
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
# out_sharding must be None in auto shard mode
if self.shard_mode != ShardMode.EXPLICIT:
out_sharding = None
if not self.with_scale:
if out_sharding is not None:
y = jax.lax.with_sharding_constraint(y, out_sharding)
return y
scale = self.scale.get_value()
# Move scale to device if parameter offloading is enabled
if self.parameter_memory_host_offload:
max_logging.log("normalizations.py: Moving scale parameter to device")
scale = jax.device_put(scale, max_utils.device_space())
scale = jnp.asarray(scale, self.dtype)
effective_scale = scale + self.scale_offset
return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding)
[docs]
class GlobalRMSNorm(RMSNorm):
"""
Applies RMSNorm over the last two dimensions (Heads * HeadDim).
Used for Olmo3 which normalizes across all heads combined.
"""
def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> jnp.ndarray:
# x shape: [..., Heads, HeadDim]
input_shape = x.shape
# Flatten the last two dimensions: [..., Heads * HeadDim]
# We use -2 and -1 to ensure we capture the last two dims regardless of rank
flattened_shape = input_shape[:-2] + (input_shape[-2] * input_shape[-1],)
x_flat = x.reshape(flattened_shape)
# Apply standard RMSNorm (which normalizes over the last axis)
y_flat = super().__call__(x_flat, out_sharding)
# Reshape back to [..., Heads, HeadDim]
return y_flat.reshape(input_shape)
[docs]
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
"""
Used for input and post attention layernorms
in Qwen3NextDecoderLayer.
This normalization layer is specific to Qwen3-Next. Key characteristics:
1. The learnable scale parameter `scale` is initialized to ZEROS.
2. The scale is applied as `(1.0 + self.scale)`, making the initial scale effectively 1.0.
This matches the PyTorch implementation of Qwen3NextRMSNorm.
"""
return nnx.data(
RMSNorm(
num_features=num_features,
epsilon=eps,
dtype=dtype,
weight_dtype=weight_dtype,
scale_init=linen_initializers.zeros,
scale_offset=1.0,
rngs=rngs,
)
)
[docs]
class Qwen3NextRMSNormGated(nnx.Module):
"""
This applies RMS Normalization and then a gated activation function (SiLU).
This is used within the Qwen3NextGatedDeltaNet.
The normalization is performed by an internal `RMSNorm` instance (`self.rms_norm`),
which has its own learnable `scale` parameter, initialized to ONES.
Attributes:
num_features: The number of features in the input.
eps: A small epsilon value to prevent division by zero in RMSNorm.
dtype: The datatype of the computation.
weight_dtype: The datatype of the internal RMSNorm scale.
"""
def __init__(self, num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
self.num_features = num_features
self.eps = eps
self.dtype = dtype
self.weight_dtype = weight_dtype
self.rms_norm = nnx.data(
RMSNorm(
num_features=num_features,
epsilon=eps,
dtype=dtype,
weight_dtype=weight_dtype,
scale_init=nnx.initializers.ones,
rngs=rngs,
)
)
def __call__(self, hidden_states: Array, gate: Array) -> Array:
"""
Applies RMSNorm and then a SiLU gate.
Args:
hidden_states: The input array to be normalized (o). Shape: (..., F)
gate: The gating array for the activation (z). Shape: (..., F)
where F is num_features.
Returns:
The normalized and gated output array. Shape: (..., F)
"""
normalized_states = self.rms_norm(hidden_states)
# Gated Activation using SiLU (Sigmoid-weighted Linear Unit)
gated_states = normalized_states * jax.nn.silu(gate.astype(jnp.float32))
return gated_states.astype(self.dtype)
[docs]
def rms_norm(
num_features: int,
epsilon: float = 1e-6,
dtype: Any = jnp.float32,
weight_dtype: Any = jnp.float32,
shard_mode: ShardMode = ShardMode.AUTO,
kernel_axes: tuple[None | str, ...] = (),
scale_init: Initializer = nn.initializers.ones,
name: None | str = None,
parameter_memory_host_offload: bool = False,
with_scale: bool = True,
):
"""Creates a RMSNorm module."""
module = nnx_wrappers.to_linen(
RMSNorm,
num_features=num_features,
epsilon=epsilon,
dtype=dtype,
weight_dtype=weight_dtype,
shard_mode=shard_mode,
kernel_axes=kernel_axes,
scale_init=scale_init,
parameter_memory_host_offload=parameter_memory_host_offload,
with_scale=with_scale,
name=name,
metadata_fn=variable_to_logically_partitioned,
)
return module
[docs]
def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array:
"""L2 normalization function. Normalizes a vector to have a length of 1.
Args:
x: Input array.
dim: The axis or axes along which to normalize. Defaults to the last axis.
eps: Small epsilon to prevent division by zero.
Returns:
L2 normalized array with the same shape as x.
"""
inv_norm = jax.lax.rsqrt((x * x).sum(axis=dim, keepdims=True) + jnp.array(eps, dtype=x.dtype))
return x * inv_norm
Qwen3NextRMSNormLinen = nnx_wrappers.to_linen_class(
RMSNorm,
base_metadata_fn=variable_to_logically_partitioned,
scale_init=linen_initializers.zeros,
scale_offset=1.0,
)