maxtext.layers.mhc module#

DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer.

maxtext.layers.mhc.get_functions(expansion_rate)[source]#

Creates functions to broadcast a single feature stream into multiple

parallel paths (expand) and aggregate them back (reduce).

Parameters:

expansion_rate (int)

maxtext.layers.mhc.sinkhorn(t, iters=20)[source]#

Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1).

class maxtext.layers.mhc.ManifoldConstrainedHyperConnections(*args, **kwargs)[source]#

Bases: Module

Implements Manifold-Constrained Hyper-Connections (mHC).

Reference: https://arxiv.org/pdf/2512.24880

Parameters:
  • config – Configuration object containing hyperparameters.

  • dim – The feature dimensionality.

  • mesh – The hardware mesh for sharding.

  • rngs – Random number generation in NNX.

  • args (Any)

  • kwargs (Any)

Return type:

Any

res_mapping(x)[source]#

Helper function for residual mapping.

Parameters:

x (Array)

mapping(x, alpha_scale, alpha, beta, scale)[source]#

Helper function for both pre and post mappings.

Parameters:
  • x (Array)

  • alpha_scale (Array)

  • alpha (Array)

  • beta (Array)

  • scale (int)