maxtext.models.models module#

Transformer models.

class maxtext.models.models.TransformerLinenPure(config, mesh, quant, model_mode='train', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

An autoregressive transformer model.

Parameters:
  • config (Any)

  • mesh (Mesh)

  • quant (AqtQuantization)

  • model_mode (str)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

config: Any#
mesh: Mesh#
quant: AqtQuantization#
model_mode: str = 'train'#
init(*args, model_mode='train', **kwargs)[source]#

Initializes the model.

Parameters:

model_mode (str)

apply(*args, model_mode='train', **kwargs)[source]#

Applies the model.

Parameters:

model_mode (str)

setup()[source]#

Initialize shared_embedding & decoder layers.

logits_from_hidden_states(hidden_states, deterministic, model_mode)[source]#

Compute logits from hidden states (wrapping decoder.apply_output_head). This function is only used for vocabulary tiling.

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
maxtext.models.models.transformer_as_linen(config, mesh, quant, model_mode='train', *, name=None)[source]#

Constructs a Transformer model as a Linen or NNX module.

This function returns an autoregressive Transformer model as either a Linen module or an NNX-wrapped module, depending on the config.enable_nnx flag. The returned module is suitable for training, evaluation, or decoding.

If config.enable_nnx is True, returns a TransformerLinen that wraps the NNX-style Transformer for integration with NNX-specific APIs and workflows. Otherwise, returns a pure Flax Linen implementation (TransformerLinenPure).

Parameters:
  • config (Config) – The configuration object specifying model hyperparameters and options.

  • mesh (Mesh) – The JAX sharding mesh for device partitioning.

  • quant (Quant) – The quantization module or configuration to use.

  • model_mode (str, optional) – The operational mode for the model, e.g. training, prefill, or autoregressive. Defaults to MODEL_MODE_TRAIN.

  • name (str, optional) – Optional module name for Linen/NNX construction.

Returns:

A constructed Transformer model compatible with the specified framework (Linen or NNX).

Return type:

nnx_wrappers.ToLinen | TransformerLinenPure

class maxtext.models.models.TransformerLinen(nnx_class, args=(), kwargs=FrozenDict({}), skip_rng=False, metadata_fn=<function to_linen_var>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: ToLinen

Transformer model as a linen module.

Parameters:
  • nnx_class (Callable[[...], Module])

  • args (Sequence)

  • kwargs (Mapping[str, Any])

  • skip_rng (bool)

  • metadata_fn (Callable[[Variable], Any] | None)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

init(*args, model_mode='train', **kwargs)[source]#

Initializes the model.

Parameters:

model_mode (str)

apply(*args, model_mode='train', **kwargs)[source]#

Applies the model.

Parameters:

model_mode (str)

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class maxtext.models.models.Transformer(*args, **kwargs)[source]#

Bases: Module

An autoregressive transformer model.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

no_op(*args, **kwargs)[source]#

A no-op method to allow the model to be used in a lazy context.

init_cache(cache_size, batch_size, dtype=<class 'jax.numpy.float32'>)[source]#

Initializes the KV cache for the Transformer.

Parameters:
  • cache_size (int) – The maximum size of the KV cache.

  • batch_size (int) – The batch size for which the cache is initialized.

  • dtype – Data type for the cache. Defaults to jnp.float32.

Returns:

True if the cache is successfully initialized.