maxtext.configs.types module

Contents

maxtext.configs.types module#

Pydantic-based configuration system for MaxText, organized into modular classes.

class maxtext.configs.types.XProfTPUPowerTraceMode(*values)[source]#

Bases: IntEnum

Enum for XProfTPUPowerTraceMode.

POWER_TRACE_NONE = 0#
POWER_TRACE_NORMAL = 1#
POWER_TRACE_SPI = 2#
class maxtext.configs.types.DType(*values)[source]#

Bases: str, Enum

Supported data types for weights and activations.

BFLOAT16 = 'bfloat16'#
FLOAT32 = 'float32'#
FLOAT16 = 'float16'#
class maxtext.configs.types.MatmulPrecision(*values)[source]#

Bases: str, Enum

Precision levels for matrix multiplications.

DEFAULT = 'default'#
HIGH = 'high'#
HIGHEST = 'highest'#
BFLOAT16 = 'bfloat16'#
FLOAT32 = 'float32'#
class maxtext.configs.types.QuantizationType(*values)[source]#

Bases: str, Enum

Supported quantization schemes.

NONE = ''#
INT4 = 'int4'#
INT8 = 'int8'#
INTMP = 'intmp'#
FP8 = 'fp8'#
NANOO_FP8 = 'nanoo_fp8'#
FP8_NANO_V2 = 'fp8_nanoo'#
FP8_GPU = 'fp8_gpu'#
FP8_FULL = 'fp8_full'#
TE_FP8_DS = 'te_fp8_delayedscaling'#
TE_FP8_CS = 'te_fp8_currentscaling'#
TE_MXFP8 = 'te_mxfp8'#
TE_NVFP4 = 'te_nvfp4'#
TE_NVFP4_NO_RHT = 'te_nvfp4_no_rht'#
class maxtext.configs.types.KvQuantAxis(*values)[source]#

Bases: str, Enum

Axes to quantize over for the Key-Value cache.

NONE = ''#
DKV = 'dkv'#
HEADS_AND_DKV = 'heads_and_dkv'#
class maxtext.configs.types.RematPolicy(*values)[source]#

Bases: str, Enum

Available rematerialization (gradient checkpointing) policies.

FULL = 'full'#
MINIMAL = 'minimal'#
SAVE_DOT_WITH_CONTEXT_EXCEPT_MLP = 'save_dot_with_context_except_mlp'#
SAVE_DOT_EXCEPT_MLPWI = 'save_dot_except_mlpwi'#
SAVE_DOT_EXCEPT_MLP = 'save_dot_except_mlp'#
SAVE_QKV_PROJ = 'save_qkv_proj'#
QKV_PROJ_OFFLOADED = 'qkv_proj_offloaded'#
CUSTOM = 'custom'#
MINIMAL_OFFLOADED = 'minimal_offloaded'#
SAVE_OUT_PROJ = 'save_out_proj'#
class maxtext.configs.types.RematLocation(*values)[source]#

Bases: str, Enum

Specifies where to store activations for rematerialization.

REMAT = 'remat'#
DEVICE = 'device'#
OFFLOAD = 'offload'#
class maxtext.configs.types.OptimizerType(*values)[source]#

Bases: str, Enum

Supported optimizer algorithms.

ADAMW = 'adamw'#
ADAM_PAX = 'adam_pax'#
SGD = 'sgd'#
MUON = 'muon'#
class maxtext.configs.types.LearningRateScheduleType(*values)[source]#

Bases: str, Enum

Supported learning rate schedule types.

COSINE = 'cosine'#
WSD = 'wsd'#
class maxtext.configs.types.WsdDecayStyle(*values)[source]#

Bases: str, Enum

Supported decay styles for WSD schedule.

LINEAR = 'linear'#
COSINE = 'cosine'#
class maxtext.configs.types.RopeType(*values)[source]#

Bases: str, Enum

Supported Rotary Positional Embedding (RoPE) implementations.

DEFAULT = 'default'#
LLAMA3_1 = 'llama3.1'#
YARN = 'yarn'#
class maxtext.configs.types.TokenizerType(*values)[source]#

Bases: str, Enum

Supported tokenizer libraries.

SENTENCEPIECE = 'sentencepiece'#
HUGGINGFACE = 'huggingface'#
TIKTOKEN = 'tiktoken'#
class maxtext.configs.types.DatasetType(*values)[source]#

Bases: str, Enum

Supported data loading pipelines.

SYNTHETIC = 'synthetic'#
HF = 'hf'#
GRAIN = 'grain'#
TFDS = 'tfds'#
C4MLPERF = 'c4_mlperf'#
OLMO_GRAIN = 'olmo_grain'#
class maxtext.configs.types.SamplingStrategy(*values)[source]#

Bases: str, Enum

Supported decoding and sampling strategies.

GREEDY = 'greedy'#
WEIGHTED = 'weighted'#
NUCLEUS = 'nucleus'#
TOPK = 'topk'#
COMPOSITE = 'composite'#
class maxtext.configs.types.ProfilerType(*values)[source]#

Bases: str, Enum

Supported performance profilers.

NONE = ''#
XPLANE = 'xplane'#
NSYS = 'nsys'#
class maxtext.configs.types.RunInfo(*, base_config=None, run_name='', model_name='default', override_model_config=False, override_logical_axis_rules=False, log_config=True, debug_sharding=False, base_output_directory='', sharding_strategy=None)[source]#

Bases: BaseModel

Configuration for the overall run, model identity, and logging.

Parameters:
  • base_config (None | str)

  • run_name (str)

  • model_name (Literal['default', 'llama2-7b', 'llama2-13b', 'llama2-70b', 'llama3-8b', 'llama3.1-8b-Instruct', 'llama3-70b', 'llama3.1-70b-Instruct', 'llama3.1-8b', 'llama3.1-70b', 'llama3.1-405b', 'llama3.3-70b', 'mistral-7b', 'mixtral-8x7b', 'mixtral-8x22b', 'deepseek2-16b', 'deepseek2-236b', 'deepseek3-671b', 'deepseek3-671b-2dfsdp', 'deepseek3-671b-batchsplit', 'deepseek3-test', 'deepseek3-tiny', 'deepseek3.2-671b', 'deepseek-custom', 'kimi-k2-1t', 'gemma-7b', 'gemma-2b', 'gemma2-2b', 'gemma2-9b', 'gemma2-27b', 'gemma3-4b', 'gemma3-12b', 'gemma3-27b', 'gemma4-26b', 'gemma4-31b', 'qwen2.5-1.5b', 'qwen2.5-7b', 'qwen2.5-14b', 'qwen3-0.6b', 'qwen3-1.7b', 'qwen3-1.7b-base', 'qwen3-4b', 'qwen3-4b-base', 'qwen3-4b-thinking-2507', 'qwen3-8b', 'qwen3-8b-base', 'qwen3-14b', 'qwen3-14b-base', 'qwen3-32b', 'qwen3-235b-a22b', 'qwen3-30b-a3b', 'qwen3-30b-a3b-base', 'qwen3-480b-a35b', 'qwen3-next-80b-a3b', 'qwen3-omni-30b-a3b', 'qwen3-custom-30b-a3b', 'qwen3.5-397b-a17b', 'gpt3-175b', 'gpt3-22b', 'gpt3-6b', 'gpt3-52k', 'gpt-oss-20b', 'gpt-oss-120b', 'llama4-17b-16e', 'llama4-17b-128e', 'olmo3-7b', 'olmo3-7b-pt', 'olmo3-32b'])

  • override_model_config (bool)

  • override_logical_axis_rules (bool)

  • log_config (bool)

  • debug_sharding (bool)

  • base_output_directory (str)

  • sharding_strategy (None | Literal['experimental'])

base_config: None | str#
run_name: str#
model_name: Literal['default', 'llama2-7b', 'llama2-13b', 'llama2-70b', 'llama3-8b', 'llama3.1-8b-Instruct', 'llama3-70b', 'llama3.1-70b-Instruct', 'llama3.1-8b', 'llama3.1-70b', 'llama3.1-405b', 'llama3.3-70b', 'mistral-7b', 'mixtral-8x7b', 'mixtral-8x22b', 'deepseek2-16b', 'deepseek2-236b', 'deepseek3-671b', 'deepseek3-671b-2dfsdp', 'deepseek3-671b-batchsplit', 'deepseek3-test', 'deepseek3-tiny', 'deepseek3.2-671b', 'deepseek-custom', 'kimi-k2-1t', 'gemma-7b', 'gemma-2b', 'gemma2-2b', 'gemma2-9b', 'gemma2-27b', 'gemma3-4b', 'gemma3-12b', 'gemma3-27b', 'gemma4-26b', 'gemma4-31b', 'qwen2.5-1.5b', 'qwen2.5-7b', 'qwen2.5-14b', 'qwen3-0.6b', 'qwen3-1.7b', 'qwen3-1.7b-base', 'qwen3-4b', 'qwen3-4b-base', 'qwen3-4b-thinking-2507', 'qwen3-8b', 'qwen3-8b-base', 'qwen3-14b', 'qwen3-14b-base', 'qwen3-32b', 'qwen3-235b-a22b', 'qwen3-30b-a3b', 'qwen3-30b-a3b-base', 'qwen3-480b-a35b', 'qwen3-next-80b-a3b', 'qwen3-omni-30b-a3b', 'qwen3-custom-30b-a3b', 'qwen3.5-397b-a17b', 'gpt3-175b', 'gpt3-22b', 'gpt3-6b', 'gpt3-52k', 'gpt-oss-20b', 'gpt-oss-120b', 'llama4-17b-16e', 'llama4-17b-128e', 'olmo3-7b', 'olmo3-7b-pt', 'olmo3-32b']#
override_model_config: bool#
override_logical_axis_rules: bool#
log_config: bool#
debug_sharding: bool#
base_output_directory: str#
sharding_strategy: None | Literal['experimental']#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Checkpointing(*, load_parameters_path='', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=True, checkpoint_period=10000, max_num_checkpoints_to_keep=None, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False)[source]#

Bases: BaseModel

Core configuration for checkpointing and run restoration.

Parameters:
  • load_parameters_path (str)

  • lora_input_adapters_path (str)

  • load_full_state_path (str)

  • enable_checkpointing (bool)

  • load_checkpoint_only_once (bool)

  • async_checkpointing (bool)

  • checkpoint_period (int)

  • max_num_checkpoints_to_keep (int | None)

  • enable_single_replica_ckpt_restoring (bool)

  • checkpoint_todelete_subdir (str | None)

  • checkpoint_todelete_full_path (str | None)

  • force_unroll (bool)

  • checkpoint_is_quantized (bool)

  • save_quantized_params_path (str)

  • enable_orbax_v1 (bool)

  • checkpoint_conversion_fn (None | str)

  • source_checkpoint_layout (Literal['orbax', 'safetensors'])

  • save_checkpoint_on_completion (bool)

  • enable_continuous_checkpointing (bool)

  • colocated_python_checkpointing (bool)

  • enable_autocheckpoint (bool)

load_parameters_path: str#
lora_input_adapters_path: str#
load_full_state_path: str#
enable_checkpointing: bool#
load_checkpoint_only_once: bool#
async_checkpointing: bool#
checkpoint_period: int#
max_num_checkpoints_to_keep: int | None#
enable_single_replica_ckpt_restoring: bool#
checkpoint_todelete_subdir: str | None#
checkpoint_todelete_full_path: str | None#
force_unroll: bool#
checkpoint_is_quantized: bool#
save_quantized_params_path: str#
enable_orbax_v1: bool#
checkpoint_conversion_fn: None | str#
source_checkpoint_layout: Literal['orbax', 'safetensors']#
save_checkpoint_on_completion: bool#
enable_continuous_checkpointing: bool#
colocated_python_checkpointing: bool#
enable_autocheckpoint: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.OrbaxStorage(*, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=True, checkpoint_storage_use_zarr3=True, checkpoint_storage_concurrent_gb=96)[source]#

Bases: BaseModel

Configuration for Orbax checkpoint storage options.

Parameters:
  • checkpoint_storage_target_data_file_size_bytes (int)

  • checkpoint_storage_use_ocdbt (bool)

  • checkpoint_storage_use_zarr3 (bool)

  • checkpoint_storage_concurrent_gb (int)

checkpoint_storage_target_data_file_size_bytes: int#
checkpoint_storage_use_ocdbt: bool#
checkpoint_storage_use_zarr3: bool#
checkpoint_storage_concurrent_gb: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.EmergencyCheckpointing(*, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0)[source]#

Bases: BaseModel

Configuration for emergency (local) checkpointing.

Parameters:
  • enable_multi_tier_checkpointing (bool)

  • local_checkpoint_directory (str)

  • local_checkpoint_period (Annotated[int, Ge(ge=0)])

  • multi_tier_checkpointing_backup_interval_minutes (Annotated[int, Ge(ge=0)])

  • mtc_data_parallelism (int)

  • enable_emergency_checkpoint (bool)

  • use_replicator_service (bool)

  • replicator_backup_interval_minutes (Annotated[int, Ge(ge=0)])

enable_multi_tier_checkpointing: bool#
local_checkpoint_directory: str#
local_checkpoint_period: Annotated[int, Ge(ge=0)]#
multi_tier_checkpointing_backup_interval_minutes: Annotated[int, Ge(ge=0)]#
mtc_data_parallelism: int#
enable_emergency_checkpoint: bool#
use_replicator_service: bool#
replicator_backup_interval_minutes: Annotated[int, Ge(ge=0)]#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.DataTypes(*, dtype=DType.BFLOAT16, grad_dtype=DType.FLOAT32, weight_dtype=DType.FLOAT32, matmul_precision=MatmulPrecision.DEFAULT, activations_in_float32=False, dtype_mm='float32')[source]#

Bases: BaseModel

Configuration for data types and precision.

