Source code for maxtext.layers.mhc

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""

from typing import Callable
from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Array, Config
from maxtext.common.common_types import HyperConnectionType
from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init
from maxtext.layers.normalizations import RMSNorm


[docs] def get_functions(expansion_rate: int): """Creates functions to broadcast a single feature stream into multiple parallel paths (expand) and aggregate them back (reduce). """ def expand(x: Array): # (batch, length, dim) -> (batch, length, streams, dim) return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2).astype(x.dtype) def reduce(x: Array): # (batch, length, streams, dim) -> (batch, length, dim) return jnp.sum(x, axis=2, dtype=x.dtype) return expand, reduce
[docs] def sinkhorn(t, iters=20): """Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1).""" # Use float32 precision for numerical stability during normalization initial_dtype = t.dtype t = t.astype(jnp.float32) # Column-wise normalization (axis=-2) - positive and sum up to 1 across columns # Equivalent to t = exp(t) / jnp.sum(jnp.exp(t), axis=-2) t = jax.nn.softmax(t, axis=-2) def body_fun(i, val): # L1 Normalization: val / sum(val) with clipping of denominator # Normalize rows (axis -1) val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12) # Normalize columns (axis -2) val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12) return val # Use lax.fori_loop for an efficient, JIT-friendly loop t = jax.lax.fori_loop(0, iters, body_fun, t) return t.astype(initial_dtype)
[docs] class ManifoldConstrainedHyperConnections(nnx.Module): """Implements Manifold-Constrained Hyper-Connections (mHC). Reference: https://arxiv.org/pdf/2512.24880 Args: config: Configuration object containing hyperparameters. dim: The feature dimensionality. mesh: The hardware mesh for sharding. rngs: Random number generation in NNX. """ def __init__( self, config: Config, dim: int, mesh: Mesh, rngs: nnx.Rngs, ): self.config = config self.sinkhorn_iterations = config.sinkhorn_iterations self.k = config.mhc_expansion_rate self.dim = dim self.rngs = rngs self.mesh = mesh self.dtype = self.config.dtype self.weight_dtype = self.config.weight_dtype self.matmul_precision = jax.lax.Precision(self.config.matmul_precision) # Norm layer self.mhc_norm = RMSNorm( num_features=self.k * self.dim, dtype=self.config.dtype, weight_dtype=self.weight_dtype, kernel_axes=("norm",), epsilon=self.config.normalization_layer_epsilon, rngs=self.rngs, ) # Scalars self.res_alpha_scale = nnx.Param( default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), out_sharding=(None,), ) self.pre_alpha_scale = nnx.Param( default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), out_sharding=(None,), ) self.post_alpha_scale = nnx.Param( default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), out_sharding=(None,), ) # Weight matrices scale_init = nd_dense_init(1.0, "fan_in", "normal") in_axis = 0 out_axis = 1 weight_sharding_axis_name = ("activation_embed", None) self.res_alpha = nnx.Param( scale_init( self.rngs.params(), (self.k * self.dim, self.k * self.k), self.weight_dtype, in_axis=in_axis, out_axis=out_axis, ), out_sharding=weight_sharding_axis_name, ) self.pre_alpha = nnx.Param( scale_init( self.rngs.params(), (self.k * self.dim, self.k), self.weight_dtype, in_axis=in_axis, out_axis=out_axis, ), out_sharding=weight_sharding_axis_name, ) self.post_alpha = nnx.Param( scale_init( self.rngs.params(), (self.k * self.dim, self.k), self.weight_dtype, in_axis=in_axis, out_axis=out_axis, ), out_sharding=weight_sharding_axis_name, ) # Biases self.res_beta = nnx.Param( default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype), out_sharding=(None, None), ) self.pre_beta = nnx.Param( default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), out_sharding=(None,), ) self.post_beta = nnx.Param( default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), out_sharding=(None,), )
[docs] def res_mapping(self, x: Array): """Helper function for residual mapping.""" # In MaxText, we match weight precision to activations before Matmul res_alpha = jnp.asarray(self.res_alpha[...], self.dtype) res_beta = jnp.asarray(self.res_beta[...], self.dtype) res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype) # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) b, s, _ = h_res.shape h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :] output = sinkhorn(intermediate, self.sinkhorn_iterations) return output
[docs] def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): """Helper function for both pre and post mappings.""" # In MaxText, we match weight precision to activations before Matmul alpha = jnp.asarray(alpha, self.dtype) beta = jnp.asarray(beta, self.dtype) alpha_scale = jnp.asarray(alpha_scale, self.dtype) # Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k) h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.matmul_precision) intermediate = alpha_scale * h + beta[None, None, :] output = scale * jax.nn.sigmoid(intermediate) return output
def __call__( self, norm_fn: Callable, branch_fn: Callable, x: Array, mhc_type: HyperConnectionType, **kwargs, ) -> Array: """Applying manifold-constrained hyper connection based on callable function. Args: norm_fn: The pre-normalization function to be applied. branch_fn: The function to be wrapped by the hyper-connection. x: Input tensor of shape `(batch..., dim)`. mhc_type: The variant of the connection to apply. **kwargs: Additional context passed to the branch function. Returns: The processed tensor, maintaining the shape of `x`. """ # x shape: [batch, seq, expansion_rate, emb] b, s, k, d = x.shape # 1. Flatten the tensor, and RMS normalization norm_x = self.mhc_norm(jnp.reshape(x, (b, s, k * d))) # 2. Pre mapping pre_mapping = self.mapping( norm_x, self.pre_alpha_scale[...], self.pre_alpha[...], self.pre_beta[...], 1.0, ) layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.matmul_precision) # 3. Pre-norm layer_input = norm_fn(layer_input) # 4. Attention or MLP metadata = {} if mhc_type == HyperConnectionType.ATTENTION: layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs) elif mhc_type == HyperConnectionType.MLP_DENSE: layer_out = branch_fn(inputs=layer_input, **kwargs) elif mhc_type == HyperConnectionType.MLP_MOE: layer_out, load_balance_loss, moe_bias_updates = branch_fn(inputs=layer_input, **kwargs) metadata["load_balance_loss"] = load_balance_loss metadata["moe_bias_updates"] = moe_bias_updates else: raise ValueError(f"Unsupported type: {mhc_type}") # 5. Post mapping post_mapping = self.mapping( norm_x, self.post_alpha_scale[...], self.post_alpha[...], self.post_beta[...], 2.0, ) post_out = jnp.einsum( "bsd,bsk -> bskd", layer_out, post_mapping, precision=self.matmul_precision, ) # 6. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb] res_mapping = self.res_mapping(norm_x) res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.matmul_precision) return res_out + post_out, metadata