maxtext.layers.linears module#
Linear Layers.
- maxtext.layers.linears.normalize_axes(axes, ndim)[source]#
- Parameters:
axes (Iterable[int])
ndim (int)
- Return type:
tuple[int, …]
- class maxtext.layers.linears.DenseGeneral(*args, **kwargs)[source]#
Bases:
ModuleA linear transformation with flexible axes.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.layers.linears.dense_general(*, inputs_shape=None, in_features_shape=None, out_features_shape, axis=-1, weight_dtype=<class 'jax.numpy.float32'>, dtype=<class 'jax.numpy.float32'>, kernel_init=<function nd_dense_init.<locals>.init_fn>, kernel_axes=(), quant=None, use_bias=False, shard_mode=ShardMode.AUTO, matmul_precision='default', parameter_memory_host_offload=False, name=None)[source]#
Creates a DenseGeneral Linen module using nnx.bridge.to_linen.
- Parameters:
inputs_shape (tuple[int, ...] | None) – tuple with the shape of the inputs
in_features_shape (tuple[int, ...] | int | None) – tuple with numbers of input features for axes specified in ‘axis’.
out_features_shape (Iterable[int] | int) – tuple with numbers of output features.
axis (Iterable[int] | int) – tuple with axes to apply the transformation on.
weight_dtype (dtype) – the dtype of the weights (default: float32).
dtype (dtype) – the dtype of the computation (default: float32).
kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array]) – initializer function for the weight matrix.
kernel_axes (tuple[None | str, ...]) – logical axes for partitioning the kernel.
quant (None | AqtQuantization) – quantization config, defaults to None implying no quantization.
use_bias (bool) – whether to add bias in linear transformation.
shard_mode (ShardMode) – indicating the shard mode
matmul_precision (str) – Precision for matrix multiplication.
parameter_memory_host_offload (bool) – Determines whether to offload params to host
name (None | str) – name passed to the ToLinen Module
- class maxtext.layers.linears.Dropout(*args, **kwargs)[source]#
Bases:
DropoutForked nnx.Dropout that is easier to use with bridge
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.layers.linears.MlpBlock(*args, **kwargs)[source]#
Bases:
ModuleTransformer MLP / feed-forward block.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.layers.linears.mlp_block(*, config, mesh, in_features, intermediate_dim=2048, activations=('relu', ), kernel_init=<function nd_dense_init.<locals>.init_fn>, intermediate_dropout_rate=0.1, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, use_bias=False, use_pre_norm=False, quant=None, model_mode=None, name=None)[source]#
Creates a MlpBlock Linen module using nnx.bridge.to_linen.
- Parameters:
config (Any)
mesh (Mesh)
in_features (int)
intermediate_dim (int)
activations (Sequence[str | Callable[[...], Any]])
kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array])
intermediate_dropout_rate (float)
dtype (Any)
weight_dtype (Any)
use_bias (bool)
use_pre_norm (bool)
quant (None | AqtQuantization)
model_mode (None | str)
name (None | str)