Parameters:
dtype: DType#
grad_dtype: DType#
weight_dtype: DType#
matmul_precision: MatmulPrecision#
activations_in_float32: bool#
dtype_mm: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Quantization(*, quantization=QuantizationType.NONE, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=KvQuantAxis.HEADS_AND_DKV, kv_quant_dtype='int8', quantization_local_shard_count=-1, use_qwix_quantization=False, use_manual_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', weight_sparsity_n=None, weight_sparsity_m=None, weight_sparsity_update_step=10, weight_sparsity_start_step=50)[source]#

Bases: BaseModel

Configuration for model quantization.

Parameters:
  • quantization (None | QuantizationType)

  • replicate_quant_scale (bool)

  • quant_cfg_path (str)

  • quantize_kvcache (bool)

  • kv_quant_axis (KvQuantAxis)

  • kv_quant_dtype (Literal['int8', 'int4'])

  • quantization_local_shard_count (int)

  • use_qwix_quantization (bool)

  • use_manual_quantization (bool)

  • weight_quantization_calibration_method (str)

  • act_quantization_calibration_method (str)

  • bwd_quantization_calibration_method (str)

  • weight_sparsity_n (int | None)

  • weight_sparsity_m (int | None)

  • weight_sparsity_update_step (int)

  • weight_sparsity_start_step (int)

quantization: None | QuantizationType#
replicate_quant_scale: bool#
quant_cfg_path: str#
quantize_kvcache: bool#
kv_quant_axis: KvQuantAxis#
kv_quant_dtype: Literal['int8', 'int4']#
quantization_local_shard_count: int#
use_qwix_quantization: bool#
use_manual_quantization: bool#
weight_quantization_calibration_method: str#
act_quantization_calibration_method: str#
bwd_quantization_calibration_method: str#
weight_sparsity_n: int | None#
weight_sparsity_m: int | None#
weight_sparsity_update_step: int#
weight_sparsity_start_step: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.ModelArchitecture(*, decoder_block='llama2', global_parameter_scale=1, base_emb_dim=2048, base_num_query_heads=16, base_num_kv_heads=16, base_mlp_dim=7168, dense_init_scale=1.0, base_num_decoder_layers=16, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True)[source]#

Bases: BaseModel

Core model architecture parameters.

Parameters:
  • decoder_block (DecoderBlockType)

  • global_parameter_scale (int)

  • base_emb_dim (int)

  • base_num_query_heads (int)

  • base_num_kv_heads (int)

  • base_mlp_dim (int)

  • dense_init_scale (float)

  • base_num_decoder_layers (int)

  • head_dim (int)

  • attention_output_dim (int)

  • global_head_dim (int)

  • mlp_activations (list[str])

  • mlp_activations_limit (float)

  • normalization_layer_epsilon (float)

  • fused_qkv (bool)

  • attention_bias (bool)

  • fused_mlp (bool)

  • qk_norm_with_scale (bool)

  • v_norm_with_scale (bool)

decoder_block: DecoderBlockType#
global_parameter_scale: int#
base_emb_dim: int#
base_num_query_heads: int#
base_num_kv_heads: int#
base_mlp_dim: int#
dense_init_scale: float#
base_num_decoder_layers: int#
head_dim: int#
attention_output_dim: int#
global_head_dim: int#
mlp_activations: list[str]#
mlp_activations_limit: float#
normalization_layer_epsilon: float#
fused_qkv: bool#
attention_bias: bool#
fused_mlp: bool#
qk_norm_with_scale: bool#
v_norm_with_scale: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.MTP(*, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0)[source]#

Bases: BaseModel

Multi-Token Prediction Configs.

Parameters:
  • mtp_num_layers (Annotated[int, Ge(ge=0)])

  • mtp_loss_scaling_factor (Annotated[float, Ge(ge=0)])

  • mtp_eval_target_module (Annotated[int, Ge(ge=0)])

mtp_num_layers: Annotated[int, Ge(ge=0)]#
mtp_loss_scaling_factor: Annotated[float, Ge(ge=0)]#
mtp_eval_target_module: Annotated[int, Ge(ge=0)]#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Logits(*, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0)[source]#

Bases: BaseModel

Configuration for the final logits computation.

Parameters:
  • logits_via_embedding (bool)

  • normalize_embedding_logits (bool)

  • logits_dot_in_fp32 (bool)

  • cast_logits_to_fp32 (bool)

  • final_logits_soft_cap (None | Annotated[float, Ge(ge=0)])

  • z_loss_multiplier (float)

logits_via_embedding: bool#
normalize_embedding_logits: bool#
logits_dot_in_fp32: bool#
cast_logits_to_fp32: bool#
final_logits_soft_cap: None | Annotated[float, Ge(ge=0)]#
z_loss_multiplier: float#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Attention(*, attention='autoselected', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0)[source]#

Bases: BaseModel

General configuration for the attention mechanism.

Parameters:
  • attention (str)

  • attention_type (Literal['global', 'local_sliding', 'chunk', 'mla', 'full'])

  • share_kv_projections (bool)

  • global_num_kv_heads (int)

  • attention_sink (bool)

  • float32_qk_product (bool)

  • float32_logits (bool)

  • sliding_window_size (Annotated[int, Ge(ge=0)])

  • chunk_attn_window_size (Annotated[int, Ge(ge=0)])

  • attn_logits_soft_cap (None | Annotated[float, Ge(ge=0)])

  • use_post_attn_norm (bool)

  • use_post_ffw_norm (bool)

  • use_ragged_attention (bool)

  • use_tokamax_gmm (bool)

  • ragged_block_size (int)

  • enable_padding_causal_mask (bool)

  • use_tokamax_splash (bool)

  • use_jax_splash (bool)

  • force_q_layout (bool)

  • use_qk_clip (bool)

  • qk_clip_threshold (float)

attention: str#
attention_type: Literal['global', 'local_sliding', 'chunk', 'mla', 'full']#
share_kv_projections: bool#
global_num_kv_heads: int#
attention_sink: bool#
float32_qk_product: bool#
float32_logits: bool#
sliding_window_size: Annotated[int, Ge(ge=0)]#
chunk_attn_window_size: Annotated[int, Ge(ge=0)]#
attn_logits_soft_cap: None | Annotated[float, Ge(ge=0)]#
use_post_attn_norm: bool#
use_post_ffw_norm: bool#
use_ragged_attention: bool#
use_tokamax_gmm: bool#
ragged_block_size: int#
enable_padding_causal_mask: bool#
use_tokamax_splash: bool#
use_jax_splash: bool#
force_q_layout: bool#
use_qk_clip: bool#
qk_clip_threshold: float#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.MoBa(*, moba=False, moba_chunk_size=1024, moba_topk=8)[source]#

Bases: BaseModel

Configuration for Mixture of Block Attention (MoBA).

Parameters:
  • moba (bool)

  • moba_chunk_size (int)

  • moba_topk (int)

moba: bool#
moba_chunk_size: int#
moba_topk: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.MlaAttention(*, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128)[source]#

Bases: BaseModel

Configuration for Multi-Layer Attention (MLA).

Parameters:
  • mla_naive_kvcache (bool)

  • q_lora_rank (Annotated[int, Ge(ge=0)])

  • kv_lora_rank (Annotated[int, Ge(ge=0)])

  • qk_nope_head_dim (Annotated[int, Ge(ge=0)])

  • qk_rope_head_dim (Annotated[int, Ge(ge=0)])

  • v_head_dim (Annotated[int, Ge(ge=0)])

mla_naive_kvcache: bool#
q_lora_rank: Annotated[int, Ge(ge=0)]#
kv_lora_rank: Annotated[int, Ge(ge=0)]#
qk_nope_head_dim: Annotated[int, Ge(ge=0)]#
qk_rope_head_dim: Annotated[int, Ge(ge=0)]#
v_head_dim: Annotated[int, Ge(ge=0)]#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.AttentionIndexer(*, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0)[source]#

Bases: BaseModel

Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer.

Parameters:
  • use_indexer (bool)

  • indexer_head_dim (Annotated[int, Ge(ge=0)])

  • indexer_n_heads (Annotated[int, Ge(ge=0)])

  • indexer_topk (Annotated[int, Ge(ge=0)])

  • indexer_sparse_training (bool)

  • indexer_loss_scaling_factor (float)

use_indexer: bool#
indexer_head_dim: Annotated[int, Ge(ge=0)]#
indexer_n_heads: Annotated[int, Ge(ge=0)]#
indexer_topk: Annotated[int, Ge(ge=0)]#
indexer_sparse_training: bool#
indexer_loss_scaling_factor: float#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Llama4Attention(*, use_qk_norm=False, temperature_tuning=False)[source]#

Bases: BaseModel

Configuration specific to Llama4-style models.

Parameters:
  • use_qk_norm (bool)

  • temperature_tuning (bool)

use_qk_norm: bool#
temperature_tuning: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.SplashAttention(*, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False)[source]#

Bases: BaseModel

Tunable block sizes for Splash Attention kernels.

Parameters:
  • sa_block_q (int)

  • sa_block_kv (int)

  • sa_block_kv_compute (int)

  • sa_block_q_dkv (int)

  • sa_block_kv_dkv (int)

  • sa_block_kv_dkv_compute (int)

  • sa_block_q_dq (int)

  • sa_block_kv_dq (int)

  • sa_use_fused_bwd_kernel (bool)

  • sa_q_layout (str)

  • sa_k_layout (str)

  • sa_v_layout (str)

  • use_max_logit_estimate (int)

  • cost_estimate_flops_fwd (int)

  • cost_estimate_flops_bwd (int)

  • dq_reduction_steps (int)

  • use_splash_scheduler (bool)

sa_block_q: int#
sa_block_kv: int#
sa_block_kv_compute: int#
sa_block_q_dkv: int#
sa_block_kv_dkv: int#
sa_block_kv_dkv_compute: int#
sa_block_q_dq: int#
sa_block_kv_dq: int#
sa_use_fused_bwd_kernel: bool#
sa_q_layout: str#
sa_k_layout: str#
sa_v_layout: str#
use_max_logit_estimate: int#
cost_estimate_flops_fwd: int#
cost_estimate_flops_bwd: int#
dq_reduction_steps: int#
use_splash_scheduler: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.PagedAttention(*, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128)[source]#

Bases: BaseModel

Tunable parameters for Paged Attention kernels.

Parameters:
  • pagedattn_num_pages (int)

  • pagedattn_tokens_per_page (int)

  • pagedattn_pages_per_compute_block (int)

  • pagedattn_max_pages_per_group (int)

  • pagedattn_head_dim_alignment (int)

pagedattn_num_pages: int#
pagedattn_tokens_per_page: int#
pagedattn_pages_per_compute_block: int#
pagedattn_max_pages_per_group: int#
pagedattn_head_dim_alignment: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.MoEGeneral(*, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, padded_base_moe_mlp_dim=None, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False)[source]#

Bases: BaseModel

General configuration for Mixture of Experts (MoE) layers.

Parameters:
  • num_experts (Annotated[int, Gt(gt=0)])

  • num_experts_per_tok (Annotated[int, Gt(gt=0)])

  • capacity_factor (float)

  • ragged_buffer_factor (float)

  • moe_expert_input_dim (int)

  • base_moe_mlp_dim (int)

  • padded_base_moe_mlp_dim (int | None)

  • load_balance_loss_weight (Annotated[float, Ge(ge=0)])

  • use_custom_sort_vjp (bool)

  • use_ring_of_experts (bool)

  • use_gather_mosaic_kernel (bool)

  • use_random_routing (bool)

  • interleave_moe_layer_step (int)

  • moe_fsdp_use_two_stage_all_gather (bool)

  • shard_exp_on_fsdp (bool)

  • use_2d_fsdp_sharding (bool)

  • norm_topk_prob (bool)

  • float32_weight_sum (bool)

  • float32_gate_logits (bool)

  • prefuse_moe_weights (bool)

num_experts: Annotated[int, Gt(gt=0)]#
num_experts_per_tok: Annotated[int, Gt(gt=0)]#
capacity_factor: float#
ragged_buffer_factor: float#
moe_expert_input_dim: int#
base_moe_mlp_dim: int#
padded_base_moe_mlp_dim: int | None#
load_balance_loss_weight: Annotated[float, Ge(ge=0)]#
use_custom_sort_vjp: bool#
use_ring_of_experts: bool#
use_gather_mosaic_kernel: bool#
use_random_routing: bool#
interleave_moe_layer_step: int#
moe_fsdp_use_two_stage_all_gather: bool#
shard_exp_on_fsdp: bool#
use_2d_fsdp_sharding: bool#
norm_topk_prob: bool#
float32_weight_sum: bool#
float32_gate_logits: bool#
prefuse_moe_weights: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.MoEKernels(*, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False)[source]#

Bases: BaseModel

Configuration for MoE-specific kernels like Megablox.

Parameters:
  • megablox (bool)

  • sparse_matmul (bool)

  • wi_tile_fwd_batch_seq (int)

  • wi_tile_fwd_embed_dim (int)

  • wi_tile_fwd_mlp_dim (int)

  • wi_tile_dlhs_batch_seq (int)

  • wi_tile_dlhs_embed_dim (int)

  • wi_tile_dlhs_mlp_dim (int)

  • wi_tile_drhs_batch_seq (int)

  • wi_tile_drhs_embed_dim (int)

  • wi_tile_drhs_mlp_dim (int)

  • wo_tile_fwd_batch_seq (int)

  • wo_tile_fwd_embed_dim (int)

  • wo_tile_fwd_mlp_dim (int)

  • wo_tile_dlhs_batch_seq (int)

  • wo_tile_dlhs_embed_dim (int)

  • wo_tile_dlhs_mlp_dim (int)

  • wo_tile_drhs_batch_seq (int)

  • wo_tile_drhs_embed_dim (int)

  • wo_tile_drhs_mlp_dim (int)

  • merge_gating_gmm (bool)

megablox: bool#
sparse_matmul: bool#
wi_tile_fwd_batch_seq: int#
wi_tile_fwd_embed_dim: int#
wi_tile_fwd_mlp_dim: int#
wi_tile_dlhs_batch_seq: int#
wi_tile_dlhs_embed_dim: int#
wi_tile_dlhs_mlp_dim: int#
wi_tile_drhs_batch_seq: int#
wi_tile_drhs_embed_dim: int#
wi_tile_drhs_mlp_dim: int#
wo_tile_fwd_batch_seq: int#
wo_tile_fwd_embed_dim: int#
wo_tile_fwd_mlp_dim: int#
wo_tile_dlhs_batch_seq: int#
wo_tile_dlhs_embed_dim: int#
wo_tile_dlhs_mlp_dim: int#
wo_tile_drhs_batch_seq: int#
wo_tile_drhs_embed_dim: int#
wo_tile_drhs_mlp_dim: int#
merge_gating_gmm: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.DeepSeekMoE(*, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1)[source]#

Bases: BaseModel

Configuration specific to DeepSeek-style MoE layers.

Parameters:
  • first_num_dense_layers (Annotated[int, Ge(ge=0)])

  • shared_experts (Annotated[int, Ge(ge=0)])

  • routed_scaling_factor (float)

  • routed_score_func (str)

  • routed_bias (bool)

  • routed_bias_update_rate (float)

  • mlp_bias (bool)

  • n_routing_groups (int)

  • topk_routing_group (int)

  • use_batch_split_schedule (bool)

  • batch_split_factor (int)

first_num_dense_layers: Annotated[int, Ge(ge=0)]#
shared_experts: Annotated[int, Ge(ge=0)]#
routed_scaling_factor: float#
routed_score_func: str#
routed_bias: bool#
routed_bias_update_rate: float#
mlp_bias: bool#
n_routing_groups: int#
topk_routing_group: int#
use_batch_split_schedule: bool#
batch_split_factor: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Qwen3Next(*, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0)[source]#

Bases: BaseModel

Configuration specific to Qwen3-Next models with Gated Delta Net.

Parameters:
  • gdn_conv_kernel_dim (int)

  • gdn_key_head_dim (int)

  • gdn_value_head_dim (int)

  • gdn_num_key_heads (int)

  • gdn_num_value_heads (int)

  • gdn_chunk_size (int)

  • use_qk_norm_in_gdn (bool)

  • partial_rotary_factor (float)

gdn_conv_kernel_dim: int#
gdn_key_head_dim: int#
gdn_value_head_dim: int#
gdn_num_key_heads: int#
gdn_num_value_heads: int#
gdn_chunk_size: int#
use_qk_norm_in_gdn: bool#
partial_rotary_factor: float#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.HardwareAndMesh(*, hardware='tpu', num_slices=-1, mesh_axes=['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode='auto', inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=ReorderStrategy.AUTO, custom_mesh='', custom_mesh_and_rule=CustomRule.DEFAULT, allow_split_physical_axes=False, enable_nnx=False, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=False, pure_nnx=False, remove_size_one_mesh_axis_from_type=True)[source]#

Bases: BaseModel

Configuration for hardware and parallelism mesh.

Parameters:
  • hardware (Literal['tpu', 'gpu', 'gpu_multiprocess', 'cpu'])

  • num_slices (int)

  • mesh_axes (list[str])

  • shard_mode (ShardMode)

  • inhomogeneous_layer_cycle_interval (int)

  • scan_layers (bool)

  • param_scan_axis (int)

  • context_parallel_load_balance (bool)

  • context_parallel_strategy (str)

  • context_parallel_reorder_strategy (ReorderStrategy)

  • custom_mesh (str)

  • custom_mesh_and_rule (CustomRule)

  • allow_split_physical_axes (bool)

  • enable_nnx (bool)

  • optimize_mesh_for_tpu_v6e (bool)

  • shardy (bool)

  • pure_nnx_decoder (bool)

  • pure_nnx (bool)

  • remove_size_one_mesh_axis_from_type (bool)

hardware: Literal['tpu', 'gpu', 'gpu_multiprocess', 'cpu']#
num_slices: int#
mesh_axes: list[str]#
shard_mode: ShardMode#
inhomogeneous_layer_cycle_interval: int#
scan_layers: bool#
param_scan_axis: int#
context_parallel_load_balance: bool#
context_parallel_strategy: str#
context_parallel_reorder_strategy: ReorderStrategy#
custom_mesh: str#
custom_mesh_and_rule: CustomRule#
allow_split_physical_axes: bool#
enable_nnx: bool#
optimize_mesh_for_tpu_v6e: bool#
shardy: bool#
pure_nnx_decoder: bool#
pure_nnx: bool#
remove_size_one_mesh_axis_from_type: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.LayoutAndSharding(*, logical_axis_rules=[], data_sharding=[], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='')[source]#

Bases: BaseModel

Configuration for data and model sharding rules.

Parameters:
  • logical_axis_rules (Any)

  • data_sharding (Any)

  • context_sharding (str)

  • input_data_sharding_logical_axes (list[str])

  • sharding_tolerance (float)

  • shard_optimizer_over_data (bool)

  • internal_compile (bool)

  • internal_compile_num_devices (int)

  • compile_xla_flags (str)

logical_axis_rules: Any#
data_sharding: Any#
context_sharding: str#
input_data_sharding_logical_axes: list[str]#
sharding_tolerance: float#
shard_optimizer_over_data: bool#
internal_compile: bool#
internal_compile_num_devices: int#
compile_xla_flags: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.DcnParallelism(*, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1)[source]#

Bases: BaseModel

Parallelism dimensions across the DCN (Data Center Network).

Parameters:
  • dcn_diloco_parallelism (int)

  • dcn_data_parallelism (int)

  • dcn_fsdp_parallelism (int)

  • dcn_fsdp_transpose_parallelism (int)

  • dcn_sequence_parallelism (int)

  • dcn_context_parallelism (int)

  • dcn_context_autoregressive_parallelism (int)

  • dcn_tensor_parallelism (int)

  • dcn_tensor_transpose_parallelism (int)

  • dcn_tensor_sequence_parallelism (int)

  • dcn_pipeline_parallelism (int)

  • dcn_expert_parallelism (int)

  • dcn_autoregressive_parallelism (int)

dcn_diloco_parallelism: int#
dcn_data_parallelism: int#
dcn_fsdp_parallelism: int#
dcn_fsdp_transpose_parallelism: int#
dcn_sequence_parallelism: int#
dcn_context_parallelism: int#
dcn_context_autoregressive_parallelism: int#
dcn_tensor_parallelism: int#
dcn_tensor_transpose_parallelism: int#
dcn_tensor_sequence_parallelism: int#
dcn_pipeline_parallelism: int#
dcn_expert_parallelism: int#
dcn_autoregressive_parallelism: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.IciParallelism(*, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1)[source]#

Bases: BaseModel

Parallelism dimensions within the ICI (Inter-Chip Interconnect).

Parameters:
  • ici_diloco_parallelism (int)

  • ici_data_parallelism (int)

  • ici_fsdp_parallelism (int)

  • ici_fsdp_transpose_parallelism (int)

  • ici_sequence_parallelism (int)

  • ici_context_parallelism (int)

  • ici_context_autoregressive_parallelism (int)

  • ici_tensor_parallelism (int)

  • ici_tensor_transpose_parallelism (int)

  • ici_tensor_sequence_parallelism (int)

  • ici_autoregressive_parallelism (int)

  • ici_pipeline_parallelism (int)

  • ici_expert_parallelism (int)

ici_diloco_parallelism: int#
ici_data_parallelism: int#
ici_fsdp_parallelism: int#
ici_fsdp_transpose_parallelism: int#
ici_sequence_parallelism: int#
ici_context_parallelism: int#
ici_context_autoregressive_parallelism: int#
ici_tensor_parallelism: int#
ici_tensor_transpose_parallelism: int#
ici_tensor_sequence_parallelism: int#
ici_autoregressive_parallelism: int#
ici_pipeline_parallelism: int#
ici_expert_parallelism: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.PipelineParallelism(*, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=-1, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=True, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False)[source]#

Bases: BaseModel

Configuration for pipeline parallelism.

Parameters:
  • pipeline_fsdp_ag_per_repeat (bool)

  • num_layers_per_pipeline_stage (int)

  • num_pipeline_repeats (int)

  • pipeline_parallel_layers (int)

  • num_pipeline_microbatches (int)

  • pipeline_delay_activation_forwarding (bool)

  • pipeline_fsdp_ag_once (bool)

  • scan_pipeline_iterations (bool)

  • scan_pipeline_repeats (bool)

  • scan_layers_per_stage (bool)

  • set_remat_policy_on_pipeline_iterations (bool)

  • set_remat_policy_on_layers_per_stage (bool)

pipeline_fsdp_ag_per_repeat: bool#
num_layers_per_pipeline_stage: int#
num_pipeline_repeats: int#
pipeline_parallel_layers: int#
num_pipeline_microbatches: int#
pipeline_delay_activation_forwarding: bool#
pipeline_fsdp_ag_once: bool#
scan_pipeline_iterations: bool#
scan_pipeline_repeats: bool#
scan_layers_per_stage: bool#
set_remat_policy_on_pipeline_iterations: bool#
set_remat_policy_on_layers_per_stage: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.RematAndOffload(*, remat_policy='full', remat_policy_for_vit='minimal', decoder_layer_input=RematLocation.DEVICE, context=RematLocation.REMAT, mlpwi=RematLocation.REMAT, mlpwi_0=RematLocation.REMAT, mlpwi_1=RematLocation.REMAT, mlpwo=RematLocation.REMAT, moe_mlpwi_0=RematLocation.REMAT, moe_mlpwi_1=RematLocation.REMAT, moe_mlpwo=RematLocation.REMAT, query_proj=RematLocation.REMAT, key_proj=RematLocation.REMAT, value_proj=RematLocation.REMAT, query_wa_proj=RematLocation.REMAT, kv_wa_proj=RematLocation.REMAT, qkv_proj=RematLocation.REMAT, out_proj=RematLocation.REMAT, mla_q=RematLocation.REMAT, mla_kv=RematLocation.REMAT, attention_out=RematLocation.REMAT, engram=RematLocation.REMAT, optimizer_memory_host_offload=False, parameter_memory_host_offload=False)[source]#

Bases: BaseModel

Configuration for gradient checkpointing (rematerialization) and offloading.

Parameters:
remat_policy: str#
remat_policy_for_vit: str#
decoder_layer_input: RematLocation#
context: RematLocation#
mlpwi: RematLocation#
mlpwi_0: RematLocation#
mlpwi_1: RematLocation#
mlpwo: RematLocation#
moe_mlpwi_0: RematLocation#
moe_mlpwi_1: RematLocation#
moe_mlpwo: RematLocation#
query_proj: RematLocation#
key_proj: RematLocation#
value_proj: RematLocation#
query_wa_proj: RematLocation#
kv_wa_proj: RematLocation#
qkv_proj: RematLocation#
out_proj: RematLocation#
mla_q: RematLocation#
mla_kv: RematLocation#
attention_out: RematLocation#
engram: RematLocation#
optimizer_memory_host_offload: bool#
parameter_memory_host_offload: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Tokenizer(*, vocab_size=32000, tokenizer_path=None, tokenizer_type=TokenizerType.SENTENCEPIECE, use_chat_template=False, chat_template_path='', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1)[source]#

Bases: BaseModel

Configuration for the tokenizer.

Parameters:
  • vocab_size (int)

  • tokenizer_path (None | str)

  • tokenizer_type (TokenizerType)

  • use_chat_template (bool)

  • chat_template_path (str)

  • chat_template (str)

  • tokenize_train_data (bool)

  • tokenize_eval_data (bool)

  • add_bos (bool)

  • add_eos (bool)

  • use_truncation (bool)

  • num_vocab_tiling (int)

vocab_size: int#
tokenizer_path: None | str#
tokenizer_type: TokenizerType#
use_chat_template: bool#
chat_template_path: str#
chat_template: str#
tokenize_train_data: bool#
tokenize_eval_data: bool#
add_bos: bool#
add_eos: bool#
use_truncation: bool#
num_vocab_tiling: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.DatasetGeneral(*, dataset_type=DatasetType.TFDS, per_device_batch_size=12, eval_per_device_batch_size=0.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False)[source]#

Bases: BaseModel

General configuration for dataset and data loading.

Parameters:
  • dataset_type (DatasetType)

  • per_device_batch_size (int | float)

  • eval_per_device_batch_size (int | float)

  • max_corpus_chars (int)

  • train_data_columns (list[str])

  • train_image_column (str | list[str])

  • eval_data_columns (list[str])

  • eval_image_column (str | list[str])

  • packing (bool)

  • grain_packing_type (Literal['first_fit', 'best_fit', 'concat_then_split'])

  • max_segments_per_seq (int)

  • num_epoch (int)

  • expansion_factor_real_data (float)

  • reuse_example_batch (int)

  • generate_padding_batch_train (bool)

  • generate_padding_batch_eval (bool)

  • enable_rampup_batch_size (bool)

  • per_device_batch_size_start (float)

  • per_device_batch_size_increment (float)

  • global_rampup_samples (int)

  • colocated_python_data_input (bool)

dataset_type: DatasetType#
per_device_batch_size: int | float#
eval_per_device_batch_size: int | float#
max_corpus_chars: int#
train_data_columns: list[str]#
train_image_column: str | list[str]#
eval_data_columns: list[str]#
eval_image_column: str | list[str]#
packing: bool#
grain_packing_type: Literal['first_fit', 'best_fit', 'concat_then_split']#
max_segments_per_seq: int#
num_epoch: int#
expansion_factor_real_data: float#
reuse_example_batch: int#
generate_padding_batch_train: bool#
generate_padding_batch_eval: bool#
enable_rampup_batch_size: bool#
per_device_batch_size_start: float#
per_device_batch_size_increment: float#
global_rampup_samples: int#
colocated_python_data_input: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.TfdsDataset(*, dataset_path='', dataset_name='c4/en:3.0.1', eval_dataset_name='c4/en:3.0.1', train_split='train', eval_split='validation')[source]#

Bases: BaseModel

Configuration specific to TFDS datasets.

Parameters:
  • dataset_path (str)

  • dataset_name (str)

  • eval_dataset_name (str)

  • train_split (str)

  • eval_split (str)

dataset_path: str#
dataset_name: str#
eval_dataset_name: str#
train_split: str#
eval_split: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.HfDataset(*, hf_path='', hf_name=None, hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token=None)[source]#

Bases: BaseModel

Configuration specific to HuggingFace datasets.

Parameters:
  • hf_path (str)

  • hf_name (None | str)

  • hf_data_dir (None | str)

  • hf_train_files (None | str)

  • hf_eval_split (None | str)

  • hf_eval_files (None | str)

  • hf_access_token (None | str)

hf_path: str#
hf_name: None | str#
hf_data_dir: None | str#
hf_train_files: None | str#
hf_eval_split: None | str#
hf_eval_files: None | str#
hf_access_token: None | str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.GrainDataset(*, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_use_elastic_iterator=False, grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100)[source]#

Bases: BaseModel

Configuration specific to Grain datasets.

Parameters:
  • grain_train_files (str)

  • grain_eval_files (str)

  • grain_train_mixture_config_path (str)

  • grain_file_type (str)

  • grain_use_elastic_iterator (bool)

  • grain_worker_count (int)

  • grain_per_worker_buffer_size (int)

  • grain_worker_count_eval (int)

  • grain_per_worker_buffer_size_eval (int)

  • grain_ram_budget_mb (int)

  • grain_num_threads (int)

  • grain_prefetch_buffer_size (int)

  • grain_num_threads_eval (int)

  • grain_prefetch_buffer_size_eval (int)

  • grain_data_source_max_workers (int)

  • grain_shuffle_buffer_size (int)

grain_train_files: str#
grain_eval_files: str#
grain_train_mixture_config_path: str#
grain_file_type: str#
grain_use_elastic_iterator: bool#
grain_worker_count: int#
grain_per_worker_buffer_size: int#
grain_worker_count_eval: int#
grain_per_worker_buffer_size_eval: int#
grain_ram_budget_mb: int#
grain_num_threads: int#
grain_prefetch_buffer_size: int#
grain_num_threads_eval: int#
grain_prefetch_buffer_size_eval: int#
grain_data_source_max_workers: int#
grain_shuffle_buffer_size: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.OlmoGrainDataset(*, olmo_index_path='', olmo_path_remap_from='', olmo_path_remap_to='', olmo_apply_ngram_filter=True)[source]#

Bases: BaseModel

Configuration for the OLMo numpy fixed-seq-length input pipeline (dataset_type=olmo_grain).

Separate from the standard grain config because this pipeline reads pre-tokenized fixed-length sequences from raw npy files (one int32 token per element, sequence_length from an index JSON), not arrayrecord/tfds shards — so flags like grain_train_files / packing don’t apply.

Worker count, per-worker buffer size, and shuffle seed reuse the standard grain flags (grain_worker_count, grain_per_worker_buffer_size, data_shuffle_seed); only OLMo-specific fields are listed here.

Parameters:
  • olmo_index_path (str)

  • olmo_path_remap_from (str)

  • olmo_path_remap_to (str)

  • olmo_apply_ngram_filter (bool)

olmo_index_path: str#
olmo_path_remap_from: str#
olmo_path_remap_to: str#
olmo_apply_ngram_filter: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.FineTuning(*, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs=<factory>, use_grpo=None)[source]#

Bases: BaseModel

Configuration for fine-tuning methods like DPO, SFT, and GRPO.

Parameters:
  • use_dpo (bool)

  • dpo_label_smoothing (float)

  • dpo_beta (float)

  • use_sft (bool)

  • sft_train_on_completion_only (bool)

  • formatting_func_path (str)

  • formatting_func_kwargs (dict)

  • use_grpo (None | bool)

use_dpo: bool#
dpo_label_smoothing: float#
dpo_beta: float#
use_sft: bool#
sft_train_on_completion_only: bool#
formatting_func_path: str#
formatting_func_kwargs: dict#
use_grpo: None | bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Distillation(*, student_overrides=<factory>, teacher_overrides=<factory>, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', learn_to_init_mode=False, lti_use_general_linear_map=False, distill_weights_copy_map=<factory>, distill_student_weights_share_map=<factory>, student_params_to_update=None)[source]#

Bases: BaseModel

Configuration for Knowledge Distillation.

Parameters:
  • student_overrides (dict[str, Any])

  • teacher_overrides (dict[str, Any])

  • offline_data_dir (str | None)

  • distill_alpha (float)

  • distill_temperature (float)

  • distill_beta (float)

  • distill_feature_loss_type (Literal['cosine', 'l2'])

  • distill_layer_indices (None | list)

  • distill_alpha_end (float | None)

  • distill_alpha_schedule (Literal['constant', 'linear', 'cosine'])

  • distill_temperature_end (float | None)

  • distill_temperature_schedule (Literal['constant', 'linear', 'cosine'])

  • distill_beta_end (float | None)

  • distill_beta_schedule (Literal['constant', 'linear', 'cosine'])

  • learn_to_init_mode (bool)

  • lti_use_general_linear_map (bool)

  • distill_weights_copy_map (dict[str, Any])

  • distill_student_weights_share_map (dict[str, Any])

  • student_params_to_update (None | list)

student_overrides: dict[str, Any]#
teacher_overrides: dict[str, Any]#
offline_data_dir: str | None#
distill_alpha: float#
distill_temperature: float#
distill_beta: float#
distill_feature_loss_type: Literal['cosine', 'l2']#
distill_layer_indices: None | list#
distill_alpha_end: float | None#
distill_alpha_schedule: Literal['constant', 'linear', 'cosine']#
distill_temperature_end: float | None#
distill_temperature_schedule: Literal['constant', 'linear', 'cosine']#
distill_beta_end: float | None#
distill_beta_schedule: Literal['constant', 'linear', 'cosine']#
learn_to_init_mode: bool#
lti_use_general_linear_map: bool#
distill_weights_copy_map: dict[str, Any]#
distill_student_weights_share_map: dict[str, Any]#
student_params_to_update: None | list#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.TrainingLoop(*, steps=150001, log_period=100, eval_interval=-1, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=True, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=0, init_weights_seed=0)[source]#

Bases: BaseModel

Configuration for the main training loop, evaluation, and reproducibility.

Parameters:
  • steps (int)

  • log_period (int)

  • eval_interval (int)

  • eval_steps (int)

  • target_eval_loss (float)

  • abort_on_nan_loss (bool)

  • abort_on_inf_loss (bool)

  • enable_dropout (bool)

  • dropout_rate (float)

  • enable_data_shuffling (bool)

  • data_shuffle_seed (int)

  • init_weights_seed (int)

steps: int#
log_period: int#
eval_interval: int#
eval_steps: int#
target_eval_loss: float#
abort_on_nan_loss: bool#
abort_on_inf_loss: bool#
enable_dropout: bool#
dropout_rate: float#
enable_data_shuffling: bool#
data_shuffle_seed: int#
init_weights_seed: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.ManifoldConstrainedHyperConnections(*, mhc_expansion_rate=1, sinkhorn_iterations=20)[source]#

Bases: BaseModel

Configuration for DeepSeek Manifold-Constrained Hyper Connections (mHC).

Parameters:
  • mhc_expansion_rate (Annotated[int, Gt(gt=0)])

  • sinkhorn_iterations (Annotated[int, Gt(gt=0)])

mhc_expansion_rate: Annotated[int, Gt(gt=0)]#
sinkhorn_iterations: Annotated[int, Gt(gt=0)]#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.DilocoParams(*, enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9)[source]#

Bases: BaseModel

Diloco Hyperparameters

Parameters:
  • enable_diloco (bool)

  • diloco_sync_period (int)

  • diloco_outer_lr (float)

  • diloco_outer_momentum (float)

enable_diloco: bool#
diloco_sync_period: int#
diloco_outer_lr: float#
diloco_outer_momentum: float#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Optimizer(*, opt_type=OptimizerType.ADAMW, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=1.0, learning_rate=3e-05, lr_schedule_type=LearningRateScheduleType.COSINE, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=WsdDecayStyle.LINEAR, warmup_steps_fraction=0.1, learning_rate_schedule_steps=-1, trainable_parameters_mask=<factory>)[source]#

Bases: BaseModel

Configuration for the optimizer and learning rate schedule.

Parameters:
  • opt_type (OptimizerType)

  • skip_step_on_spikes (bool)

  • skip_step_interval (Annotated[int, Gt(gt=0)])

  • skip_step_scaling_factor (float)

  • gradient_accumulation_steps (Annotated[int, Gt(gt=0)])

  • use_tunix_gradient_accumulation (bool)

  • gradient_clipping_threshold (Annotated[float, Ge(ge=0)])

  • learning_rate (Annotated[float, Ge(ge=0)])

  • lr_schedule_type (LearningRateScheduleType)

  • learning_rate_final_fraction (float)

  • wsd_decay_steps_fraction (float)

  • wsd_decay_style (WsdDecayStyle)

  • warmup_steps_fraction (float)

  • learning_rate_schedule_steps (int)

  • trainable_parameters_mask (list[str])

opt_type: OptimizerType#
skip_step_on_spikes: bool#
skip_step_interval: Annotated[int, Gt(gt=0)]#
skip_step_scaling_factor: float#
gradient_accumulation_steps: Annotated[int, Gt(gt=0)]#
use_tunix_gradient_accumulation: bool#
gradient_clipping_threshold: Annotated[float, Ge(ge=0)]#
learning_rate: Annotated[float, Ge(ge=0)]#
lr_schedule_type: LearningRateScheduleType#
learning_rate_final_fraction: float#
wsd_decay_steps_fraction: float#
wsd_decay_style: WsdDecayStyle#
warmup_steps_fraction: float#
learning_rate_schedule_steps: int#
trainable_parameters_mask: list[str]#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.AdamW(*, adam_b1=0.9, adam_b2=0.95, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=<factory>, mu_dtype='')[source]#

Bases: BaseModel

Configuration specific to the AdamW optimizer.

Parameters:
  • adam_b1 (float)

  • adam_b2 (float)

  • adam_eps (float)

  • adam_eps_root (float)

  • adam_weight_decay (float)

  • adamw_mask (list[str])

  • mu_dtype (str)

adam_b1: float#
adam_b2: float#
adam_eps: float#
adam_eps_root: float#
adam_weight_decay: float#
adamw_mask: list[str]#
mu_dtype: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Muon(*, muon_beta=0.95, muon_weight_decay=0, muon_consistent_rms=None)[source]#

Bases: BaseModel

Configuration specific to the Muon optimizer.

Parameters:
  • muon_beta (float)

  • muon_weight_decay (float)

  • muon_consistent_rms (float | None)

muon_beta: float#
muon_weight_decay: float#
muon_consistent_rms: float | None#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.PositionalEmbedding(*, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1)[source]#

Bases: BaseModel

General configuration for positional embeddings.

Parameters:
  • use_iota_embed (bool)

  • use_untrainable_positional_embedding (bool)

  • trainable_position_size (int)

  • nope_layer_interval (int)

use_iota_embed: bool#
use_untrainable_positional_embedding: bool#
trainable_position_size: int#
nope_layer_interval: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Rope(*, rope_type=RopeType.DEFAULT, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=10000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0)[source]#

Bases: BaseModel

Configuration for Rotary Positional Embedding (RoPE).

Parameters:
  • rope_type (RopeType)

  • rope_use_scale (bool)

  • rope_min_timescale (int)

  • rope_max_timescale (int)

  • rope_linear_scaling_factor (float)

  • local_rope_max_timescale (int)

  • global_rope_max_timescale (int)

  • global_rope_proportion (float)

  • local_rope_proportion (float)

rope_type: RopeType#
rope_use_scale: bool#
rope_min_timescale: int#
rope_max_timescale: int#
rope_linear_scaling_factor: float#
local_rope_max_timescale: int#
global_rope_max_timescale: int#
global_rope_proportion: float#
local_rope_proportion: float#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.YarnRope(*, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False)[source]#

Bases: BaseModel

Configuration specific to YaRN (Yet another RoPE) scaling.

Parameters:
  • max_position_embeddings (int)

  • original_max_position_embeddings (int)

  • rope_factor (int)

  • beta_fast (int)

  • beta_slow (int)

  • mscale (float)

  • rope_interleave (bool)

  • rope_truncate (bool)

  • rope_attention_scaling (bool)

max_position_embeddings: int#
original_max_position_embeddings: int#
rope_factor: int#
beta_fast: int#
beta_slow: int#
mscale: float#
rope_interleave: bool#
rope_truncate: bool#
rope_attention_scaling: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.InferenceGeneral(*, max_target_length=2048, max_prefill_predict_length=64, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False)[source]#

Bases: BaseModel

General configuration for inference.

Parameters:
  • max_target_length (int)

  • max_prefill_predict_length (int)

  • prompt (str)

  • load_from_prefill_dir (bool)

  • prefill_cache_dir (str)

  • autoregressive_decode_assert (str)

  • model_call_mode (str)

  • use_chunked_prefill (bool)

  • prefill_chunk_size (int)

  • enable_model_warmup (bool)

  • enable_llm_inference_pool (bool)

  • multi_sampling (bool)

  • return_log_prob (bool)

max_target_length: int#
max_prefill_predict_length: int#
prompt: str#
load_from_prefill_dir: bool#
prefill_cache_dir: str#
autoregressive_decode_assert: str#
model_call_mode: str#
use_chunked_prefill: bool#
prefill_chunk_size: int#
enable_model_warmup: bool#
enable_llm_inference_pool: bool#
multi_sampling: bool#
return_log_prob: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Decoding(*, decode_sampling_strategy=SamplingStrategy.GREEDY, decode_sampling_nucleus_p=-1.0, decode_sampling_top_k=0, decode_sampling_temperature=1.0)[source]#

Bases: BaseModel

Configuration for decoding and sampling strategies.

Parameters:
  • decode_sampling_strategy (SamplingStrategy)

  • decode_sampling_nucleus_p (int | float)

  • decode_sampling_top_k (int)

  • decode_sampling_temperature (float)

decode_sampling_strategy: SamplingStrategy#
decode_sampling_nucleus_p: int | float#
decode_sampling_top_k: int#
decode_sampling_temperature: float#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.InferenceLayout(*, stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False)[source]#

Bases: BaseModel

Configuration for KV cache and compute layouts during inference.

Parameters:
  • stack_prefill_result_cache (bool)

  • prefill_cache_axis_order (str)

  • ar_cache_axis_order (str)

  • compute_axis_order (str)

  • reshape_q (bool)

stack_prefill_result_cache: bool#
prefill_cache_axis_order: str#
ar_cache_axis_order: str#
compute_axis_order: str#
reshape_q: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.InferenceServer(*, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16')[source]#

Bases: BaseModel

Configuration for running as an inference server.

Parameters:
  • inference_server (str)

  • prefill_slice (str)

  • generate_slice (str)

inference_server: str#
prefill_slice: str#
generate_slice: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.InferenceBenchmark(*, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False)[source]#

Bases: BaseModel

Configuration for running inference microbenchmarks.

Parameters:
  • inference_microbenchmark_prefill_lengths (str)

  • inference_microbenchmark_stages (str)

  • inference_microbenchmark_loop_iters (int)

  • inference_microbenchmark_log_file_path (str)

  • inference_microbenchmark_num_samples (list[int])

  • inference_metadata_file (str)

  • inference_benchmark_test (bool)

inference_microbenchmark_prefill_lengths: str#
inference_microbenchmark_stages: str#
inference_microbenchmark_loop_iters: int#
inference_microbenchmark_log_file_path: str#
inference_microbenchmark_num_samples: list[int]#
inference_metadata_file: str#
inference_benchmark_test: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.PrefixCaching(*, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000)[source]#

Bases: BaseModel

Configuration for Prefix Caching in JetStream.

Parameters:
  • enable_prefix_caching (bool)

  • prefix_caching_hbm_byte (int)

  • prefix_caching_dram_byte (int)

enable_prefix_caching: bool#
prefix_caching_hbm_byte: int#
prefix_caching_dram_byte: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.AOT(*, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1)[source]#

Bases: BaseModel

Ahead of Time (AOT) Compilation settings.

Parameters:
  • compiled_trainstep_file (str)

  • compile_topology (str)

  • compile_topology_num_slices (int)

compiled_trainstep_file: str#
compile_topology: str#
compile_topology_num_slices: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.DevelopmentAndDebugging(*, constant_bound_config=[], jax_cache_dir='/home/docs/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=False, enable_single_controller=False, subslice_shape='', max_checkify=False)[source]#

Bases: BaseModel

General settings for development and debugging.

Parameters:
  • constant_bound_config (list)

  • jax_cache_dir (str | None)

  • jax_distributed_initialization_timeout (int)

  • jax_debug_log_modules (str)

  • skip_jax_distributed_system (bool)

  • enable_single_controller (bool)

  • subslice_shape (str)

  • max_checkify (bool)

constant_bound_config: list#
jax_cache_dir: str | None#
jax_distributed_initialization_timeout: int#
jax_debug_log_modules: str#
skip_jax_distributed_system: bool#
enable_single_controller: bool#
subslice_shape: str#
max_checkify: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Profiling(*, profiler=ProfilerType.NONE, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, enable_tpu_profiling_options=False, tpu_num_chips_to_profile_per_task=1, tpu_num_sparse_cores_to_trace=2, tpu_num_sparse_core_tiles_to_trace=1, xprof_tpu_power_trace_level=XProfTPUPowerTraceMode.POWER_TRACE_NONE, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False)[source]#

Bases: BaseModel

Configuration for performance profiling.

Parameters:
  • profiler (ProfilerType)

  • upload_all_profiler_results (bool)

  • skip_first_n_steps_for_profiler (int)

  • profiler_steps (int)

  • profile_cleanly (bool)

  • profile_periodically_period (int)

  • hide_profiler_step_metric (bool)

  • enable_jax_profiler (bool)

  • jax_profiler_port (int)

  • enable_tpu_profiling_options (bool)

  • tpu_num_chips_to_profile_per_task (int)

  • tpu_num_sparse_cores_to_trace (int)

  • tpu_num_sparse_core_tiles_to_trace (int)

  • xprof_tpu_power_trace_level (XProfTPUPowerTraceMode)

  • xprof_e2e_enable_fw_throttle_event (bool)

  • xprof_e2e_enable_fw_power_level_event (bool)

  • xprof_e2e_enable_fw_thermal_event (bool)

  • profile_power_events (bool)

profiler: ProfilerType#
upload_all_profiler_results: bool#
skip_first_n_steps_for_profiler: int#
profiler_steps: int#
profile_cleanly: bool#
profile_periodically_period: int#
hide_profiler_step_metric: bool#
enable_jax_profiler: bool#
jax_profiler_port: int#
enable_tpu_profiling_options: bool#
tpu_num_chips_to_profile_per_task: int#
tpu_num_sparse_cores_to_trace: int#
tpu_num_sparse_core_tiles_to_trace: int#
xprof_tpu_power_trace_level: XProfTPUPowerTraceMode#
xprof_e2e_enable_fw_throttle_event: bool#
xprof_e2e_enable_fw_power_level_event: bool#
xprof_e2e_enable_fw_thermal_event: bool#
profile_power_events: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.HloDump(*, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='')[source]#

Bases: BaseModel

Configuration for dumping HLO modules for debugging.

Parameters:
  • dump_hlo (bool)

  • dump_step (int)

  • dump_hlo_local_dir (str)

  • dump_hlo_delete_local_after (bool)

  • dump_hlo_gcs_dir (str)

  • dump_hlo_module_name (str)

  • dump_hlo_local_module_name (str)

  • dump_hlo_xla_flags (str)

  • dump_hlo_upload_all (bool)

  • dump_jaxpr (bool)

  • dump_jaxpr_local_dir (str)

  • dump_jaxpr_delete_local_after (bool)

  • dump_jaxpr_gcs_dir (str)

dump_hlo: bool#
dump_step: int#
dump_hlo_local_dir: str#
dump_hlo_delete_local_after: bool#
dump_hlo_gcs_dir: str#
dump_hlo_module_name: str#
dump_hlo_local_module_name: str#
dump_hlo_xla_flags: str#
dump_hlo_upload_all: bool#
dump_jaxpr: bool#
dump_jaxpr_local_dir: str#
dump_jaxpr_delete_local_after: bool#
dump_jaxpr_gcs_dir: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.StackTrace(*, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600)[source]#

Bases: BaseModel

Configuration for collecting and logging stack traces.

Parameters:
  • collect_stack_trace (bool)

  • stack_trace_to_cloud (bool)

  • stack_trace_interval_seconds (int)

collect_stack_trace: bool#
stack_trace_to_cloud: bool#
stack_trace_interval_seconds: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Metrics(*, metrics_file=None, gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False)[source]#

Bases: BaseModel

General configuration for metrics and monitoring.

Parameters:
  • metrics_file (None | str)

  • gcs_metrics (bool)

  • save_config_to_gcs (bool)

  • record_internal_nn_metrics (int)

  • prometheus_port (int)

  • enable_checkpoint_cloud_logger (bool)

  • enable_tunix_perf_metrics (bool)

metrics_file: None | str#
gcs_metrics: bool#
save_config_to_gcs: bool#
record_internal_nn_metrics: int#
prometheus_port: int#
enable_checkpoint_cloud_logger: bool#
enable_tunix_perf_metrics: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.ManagedMLDiagnostics(*, managed_mldiagnostics=False, managed_mldiagnostics_run_group='')[source]#

Bases: BaseModel

Configuration for managed mldiagnostics.

Parameters:
  • managed_mldiagnostics (bool)

  • managed_mldiagnostics_run_group (str)

managed_mldiagnostics: bool#
managed_mldiagnostics_run_group: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Goodput(*, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True)[source]#

Bases: BaseModel

Configuration for goodput monitoring.

Parameters:
  • enable_goodput_recording (bool)

  • monitor_goodput (bool)

  • goodput_upload_interval_seconds (int)

  • enable_pathways_goodput (bool)

  • monitor_step_time_deviation (bool)

  • step_deviation_interval_seconds (int)

  • enable_gcp_goodput_metrics (bool)

  • enable_gcp_step_deviation_metrics (bool)

enable_goodput_recording: bool#
monitor_goodput: bool#
goodput_upload_interval_seconds: int#
enable_pathways_goodput: bool#
monitor_step_time_deviation: bool#
step_deviation_interval_seconds: int#
enable_gcp_goodput_metrics: bool#
enable_gcp_step_deviation_metrics: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.ElasticTraining(*, elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, elastic_min_slice_count=-1)[source]#

Bases: BaseModel

Configuration for elastic training and fault tolerance.

Elastic training is Pathways-specific and does not work on McJAX.

Parameters:
  • elastic_enabled (bool)

  • elastic_timeout_seconds (int)

  • elastic_max_retries (int)

  • elastic_min_slice_count (int)

elastic_enabled: bool#
elastic_timeout_seconds: int#
elastic_max_retries: int#
elastic_min_slice_count: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.GcpMonitoring(*, report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False)[source]#

Bases: BaseModel

Configuration for GCP-specific workload monitoring.

Parameters:
  • report_heartbeat_metric_for_gcp_monitoring (bool)

  • heartbeat_reporting_interval_in_seconds (int)

  • report_performance_metric_for_gcp_monitoring (bool)

report_heartbeat_metric_for_gcp_monitoring: bool#
heartbeat_reporting_interval_in_seconds: int#
report_performance_metric_for_gcp_monitoring: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Tensorboard(*, enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='')[source]#

Bases: BaseModel

Configuration for Tensorboard logging.

Parameters:
  • enable_tensorboard (bool)

  • use_vertex_tensorboard (bool)

  • vertex_tensorboard_project (str | None)

  • vertex_tensorboard_region (str | None)

enable_tensorboard: bool#
use_vertex_tensorboard: bool#
vertex_tensorboard_project: str | None#
vertex_tensorboard_region: str | None#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.MultimodalGeneral(*, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25)[source]#

Bases: BaseModel

General configuration for Multimodal models.

Parameters:
  • use_multimodal (bool)

  • freeze_vision_encoder_params (bool)

  • freeze_audio_encoder_params (bool)

  • use_audio (bool)

  • image_size_for_vit (int | list[int])

  • image_path (str)

  • image_placeholder (str)

  • posemb_type_for_vit (str)

  • max_num_images_per_example (int)

  • video_path (str)

  • audio_path (str)

  • video_placeholder (str)

  • audio_placeholder (str)

  • use_audio_in_video (bool)

  • use_mrope (bool)

  • mrope_section (list[int])

  • position_id_per_seconds (int)

use_multimodal: bool#
freeze_vision_encoder_params: bool#
freeze_audio_encoder_params: bool#
use_audio: bool#
image_size_for_vit: int | list[int]#
image_path: str#
image_placeholder: str#
posemb_type_for_vit: str#
max_num_images_per_example: int#
video_path: str#
audio_path: str#
video_placeholder: str#
audio_placeholder: str#
use_audio_in_video: bool#
use_mrope: bool#
mrope_section: list[int]#
position_id_per_seconds: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.VisionTower(*, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1)[source]#

Bases: BaseModel

Configuration for the Vision Tower (Encoder) in a multimodal model.

Parameters:
  • hidden_size_for_vit (int)

  • intermediate_size_for_vit (int)

  • num_attention_heads_for_vit (int)

  • num_channels_for_vit (int)

  • tile_size_for_vit (int)

  • patch_size_for_vit (int)

  • conv_stride_for_vit (int)

  • num_hidden_layers_for_vit (int)

  • rope_theta_for_vit (int)

  • vision_output_dim_for_vit (int)

  • spatial_merge_size_for_vit (int)

  • out_hidden_size_for_vit (int)

  • temporal_patch_size_for_vit (int)

  • num_position_embeddings_for_vit (int)

  • deepstack_visual_indexes_for_vit (list[int])

  • vision_output_length (int)

hidden_size_for_vit: int#
intermediate_size_for_vit: int#
num_attention_heads_for_vit: int#
num_channels_for_vit: int#
tile_size_for_vit: int#
patch_size_for_vit: int#
conv_stride_for_vit: int#
num_hidden_layers_for_vit: int#
rope_theta_for_vit: int#
vision_output_dim_for_vit: int#
spatial_merge_size_for_vit: int#
out_hidden_size_for_vit: int#
temporal_patch_size_for_vit: int#
num_position_embeddings_for_vit: int#
deepstack_visual_indexes_for_vit: list[int]#
vision_output_length: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.VisionProjector(*, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0)[source]#

Bases: BaseModel

Configuration for the Vision Projector in a multimodal model.

Parameters:
  • projector_input_dim_for_vit (int)

  • projector_output_dim_for_vit (int)

  • pixel_shuffle_ratio_for_vit (float)

  • projector_dropout_for_vit (float)

projector_input_dim_for_vit: int#
projector_output_dim_for_vit: int#
pixel_shuffle_ratio_for_vit: float#
projector_dropout_for_vit: float#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.AudioEncoder(*, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000)[source]#

Bases: BaseModel

Configuration for the Audio Encoder in a multimodal model.

Parameters:
  • d_model_for_audio (int)

  • encoder_attention_heads_for_audio (int)

  • encoder_ffn_dim_for_audio (int)

  • encoder_layers_for_audio (int)

  • attention_dropout_for_audio (float)

  • activation_dropout_for_audio (float)

  • activation_function_for_audio (str)

  • num_mel_bins_for_audio (int)

  • max_source_positions_for_audio (int)

  • scale_embedding_for_audio (bool)

  • n_window_for_audio (int)

  • n_window_infer_for_audio (int)

  • conv_chunksize_for_audio (int)

  • downsample_hidden_size_for_audio (int)

  • output_dim_for_audio (int)

  • num_conv_layers_for_audio (int)

  • max_timescale_for_audio (float)

  • max_sample_len_for_audio (int)

d_model_for_audio: int#
encoder_attention_heads_for_audio: int#
encoder_ffn_dim_for_audio: int#
encoder_layers_for_audio: int#
attention_dropout_for_audio: float#
activation_dropout_for_audio: float#
activation_function_for_audio: str#
num_mel_bins_for_audio: int#
max_source_positions_for_audio: int#
scale_embedding_for_audio: bool#
n_window_for_audio: int#
n_window_infer_for_audio: int#
conv_chunksize_for_audio: int#
downsample_hidden_size_for_audio: int#
output_dim_for_audio: int#
num_conv_layers_for_audio: int#
max_timescale_for_audio: float#
max_sample_len_for_audio: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Debug(*, rl=False)[source]#

Bases: BaseModel

Configuration for debugging options.

Parameters:

rl (bool)

rl: bool#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.RLHardware(*, trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=True, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=-1, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1)[source]#

Bases: BaseModel

Hardware settings specific to RL training.

Parameters:
  • trainer_devices_fraction (float)

  • sampler_devices_fraction (float)

  • chips_per_vm (int)

  • use_pathways (bool)

  • num_trainer_slices (int)

  • num_samplers_slices (int)

  • rollout_data_parallelism (int)

  • rollout_tensor_parallelism (int)

  • rollout_expert_parallelism (int)

trainer_devices_fraction: float#
sampler_devices_fraction: float#
chips_per_vm: int#
use_pathways: bool#
num_trainer_slices: int#
num_samplers_slices: int#
rollout_data_parallelism: int#
rollout_tensor_parallelism: int#
rollout_expert_parallelism: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.VLLM(*, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config=<factory>, vllm_hf_overrides=<factory>, vllm_hf_config_path='', use_standalone_converter=False, vllm_load_format='dummy', debug_converter=False, gcs_debug_path='')[source]#

Bases: BaseModel

vLLM-specific configuration for rollouts.

Parameters:
  • kv_cache_buffer (int)

  • hbm_utilization_vllm (float)

  • swap_space_vllm_gb (int)

  • enable_dp_attention (bool)

  • enable_expert_parallel (bool)

  • async_scheduling (bool)

  • max_num_batched_tokens (int | None)

  • max_num_seqs (int | None)

  • stop_strings (list[str] | None)

  • vllm_additional_config (dict[str, Any])

  • vllm_hf_overrides (dict[str, Any])

  • vllm_hf_config_path (str)

  • use_standalone_converter (bool)

  • vllm_load_format (str)

  • debug_converter (bool)

  • gcs_debug_path (str)

kv_cache_buffer: int#
hbm_utilization_vllm: float#
swap_space_vllm_gb: int#
enable_dp_attention: bool#
enable_expert_parallel: bool#
async_scheduling: bool#
max_num_batched_tokens: int | None#
max_num_seqs: int | None#
stop_strings: list[str] | None#
vllm_additional_config: dict[str, Any]#
vllm_hf_overrides: dict[str, Any]#
vllm_hf_config_path: str#
use_standalone_converter: bool#
vllm_load_format: str#
debug_converter: bool#
gcs_debug_path: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.RL(*, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, reshard_chunk_size=None)[source]#

Bases: BaseModel

Configuration for RL algorithms like Group Relative Policy Optimization (GRPO) among others.

Parameters:
  • num_generations (int)

  • num_iterations (int)

  • grpo_beta (float)

  • grpo_epsilon (float)

  • loss_algo (Literal['grpo', 'gspo-token'])

  • use_agentic_rollout (bool)

  • max_concurrency (int)

  • off_policy_steps (int)

  • system_prompt (str)

  • degenerate_group_masking (bool)

  • epsilon_high (float | None)

  • reshard_chunk_size (int | None)

num_generations: int#
num_iterations: int#
grpo_beta: float#
grpo_epsilon: float#
loss_algo: Literal['grpo', 'gspo-token']#
use_agentic_rollout: bool#
max_concurrency: int#
off_policy_steps: int#
system_prompt: str#
degenerate_group_masking: bool#
epsilon_high: float | None#
reshard_chunk_size: int | None#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.RLDataset(*, batch_size=1, num_batches=4, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1)[source]#

Bases: BaseModel

Dataset settings for RL training.

Parameters:
  • batch_size (int)

  • num_batches (int)

  • num_test_batches (int)

  • test_batch_start_index (int)

  • train_fraction (float)

  • train_micro_batch_size (int)

  • rollout_micro_batch_size (int)

batch_size: int#
num_batches: int#
num_test_batches: int#
test_batch_start_index: int#
train_fraction: float#
train_micro_batch_size: int#
rollout_micro_batch_size: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.RLEvaluation(*, eval_sampling_strategy='greedy', generation_configs=<factory>, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass')[source]#

Bases: BaseModel

Settings for RL evaluation.

Parameters:
  • eval_sampling_strategy (str)

  • generation_configs (dict[str, Any])

  • num_eval_passes (int)

  • eval_corr_lst (bool)

  • eval_make_lst (bool)

  • eval_mode (Literal['pass', 'maj', 'pass_at_1'])

eval_sampling_strategy: str#
generation_configs: dict[str, Any]#
num_eval_passes: int#
eval_corr_lst: bool#
eval_make_lst: bool#
eval_mode: Literal['pass', 'maj', 'pass_at_1']#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Reward(*, reward_exact_answer=5.0, reward_exact_format_match=3.0, reward_white_space_format_match=1.5, reward_partial_format_match=0.5, reward_ratio_guess_to_answer_high=0.5, reward_ratio_guess_to_answer_low=0.25, penalty_incorrect_format=-0.5, penalty_incorrect_answer=-1.0, math_verify_timeout=300, math_verify_num_procs=None)[source]#

Bases: BaseModel

Configuration for the reward/penalty model in RL.

Parameters:
  • reward_exact_answer (float)

  • reward_exact_format_match (float)

  • reward_white_space_format_match (float)

  • reward_partial_format_match (float)

  • reward_ratio_guess_to_answer_high (float)

  • reward_ratio_guess_to_answer_low (float)

  • penalty_incorrect_format (float)

  • penalty_incorrect_answer (float)

  • math_verify_timeout (int)

  • math_verify_num_procs (int | None)

reward_exact_answer: float#
reward_exact_format_match: float#
reward_white_space_format_match: float#
reward_partial_format_match: float#
reward_ratio_guess_to_answer_high: float#
reward_ratio_guess_to_answer_low: float#
penalty_incorrect_format: float#
penalty_incorrect_answer: float#
math_verify_timeout: int#
math_verify_num_procs: int | None#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.SpecialTokens(*, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>')[source]#

Bases: BaseModel

Special tokens used for formatting prompts and responses in RL.

Parameters:
  • reasoning_start_token (str)

  • reasoning_end_token (str)

  • solution_start_token (str)

  • solution_end_token (str)

reasoning_start_token: str#
reasoning_end_token: str#
solution_start_token: str#
solution_end_token: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.Engram(*, engram_layers=<factory>, engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=<factory>, engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0)[source]#

Bases: BaseModel

Configuration for DeepSeek Engram (https://www.arxiv.org/pdf/2601.07372).

Parameters:
  • engram_layers (list[int])

  • engram_num_heads (int)

  • engram_head_dim (int)

  • engram_vocab_bases (list[int])

  • engram_max_ngram_size (int)

  • engram_kernel_size (int)

  • engram_seed (int)

engram_layers: list[int]#
engram_num_heads: int#
engram_head_dim: int#
engram_vocab_bases: list[int]#
engram_max_ngram_size: int#
engram_kernel_size: int#
engram_seed: int#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class maxtext.configs.types.DerivedValues(*, emb_dim=None, mlp_dim=None, moe_mlp_dim=None, num_decoder_layers=None, num_kv_heads=None, num_query_heads=None, num_diloco_replicas=None, ici_parallelism=None, dcn_parallelism=None, using_pipeline_parallelism=None, context_parallel_size=None, num_target_devices=None, global_batch_size_to_train_on=None, global_batch_size_to_eval_on=None, global_batch_size_to_load=None, global_batch_size_to_load_eval=None, micro_batch_size_to_train_on=None, micro_batch_size_to_eval_on=None, checkpoint_dir=None, convert_checkpoint_if_possible=False, metrics_dir=None, tensorboard_dir=None, managed_mldiagnostics_dir=None, rampup_end_step=None, tensors_on_device=None, tensors_to_offload=None, global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None)[source]#

Bases: BaseModel

Holds all fields that are derived from other config values for perfect legacy compatibility.

Parameters:
  • emb_dim (None | int)

  • mlp_dim (None | int)

  • moe_mlp_dim (None | int)

  • num_decoder_layers (None | int)

  • num_kv_heads (None | int)

  • num_query_heads (None | int)

  • num_diloco_replicas (None | int)

  • ici_parallelism (None | list[int])

  • dcn_parallelism (None | list[int])

  • using_pipeline_parallelism (None | bool)

  • context_parallel_size (None | int)

  • num_target_devices (None | int)

  • global_batch_size_to_train_on (None | int)

  • global_batch_size_to_eval_on (None | int)

  • global_batch_size_to_load (None | int)

  • global_batch_size_to_load_eval (None | int)

  • micro_batch_size_to_train_on (None | int)

  • micro_batch_size_to_eval_on (None | int)

  • checkpoint_dir (None | str)

  • convert_checkpoint_if_possible (bool)

  • metrics_dir (None | str)

  • tensorboard_dir (None | str)

  • managed_mldiagnostics_dir (None | str)

  • rampup_end_step (None | int)

  • tensors_on_device (None | list[str])

  • tensors_to_offload (None | list[str])

  • global_batch_size_to_load_start (None | int)

  • global_batch_size_to_load_increment (None | int)

  • rampup_samples_per_increment_to_load (None | float)

emb_dim: None | int#
mlp_dim: None | int#
moe_mlp_dim: None | int#
num_decoder_layers: None | int#
num_kv_heads: None | int#
num_query_heads: None | int#
num_diloco_replicas: None | int#
ici_parallelism: None | list[int]#
dcn_parallelism: None | list[int]#
using_pipeline_parallelism: None | bool#
context_parallel_size: None | int#
num_target_devices: None | int#
global_batch_size_to_train_on: None | int#
global_batch_size_to_eval_on: None | int#
global_batch_size_to_load: None | int#
global_batch_size_to_load_eval: None | int#
micro_batch_size_to_train_on: None | int#
micro_batch_size_to_eval_on: None | int#
checkpoint_dir: None | str#
convert_checkpoint_if_possible: bool#
metrics_dir: None | str#
tensorboard_dir: None | str#
managed_mldiagnostics_dir: None | str#
rampup_end_step: None | int#
tensors_on_device: None | list[str]#
tensors_to_offload: None | list[str]#
global_batch_size_to_load_start: None | int#
global_batch_size_to_load_increment: None | int#
rampup_samples_per_increment_to_load: None | float#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

maxtext.configs.types.get_individual_scales(scale)[source]#

Choose appropriate scales for individual dimensions based on global scale.

Parameters:

scale (int)

Return type:

tuple[int, int, int, int]

class maxtext.configs.types.MaxTextConfig(*, emb_dim=None, mlp_dim=None, moe_mlp_dim=None, num_decoder_layers=None, num_kv_heads=None, num_query_heads=None, num_diloco_replicas=None, ici_parallelism=None, dcn_parallelism=None, using_pipeline_parallelism=None, context_parallel_size=None, num_target_devices=None, global_batch_size_to_train_on=None, global_batch_size_to_eval_on=None, global_batch_size_to_load=None, global_batch_size_to_load_eval=None, micro_batch_size_to_train_on=None, micro_batch_size_to_eval_on=None, checkpoint_dir=None, convert_checkpoint_if_possible=False, metrics_dir=None, tensorboard_dir=None, managed_mldiagnostics_dir=None, rampup_end_step=None, tensors_on_device=None, tensors_to_offload=None, global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file=None, gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='', profiler=ProfilerType.NONE, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, enable_tpu_profiling_options=False, tpu_num_chips_to_profile_per_task=1, tpu_num_sparse_cores_to_trace=2, tpu_num_sparse_core_tiles_to_trace=1, xprof_tpu_power_trace_level=XProfTPUPowerTraceMode.POWER_TRACE_NONE, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='/home/docs/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=False, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=SamplingStrategy.GREEDY, decode_sampling_nucleus_p=-1.0, decode_sampling_top_k=0, decode_sampling_temperature=1.0, max_target_length=2048, max_prefill_predict_length=64, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=32000, tokenizer_path=None, tokenizer_type=TokenizerType.SENTENCEPIECE, use_chat_template=False, chat_template_path='', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, olmo_index_path='', olmo_path_remap_from='', olmo_path_remap_to='', olmo_apply_ngram_filter=True, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_use_elastic_iterator=False, grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name=None, hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token=None, dataset_path='', dataset_name='c4/en:3.0.1', eval_dataset_name='c4/en:3.0.1', train_split='train', eval_split='validation', dataset_type=DatasetType.TFDS, per_device_batch_size=12, eval_per_device_batch_size=0.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=RopeType.DEFAULT, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=10000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=5.0, reward_exact_format_match=3.0, reward_white_space_format_match=1.5, reward_partial_format_match=0.5, reward_ratio_guess_to_answer_high=0.5, reward_ratio_guess_to_answer_low=0.25, penalty_incorrect_format=-0.5, penalty_incorrect_answer=-1.0, math_verify_timeout=300, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs=<factory>, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=4, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, reshard_chunk_size=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config=<factory>, vllm_hf_overrides=<factory>, vllm_hf_config_path='', use_standalone_converter=False, vllm_load_format='dummy', debug_converter=False, gcs_debug_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=True, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=-1, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides=<factory>, teacher_overrides=<factory>, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', learn_to_init_mode=False, lti_use_general_linear_map=False, distill_weights_copy_map=<factory>, distill_student_weights_share_map=<factory>, student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs=<factory>, use_grpo=None, muon_beta=0.95, muon_weight_decay=0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.95, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=<factory>, mu_dtype='', opt_type=OptimizerType.ADAMW, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=1.0, learning_rate=3e-05, lr_schedule_type=LearningRateScheduleType.COSINE, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=WsdDecayStyle.LINEAR, warmup_steps_fraction=0.1, learning_rate_schedule_steps=-1, trainable_parameters_mask=<factory>, enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=100, eval_interval=-1, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=True, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=0, init_weights_seed=0, remat_policy='full', remat_policy_for_vit='minimal', decoder_layer_input=RematLocation.DEVICE, context=RematLocation.REMAT, mlpwi=RematLocation.REMAT, mlpwi_0=RematLocation.REMAT, mlpwi_1=RematLocation.REMAT, mlpwo=RematLocation.REMAT, moe_mlpwi_0=RematLocation.REMAT, moe_mlpwi_1=RematLocation.REMAT, moe_mlpwo=RematLocation.REMAT, query_proj=RematLocation.REMAT, key_proj=RematLocation.REMAT, value_proj=RematLocation.REMAT, query_wa_proj=RematLocation.REMAT, kv_wa_proj=RematLocation.REMAT, qkv_proj=RematLocation.REMAT, out_proj=RematLocation.REMAT, mla_q=RematLocation.REMAT, mla_kv=RematLocation.REMAT, attention_out=RematLocation.REMAT, engram=RematLocation.REMAT, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=-1, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=True, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[], data_sharding=[], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=-1, mesh_axes=['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode='auto', inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=ReorderStrategy.AUTO, custom_mesh='', custom_mesh_and_rule=CustomRule.DEFAULT, allow_split_physical_axes=False, enable_nnx=False, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=False, pure_nnx=False, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, padded_base_moe_mlp_dim=None, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='autoselected', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=<factory>, engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=<factory>, engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block='llama2', global_parameter_scale=1, base_emb_dim=2048, base_num_query_heads=16, base_num_kv_heads=16, base_mlp_dim=7168, dense_init_scale=1.0, base_num_decoder_layers=16, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=QuantizationType.NONE, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=KvQuantAxis.HEADS_AND_DKV, kv_quant_dtype='int8', quantization_local_shard_count=-1, use_qwix_quantization=False, use_manual_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', weight_sparsity_n=None, weight_sparsity_m=None, weight_sparsity_update_step=10, weight_sparsity_start_step=50, dtype=DType.BFLOAT16, grad_dtype=DType.FLOAT32, weight_dtype=DType.FLOAT32, matmul_precision=MatmulPrecision.DEFAULT, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, elastic_min_slice_count=-1, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=True, checkpoint_storage_use_zarr3=True, checkpoint_storage_concurrent_gb=96, load_parameters_path='', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=True, checkpoint_period=10000, max_num_checkpoints_to_keep=None, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config=None, run_name='', model_name='default', override_model_config=False, override_logical_axis_rules=False, log_config=True, debug_sharding=False, base_output_directory='', sharding_strategy=None, debug=<factory>, rl=<factory>)[source]#

Bases: RunInfo, Checkpointing, OrbaxStorage, EmergencyCheckpointing, ElasticTraining, DataTypes, Quantization, ModelArchitecture, Engram, MTP, Logits, Attention, MlaAttention, MoBa, AttentionIndexer, Llama4Attention, SplashAttention, PagedAttention, MoEGeneral, MoEKernels, DeepSeekMoE, Qwen3Next, HardwareAndMesh, LayoutAndSharding, DcnParallelism, IciParallelism, PipelineParallelism, RematAndOffload, TrainingLoop, ManifoldConstrainedHyperConnections, DilocoParams, Optimizer, AdamW, Muon, FineTuning, Distillation, RLHardware, VLLM, RL, RLDataset, RLEvaluation, Reward, SpecialTokens, PositionalEmbedding, Rope, YarnRope, DatasetGeneral, TfdsDataset, HfDataset, GrainDataset, OlmoGrainDataset, Tokenizer, InferenceGeneral, Decoding, InferenceLayout, InferenceServer, InferenceBenchmark, PrefixCaching, AOT, DevelopmentAndDebugging, Profiling, HloDump, StackTrace, Metrics, Goodput, GcpMonitoring, Tensorboard, ManagedMLDiagnostics, MultimodalGeneral, VisionTower, VisionProjector, AudioEncoder, DerivedValues

The main configuration object for MaxText.

This class aggregates all configuration options from modular BaseModel classes into a single, validated object. It is populated by the initialize function. Every field is explicitly defined to prevent misconfigurations (extra=’forbid’).

Parameters:
  • emb_dim (None | int)

  • mlp_dim (None | int)

  • moe_mlp_dim (None | int)

  • num_decoder_layers (None | int)

  • num_kv_heads (None | int)

  • num_query_heads (None | int)

  • num_diloco_replicas (None | int)

  • ici_parallelism (None | list[int])

  • dcn_parallelism (None | list[int])

  • using_pipeline_parallelism (None | bool)

  • context_parallel_size (None | int)

  • num_target_devices (None | int)

  • global_batch_size_to_train_on (None | int)

  • global_batch_size_to_eval_on (None | int)

  • global_batch_size_to_load (None | int)

  • global_batch_size_to_load_eval (None | int)

  • micro_batch_size_to_train_on (None | int)

  • micro_batch_size_to_eval_on (None | int)

  • checkpoint_dir (None | str)

  • convert_checkpoint_if_possible (bool)

  • metrics_dir (None | str)

  • tensorboard_dir (None | str)

  • managed_mldiagnostics_dir (None | str)

  • rampup_end_step (None | int)

  • tensors_on_device (None | list[str])

  • tensors_to_offload (None | list[str])

  • global_batch_size_to_load_start (None | int)

  • global_batch_size_to_load_increment (None | int)

  • rampup_samples_per_increment_to_load (None | float)

  • d_model_for_audio (int)

  • encoder_attention_heads_for_audio (int)

  • encoder_ffn_dim_for_audio (int)

  • encoder_layers_for_audio (int)

  • attention_dropout_for_audio (float)

  • activation_dropout_for_audio (float)

  • activation_function_for_audio (str)

  • num_mel_bins_for_audio (int)

  • max_source_positions_for_audio (int)

  • scale_embedding_for_audio (bool)

  • n_window_for_audio (int)

  • n_window_infer_for_audio (int)

  • conv_chunksize_for_audio (int)

  • downsample_hidden_size_for_audio (int)

  • output_dim_for_audio (int)

  • num_conv_layers_for_audio (int)

  • max_timescale_for_audio (float)

  • max_sample_len_for_audio (int)

  • projector_input_dim_for_vit (int)

  • projector_output_dim_for_vit (int)

  • pixel_shuffle_ratio_for_vit (float)

  • projector_dropout_for_vit (float)

  • hidden_size_for_vit (int)

  • intermediate_size_for_vit (int)

  • num_attention_heads_for_vit (int)

  • num_channels_for_vit (int)

  • tile_size_for_vit (int)

  • patch_size_for_vit (int)

  • conv_stride_for_vit (int)

  • num_hidden_layers_for_vit (int)

  • rope_theta_for_vit (int)

  • vision_output_dim_for_vit (int)

  • spatial_merge_size_for_vit (int)

  • out_hidden_size_for_vit (int)

  • temporal_patch_size_for_vit (int)

  • num_position_embeddings_for_vit (int)

  • deepstack_visual_indexes_for_vit (list[int])

  • vision_output_length (int)

  • use_multimodal (bool)

  • freeze_vision_encoder_params (bool)

  • freeze_audio_encoder_params (bool)

  • use_audio (bool)

  • image_size_for_vit (int | list[int])

  • image_path (str)

  • image_placeholder (str)

  • posemb_type_for_vit (str)

  • max_num_images_per_example (int)

  • video_path (str)

  • audio_path (str)

  • video_placeholder (str)

  • audio_placeholder (str)

  • use_audio_in_video (bool)

  • use_mrope (bool)

  • mrope_section (list[int])

  • position_id_per_seconds (int)

  • managed_mldiagnostics (bool)

  • managed_mldiagnostics_run_group (str)

  • enable_tensorboard (bool)

  • use_vertex_tensorboard (bool)

  • vertex_tensorboard_project (str | None)

  • vertex_tensorboard_region (str | None)

  • report_heartbeat_metric_for_gcp_monitoring (bool)

  • heartbeat_reporting_interval_in_seconds (int)

  • report_performance_metric_for_gcp_monitoring (bool)

  • enable_goodput_recording (bool)

  • monitor_goodput (bool)

  • goodput_upload_interval_seconds (int)

  • enable_pathways_goodput (bool)

  • monitor_step_time_deviation (bool)

  • step_deviation_interval_seconds (int)

  • enable_gcp_goodput_metrics (bool)

  • enable_gcp_step_deviation_metrics (bool)

  • metrics_file (None | str)

  • gcs_metrics (bool)

  • save_config_to_gcs (bool)

  • record_internal_nn_metrics (int)

  • prometheus_port (int)

  • enable_checkpoint_cloud_logger (bool)

  • enable_tunix_perf_metrics (bool)

  • collect_stack_trace (bool)

  • stack_trace_to_cloud (bool)

  • stack_trace_interval_seconds (int)

  • dump_hlo (bool)

  • dump_step (int)

  • dump_hlo_local_dir (str)

  • dump_hlo_delete_local_after (bool)

  • dump_hlo_gcs_dir (str)

  • dump_hlo_module_name (str)

  • dump_hlo_local_module_name (str)

  • dump_hlo_xla_flags (str)

  • dump_hlo_upload_all (bool)

  • dump_jaxpr (bool)

  • dump_jaxpr_local_dir (str)

  • dump_jaxpr_delete_local_after (bool)

  • dump_jaxpr_gcs_dir (str)

  • profiler (ProfilerType)

  • upload_all_profiler_results (bool)

  • skip_first_n_steps_for_profiler (int)

  • profiler_steps (int)

  • profile_cleanly (bool)

  • profile_periodically_period (int)

  • hide_profiler_step_metric (bool)

  • enable_jax_profiler (bool)

  • jax_profiler_port (int)

  • enable_tpu_profiling_options (bool)

  • tpu_num_chips_to_profile_per_task (int)

  • tpu_num_sparse_cores_to_trace (int)

  • tpu_num_sparse_core_tiles_to_trace (int)

  • xprof_tpu_power_trace_level (XProfTPUPowerTraceMode)

  • xprof_e2e_enable_fw_throttle_event (bool)

  • xprof_e2e_enable_fw_power_level_event (bool)

  • xprof_e2e_enable_fw_thermal_event (bool)

  • profile_power_events (bool)

  • constant_bound_config (list)

  • jax_cache_dir (str | None)

  • jax_distributed_initialization_timeout (int)

  • jax_debug_log_modules (str)

  • skip_jax_distributed_system (bool)

  • enable_single_controller (bool)

  • subslice_shape (str)

  • max_checkify (bool)

  • compiled_trainstep_file (str)

  • compile_topology (str)

  • compile_topology_num_slices (int)

  • enable_prefix_caching (bool)

  • prefix_caching_hbm_byte (int)

  • prefix_caching_dram_byte (int)

  • inference_microbenchmark_prefill_lengths (str)

  • inference_microbenchmark_stages (str)

  • inference_microbenchmark_loop_iters (int)

  • inference_microbenchmark_log_file_path (str)

  • inference_microbenchmark_num_samples (list[int])

  • inference_metadata_file (str)

  • inference_benchmark_test (bool)

  • inference_server (str)

  • prefill_slice (str)

  • generate_slice (str)

  • stack_prefill_result_cache (bool)

  • prefill_cache_axis_order (str)

  • ar_cache_axis_order (str)

  • compute_axis_order (str)

  • reshape_q (bool)

  • decode_sampling_strategy (SamplingStrategy)

  • decode_sampling_nucleus_p (int | float)

  • decode_sampling_top_k (int)

  • decode_sampling_temperature (float)

  • max_target_length (int)

  • max_prefill_predict_length (int)

  • prompt (str)

  • load_from_prefill_dir (bool)

  • prefill_cache_dir (str)

  • autoregressive_decode_assert (str)

  • model_call_mode (str)

  • use_chunked_prefill (bool)

  • prefill_chunk_size (int)

  • enable_model_warmup (bool)

  • enable_llm_inference_pool (bool)

  • multi_sampling (bool)

  • return_log_prob (bool)

  • vocab_size (int)

  • tokenizer_path (None | str)

  • tokenizer_type (TokenizerType)

  • use_chat_template (bool)

  • chat_template_path (str)

  • chat_template (str)

  • tokenize_train_data (bool)

  • tokenize_eval_data (bool)

  • add_bos (bool)

  • add_eos (bool)

  • use_truncation (bool)

  • num_vocab_tiling (int)

  • olmo_index_path (str)

  • olmo_path_remap_from (str)

  • olmo_path_remap_to (str)

  • olmo_apply_ngram_filter (bool)

  • grain_train_files (str)

  • grain_eval_files (str)

  • grain_train_mixture_config_path (str)

  • grain_file_type (str)

  • grain_use_elastic_iterator (bool)

  • grain_worker_count (int)

  • grain_per_worker_buffer_size (int)

  • grain_worker_count_eval (int)

  • grain_per_worker_buffer_size_eval (int)

  • grain_ram_budget_mb (int)

  • grain_num_threads (int)

  • grain_prefetch_buffer_size (int)

  • grain_num_threads_eval (int)

  • grain_prefetch_buffer_size_eval (int)

  • grain_data_source_max_workers (int)

  • grain_shuffle_buffer_size (int)

  • hf_path (str)

  • hf_name (None | str)

  • hf_data_dir (None | str)

  • hf_train_files (None | str)

  • hf_eval_split (None | str)

  • hf_eval_files (None | str)

  • hf_access_token (None | str)

  • dataset_path (str)

  • dataset_name (str)

  • eval_dataset_name (str)

  • train_split (str)

  • eval_split (str)

  • dataset_type (DatasetType)

  • per_device_batch_size (int | float)

  • eval_per_device_batch_size (int | float)

  • max_corpus_chars (int)

  • train_data_columns (list[str])

  • train_image_column (str | list[str])

  • eval_data_columns (list[str])

  • eval_image_column (str | list[str])

  • packing (bool)

  • grain_packing_type (Literal['first_fit', 'best_fit', 'concat_then_split'])

  • max_segments_per_seq (int)

  • num_epoch (int)

  • expansion_factor_real_data (float)

  • reuse_example_batch (int)

  • generate_padding_batch_train (bool)

  • generate_padding_batch_eval (bool)

  • enable_rampup_batch_size (bool)

  • per_device_batch_size_start (float)

  • per_device_batch_size_increment (float)

  • global_rampup_samples (int)

  • colocated_python_data_input (bool)

  • max_position_embeddings (int)

  • original_max_position_embeddings (int)

  • rope_factor (int)

  • beta_fast (int)

  • beta_slow (int)

  • mscale (float)

  • rope_interleave (bool)

  • rope_truncate (bool)

  • rope_attention_scaling (bool)

  • rope_type (RopeType)

  • rope_use_scale (bool)

  • rope_min_timescale (int)

  • rope_max_timescale (int)

  • rope_linear_scaling_factor (float)

  • local_rope_max_timescale (int)

  • global_rope_max_timescale (int)

  • global_rope_proportion (float)

  • local_rope_proportion (float)

  • use_iota_embed (bool)

  • use_untrainable_positional_embedding (bool)

  • trainable_position_size (int)

  • nope_layer_interval (int)

  • reasoning_start_token (str)

  • reasoning_end_token (str)

  • solution_start_token (str)

  • solution_end_token (str)

  • reward_exact_answer (float)

  • reward_exact_format_match (float)

  • reward_white_space_format_match (float)

  • reward_partial_format_match (float)

  • reward_ratio_guess_to_answer_high (float)

  • reward_ratio_guess_to_answer_low (float)

  • penalty_incorrect_format (float)

  • penalty_incorrect_answer (float)

  • math_verify_timeout (int)

  • math_verify_num_procs (int | None)

  • eval_sampling_strategy (str)

  • generation_configs (dict[str, Any])

  • num_eval_passes (int)

  • eval_corr_lst (bool)

  • eval_make_lst (bool)

  • eval_mode (Literal['pass', 'maj', 'pass_at_1'])

  • batch_size (int)

  • num_batches (int)

  • num_test_batches (int)

  • test_batch_start_index (int)

  • train_fraction (float)

  • train_micro_batch_size (int)

  • rollout_micro_batch_size (int)

  • num_generations (int)

  • num_iterations (int)

  • grpo_beta (float)

  • grpo_epsilon (float)

  • loss_algo (Literal['grpo', 'gspo-token'])

  • use_agentic_rollout (bool)

  • max_concurrency (int)

  • off_policy_steps (int)

  • system_prompt (str)

  • degenerate_group_masking (bool)

  • epsilon_high (float | None)

  • reshard_chunk_size (int | None)

  • kv_cache_buffer (int)

  • hbm_utilization_vllm (float)

  • swap_space_vllm_gb (int)

  • enable_dp_attention (bool)

  • enable_expert_parallel (bool)

  • async_scheduling (bool)

  • max_num_batched_tokens (int | None)

  • max_num_seqs (int | None)

  • stop_strings (list[str] | None)

  • vllm_additional_config (dict[str, Any])

  • vllm_hf_overrides (dict[str, Any])

  • vllm_hf_config_path (str)

  • use_standalone_converter (bool)

  • vllm_load_format (str)

  • debug_converter (bool)

  • gcs_debug_path (str)

  • trainer_devices_fraction (float)

  • sampler_devices_fraction (float)

  • chips_per_vm (int)

  • use_pathways (bool)

  • num_trainer_slices (int)

  • num_samplers_slices (int)

  • rollout_data_parallelism (int)

  • rollout_tensor_parallelism (int)

  • rollout_expert_parallelism (int)

  • student_overrides (dict[str, Any])

  • teacher_overrides (dict[str, Any])

  • offline_data_dir (str | None)

  • distill_alpha (float)

  • distill_temperature (float)

  • distill_beta (float)

  • distill_feature_loss_type (Literal['cosine', 'l2'])

  • distill_layer_indices (None | list)

  • distill_alpha_end (float | None)

  • distill_alpha_schedule (Literal['constant', 'linear', 'cosine'])

  • distill_temperature_end (float | None)

  • distill_temperature_schedule (Literal['constant', 'linear', 'cosine'])

  • distill_beta_end (float | None)

  • distill_beta_schedule (Literal['constant', 'linear', 'cosine'])

  • learn_to_init_mode (bool)

  • lti_use_general_linear_map (bool)

  • distill_weights_copy_map (dict[str, Any])

  • distill_student_weights_share_map (dict[str, Any])

  • student_params_to_update (None | list)

  • use_dpo (bool)

  • dpo_label_smoothing (Annotated[float, Ge(ge=0.0), Le(le=1.0)])

  • dpo_beta (float)

  • use_sft (bool)

  • sft_train_on_completion_only (bool)

  • formatting_func_path (str)

  • formatting_func_kwargs (dict)

  • use_grpo (None | bool)

  • muon_beta (float)

  • muon_weight_decay (float)

  • muon_consistent_rms (float | None)

  • adam_b1 (float)

  • adam_b2 (float)

  • adam_eps (float)

  • adam_eps_root (float)

  • adam_weight_decay (float)

  • adamw_mask (list[str])

  • mu_dtype (str)

  • opt_type (OptimizerType)

  • skip_step_on_spikes (bool)

  • skip_step_interval (Annotated[int, Gt(gt=0)])

  • skip_step_scaling_factor (float)

  • gradient_accumulation_steps (Annotated[int, Gt(gt=0)])

  • use_tunix_gradient_accumulation (bool)

  • gradient_clipping_threshold (Annotated[float, Ge(ge=0)])

  • learning_rate (Annotated[float, Ge(ge=0)])

  • lr_schedule_type (LearningRateScheduleType)

  • learning_rate_final_fraction (float)

  • wsd_decay_steps_fraction (Annotated[float, Ge(ge=0.0), Le(le=1.0)])

  • wsd_decay_style (WsdDecayStyle)

  • warmup_steps_fraction (Annotated[float, Ge(ge=0.0), Le(le=1.0)])

  • learning_rate_schedule_steps (Annotated[int, Ge(ge=-1)])

  • trainable_parameters_mask (list[str])

  • enable_diloco (bool)

  • diloco_sync_period (int)

  • diloco_outer_lr (float)

  • diloco_outer_momentum (float)

  • mhc_expansion_rate (Annotated[int, Gt(gt=0)])

  • sinkhorn_iterations (Annotated[int, Gt(gt=0)])

  • steps (Annotated[int, Ge(ge=-1)])

  • log_period (int)

  • eval_interval (int)

  • eval_steps (int)

  • target_eval_loss (float)

  • abort_on_nan_loss (bool)

  • abort_on_inf_loss (bool)

  • enable_dropout (bool)

  • dropout_rate (Annotated[float, Ge(ge=0.0), Le(le=1.0)])

  • enable_data_shuffling (bool)

  • data_shuffle_seed (int)

  • init_weights_seed (int)

  • remat_policy (str)

  • remat_policy_for_vit (str)

  • decoder_layer_input (RematLocation)

  • context (RematLocation)

  • mlpwi (RematLocation)

  • mlpwi_0 (RematLocation)

  • mlpwi_1 (RematLocation)

  • mlpwo (RematLocation)

  • moe_mlpwi_0 (RematLocation)

  • moe_mlpwi_1 (RematLocation)

  • moe_mlpwo (RematLocation)

  • query_proj (RematLocation)

  • key_proj (RematLocation)

  • value_proj (RematLocation)

  • query_wa_proj (RematLocation)

  • kv_wa_proj (RematLocation)

  • qkv_proj (RematLocation)

  • out_proj (RematLocation)

  • mla_q (RematLocation)

  • mla_kv (RematLocation)

  • attention_out (RematLocation)

  • engram (RematLocation)

  • optimizer_memory_host_offload (bool)

  • parameter_memory_host_offload (bool)

  • pipeline_fsdp_ag_per_repeat (bool)

  • num_layers_per_pipeline_stage (int)

  • num_pipeline_repeats (int)

  • pipeline_parallel_layers (int)

  • num_pipeline_microbatches (int)

  • pipeline_delay_activation_forwarding (bool)

  • pipeline_fsdp_ag_once (bool)

  • scan_pipeline_iterations (bool)

  • scan_pipeline_repeats (bool)

  • scan_layers_per_stage (bool)

  • set_remat_policy_on_pipeline_iterations (bool)

  • set_remat_policy_on_layers_per_stage (bool)

  • ici_diloco_parallelism (int)

  • ici_data_parallelism (int)

  • ici_fsdp_parallelism (int)

  • ici_fsdp_transpose_parallelism (int)

  • ici_sequence_parallelism (int)

  • ici_context_parallelism (int)

  • ici_context_autoregressive_parallelism (int)

  • ici_tensor_parallelism (int)

  • ici_tensor_transpose_parallelism (int)

  • ici_tensor_sequence_parallelism (int)

  • ici_autoregressive_parallelism (int)

  • ici_pipeline_parallelism (int)

  • ici_expert_parallelism (int)

  • dcn_diloco_parallelism (int)

  • dcn_data_parallelism (int)

  • dcn_fsdp_parallelism (int)

  • dcn_fsdp_transpose_parallelism (int)

  • dcn_sequence_parallelism (int)

  • dcn_context_parallelism (int)

  • dcn_context_autoregressive_parallelism (int)

  • dcn_tensor_parallelism (int)

  • dcn_tensor_transpose_parallelism (int)

  • dcn_tensor_sequence_parallelism (int)

  • dcn_pipeline_parallelism (int)

  • dcn_expert_parallelism (int)

  • dcn_autoregressive_parallelism (int)

  • logical_axis_rules (Any)

  • data_sharding (Any)

  • context_sharding (str)

  • input_data_sharding_logical_axes (list[str])

  • sharding_tolerance (Annotated[float, Ge(ge=0.0), Le(le=1.0)])

  • shard_optimizer_over_data (bool)

  • internal_compile (bool)

  • internal_compile_num_devices (int)

  • compile_xla_flags (str)

  • hardware (Literal['tpu', 'gpu', 'gpu_multiprocess', 'cpu'])

  • num_slices (int)

  • mesh_axes (list[str])

  • shard_mode (ShardMode)

  • inhomogeneous_layer_cycle_interval (int)

  • scan_layers (bool)

  • param_scan_axis (int)

  • context_parallel_load_balance (bool)

  • context_parallel_strategy (str)

  • context_parallel_reorder_strategy (ReorderStrategy)

  • custom_mesh (str)

  • custom_mesh_and_rule (CustomRule)

  • allow_split_physical_axes (bool)

  • enable_nnx (bool)

  • optimize_mesh_for_tpu_v6e (bool)

  • shardy (bool)

  • pure_nnx_decoder (bool)

  • pure_nnx (bool)

  • remove_size_one_mesh_axis_from_type (bool)

  • gdn_conv_kernel_dim (int)

  • gdn_key_head_dim (int)

  • gdn_value_head_dim (int)

  • gdn_num_key_heads (int)

  • gdn_num_value_heads (int)

  • gdn_chunk_size (int)

  • use_qk_norm_in_gdn (bool)

  • partial_rotary_factor (float)

  • first_num_dense_layers (Annotated[int, Ge(ge=0)])

  • shared_experts (Annotated[int, Ge(ge=0)])

  • routed_scaling_factor (float)

  • routed_score_func (str)

  • routed_bias (bool)

  • routed_bias_update_rate (float)

  • mlp_bias (bool)

  • n_routing_groups (int)

  • topk_routing_group (int)

  • use_batch_split_schedule (bool)

  • batch_split_factor (int)

  • megablox (bool)

  • sparse_matmul (bool)

  • wi_tile_fwd_batch_seq (int)

  • wi_tile_fwd_embed_dim (int)

  • wi_tile_fwd_mlp_dim (int)

  • wi_tile_dlhs_batch_seq (int)

  • wi_tile_dlhs_embed_dim (int)

  • wi_tile_dlhs_mlp_dim (int)

  • wi_tile_drhs_batch_seq (int)

  • wi_tile_drhs_embed_dim (int)

  • wi_tile_drhs_mlp_dim (int)

  • wo_tile_fwd_batch_seq (int)

  • wo_tile_fwd_embed_dim (int)

  • wo_tile_fwd_mlp_dim (int)

  • wo_tile_dlhs_batch_seq (int)

  • wo_tile_dlhs_embed_dim (int)

  • wo_tile_dlhs_mlp_dim (int)

  • wo_tile_drhs_batch_seq (int)

  • wo_tile_drhs_embed_dim (int)

  • wo_tile_drhs_mlp_dim (int)

  • merge_gating_gmm (bool)

  • num_experts (Annotated[int, Gt(gt=0)])

  • num_experts_per_tok (Annotated[int, Gt(gt=0)])

  • capacity_factor (float)

  • ragged_buffer_factor (float)

  • moe_expert_input_dim (int)

  • base_moe_mlp_dim (int)

  • padded_base_moe_mlp_dim (int | None)

  • load_balance_loss_weight (Annotated[float, Ge(ge=0)])

  • use_custom_sort_vjp (bool)

  • use_ring_of_experts (bool)

  • use_gather_mosaic_kernel (bool)

  • use_random_routing (bool)

  • interleave_moe_layer_step (int)

  • moe_fsdp_use_two_stage_all_gather (bool)

  • shard_exp_on_fsdp (bool)

  • use_2d_fsdp_sharding (bool)

  • norm_topk_prob (bool)

  • float32_weight_sum (bool)

  • float32_gate_logits (bool)

  • prefuse_moe_weights (bool)

  • pagedattn_num_pages (int)

  • pagedattn_tokens_per_page (int)

  • pagedattn_pages_per_compute_block (int)

  • pagedattn_max_pages_per_group (int)

  • pagedattn_head_dim_alignment (int)

  • sa_block_q (int)

  • sa_block_kv (int)

  • sa_block_kv_compute (int)

  • sa_block_q_dkv (int)

  • sa_block_kv_dkv (int)

  • sa_block_kv_dkv_compute (int)

  • sa_block_q_dq (int)

  • sa_block_kv_dq (int)

  • sa_use_fused_bwd_kernel (bool)

  • sa_q_layout (str)

  • sa_k_layout (str)

  • sa_v_layout (str)

  • use_max_logit_estimate (int)

  • cost_estimate_flops_fwd (int)

  • cost_estimate_flops_bwd (int)

  • dq_reduction_steps (int)

  • use_splash_scheduler (bool)

  • use_qk_norm (bool)

  • temperature_tuning (bool)

  • use_indexer (bool)

  • indexer_head_dim (Annotated[int, Ge(ge=0)])

  • indexer_n_heads (Annotated[int, Ge(ge=0)])

  • indexer_topk (Annotated[int, Ge(ge=0)])

  • indexer_sparse_training (bool)

  • indexer_loss_scaling_factor (float)

  • moba (bool)

  • moba_chunk_size (int)

  • moba_topk (int)

  • mla_naive_kvcache (bool)

  • q_lora_rank (Annotated[int, Ge(ge=0)])

  • kv_lora_rank (Annotated[int, Ge(ge=0)])

  • qk_nope_head_dim (Annotated[int, Ge(ge=0)])

  • qk_rope_head_dim (Annotated[int, Ge(ge=0)])

  • v_head_dim (Annotated[int, Ge(ge=0)])

  • attention (str)

  • attention_type (Literal['global', 'local_sliding', 'chunk', 'mla', 'full'])

  • share_kv_projections (bool)

  • global_num_kv_heads (int)

  • attention_sink (bool)

  • float32_qk_product (bool)

  • float32_logits (bool)

  • sliding_window_size (Annotated[int, Ge(ge=0)])

  • chunk_attn_window_size (Annotated[int, Ge(ge=0)])

  • attn_logits_soft_cap (None | Annotated[float, Ge(ge=0)])

  • use_post_attn_norm (bool)

  • use_post_ffw_norm (bool)

  • use_ragged_attention (bool)

  • use_tokamax_gmm (bool)

  • ragged_block_size (int)

  • enable_padding_causal_mask (bool)

  • use_tokamax_splash (bool)

  • use_jax_splash (bool)

  • force_q_layout (bool)

  • use_qk_clip (bool)

  • qk_clip_threshold (float)

  • logits_via_embedding (bool)

  • normalize_embedding_logits (bool)

  • logits_dot_in_fp32 (bool)

  • cast_logits_to_fp32 (bool)

  • final_logits_soft_cap (None | Annotated[float, Ge(ge=0)])

  • z_loss_multiplier (float)

  • mtp_num_layers (Annotated[int, Ge(ge=0)])

  • mtp_loss_scaling_factor (Annotated[float, Ge(ge=0)])

  • mtp_eval_target_module (Annotated[int, Ge(ge=0)])

  • engram_layers (list[int])

  • engram_num_heads (int)

  • engram_head_dim (int)

  • engram_vocab_bases (list[int])

  • engram_max_ngram_size (int)

  • engram_kernel_size (int)

  • engram_seed (int)

  • decoder_block (DecoderBlockType)

  • global_parameter_scale (int)

  • base_emb_dim (int)

  • base_num_query_heads (int)

  • base_num_kv_heads (int)

  • base_mlp_dim (int)

  • dense_init_scale (float)

  • base_num_decoder_layers (int)

  • head_dim (int)

  • attention_output_dim (int)

  • global_head_dim (int)

  • mlp_activations (list[str])

  • mlp_activations_limit (float)

  • normalization_layer_epsilon (float)

  • fused_qkv (bool)

  • attention_bias (bool)

  • fused_mlp (bool)

  • qk_norm_with_scale (bool)

  • v_norm_with_scale (bool)

  • quantization (None | QuantizationType)

  • replicate_quant_scale (bool)

  • quant_cfg_path (str)

  • quantize_kvcache (bool)

  • kv_quant_axis (KvQuantAxis)

  • kv_quant_dtype (Literal['int8', 'int4'])

  • quantization_local_shard_count (int)

  • use_qwix_quantization (bool)

  • use_manual_quantization (bool)

  • weight_quantization_calibration_method (str)

  • act_quantization_calibration_method (str)

  • bwd_quantization_calibration_method (str)

  • weight_sparsity_n (int | None)

  • weight_sparsity_m (int | None)

  • weight_sparsity_update_step (int)

  • weight_sparsity_start_step (int)

  • dtype (DType)

  • grad_dtype (DType)

  • weight_dtype (DType)

  • matmul_precision (MatmulPrecision)

  • activations_in_float32 (bool)

  • dtype_mm (str)

  • elastic_enabled (bool)

  • elastic_timeout_seconds (int)

  • elastic_max_retries (int)

  • elastic_min_slice_count (int)

  • enable_multi_tier_checkpointing (bool)

  • local_checkpoint_directory (str)

  • local_checkpoint_period (Annotated[int, Ge(ge=0)])

  • multi_tier_checkpointing_backup_interval_minutes (Annotated[int, Ge(ge=0)])

  • mtc_data_parallelism (int)

  • enable_emergency_checkpoint (bool)

  • use_replicator_service (bool)

  • replicator_backup_interval_minutes (Annotated[int, Ge(ge=0)])

  • checkpoint_storage_target_data_file_size_bytes (int)

  • checkpoint_storage_use_ocdbt (bool)

  • checkpoint_storage_use_zarr3 (bool)

  • checkpoint_storage_concurrent_gb (int)

  • load_parameters_path (str)

  • lora_input_adapters_path (str)

  • load_full_state_path (str)

  • enable_checkpointing (bool)

  • load_checkpoint_only_once (bool)

  • async_checkpointing (bool)

  • checkpoint_period (int)

  • max_num_checkpoints_to_keep (int | None)

  • enable_single_replica_ckpt_restoring (bool)

  • checkpoint_todelete_subdir (str | None)

  • checkpoint_todelete_full_path (str | None)

  • force_unroll (bool)

  • checkpoint_is_quantized (bool)

  • save_quantized_params_path (str)

  • enable_orbax_v1 (bool)

  • checkpoint_conversion_fn (None | str)

  • source_checkpoint_layout (Literal['orbax', 'safetensors'])

  • save_checkpoint_on_completion (bool)

  • enable_continuous_checkpointing (bool)

  • colocated_python_checkpointing (bool)

  • enable_autocheckpoint (bool)

  • base_config (None | str)

  • run_name (str)

  • model_name (Literal['default', 'llama2-7b', 'llama2-13b', 'llama2-70b', 'llama3-8b', 'llama3.1-8b-Instruct', 'llama3-70b', 'llama3.1-70b-Instruct', 'llama3.1-8b', 'llama3.1-70b', 'llama3.1-405b', 'llama3.3-70b', 'mistral-7b', 'mixtral-8x7b', 'mixtral-8x22b', 'deepseek2-16b', 'deepseek2-236b', 'deepseek3-671b', 'deepseek3-671b-2dfsdp', 'deepseek3-671b-batchsplit', 'deepseek3-test', 'deepseek3-tiny', 'deepseek3.2-671b', 'deepseek-custom', 'kimi-k2-1t', 'gemma-7b', 'gemma-2b', 'gemma2-2b', 'gemma2-9b', 'gemma2-27b', 'gemma3-4b', 'gemma3-12b', 'gemma3-27b', 'gemma4-26b', 'gemma4-31b', 'qwen2.5-1.5b', 'qwen2.5-7b', 'qwen2.5-14b', 'qwen3-0.6b', 'qwen3-1.7b', 'qwen3-1.7b-base', 'qwen3-4b', 'qwen3-4b-base', 'qwen3-4b-thinking-2507', 'qwen3-8b', 'qwen3-8b-base', 'qwen3-14b', 'qwen3-14b-base', 'qwen3-32b', 'qwen3-235b-a22b', 'qwen3-30b-a3b', 'qwen3-30b-a3b-base', 'qwen3-480b-a35b', 'qwen3-next-80b-a3b', 'qwen3-omni-30b-a3b', 'qwen3-custom-30b-a3b', 'qwen3.5-397b-a17b', 'gpt3-175b', 'gpt3-22b', 'gpt3-6b', 'gpt3-52k', 'gpt-oss-20b', 'gpt-oss-120b', 'llama4-17b-16e', 'llama4-17b-128e', 'olmo3-7b', 'olmo3-7b-pt', 'olmo3-32b'])

  • override_model_config (bool)

  • override_logical_axis_rules (bool)

  • log_config (bool)

  • debug_sharding (bool)

  • base_output_directory (str)

  • sharding_strategy (None | Literal['experimental'])

  • debug (Debug)

  • rl (RL)

debug: Debug#
rl: RL#
model_config: ClassVar[ConfigDict] = {'extra': 'forbid', 'protected_namespaces': ()}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

classmethod load_model_specific_defaults(values)[source]#

This method is a no-op because pyconfig handles model-specific config loading.

Parameters:

values (dict[str, Any])

Return type:

dict[str, Any]

validate_ragged_buffer_factor()[source]#
set_derived_and_validate_values()[source]#

Computes all derived values and runs all cross-field validations after initial parsing. This logic is ported from the legacy pyconfig_deprecated.py system and adapted for Pydantic.

Return type:

MaxTextConfig