maxtext.layers.linears module#

Linear Layers.

maxtext.layers.linears.normalize_axes(axes, ndim)[source]#
Parameters:
  • axes (Iterable[int])

  • ndim (int)

Return type:

tuple[int, …]

maxtext.layers.linears.canonicalize_tuple(x)[source]#
class maxtext.layers.linears.DenseGeneral(*args, **kwargs)[source]#

Bases: Module

A linear transformation with flexible axes.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

property quant_dot_general: ToNNX | None#
maxtext.layers.linears.dense_general(*, inputs_shape=None, in_features_shape=None, out_features_shape, axis=-1, weight_dtype=<class 'jax.numpy.float32'>, dtype=<class 'jax.numpy.float32'>, kernel_init=<function nd_dense_init.<locals>.init_fn>, kernel_axes=(), quant=None, use_bias=False, shard_mode=ShardMode.AUTO, matmul_precision='default', parameter_memory_host_offload=False, name=None)[source]#

Creates a DenseGeneral Linen module using nnx.bridge.to_linen.

Parameters:
  • inputs_shape (tuple[int, ...] | None) – tuple with the shape of the inputs

  • in_features_shape (tuple[int, ...] | int | None) – tuple with numbers of input features for axes specified in ‘axis’.

  • out_features_shape (Iterable[int] | int) – tuple with numbers of output features.

  • axis (Iterable[int] | int) – tuple with axes to apply the transformation on.

  • weight_dtype (dtype) – the dtype of the weights (default: float32).

  • dtype (dtype) – the dtype of the computation (default: float32).

  • kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array]) – initializer function for the weight matrix.

  • kernel_axes (tuple[None | str, ...]) – logical axes for partitioning the kernel.

  • quant (None | AqtQuantization) – quantization config, defaults to None implying no quantization.

  • use_bias (bool) – whether to add bias in linear transformation.

  • shard_mode (ShardMode) – indicating the shard mode

  • matmul_precision (str) – Precision for matrix multiplication.

  • parameter_memory_host_offload (bool) – Determines whether to offload params to host

  • name (None | str) – name passed to the ToLinen Module

class maxtext.layers.linears.Dropout(*args, **kwargs)[source]#

Bases: Dropout

Forked nnx.Dropout that is easier to use with bridge

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.layers.linears.MlpBlock(*args, **kwargs)[source]#

Bases: Module

Transformer MLP / feed-forward block.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

get_norm_layer(num_features)[source]#

get normalization layer.

Parameters:

num_features (int)

maxtext.layers.linears.mlp_block(*, config, mesh, in_features, intermediate_dim=2048, activations=('relu', ), kernel_init=<function nd_dense_init.<locals>.init_fn>, intermediate_dropout_rate=0.1, dtype=<class 'jax.numpy.float32'>, weight_dtype=<class 'jax.numpy.float32'>, use_bias=False, use_pre_norm=False, quant=None, model_mode=None, name=None)[source]#

Creates a MlpBlock Linen module using nnx.bridge.to_linen.

Parameters:
  • config (Any)

  • mesh (Mesh)

  • in_features (int)

  • intermediate_dim (int)

  • activations (Sequence[str | Callable[[...], Any]])

  • kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array])

  • intermediate_dropout_rate (float)

  • dtype (Any)

  • weight_dtype (Any)

  • use_bias (bool)

  • use_pre_norm (bool)

  • quant (None | AqtQuantization)

  • model_mode (None | str)

  • name (None | str)