maxtext.layers.initializers module

maxtext.layers.initializers module#

Initializers.

maxtext.layers.initializers.nd_dense_init(scale, mode, distribution)[source]#

Creates a variance-scaling initializer with dynamic in/out axes.

This function is a factory that returns an initializer function. The returned function is a wrapper around jax.nn.initializers.variance_scaling that allows the in_axis and out_axis to be specified at call time, rather than at creation time.

Parameters:
  • scale – The scaling factor for the variance.

  • mode – The mode for variance scaling (‘fan_in’, ‘fan_out’, ‘fan_avg’).

  • distribution – The distribution to sample from (‘normal’, ‘uniform’, etc.).

Returns:

A function that takes a PRNG key, shape, dtype, in_axis, and out_axis, and returns an initialized array.

maxtext.layers.initializers.variable_to_logically_partitioned(variable)[source]#

Wraps an NNX variable’s value in nn.LogicallyPartitioned.

This function inspects the metadata of an nnx.Variable object. If sharding information (‘out_sharding’, ‘sharding’ or ‘sharding_names’) is present, it wraps the variable’s value in nn.LogicallyPartitioned to apply the specified sharding constraints.

It handles special cases for aqt_tensor.QTensor and variables of type _overwrite_with_gradient by returning their values directly without wrapping.

Parameters:

variable (Variable) – The nnx.Variable object to process.

Returns:

The variable’s value, potentially wrapped in nn.LogicallyPartitioned.