maxtext.layers.initializers module#
Initializers.
- maxtext.layers.initializers.nd_dense_init(scale, mode, distribution)[source]#
Creates a variance-scaling initializer with dynamic in/out axes.
This function is a factory that returns an initializer function. The returned function is a wrapper around jax.nn.initializers.variance_scaling that allows the in_axis and out_axis to be specified at call time, rather than at creation time.
- Parameters:
scale – The scaling factor for the variance.
mode – The mode for variance scaling (‘fan_in’, ‘fan_out’, ‘fan_avg’).
distribution – The distribution to sample from (‘normal’, ‘uniform’, etc.).
- Returns:
A function that takes a PRNG key, shape, dtype, in_axis, and out_axis, and returns an initialized array.
- maxtext.layers.initializers.variable_to_logically_partitioned(variable)[source]#
Wraps an NNX variable’s value in nn.LogicallyPartitioned.
This function inspects the metadata of an nnx.Variable object. If sharding information (‘out_sharding’, ‘sharding’ or ‘sharding_names’) is present, it wraps the variable’s value in nn.LogicallyPartitioned to apply the specified sharding constraints.
It handles special cases for aqt_tensor.QTensor and variables of type _overwrite_with_gradient by returning their values directly without wrapping.
- Parameters:
variable (Variable) – The nnx.Variable object to process.
- Returns:
The variable’s value, potentially wrapped in nn.LogicallyPartitioned.