# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Initializers."""
from typing import Callable
import jax
from flax import linen as nn
from flax import nnx
from aqt.jax.v2 import aqt_tensor
from maxtext.common.common_types import Array, DType, Shape, PRNGKey
Initializer = Callable[[PRNGKey, Shape, DType], Array]
InitializerAxis = int | tuple[int, ...]
NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array]
default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
default_bias_init = jax.nn.initializers.constant(0.0)
default_scalar_init = jax.nn.initializers.constant(0.01)
[docs]
def nd_dense_init(scale, mode, distribution):
"""Creates a variance-scaling initializer with dynamic in/out axes.
This function is a factory that returns an initializer function. The returned
function is a wrapper around `jax.nn.initializers.variance_scaling` that
allows the `in_axis` and `out_axis` to be specified at call time, rather
than at creation time.
Args:
scale: The scaling factor for the variance.
mode: The mode for variance scaling ('fan_in', 'fan_out', 'fan_avg').
distribution: The distribution to sample from ('normal', 'uniform', etc.).
Returns:
A function that takes a PRNG key, shape, dtype, in_axis, and out_axis,
and returns an initialized array.
"""
def init_fn(key, shape, dtype, in_axis, out_axis):
"""Initializes an array using variance scaling with specified axes."""
fn = jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis, out_axis)
return fn(key, shape, dtype)
return init_fn
[docs]
def variable_to_logically_partitioned(variable: nnx.Variable):
"""Wraps an NNX variable's value in `nn.LogicallyPartitioned`.
This function inspects the metadata of an `nnx.Variable` object. If
sharding information ('out_sharding', 'sharding' or 'sharding_names') is
present, it wraps the variable's value in `nn.LogicallyPartitioned` to apply
the specified sharding constraints.
It handles special cases for `aqt_tensor.QTensor` and variables of type
`_overwrite_with_gradient` by returning their values directly without
wrapping.
Args:
variable: The `nnx.Variable` object to process.
Returns:
The variable's value, potentially wrapped in `nn.LogicallyPartitioned`.
"""
val = variable.get_value()
if isinstance(val, aqt_tensor.QTensor):
return val
if variable.type.__name__ == "_overwrite_with_gradient":
return val
metadata = variable.get_metadata()
out_sharding = None
if "out_sharding" in metadata:
out_sharding = metadata["out_sharding"]
elif "sharding_names" in metadata:
out_sharding = metadata["sharding_names"]
elif "sharding" in metadata:
out_sharding = metadata["sharding"]
if out_sharding is not None:
if nnx.PARTITION_NAME in metadata:
partition_name = metadata[nnx.PARTITION_NAME]
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0
sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
if partition_name not in sharding_list:
sharding_list.insert(scan_axis, partition_name)
out_sharding = tuple(sharding_list)
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
val,
out_sharding, # type: ignore[arg-type]
mesh=metadata.get("mesh"),
rules=metadata.get("rules"),
)
else:
return val