maxtext.models.models module#
Transformer models.
- class maxtext.models.models.TransformerLinenPure(config, mesh, quant, model_mode='train', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleAn autoregressive transformer model.
- Parameters:
config (Any)
mesh (Mesh)
quant (AqtQuantization)
model_mode (str)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- config: Any#
- mesh: Mesh#
- quant: AqtQuantization#
- model_mode: str = 'train'#
- init(*args, model_mode='train', **kwargs)[source]#
Initializes the model.
- Parameters:
model_mode (str)
Compute logits from hidden states (wrapping decoder.apply_output_head). This function is only used for vocabulary tiling.
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- maxtext.models.models.transformer_as_linen(config, mesh, quant, model_mode='train', *, name=None)[source]#
Constructs a Transformer model as a Linen or NNX module.
This function returns an autoregressive Transformer model as either a Linen module or an NNX-wrapped module, depending on the config.enable_nnx flag. The returned module is suitable for training, evaluation, or decoding.
If config.enable_nnx is True, returns a TransformerLinen that wraps the NNX-style Transformer for integration with NNX-specific APIs and workflows. Otherwise, returns a pure Flax Linen implementation (TransformerLinenPure).
- Parameters:
config (Config) – The configuration object specifying model hyperparameters and options.
mesh (Mesh) – The JAX sharding mesh for device partitioning.
quant (Quant) – The quantization module or configuration to use.
model_mode (str, optional) – The operational mode for the model, e.g. training, prefill, or autoregressive. Defaults to MODEL_MODE_TRAIN.
name (str, optional) – Optional module name for Linen/NNX construction.
- Returns:
A constructed Transformer model compatible with the specified framework (Linen or NNX).
- Return type:
- class maxtext.models.models.TransformerLinen(nnx_class, args=(), kwargs=FrozenDict({}), skip_rng=False, metadata_fn=<function to_linen_var>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ToLinenTransformer model as a linen module.
- Parameters:
nnx_class (Callable[[...], Module])
args (Sequence)
kwargs (Mapping[str, Any])
skip_rng (bool)
metadata_fn (Callable[[Variable], Any] | None)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- init(*args, model_mode='train', **kwargs)[source]#
Initializes the model.
- Parameters:
model_mode (str)
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class maxtext.models.models.Transformer(*args, **kwargs)[source]#
Bases:
ModuleAn autoregressive transformer model.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- init_cache(cache_size, batch_size, dtype=<class 'jax.numpy.float32'>)[source]#
Initializes the KV cache for the Transformer.
- Parameters:
cache_size (int) – The maximum size of the KV cache.
batch_size (int) – The batch size for which the cache is initialized.
dtype – Data type for the cache. Defaults to jnp.float32.
- Returns:
True if the cache is successfully initialized.