maxtext.configs.types module#
Pydantic-based configuration system for MaxText, organized into modular classes.
- class maxtext.configs.types.XProfTPUPowerTraceMode(*values)[source]#
Bases:
IntEnumEnum for XProfTPUPowerTraceMode.
- POWER_TRACE_NONE = 0#
- POWER_TRACE_NORMAL = 1#
- POWER_TRACE_SPI = 2#
- class maxtext.configs.types.DType(*values)[source]#
Bases:
str,EnumSupported data types for weights and activations.
- BFLOAT16 = 'bfloat16'#
- FLOAT32 = 'float32'#
- FLOAT16 = 'float16'#
- class maxtext.configs.types.MatmulPrecision(*values)[source]#
Bases:
str,EnumPrecision levels for matrix multiplications.
- DEFAULT = 'default'#
- HIGH = 'high'#
- HIGHEST = 'highest'#
- BFLOAT16 = 'bfloat16'#
- FLOAT32 = 'float32'#
- class maxtext.configs.types.QuantizationType(*values)[source]#
Bases:
str,EnumSupported quantization schemes.
- NONE = ''#
- INT4 = 'int4'#
- INT8 = 'int8'#
- INTMP = 'intmp'#
- FP8 = 'fp8'#
- NANOO_FP8 = 'nanoo_fp8'#
- FP8_NANO_V2 = 'fp8_nanoo'#
- FP8_GPU = 'fp8_gpu'#
- FP8_FULL = 'fp8_full'#
- TE_FP8_DS = 'te_fp8_delayedscaling'#
- TE_FP8_CS = 'te_fp8_currentscaling'#
- TE_MXFP8 = 'te_mxfp8'#
- TE_NVFP4 = 'te_nvfp4'#
- TE_NVFP4_NO_RHT = 'te_nvfp4_no_rht'#
- class maxtext.configs.types.KvQuantAxis(*values)[source]#
Bases:
str,EnumAxes to quantize over for the Key-Value cache.
- NONE = ''#
- DKV = 'dkv'#
- HEADS_AND_DKV = 'heads_and_dkv'#
- class maxtext.configs.types.RematPolicy(*values)[source]#
Bases:
str,EnumAvailable rematerialization (gradient checkpointing) policies.
- FULL = 'full'#
- MINIMAL = 'minimal'#
- SAVE_DOT_WITH_CONTEXT_EXCEPT_MLP = 'save_dot_with_context_except_mlp'#
- SAVE_DOT_EXCEPT_MLPWI = 'save_dot_except_mlpwi'#
- SAVE_DOT_EXCEPT_MLP = 'save_dot_except_mlp'#
- SAVE_QKV_PROJ = 'save_qkv_proj'#
- QKV_PROJ_OFFLOADED = 'qkv_proj_offloaded'#
- CUSTOM = 'custom'#
- MINIMAL_OFFLOADED = 'minimal_offloaded'#
- SAVE_OUT_PROJ = 'save_out_proj'#
- class maxtext.configs.types.RematLocation(*values)[source]#
Bases:
str,EnumSpecifies where to store activations for rematerialization.
- REMAT = 'remat'#
- DEVICE = 'device'#
- OFFLOAD = 'offload'#
- class maxtext.configs.types.OptimizerType(*values)[source]#
Bases:
str,EnumSupported optimizer algorithms.
- ADAMW = 'adamw'#
- ADAM_PAX = 'adam_pax'#
- SGD = 'sgd'#
- MUON = 'muon'#
- class maxtext.configs.types.LearningRateScheduleType(*values)[source]#
Bases:
str,EnumSupported learning rate schedule types.
- COSINE = 'cosine'#
- WSD = 'wsd'#
- class maxtext.configs.types.WsdDecayStyle(*values)[source]#
Bases:
str,EnumSupported decay styles for WSD schedule.
- LINEAR = 'linear'#
- COSINE = 'cosine'#
- class maxtext.configs.types.RopeType(*values)[source]#
Bases:
str,EnumSupported Rotary Positional Embedding (RoPE) implementations.
- DEFAULT = 'default'#
- LLAMA3_1 = 'llama3.1'#
- YARN = 'yarn'#
- class maxtext.configs.types.TokenizerType(*values)[source]#
Bases:
str,EnumSupported tokenizer libraries.
- SENTENCEPIECE = 'sentencepiece'#
- HUGGINGFACE = 'huggingface'#
- TIKTOKEN = 'tiktoken'#
- class maxtext.configs.types.DatasetType(*values)[source]#
Bases:
str,EnumSupported data loading pipelines.
- SYNTHETIC = 'synthetic'#
- HF = 'hf'#
- GRAIN = 'grain'#
- TFDS = 'tfds'#
- C4MLPERF = 'c4_mlperf'#
- OLMO_GRAIN = 'olmo_grain'#
- class maxtext.configs.types.SamplingStrategy(*values)[source]#
Bases:
str,EnumSupported decoding and sampling strategies.
- GREEDY = 'greedy'#
- WEIGHTED = 'weighted'#
- NUCLEUS = 'nucleus'#
- TOPK = 'topk'#
- COMPOSITE = 'composite'#
- class maxtext.configs.types.ProfilerType(*values)[source]#
Bases:
str,EnumSupported performance profilers.
- NONE = ''#
- XPLANE = 'xplane'#
- NSYS = 'nsys'#
- class maxtext.configs.types.RunInfo(*, base_config=None, run_name='', model_name='default', override_model_config=False, override_logical_axis_rules=False, log_config=True, debug_sharding=False, base_output_directory='', sharding_strategy=None)[source]#
Bases:
BaseModelConfiguration for the overall run, model identity, and logging.
- Parameters:
base_config (None | str)
run_name (str)
model_name (Literal['default', 'llama2-7b', 'llama2-13b', 'llama2-70b', 'llama3-8b', 'llama3.1-8b-Instruct', 'llama3-70b', 'llama3.1-70b-Instruct', 'llama3.1-8b', 'llama3.1-70b', 'llama3.1-405b', 'llama3.3-70b', 'mistral-7b', 'mixtral-8x7b', 'mixtral-8x22b', 'deepseek2-16b', 'deepseek2-236b', 'deepseek3-671b', 'deepseek3-671b-2dfsdp', 'deepseek3-671b-batchsplit', 'deepseek3-test', 'deepseek3-tiny', 'deepseek3.2-671b', 'deepseek-custom', 'kimi-k2-1t', 'gemma-7b', 'gemma-2b', 'gemma2-2b', 'gemma2-9b', 'gemma2-27b', 'gemma3-4b', 'gemma3-12b', 'gemma3-27b', 'gemma4-26b', 'gemma4-31b', 'qwen2.5-1.5b', 'qwen2.5-7b', 'qwen2.5-14b', 'qwen3-0.6b', 'qwen3-1.7b', 'qwen3-1.7b-base', 'qwen3-4b', 'qwen3-4b-base', 'qwen3-4b-thinking-2507', 'qwen3-8b', 'qwen3-8b-base', 'qwen3-14b', 'qwen3-14b-base', 'qwen3-32b', 'qwen3-235b-a22b', 'qwen3-30b-a3b', 'qwen3-30b-a3b-base', 'qwen3-480b-a35b', 'qwen3-next-80b-a3b', 'qwen3-omni-30b-a3b', 'qwen3-custom-30b-a3b', 'qwen3.5-397b-a17b', 'gpt3-175b', 'gpt3-22b', 'gpt3-6b', 'gpt3-52k', 'gpt-oss-20b', 'gpt-oss-120b', 'llama4-17b-16e', 'llama4-17b-128e', 'olmo3-7b', 'olmo3-7b-pt', 'olmo3-32b'])
override_model_config (bool)
override_logical_axis_rules (bool)
log_config (bool)
debug_sharding (bool)
base_output_directory (str)
sharding_strategy (None | Literal['experimental'])
- base_config: None | str#
- run_name: str#
- model_name: Literal['default', 'llama2-7b', 'llama2-13b', 'llama2-70b', 'llama3-8b', 'llama3.1-8b-Instruct', 'llama3-70b', 'llama3.1-70b-Instruct', 'llama3.1-8b', 'llama3.1-70b', 'llama3.1-405b', 'llama3.3-70b', 'mistral-7b', 'mixtral-8x7b', 'mixtral-8x22b', 'deepseek2-16b', 'deepseek2-236b', 'deepseek3-671b', 'deepseek3-671b-2dfsdp', 'deepseek3-671b-batchsplit', 'deepseek3-test', 'deepseek3-tiny', 'deepseek3.2-671b', 'deepseek-custom', 'kimi-k2-1t', 'gemma-7b', 'gemma-2b', 'gemma2-2b', 'gemma2-9b', 'gemma2-27b', 'gemma3-4b', 'gemma3-12b', 'gemma3-27b', 'gemma4-26b', 'gemma4-31b', 'qwen2.5-1.5b', 'qwen2.5-7b', 'qwen2.5-14b', 'qwen3-0.6b', 'qwen3-1.7b', 'qwen3-1.7b-base', 'qwen3-4b', 'qwen3-4b-base', 'qwen3-4b-thinking-2507', 'qwen3-8b', 'qwen3-8b-base', 'qwen3-14b', 'qwen3-14b-base', 'qwen3-32b', 'qwen3-235b-a22b', 'qwen3-30b-a3b', 'qwen3-30b-a3b-base', 'qwen3-480b-a35b', 'qwen3-next-80b-a3b', 'qwen3-omni-30b-a3b', 'qwen3-custom-30b-a3b', 'qwen3.5-397b-a17b', 'gpt3-175b', 'gpt3-22b', 'gpt3-6b', 'gpt3-52k', 'gpt-oss-20b', 'gpt-oss-120b', 'llama4-17b-16e', 'llama4-17b-128e', 'olmo3-7b', 'olmo3-7b-pt', 'olmo3-32b']#
- override_model_config: bool#
- override_logical_axis_rules: bool#
- log_config: bool#
- debug_sharding: bool#
- base_output_directory: str#
- sharding_strategy: None | Literal['experimental']#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Checkpointing(*, load_parameters_path='', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=True, checkpoint_period=10000, max_num_checkpoints_to_keep=None, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False)[source]#
Bases:
BaseModelCore configuration for checkpointing and run restoration.
- Parameters:
load_parameters_path (str)
lora_input_adapters_path (str)
load_full_state_path (str)
enable_checkpointing (bool)
load_checkpoint_only_once (bool)
async_checkpointing (bool)
checkpoint_period (int)
max_num_checkpoints_to_keep (int | None)
enable_single_replica_ckpt_restoring (bool)
checkpoint_todelete_subdir (str | None)
checkpoint_todelete_full_path (str | None)
force_unroll (bool)
checkpoint_is_quantized (bool)
save_quantized_params_path (str)
enable_orbax_v1 (bool)
checkpoint_conversion_fn (None | str)
source_checkpoint_layout (Literal['orbax', 'safetensors'])
save_checkpoint_on_completion (bool)
enable_continuous_checkpointing (bool)
colocated_python_checkpointing (bool)
enable_autocheckpoint (bool)
- load_parameters_path: str#
- lora_input_adapters_path: str#
- load_full_state_path: str#
- enable_checkpointing: bool#
- load_checkpoint_only_once: bool#
- async_checkpointing: bool#
- checkpoint_period: int#
- max_num_checkpoints_to_keep: int | None#
- enable_single_replica_ckpt_restoring: bool#
- checkpoint_todelete_subdir: str | None#
- checkpoint_todelete_full_path: str | None#
- force_unroll: bool#
- checkpoint_is_quantized: bool#
- save_quantized_params_path: str#
- enable_orbax_v1: bool#
- checkpoint_conversion_fn: None | str#
- source_checkpoint_layout: Literal['orbax', 'safetensors']#
- save_checkpoint_on_completion: bool#
- enable_continuous_checkpointing: bool#
- colocated_python_checkpointing: bool#
- enable_autocheckpoint: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.OrbaxStorage(*, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=True, checkpoint_storage_use_zarr3=True, checkpoint_storage_concurrent_gb=96)[source]#
Bases:
BaseModelConfiguration for Orbax checkpoint storage options.
- Parameters:
checkpoint_storage_target_data_file_size_bytes (int)
checkpoint_storage_use_ocdbt (bool)
checkpoint_storage_use_zarr3 (bool)
checkpoint_storage_concurrent_gb (int)
- checkpoint_storage_target_data_file_size_bytes: int#
- checkpoint_storage_use_ocdbt: bool#
- checkpoint_storage_use_zarr3: bool#
- checkpoint_storage_concurrent_gb: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.EmergencyCheckpointing(*, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0)[source]#
Bases:
BaseModelConfiguration for emergency (local) checkpointing.
- Parameters:
enable_multi_tier_checkpointing (bool)
local_checkpoint_directory (str)
local_checkpoint_period (Annotated[int, Ge(ge=0)])
multi_tier_checkpointing_backup_interval_minutes (Annotated[int, Ge(ge=0)])
mtc_data_parallelism (int)
enable_emergency_checkpoint (bool)
use_replicator_service (bool)
replicator_backup_interval_minutes (Annotated[int, Ge(ge=0)])
- enable_multi_tier_checkpointing: bool#
- local_checkpoint_directory: str#
- local_checkpoint_period: Annotated[int, Ge(ge=0)]#
- multi_tier_checkpointing_backup_interval_minutes: Annotated[int, Ge(ge=0)]#
- mtc_data_parallelism: int#
- enable_emergency_checkpoint: bool#
- use_replicator_service: bool#
- replicator_backup_interval_minutes: Annotated[int, Ge(ge=0)]#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.DataTypes(*, dtype=DType.BFLOAT16, grad_dtype=DType.FLOAT32, weight_dtype=DType.FLOAT32, matmul_precision=MatmulPrecision.DEFAULT, activations_in_float32=False, dtype_mm='float32')[source]#
Bases:
BaseModelConfiguration for data types and precision.
- Parameters:
dtype (DType)
grad_dtype (DType)
weight_dtype (DType)
matmul_precision (MatmulPrecision)
activations_in_float32 (bool)
dtype_mm (str)
- matmul_precision: MatmulPrecision#
- activations_in_float32: bool#
- dtype_mm: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Quantization(*, quantization=QuantizationType.NONE, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=KvQuantAxis.HEADS_AND_DKV, kv_quant_dtype='int8', quantization_local_shard_count=-1, use_qwix_quantization=False, use_manual_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', weight_sparsity_n=None, weight_sparsity_m=None, weight_sparsity_update_step=10, weight_sparsity_start_step=50)[source]#
Bases:
BaseModelConfiguration for model quantization.
- Parameters:
quantization (None | QuantizationType)
replicate_quant_scale (bool)
quant_cfg_path (str)
quantize_kvcache (bool)
kv_quant_axis (KvQuantAxis)
kv_quant_dtype (Literal['int8', 'int4'])
quantization_local_shard_count (int)
use_qwix_quantization (bool)
use_manual_quantization (bool)
weight_quantization_calibration_method (str)
act_quantization_calibration_method (str)
bwd_quantization_calibration_method (str)
weight_sparsity_n (int | None)
weight_sparsity_m (int | None)
weight_sparsity_update_step (int)
weight_sparsity_start_step (int)
- quantization: None | QuantizationType#
- replicate_quant_scale: bool#
- quant_cfg_path: str#
- quantize_kvcache: bool#
- kv_quant_axis: KvQuantAxis#
- kv_quant_dtype: Literal['int8', 'int4']#
- quantization_local_shard_count: int#
- use_qwix_quantization: bool#
- use_manual_quantization: bool#
- weight_quantization_calibration_method: str#
- act_quantization_calibration_method: str#
- bwd_quantization_calibration_method: str#
- weight_sparsity_n: int | None#
- weight_sparsity_m: int | None#
- weight_sparsity_update_step: int#
- weight_sparsity_start_step: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.ModelArchitecture(*, decoder_block='llama2', global_parameter_scale=1, base_emb_dim=2048, base_num_query_heads=16, base_num_kv_heads=16, base_mlp_dim=7168, dense_init_scale=1.0, base_num_decoder_layers=16, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True)[source]#
Bases:
BaseModelCore model architecture parameters.
- Parameters:
decoder_block (DecoderBlockType)
global_parameter_scale (int)
base_emb_dim (int)
base_num_query_heads (int)
base_num_kv_heads (int)
base_mlp_dim (int)
dense_init_scale (float)
base_num_decoder_layers (int)
head_dim (int)
attention_output_dim (int)
global_head_dim (int)
mlp_activations (list[str])
mlp_activations_limit (float)
normalization_layer_epsilon (float)
fused_qkv (bool)
attention_bias (bool)
fused_mlp (bool)
qk_norm_with_scale (bool)
v_norm_with_scale (bool)
- decoder_block: DecoderBlockType#
- global_parameter_scale: int#
- base_emb_dim: int#
- base_num_query_heads: int#
- base_num_kv_heads: int#
- base_mlp_dim: int#
- dense_init_scale: float#
- base_num_decoder_layers: int#
- head_dim: int#
- attention_output_dim: int#
- global_head_dim: int#
- mlp_activations: list[str]#
- mlp_activations_limit: float#
- normalization_layer_epsilon: float#
- fused_qkv: bool#
- attention_bias: bool#
- fused_mlp: bool#
- qk_norm_with_scale: bool#
- v_norm_with_scale: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.MTP(*, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0)[source]#
Bases:
BaseModelMulti-Token Prediction Configs.
- Parameters:
mtp_num_layers (Annotated[int, Ge(ge=0)])
mtp_loss_scaling_factor (Annotated[float, Ge(ge=0)])
mtp_eval_target_module (Annotated[int, Ge(ge=0)])
- mtp_num_layers: Annotated[int, Ge(ge=0)]#
- mtp_loss_scaling_factor: Annotated[float, Ge(ge=0)]#
- mtp_eval_target_module: Annotated[int, Ge(ge=0)]#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Logits(*, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0)[source]#
Bases:
BaseModelConfiguration for the final logits computation.
- Parameters:
logits_via_embedding (bool)
normalize_embedding_logits (bool)
logits_dot_in_fp32 (bool)
cast_logits_to_fp32 (bool)
final_logits_soft_cap (None | Annotated[float, Ge(ge=0)])
z_loss_multiplier (float)
- logits_via_embedding: bool#
- normalize_embedding_logits: bool#
- logits_dot_in_fp32: bool#
- cast_logits_to_fp32: bool#
- final_logits_soft_cap: None | Annotated[float, Ge(ge=0)]#
- z_loss_multiplier: float#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Attention(*, attention='autoselected', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0)[source]#
Bases:
BaseModelGeneral configuration for the attention mechanism.
- Parameters:
attention (str)
attention_type (Literal['global', 'local_sliding', 'chunk', 'mla', 'full'])
share_kv_projections (bool)
global_num_kv_heads (int)
attention_sink (bool)
float32_qk_product (bool)
float32_logits (bool)
sliding_window_size (Annotated[int, Ge(ge=0)])
chunk_attn_window_size (Annotated[int, Ge(ge=0)])
attn_logits_soft_cap (None | Annotated[float, Ge(ge=0)])
use_post_attn_norm (bool)
use_post_ffw_norm (bool)
use_ragged_attention (bool)
use_tokamax_gmm (bool)
ragged_block_size (int)
enable_padding_causal_mask (bool)
use_tokamax_splash (bool)
use_jax_splash (bool)
force_q_layout (bool)
use_qk_clip (bool)
qk_clip_threshold (float)
- attention: str#
- attention_type: Literal['global', 'local_sliding', 'chunk', 'mla', 'full']#
- global_num_kv_heads: int#
- attention_sink: bool#
- float32_qk_product: bool#
- float32_logits: bool#
- sliding_window_size: Annotated[int, Ge(ge=0)]#
- chunk_attn_window_size: Annotated[int, Ge(ge=0)]#
- attn_logits_soft_cap: None | Annotated[float, Ge(ge=0)]#
- use_post_attn_norm: bool#
- use_post_ffw_norm: bool#
- use_ragged_attention: bool#
- use_tokamax_gmm: bool#
- ragged_block_size: int#
- enable_padding_causal_mask: bool#
- use_tokamax_splash: bool#
- use_jax_splash: bool#
- force_q_layout: bool#
- use_qk_clip: bool#
- qk_clip_threshold: float#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.MoBa(*, moba=False, moba_chunk_size=1024, moba_topk=8)[source]#
Bases:
BaseModelConfiguration for Mixture of Block Attention (MoBA).
- Parameters:
moba (bool)
moba_chunk_size (int)
moba_topk (int)
- moba: bool#
- moba_chunk_size: int#
- moba_topk: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.MlaAttention(*, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128)[source]#
Bases:
BaseModelConfiguration for Multi-Layer Attention (MLA).
- Parameters:
mla_naive_kvcache (bool)
q_lora_rank (Annotated[int, Ge(ge=0)])
kv_lora_rank (Annotated[int, Ge(ge=0)])
qk_nope_head_dim (Annotated[int, Ge(ge=0)])
qk_rope_head_dim (Annotated[int, Ge(ge=0)])
v_head_dim (Annotated[int, Ge(ge=0)])
- mla_naive_kvcache: bool#
- q_lora_rank: Annotated[int, Ge(ge=0)]#
- kv_lora_rank: Annotated[int, Ge(ge=0)]#
- qk_nope_head_dim: Annotated[int, Ge(ge=0)]#
- qk_rope_head_dim: Annotated[int, Ge(ge=0)]#
- v_head_dim: Annotated[int, Ge(ge=0)]#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.AttentionIndexer(*, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0)[source]#
Bases:
BaseModelConfiguration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer.
- Parameters:
use_indexer (bool)
indexer_head_dim (Annotated[int, Ge(ge=0)])
indexer_n_heads (Annotated[int, Ge(ge=0)])
indexer_topk (Annotated[int, Ge(ge=0)])
indexer_sparse_training (bool)
indexer_loss_scaling_factor (float)
- use_indexer: bool#
- indexer_head_dim: Annotated[int, Ge(ge=0)]#
- indexer_n_heads: Annotated[int, Ge(ge=0)]#
- indexer_topk: Annotated[int, Ge(ge=0)]#
- indexer_sparse_training: bool#
- indexer_loss_scaling_factor: float#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Llama4Attention(*, use_qk_norm=False, temperature_tuning=False)[source]#
Bases:
BaseModelConfiguration specific to Llama4-style models.
- Parameters:
use_qk_norm (bool)
temperature_tuning (bool)
- use_qk_norm: bool#
- temperature_tuning: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.SplashAttention(*, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False)[source]#
Bases:
BaseModelTunable block sizes for Splash Attention kernels.
- Parameters:
sa_block_q (int)
sa_block_kv (int)
sa_block_kv_compute (int)
sa_block_q_dkv (int)
sa_block_kv_dkv (int)
sa_block_kv_dkv_compute (int)
sa_block_q_dq (int)
sa_block_kv_dq (int)
sa_use_fused_bwd_kernel (bool)
sa_q_layout (str)
sa_k_layout (str)
sa_v_layout (str)
use_max_logit_estimate (int)
cost_estimate_flops_fwd (int)
cost_estimate_flops_bwd (int)
dq_reduction_steps (int)
use_splash_scheduler (bool)
- sa_block_q: int#
- sa_block_kv: int#
- sa_block_kv_compute: int#
- sa_block_q_dkv: int#
- sa_block_kv_dkv: int#
- sa_block_kv_dkv_compute: int#
- sa_block_q_dq: int#
- sa_block_kv_dq: int#
- sa_use_fused_bwd_kernel: bool#
- sa_q_layout: str#
- sa_k_layout: str#
- sa_v_layout: str#
- use_max_logit_estimate: int#
- cost_estimate_flops_fwd: int#
- cost_estimate_flops_bwd: int#
- dq_reduction_steps: int#
- use_splash_scheduler: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.PagedAttention(*, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128)[source]#
Bases:
BaseModelTunable parameters for Paged Attention kernels.
- Parameters:
pagedattn_num_pages (int)
pagedattn_tokens_per_page (int)
pagedattn_pages_per_compute_block (int)
pagedattn_max_pages_per_group (int)
pagedattn_head_dim_alignment (int)
- pagedattn_num_pages: int#
- pagedattn_tokens_per_page: int#
- pagedattn_pages_per_compute_block: int#
- pagedattn_max_pages_per_group: int#
- pagedattn_head_dim_alignment: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.MoEGeneral(*, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, padded_base_moe_mlp_dim=None, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False)[source]#
Bases:
BaseModelGeneral configuration for Mixture of Experts (MoE) layers.
- Parameters:
num_experts (Annotated[int, Gt(gt=0)])
num_experts_per_tok (Annotated[int, Gt(gt=0)])
capacity_factor (float)
ragged_buffer_factor (float)
moe_expert_input_dim (int)
base_moe_mlp_dim (int)
padded_base_moe_mlp_dim (int | None)
load_balance_loss_weight (Annotated[float, Ge(ge=0)])
use_custom_sort_vjp (bool)
use_ring_of_experts (bool)
use_gather_mosaic_kernel (bool)
use_random_routing (bool)
interleave_moe_layer_step (int)
moe_fsdp_use_two_stage_all_gather (bool)
shard_exp_on_fsdp (bool)
use_2d_fsdp_sharding (bool)
norm_topk_prob (bool)
float32_weight_sum (bool)
float32_gate_logits (bool)
prefuse_moe_weights (bool)
- num_experts: Annotated[int, Gt(gt=0)]#
- num_experts_per_tok: Annotated[int, Gt(gt=0)]#
- capacity_factor: float#
- ragged_buffer_factor: float#
- moe_expert_input_dim: int#
- base_moe_mlp_dim: int#
- padded_base_moe_mlp_dim: int | None#
- load_balance_loss_weight: Annotated[float, Ge(ge=0)]#
- use_custom_sort_vjp: bool#
- use_ring_of_experts: bool#
- use_gather_mosaic_kernel: bool#
- use_random_routing: bool#
- interleave_moe_layer_step: int#
- moe_fsdp_use_two_stage_all_gather: bool#
- shard_exp_on_fsdp: bool#
- use_2d_fsdp_sharding: bool#
- norm_topk_prob: bool#
- float32_weight_sum: bool#
- float32_gate_logits: bool#
- prefuse_moe_weights: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.MoEKernels(*, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False)[source]#
Bases:
BaseModelConfiguration for MoE-specific kernels like Megablox.
- Parameters:
megablox (bool)
sparse_matmul (bool)
wi_tile_fwd_batch_seq (int)
wi_tile_fwd_embed_dim (int)
wi_tile_fwd_mlp_dim (int)
wi_tile_dlhs_batch_seq (int)
wi_tile_dlhs_embed_dim (int)
wi_tile_dlhs_mlp_dim (int)
wi_tile_drhs_batch_seq (int)
wi_tile_drhs_embed_dim (int)
wi_tile_drhs_mlp_dim (int)
wo_tile_fwd_batch_seq (int)
wo_tile_fwd_embed_dim (int)
wo_tile_fwd_mlp_dim (int)
wo_tile_dlhs_batch_seq (int)
wo_tile_dlhs_embed_dim (int)
wo_tile_dlhs_mlp_dim (int)
wo_tile_drhs_batch_seq (int)
wo_tile_drhs_embed_dim (int)
wo_tile_drhs_mlp_dim (int)
merge_gating_gmm (bool)
- megablox: bool#
- sparse_matmul: bool#
- wi_tile_fwd_batch_seq: int#
- wi_tile_fwd_embed_dim: int#
- wi_tile_fwd_mlp_dim: int#
- wi_tile_dlhs_batch_seq: int#
- wi_tile_dlhs_embed_dim: int#
- wi_tile_dlhs_mlp_dim: int#
- wi_tile_drhs_batch_seq: int#
- wi_tile_drhs_embed_dim: int#
- wi_tile_drhs_mlp_dim: int#
- wo_tile_fwd_batch_seq: int#
- wo_tile_fwd_embed_dim: int#
- wo_tile_fwd_mlp_dim: int#
- wo_tile_dlhs_batch_seq: int#
- wo_tile_dlhs_embed_dim: int#
- wo_tile_dlhs_mlp_dim: int#
- wo_tile_drhs_batch_seq: int#
- wo_tile_drhs_embed_dim: int#
- wo_tile_drhs_mlp_dim: int#
- merge_gating_gmm: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.DeepSeekMoE(*, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1)[source]#
Bases:
BaseModelConfiguration specific to DeepSeek-style MoE layers.
- Parameters:
first_num_dense_layers (Annotated[int, Ge(ge=0)])
shared_experts (Annotated[int, Ge(ge=0)])
routed_scaling_factor (float)
routed_score_func (str)
routed_bias (bool)
routed_bias_update_rate (float)
mlp_bias (bool)
n_routing_groups (int)
topk_routing_group (int)
use_batch_split_schedule (bool)
batch_split_factor (int)
- first_num_dense_layers: Annotated[int, Ge(ge=0)]#
- routed_scaling_factor: float#
- routed_score_func: str#
- routed_bias: bool#
- routed_bias_update_rate: float#
- mlp_bias: bool#
- n_routing_groups: int#
- topk_routing_group: int#
- use_batch_split_schedule: bool#
- batch_split_factor: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Qwen3Next(*, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0)[source]#
Bases:
BaseModelConfiguration specific to Qwen3-Next models with Gated Delta Net.
- Parameters:
gdn_conv_kernel_dim (int)
gdn_key_head_dim (int)
gdn_value_head_dim (int)
gdn_num_key_heads (int)
gdn_num_value_heads (int)
gdn_chunk_size (int)
use_qk_norm_in_gdn (bool)
partial_rotary_factor (float)
- gdn_conv_kernel_dim: int#
- gdn_key_head_dim: int#
- gdn_value_head_dim: int#
- gdn_num_key_heads: int#
- gdn_num_value_heads: int#
- gdn_chunk_size: int#
- use_qk_norm_in_gdn: bool#
- partial_rotary_factor: float#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.HardwareAndMesh(*, hardware='tpu', num_slices=-1, mesh_axes=['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode='auto', inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=ReorderStrategy.AUTO, custom_mesh='', custom_mesh_and_rule=CustomRule.DEFAULT, allow_split_physical_axes=False, enable_nnx=False, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=False, pure_nnx=False, remove_size_one_mesh_axis_from_type=True)[source]#
Bases:
BaseModelConfiguration for hardware and parallelism mesh.
- Parameters:
hardware (Literal['tpu', 'gpu', 'gpu_multiprocess', 'cpu'])
num_slices (int)
mesh_axes (list[str])
shard_mode (ShardMode)
inhomogeneous_layer_cycle_interval (int)
scan_layers (bool)
param_scan_axis (int)
context_parallel_load_balance (bool)
context_parallel_strategy (str)
context_parallel_reorder_strategy (ReorderStrategy)
custom_mesh (str)
custom_mesh_and_rule (CustomRule)
allow_split_physical_axes (bool)
enable_nnx (bool)
optimize_mesh_for_tpu_v6e (bool)
shardy (bool)
pure_nnx_decoder (bool)
pure_nnx (bool)
remove_size_one_mesh_axis_from_type (bool)
- hardware: Literal['tpu', 'gpu', 'gpu_multiprocess', 'cpu']#
- num_slices: int#
- mesh_axes: list[str]#
- shard_mode: ShardMode#
- inhomogeneous_layer_cycle_interval: int#
- scan_layers: bool#
- param_scan_axis: int#
- context_parallel_load_balance: bool#
- context_parallel_strategy: str#
- context_parallel_reorder_strategy: ReorderStrategy#
- custom_mesh: str#
- custom_mesh_and_rule: CustomRule#
- allow_split_physical_axes: bool#
- enable_nnx: bool#
- optimize_mesh_for_tpu_v6e: bool#
- shardy: bool#
- pure_nnx_decoder: bool#
- pure_nnx: bool#
- remove_size_one_mesh_axis_from_type: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.LayoutAndSharding(*, logical_axis_rules=[], data_sharding=[], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='')[source]#
Bases:
BaseModelConfiguration for data and model sharding rules.
- Parameters:
logical_axis_rules (Any)
data_sharding (Any)
context_sharding (str)
input_data_sharding_logical_axes (list[str])
sharding_tolerance (float)
shard_optimizer_over_data (bool)
internal_compile (bool)
internal_compile_num_devices (int)
compile_xla_flags (str)
- logical_axis_rules: Any#
- data_sharding: Any#
- context_sharding: str#
- input_data_sharding_logical_axes: list[str]#
- sharding_tolerance: float#
- shard_optimizer_over_data: bool#
- internal_compile: bool#
- internal_compile_num_devices: int#
- compile_xla_flags: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.DcnParallelism(*, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1)[source]#
Bases:
BaseModelParallelism dimensions across the DCN (Data Center Network).
- Parameters:
dcn_diloco_parallelism (int)
dcn_data_parallelism (int)
dcn_fsdp_parallelism (int)
dcn_fsdp_transpose_parallelism (int)
dcn_sequence_parallelism (int)
dcn_context_parallelism (int)
dcn_context_autoregressive_parallelism (int)
dcn_tensor_parallelism (int)
dcn_tensor_transpose_parallelism (int)
dcn_tensor_sequence_parallelism (int)
dcn_pipeline_parallelism (int)
dcn_expert_parallelism (int)
dcn_autoregressive_parallelism (int)
- dcn_diloco_parallelism: int#
- dcn_data_parallelism: int#
- dcn_fsdp_parallelism: int#
- dcn_fsdp_transpose_parallelism: int#
- dcn_sequence_parallelism: int#
- dcn_context_parallelism: int#
- dcn_context_autoregressive_parallelism: int#
- dcn_tensor_parallelism: int#
- dcn_tensor_transpose_parallelism: int#
- dcn_tensor_sequence_parallelism: int#
- dcn_pipeline_parallelism: int#
- dcn_expert_parallelism: int#
- dcn_autoregressive_parallelism: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.IciParallelism(*, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1)[source]#
Bases:
BaseModelParallelism dimensions within the ICI (Inter-Chip Interconnect).
- Parameters:
ici_diloco_parallelism (int)
ici_data_parallelism (int)
ici_fsdp_parallelism (int)
ici_fsdp_transpose_parallelism (int)
ici_sequence_parallelism (int)
ici_context_parallelism (int)
ici_context_autoregressive_parallelism (int)
ici_tensor_parallelism (int)
ici_tensor_transpose_parallelism (int)
ici_tensor_sequence_parallelism (int)
ici_autoregressive_parallelism (int)
ici_pipeline_parallelism (int)
ici_expert_parallelism (int)
- ici_diloco_parallelism: int#
- ici_data_parallelism: int#
- ici_fsdp_parallelism: int#
- ici_fsdp_transpose_parallelism: int#
- ici_sequence_parallelism: int#
- ici_context_parallelism: int#
- ici_context_autoregressive_parallelism: int#
- ici_tensor_parallelism: int#
- ici_tensor_transpose_parallelism: int#
- ici_tensor_sequence_parallelism: int#
- ici_autoregressive_parallelism: int#
- ici_pipeline_parallelism: int#
- ici_expert_parallelism: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.PipelineParallelism(*, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=-1, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=True, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False)[source]#
Bases:
BaseModelConfiguration for pipeline parallelism.
- Parameters:
pipeline_fsdp_ag_per_repeat (bool)
num_layers_per_pipeline_stage (int)
num_pipeline_repeats (int)
pipeline_parallel_layers (int)
num_pipeline_microbatches (int)
pipeline_delay_activation_forwarding (bool)
pipeline_fsdp_ag_once (bool)
scan_pipeline_iterations (bool)
scan_pipeline_repeats (bool)
scan_layers_per_stage (bool)
set_remat_policy_on_pipeline_iterations (bool)
set_remat_policy_on_layers_per_stage (bool)
- pipeline_fsdp_ag_per_repeat: bool#
- num_layers_per_pipeline_stage: int#
- num_pipeline_repeats: int#
- pipeline_parallel_layers: int#
- num_pipeline_microbatches: int#
- pipeline_delay_activation_forwarding: bool#
- pipeline_fsdp_ag_once: bool#
- scan_pipeline_iterations: bool#
- scan_pipeline_repeats: bool#
- scan_layers_per_stage: bool#
- set_remat_policy_on_pipeline_iterations: bool#
- set_remat_policy_on_layers_per_stage: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.RematAndOffload(*, remat_policy='full', remat_policy_for_vit='minimal', decoder_layer_input=RematLocation.DEVICE, context=RematLocation.REMAT, mlpwi=RematLocation.REMAT, mlpwi_0=RematLocation.REMAT, mlpwi_1=RematLocation.REMAT, mlpwo=RematLocation.REMAT, moe_mlpwi_0=RematLocation.REMAT, moe_mlpwi_1=RematLocation.REMAT, moe_mlpwo=RematLocation.REMAT, query_proj=RematLocation.REMAT, key_proj=RematLocation.REMAT, value_proj=RematLocation.REMAT, query_wa_proj=RematLocation.REMAT, kv_wa_proj=RematLocation.REMAT, qkv_proj=RematLocation.REMAT, out_proj=RematLocation.REMAT, mla_q=RematLocation.REMAT, mla_kv=RematLocation.REMAT, attention_out=RematLocation.REMAT, engram=RematLocation.REMAT, optimizer_memory_host_offload=False, parameter_memory_host_offload=False)[source]#
Bases:
BaseModelConfiguration for gradient checkpointing (rematerialization) and offloading.
- Parameters:
remat_policy (str)
remat_policy_for_vit (str)
decoder_layer_input (RematLocation)
context (RematLocation)
mlpwi (RematLocation)
mlpwi_0 (RematLocation)
mlpwi_1 (RematLocation)
mlpwo (RematLocation)
moe_mlpwi_0 (RematLocation)
moe_mlpwi_1 (RematLocation)
moe_mlpwo (RematLocation)
query_proj (RematLocation)
key_proj (RematLocation)
value_proj (RematLocation)
query_wa_proj (RematLocation)
kv_wa_proj (RematLocation)
qkv_proj (RematLocation)
out_proj (RematLocation)
mla_q (RematLocation)
mla_kv (RematLocation)
attention_out (RematLocation)
engram (RematLocation)
optimizer_memory_host_offload (bool)
parameter_memory_host_offload (bool)
- remat_policy: str#
- remat_policy_for_vit: str#
- decoder_layer_input: RematLocation#
- context: RematLocation#
- mlpwi: RematLocation#
- mlpwi_0: RematLocation#
- mlpwi_1: RematLocation#
- mlpwo: RematLocation#
- moe_mlpwi_0: RematLocation#
- moe_mlpwi_1: RematLocation#
- moe_mlpwo: RematLocation#
- query_proj: RematLocation#
- key_proj: RematLocation#
- value_proj: RematLocation#
- query_wa_proj: RematLocation#
- kv_wa_proj: RematLocation#
- qkv_proj: RematLocation#
- out_proj: RematLocation#
- mla_q: RematLocation#
- mla_kv: RematLocation#
- attention_out: RematLocation#
- engram: RematLocation#
- optimizer_memory_host_offload: bool#
- parameter_memory_host_offload: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Tokenizer(*, vocab_size=32000, tokenizer_path=None, tokenizer_type=TokenizerType.SENTENCEPIECE, use_chat_template=False, chat_template_path='', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1)[source]#
Bases:
BaseModelConfiguration for the tokenizer.
- Parameters:
vocab_size (int)
tokenizer_path (None | str)
tokenizer_type (TokenizerType)
use_chat_template (bool)
chat_template_path (str)
chat_template (str)
tokenize_train_data (bool)
tokenize_eval_data (bool)
add_bos (bool)
add_eos (bool)
use_truncation (bool)
num_vocab_tiling (int)
- vocab_size: int#
- tokenizer_path: None | str#
- tokenizer_type: TokenizerType#
- use_chat_template: bool#
- chat_template_path: str#
- chat_template: str#
- tokenize_train_data: bool#
- tokenize_eval_data: bool#
- add_bos: bool#
- add_eos: bool#
- use_truncation: bool#
- num_vocab_tiling: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.DatasetGeneral(*, dataset_type=DatasetType.TFDS, per_device_batch_size=12, eval_per_device_batch_size=0.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False)[source]#
Bases:
BaseModelGeneral configuration for dataset and data loading.
- Parameters:
dataset_type (DatasetType)
per_device_batch_size (int | float)
eval_per_device_batch_size (int | float)
max_corpus_chars (int)
train_data_columns (list[str])
train_image_column (str | list[str])
eval_data_columns (list[str])
eval_image_column (str | list[str])
packing (bool)
grain_packing_type (Literal['first_fit', 'best_fit', 'concat_then_split'])
max_segments_per_seq (int)
num_epoch (int)
expansion_factor_real_data (float)
reuse_example_batch (int)
generate_padding_batch_train (bool)
generate_padding_batch_eval (bool)
enable_rampup_batch_size (bool)
per_device_batch_size_start (float)
per_device_batch_size_increment (float)
global_rampup_samples (int)
colocated_python_data_input (bool)
- dataset_type: DatasetType#
- per_device_batch_size: int | float#
- eval_per_device_batch_size: int | float#
- max_corpus_chars: int#
- train_data_columns: list[str]#
- train_image_column: str | list[str]#
- eval_data_columns: list[str]#
- eval_image_column: str | list[str]#
- packing: bool#
- grain_packing_type: Literal['first_fit', 'best_fit', 'concat_then_split']#
- max_segments_per_seq: int#
- num_epoch: int#
- expansion_factor_real_data: float#
- reuse_example_batch: int#
- generate_padding_batch_train: bool#
- generate_padding_batch_eval: bool#
- enable_rampup_batch_size: bool#
- per_device_batch_size_start: float#
- per_device_batch_size_increment: float#
- global_rampup_samples: int#
- colocated_python_data_input: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.TfdsDataset(*, dataset_path='', dataset_name='c4/en:3.0.1', eval_dataset_name='c4/en:3.0.1', train_split='train', eval_split='validation')[source]#
Bases:
BaseModelConfiguration specific to TFDS datasets.
- Parameters:
dataset_path (str)
dataset_name (str)
eval_dataset_name (str)
train_split (str)
eval_split (str)
- dataset_path: str#
- dataset_name: str#
- eval_dataset_name: str#
- train_split: str#
- eval_split: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.HfDataset(*, hf_path='', hf_name=None, hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token=None)[source]#
Bases:
BaseModelConfiguration specific to HuggingFace datasets.
- Parameters:
hf_path (str)
hf_name (None | str)
hf_data_dir (None | str)
hf_train_files (None | str)
hf_eval_split (None | str)
hf_eval_files (None | str)
hf_access_token (None | str)
- hf_path: str#
- hf_name: None | str#
- hf_data_dir: None | str#
- hf_train_files: None | str#
- hf_eval_split: None | str#
- hf_eval_files: None | str#
- hf_access_token: None | str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.GrainDataset(*, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_use_elastic_iterator=False, grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100)[source]#
Bases:
BaseModelConfiguration specific to Grain datasets.
- Parameters:
grain_train_files (str)
grain_eval_files (str)
grain_train_mixture_config_path (str)
grain_file_type (str)
grain_use_elastic_iterator (bool)
grain_worker_count (int)
grain_per_worker_buffer_size (int)
grain_worker_count_eval (int)
grain_per_worker_buffer_size_eval (int)
grain_ram_budget_mb (int)
grain_num_threads (int)
grain_prefetch_buffer_size (int)
grain_num_threads_eval (int)
grain_prefetch_buffer_size_eval (int)
grain_data_source_max_workers (int)
grain_shuffle_buffer_size (int)
- grain_train_files: str#
- grain_eval_files: str#
- grain_train_mixture_config_path: str#
- grain_file_type: str#
- grain_use_elastic_iterator: bool#
- grain_worker_count: int#
- grain_per_worker_buffer_size: int#
- grain_worker_count_eval: int#
- grain_per_worker_buffer_size_eval: int#
- grain_ram_budget_mb: int#
- grain_num_threads: int#
- grain_prefetch_buffer_size: int#
- grain_num_threads_eval: int#
- grain_prefetch_buffer_size_eval: int#
- grain_data_source_max_workers: int#
- grain_shuffle_buffer_size: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.OlmoGrainDataset(*, olmo_index_path='', olmo_path_remap_from='', olmo_path_remap_to='', olmo_apply_ngram_filter=True)[source]#
Bases:
BaseModelConfiguration for the OLMo numpy fixed-seq-length input pipeline (dataset_type=olmo_grain).
Separate from the standard grain config because this pipeline reads pre-tokenized fixed-length sequences from raw npy files (one
int32token per element,sequence_lengthfrom an index JSON), not arrayrecord/tfds shards — so flags likegrain_train_files/packingdon’t apply.Worker count, per-worker buffer size, and shuffle seed reuse the standard grain flags (
grain_worker_count,grain_per_worker_buffer_size,data_shuffle_seed); only OLMo-specific fields are listed here.- Parameters:
olmo_index_path (str)
olmo_path_remap_from (str)
olmo_path_remap_to (str)
olmo_apply_ngram_filter (bool)
- olmo_index_path: str#
- olmo_path_remap_from: str#
- olmo_path_remap_to: str#
- olmo_apply_ngram_filter: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.FineTuning(*, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs=<factory>, use_grpo=None)[source]#
Bases:
BaseModelConfiguration for fine-tuning methods like DPO, SFT, and GRPO.
- Parameters:
use_dpo (bool)
dpo_label_smoothing (float)
dpo_beta (float)
use_sft (bool)
sft_train_on_completion_only (bool)
formatting_func_path (str)
formatting_func_kwargs (dict)
use_grpo (None | bool)
- use_dpo: bool#
- dpo_label_smoothing: float#
- dpo_beta: float#
- use_sft: bool#
- sft_train_on_completion_only: bool#
- formatting_func_path: str#
- formatting_func_kwargs: dict#
- use_grpo: None | bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Distillation(*, student_overrides=<factory>, teacher_overrides=<factory>, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', learn_to_init_mode=False, lti_use_general_linear_map=False, distill_weights_copy_map=<factory>, distill_student_weights_share_map=<factory>, student_params_to_update=None)[source]#
Bases:
BaseModelConfiguration for Knowledge Distillation.
- Parameters:
student_overrides (dict[str, Any])
teacher_overrides (dict[str, Any])
offline_data_dir (str | None)
distill_alpha (float)
distill_temperature (float)
distill_beta (float)
distill_feature_loss_type (Literal['cosine', 'l2'])
distill_layer_indices (None | list)
distill_alpha_end (float | None)
distill_alpha_schedule (Literal['constant', 'linear', 'cosine'])
distill_temperature_end (float | None)
distill_temperature_schedule (Literal['constant', 'linear', 'cosine'])
distill_beta_end (float | None)
distill_beta_schedule (Literal['constant', 'linear', 'cosine'])
learn_to_init_mode (bool)
lti_use_general_linear_map (bool)
distill_weights_copy_map (dict[str, Any])
distill_student_weights_share_map (dict[str, Any])
student_params_to_update (None | list)
- student_overrides: dict[str, Any]#
- teacher_overrides: dict[str, Any]#
- offline_data_dir: str | None#
- distill_alpha: float#
- distill_temperature: float#
- distill_beta: float#
- distill_feature_loss_type: Literal['cosine', 'l2']#
- distill_layer_indices: None | list#
- distill_alpha_end: float | None#
- distill_alpha_schedule: Literal['constant', 'linear', 'cosine']#
- distill_temperature_end: float | None#
- distill_temperature_schedule: Literal['constant', 'linear', 'cosine']#
- distill_beta_end: float | None#
- distill_beta_schedule: Literal['constant', 'linear', 'cosine']#
- learn_to_init_mode: bool#
- lti_use_general_linear_map: bool#
- distill_weights_copy_map: dict[str, Any]#
- student_params_to_update: None | list#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.TrainingLoop(*, steps=150001, log_period=100, eval_interval=-1, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=True, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=0, init_weights_seed=0)[source]#
Bases:
BaseModelConfiguration for the main training loop, evaluation, and reproducibility.
- Parameters:
steps (int)
log_period (int)
eval_interval (int)
eval_steps (int)
target_eval_loss (float)
abort_on_nan_loss (bool)
abort_on_inf_loss (bool)
enable_dropout (bool)
dropout_rate (float)
enable_data_shuffling (bool)
data_shuffle_seed (int)
init_weights_seed (int)
- steps: int#
- log_period: int#
- eval_interval: int#
- eval_steps: int#
- target_eval_loss: float#
- abort_on_nan_loss: bool#
- abort_on_inf_loss: bool#
- enable_dropout: bool#
- dropout_rate: float#
- enable_data_shuffling: bool#
- data_shuffle_seed: int#
- init_weights_seed: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.ManifoldConstrainedHyperConnections(*, mhc_expansion_rate=1, sinkhorn_iterations=20)[source]#
Bases:
BaseModelConfiguration for DeepSeek Manifold-Constrained Hyper Connections (mHC).
- Parameters:
mhc_expansion_rate (Annotated[int, Gt(gt=0)])
sinkhorn_iterations (Annotated[int, Gt(gt=0)])
- mhc_expansion_rate: Annotated[int, Gt(gt=0)]#
- sinkhorn_iterations: Annotated[int, Gt(gt=0)]#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.DilocoParams(*, enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9)[source]#
Bases:
BaseModelDiloco Hyperparameters
- Parameters:
enable_diloco (bool)
diloco_sync_period (int)
diloco_outer_lr (float)
diloco_outer_momentum (float)
- enable_diloco: bool#
- diloco_sync_period: int#
- diloco_outer_lr: float#
- diloco_outer_momentum: float#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Optimizer(*, opt_type=OptimizerType.ADAMW, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=1.0, learning_rate=3e-05, lr_schedule_type=LearningRateScheduleType.COSINE, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=WsdDecayStyle.LINEAR, warmup_steps_fraction=0.1, learning_rate_schedule_steps=-1, trainable_parameters_mask=<factory>)[source]#
Bases:
BaseModelConfiguration for the optimizer and learning rate schedule.
- Parameters:
opt_type (OptimizerType)
skip_step_on_spikes (bool)
skip_step_interval (Annotated[int, Gt(gt=0)])
skip_step_scaling_factor (float)
gradient_accumulation_steps (Annotated[int, Gt(gt=0)])
use_tunix_gradient_accumulation (bool)
gradient_clipping_threshold (Annotated[float, Ge(ge=0)])
learning_rate (Annotated[float, Ge(ge=0)])
lr_schedule_type (LearningRateScheduleType)
learning_rate_final_fraction (float)
wsd_decay_steps_fraction (float)
wsd_decay_style (WsdDecayStyle)
warmup_steps_fraction (float)
learning_rate_schedule_steps (int)
trainable_parameters_mask (list[str])
- opt_type: OptimizerType#
- skip_step_on_spikes: bool#
- skip_step_interval: Annotated[int, Gt(gt=0)]#
- skip_step_scaling_factor: float#
- gradient_accumulation_steps: Annotated[int, Gt(gt=0)]#
- use_tunix_gradient_accumulation: bool#
- gradient_clipping_threshold: Annotated[float, Ge(ge=0)]#
- learning_rate: Annotated[float, Ge(ge=0)]#
- lr_schedule_type: LearningRateScheduleType#
- learning_rate_final_fraction: float#
- wsd_decay_steps_fraction: float#
- wsd_decay_style: WsdDecayStyle#
- warmup_steps_fraction: float#
- learning_rate_schedule_steps: int#
- trainable_parameters_mask: list[str]#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.AdamW(*, adam_b1=0.9, adam_b2=0.95, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=<factory>, mu_dtype='')[source]#
Bases:
BaseModelConfiguration specific to the AdamW optimizer.
- Parameters:
adam_b1 (float)
adam_b2 (float)
adam_eps (float)
adam_eps_root (float)
adam_weight_decay (float)
adamw_mask (list[str])
mu_dtype (str)
- adam_b1: float#
- adam_b2: float#
- adam_eps: float#
- adam_eps_root: float#
- adam_weight_decay: float#
- adamw_mask: list[str]#
- mu_dtype: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Muon(*, muon_beta=0.95, muon_weight_decay=0, muon_consistent_rms=None)[source]#
Bases:
BaseModelConfiguration specific to the Muon optimizer.
- Parameters:
muon_beta (float)
muon_weight_decay (float)
muon_consistent_rms (float | None)
- muon_beta: float#
- muon_weight_decay: float#
- muon_consistent_rms: float | None#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.PositionalEmbedding(*, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1)[source]#
Bases:
BaseModelGeneral configuration for positional embeddings.
- Parameters:
use_iota_embed (bool)
use_untrainable_positional_embedding (bool)
trainable_position_size (int)
nope_layer_interval (int)
- use_iota_embed: bool#
- use_untrainable_positional_embedding: bool#
- trainable_position_size: int#
- nope_layer_interval: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Rope(*, rope_type=RopeType.DEFAULT, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=10000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0)[source]#
Bases:
BaseModelConfiguration for Rotary Positional Embedding (RoPE).
- Parameters:
rope_type (RopeType)
rope_use_scale (bool)
rope_min_timescale (int)
rope_max_timescale (int)
rope_linear_scaling_factor (float)
local_rope_max_timescale (int)
global_rope_max_timescale (int)
global_rope_proportion (float)
local_rope_proportion (float)
- rope_use_scale: bool#
- rope_min_timescale: int#
- rope_max_timescale: int#
- rope_linear_scaling_factor: float#
- local_rope_max_timescale: int#
- global_rope_max_timescale: int#
- global_rope_proportion: float#
- local_rope_proportion: float#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.YarnRope(*, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False)[source]#
Bases:
BaseModelConfiguration specific to YaRN (Yet another RoPE) scaling.
- Parameters:
max_position_embeddings (int)
original_max_position_embeddings (int)
rope_factor (int)
beta_fast (int)
beta_slow (int)
mscale (float)
rope_interleave (bool)
rope_truncate (bool)
rope_attention_scaling (bool)
- max_position_embeddings: int#
- original_max_position_embeddings: int#
- rope_factor: int#
- beta_fast: int#
- beta_slow: int#
- mscale: float#
- rope_interleave: bool#
- rope_truncate: bool#
- rope_attention_scaling: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.InferenceGeneral(*, max_target_length=2048, max_prefill_predict_length=64, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False)[source]#
Bases:
BaseModelGeneral configuration for inference.
- Parameters:
max_target_length (int)
max_prefill_predict_length (int)
prompt (str)
load_from_prefill_dir (bool)
prefill_cache_dir (str)
autoregressive_decode_assert (str)
model_call_mode (str)
use_chunked_prefill (bool)
prefill_chunk_size (int)
enable_model_warmup (bool)
enable_llm_inference_pool (bool)
multi_sampling (bool)
return_log_prob (bool)
- max_target_length: int#
- max_prefill_predict_length: int#
- prompt: str#
- load_from_prefill_dir: bool#
- prefill_cache_dir: str#
- autoregressive_decode_assert: str#
- model_call_mode: str#
- use_chunked_prefill: bool#
- prefill_chunk_size: int#
- enable_model_warmup: bool#
- enable_llm_inference_pool: bool#
- multi_sampling: bool#
- return_log_prob: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Decoding(*, decode_sampling_strategy=SamplingStrategy.GREEDY, decode_sampling_nucleus_p=-1.0, decode_sampling_top_k=0, decode_sampling_temperature=1.0)[source]#
Bases:
BaseModelConfiguration for decoding and sampling strategies.
- Parameters:
decode_sampling_strategy (SamplingStrategy)
decode_sampling_nucleus_p (int | float)
decode_sampling_top_k (int)
decode_sampling_temperature (float)
- decode_sampling_strategy: SamplingStrategy#
- decode_sampling_nucleus_p: int | float#
- decode_sampling_top_k: int#
- decode_sampling_temperature: float#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.InferenceLayout(*, stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False)[source]#
Bases:
BaseModelConfiguration for KV cache and compute layouts during inference.
- Parameters:
stack_prefill_result_cache (bool)
prefill_cache_axis_order (str)
ar_cache_axis_order (str)
compute_axis_order (str)
reshape_q (bool)
- stack_prefill_result_cache: bool#
- prefill_cache_axis_order: str#
- ar_cache_axis_order: str#
- compute_axis_order: str#
- reshape_q: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.InferenceServer(*, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16')[source]#
Bases:
BaseModelConfiguration for running as an inference server.
- Parameters:
inference_server (str)
prefill_slice (str)
generate_slice (str)
- inference_server: str#
- prefill_slice: str#
- generate_slice: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.InferenceBenchmark(*, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False)[source]#
Bases:
BaseModelConfiguration for running inference microbenchmarks.
- Parameters:
inference_microbenchmark_prefill_lengths (str)
inference_microbenchmark_stages (str)
inference_microbenchmark_loop_iters (int)
inference_microbenchmark_log_file_path (str)
inference_microbenchmark_num_samples (list[int])
inference_metadata_file (str)
inference_benchmark_test (bool)
- inference_microbenchmark_prefill_lengths: str#
- inference_microbenchmark_stages: str#
- inference_microbenchmark_loop_iters: int#
- inference_microbenchmark_log_file_path: str#
- inference_microbenchmark_num_samples: list[int]#
- inference_metadata_file: str#
- inference_benchmark_test: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.PrefixCaching(*, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000)[source]#
Bases:
BaseModelConfiguration for Prefix Caching in JetStream.
- Parameters:
enable_prefix_caching (bool)
prefix_caching_hbm_byte (int)
prefix_caching_dram_byte (int)
- enable_prefix_caching: bool#
- prefix_caching_hbm_byte: int#
- prefix_caching_dram_byte: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.AOT(*, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1)[source]#
Bases:
BaseModelAhead of Time (AOT) Compilation settings.
- Parameters:
compiled_trainstep_file (str)
compile_topology (str)
compile_topology_num_slices (int)
- compiled_trainstep_file: str#
- compile_topology: str#
- compile_topology_num_slices: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.DevelopmentAndDebugging(*, constant_bound_config=[], jax_cache_dir='/home/docs/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=False, enable_single_controller=False, subslice_shape='', max_checkify=False)[source]#
Bases:
BaseModelGeneral settings for development and debugging.
- Parameters:
constant_bound_config (list)
jax_cache_dir (str | None)
jax_distributed_initialization_timeout (int)
jax_debug_log_modules (str)
skip_jax_distributed_system (bool)
enable_single_controller (bool)
subslice_shape (str)
max_checkify (bool)
- constant_bound_config: list#
- jax_cache_dir: str | None#
- jax_distributed_initialization_timeout: int#
- jax_debug_log_modules: str#
- skip_jax_distributed_system: bool#
- enable_single_controller: bool#
- subslice_shape: str#
- max_checkify: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Profiling(*, profiler=ProfilerType.NONE, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, enable_tpu_profiling_options=False, tpu_num_chips_to_profile_per_task=1, tpu_num_sparse_cores_to_trace=2, tpu_num_sparse_core_tiles_to_trace=1, xprof_tpu_power_trace_level=XProfTPUPowerTraceMode.POWER_TRACE_NONE, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False)[source]#
Bases:
BaseModelConfiguration for performance profiling.
- Parameters:
profiler (ProfilerType)
upload_all_profiler_results (bool)
skip_first_n_steps_for_profiler (int)
profiler_steps (int)
profile_cleanly (bool)
profile_periodically_period (int)
hide_profiler_step_metric (bool)
enable_jax_profiler (bool)
jax_profiler_port (int)
enable_tpu_profiling_options (bool)
tpu_num_chips_to_profile_per_task (int)
tpu_num_sparse_cores_to_trace (int)
tpu_num_sparse_core_tiles_to_trace (int)
xprof_tpu_power_trace_level (XProfTPUPowerTraceMode)
xprof_e2e_enable_fw_throttle_event (bool)
xprof_e2e_enable_fw_power_level_event (bool)
xprof_e2e_enable_fw_thermal_event (bool)
profile_power_events (bool)
- profiler: ProfilerType#
- upload_all_profiler_results: bool#
- skip_first_n_steps_for_profiler: int#
- profiler_steps: int#
- profile_cleanly: bool#
- profile_periodically_period: int#
- hide_profiler_step_metric: bool#
- enable_jax_profiler: bool#
- jax_profiler_port: int#
- enable_tpu_profiling_options: bool#
- tpu_num_chips_to_profile_per_task: int#
- tpu_num_sparse_cores_to_trace: int#
- tpu_num_sparse_core_tiles_to_trace: int#
- xprof_tpu_power_trace_level: XProfTPUPowerTraceMode#
- xprof_e2e_enable_fw_throttle_event: bool#
- xprof_e2e_enable_fw_power_level_event: bool#
- xprof_e2e_enable_fw_thermal_event: bool#
- profile_power_events: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.HloDump(*, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='')[source]#
Bases:
BaseModelConfiguration for dumping HLO modules for debugging.
- Parameters:
dump_hlo (bool)
dump_step (int)
dump_hlo_local_dir (str)
dump_hlo_delete_local_after (bool)
dump_hlo_gcs_dir (str)
dump_hlo_module_name (str)
dump_hlo_local_module_name (str)
dump_hlo_xla_flags (str)
dump_hlo_upload_all (bool)
dump_jaxpr (bool)
dump_jaxpr_local_dir (str)
dump_jaxpr_delete_local_after (bool)
dump_jaxpr_gcs_dir (str)
- dump_hlo: bool#
- dump_step: int#
- dump_hlo_local_dir: str#
- dump_hlo_delete_local_after: bool#
- dump_hlo_gcs_dir: str#
- dump_hlo_module_name: str#
- dump_hlo_local_module_name: str#
- dump_hlo_xla_flags: str#
- dump_hlo_upload_all: bool#
- dump_jaxpr: bool#
- dump_jaxpr_local_dir: str#
- dump_jaxpr_delete_local_after: bool#
- dump_jaxpr_gcs_dir: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.StackTrace(*, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600)[source]#
Bases:
BaseModelConfiguration for collecting and logging stack traces.
- Parameters:
collect_stack_trace (bool)
stack_trace_to_cloud (bool)
stack_trace_interval_seconds (int)
- collect_stack_trace: bool#
- stack_trace_to_cloud: bool#
- stack_trace_interval_seconds: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Metrics(*, metrics_file=None, gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False)[source]#
Bases:
BaseModelGeneral configuration for metrics and monitoring.
- Parameters:
metrics_file (None | str)
gcs_metrics (bool)
save_config_to_gcs (bool)
record_internal_nn_metrics (int)
prometheus_port (int)
enable_checkpoint_cloud_logger (bool)
enable_tunix_perf_metrics (bool)
- metrics_file: None | str#
- gcs_metrics: bool#
- save_config_to_gcs: bool#
- record_internal_nn_metrics: int#
- prometheus_port: int#
- enable_checkpoint_cloud_logger: bool#
- enable_tunix_perf_metrics: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.ManagedMLDiagnostics(*, managed_mldiagnostics=False, managed_mldiagnostics_run_group='')[source]#
Bases:
BaseModelConfiguration for managed mldiagnostics.
- Parameters:
managed_mldiagnostics (bool)
managed_mldiagnostics_run_group (str)
- managed_mldiagnostics: bool#
- managed_mldiagnostics_run_group: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Goodput(*, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True)[source]#
Bases:
BaseModelConfiguration for goodput monitoring.
- Parameters:
enable_goodput_recording (bool)
monitor_goodput (bool)
goodput_upload_interval_seconds (int)
enable_pathways_goodput (bool)
monitor_step_time_deviation (bool)
step_deviation_interval_seconds (int)
enable_gcp_goodput_metrics (bool)
enable_gcp_step_deviation_metrics (bool)
- enable_goodput_recording: bool#
- monitor_goodput: bool#
- goodput_upload_interval_seconds: int#
- enable_pathways_goodput: bool#
- monitor_step_time_deviation: bool#
- step_deviation_interval_seconds: int#
- enable_gcp_goodput_metrics: bool#
- enable_gcp_step_deviation_metrics: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.ElasticTraining(*, elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, elastic_min_slice_count=-1)[source]#
Bases:
BaseModelConfiguration for elastic training and fault tolerance.
Elastic training is Pathways-specific and does not work on McJAX.
- Parameters:
elastic_enabled (bool)
elastic_timeout_seconds (int)
elastic_max_retries (int)
elastic_min_slice_count (int)
- elastic_enabled: bool#
- elastic_timeout_seconds: int#
- elastic_max_retries: int#
- elastic_min_slice_count: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.GcpMonitoring(*, report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False)[source]#
Bases:
BaseModelConfiguration for GCP-specific workload monitoring.
- Parameters:
report_heartbeat_metric_for_gcp_monitoring (bool)
heartbeat_reporting_interval_in_seconds (int)
report_performance_metric_for_gcp_monitoring (bool)
- report_heartbeat_metric_for_gcp_monitoring: bool#
- heartbeat_reporting_interval_in_seconds: int#
- report_performance_metric_for_gcp_monitoring: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Tensorboard(*, enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='')[source]#
Bases:
BaseModelConfiguration for Tensorboard logging.
- Parameters:
enable_tensorboard (bool)
use_vertex_tensorboard (bool)
vertex_tensorboard_project (str | None)
vertex_tensorboard_region (str | None)
- enable_tensorboard: bool#
- use_vertex_tensorboard: bool#
- vertex_tensorboard_project: str | None#
- vertex_tensorboard_region: str | None#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.MultimodalGeneral(*, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25)[source]#
Bases:
BaseModelGeneral configuration for Multimodal models.
- Parameters:
use_multimodal (bool)
freeze_vision_encoder_params (bool)
freeze_audio_encoder_params (bool)
use_audio (bool)
image_size_for_vit (int | list[int])
image_path (str)
image_placeholder (str)
posemb_type_for_vit (str)
max_num_images_per_example (int)
video_path (str)
audio_path (str)
video_placeholder (str)
audio_placeholder (str)
use_audio_in_video (bool)
use_mrope (bool)
mrope_section (list[int])
position_id_per_seconds (int)
- use_multimodal: bool#
- freeze_vision_encoder_params: bool#
- freeze_audio_encoder_params: bool#
- use_audio: bool#
- image_size_for_vit: int | list[int]#
- image_path: str#
- image_placeholder: str#
- posemb_type_for_vit: str#
- max_num_images_per_example: int#
- video_path: str#
- audio_path: str#
- video_placeholder: str#
- audio_placeholder: str#
- use_audio_in_video: bool#
- use_mrope: bool#
- mrope_section: list[int]#
- position_id_per_seconds: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.VisionTower(*, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1)[source]#
Bases:
BaseModelConfiguration for the Vision Tower (Encoder) in a multimodal model.
- Parameters:
hidden_size_for_vit (int)
intermediate_size_for_vit (int)
num_attention_heads_for_vit (int)
num_channels_for_vit (int)
tile_size_for_vit (int)
patch_size_for_vit (int)
conv_stride_for_vit (int)
num_hidden_layers_for_vit (int)
rope_theta_for_vit (int)
vision_output_dim_for_vit (int)
spatial_merge_size_for_vit (int)
out_hidden_size_for_vit (int)
temporal_patch_size_for_vit (int)
num_position_embeddings_for_vit (int)
deepstack_visual_indexes_for_vit (list[int])
vision_output_length (int)
- intermediate_size_for_vit: int#
- num_attention_heads_for_vit: int#
- num_channels_for_vit: int#
- tile_size_for_vit: int#
- patch_size_for_vit: int#
- conv_stride_for_vit: int#
- rope_theta_for_vit: int#
- vision_output_dim_for_vit: int#
- spatial_merge_size_for_vit: int#
- temporal_patch_size_for_vit: int#
- num_position_embeddings_for_vit: int#
- deepstack_visual_indexes_for_vit: list[int]#
- vision_output_length: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.VisionProjector(*, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0)[source]#
Bases:
BaseModelConfiguration for the Vision Projector in a multimodal model.
- Parameters:
projector_input_dim_for_vit (int)
projector_output_dim_for_vit (int)
pixel_shuffle_ratio_for_vit (float)
projector_dropout_for_vit (float)
- projector_input_dim_for_vit: int#
- projector_output_dim_for_vit: int#
- pixel_shuffle_ratio_for_vit: float#
- projector_dropout_for_vit: float#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.AudioEncoder(*, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000)[source]#
Bases:
BaseModelConfiguration for the Audio Encoder in a multimodal model.
- Parameters:
d_model_for_audio (int)
encoder_attention_heads_for_audio (int)
encoder_ffn_dim_for_audio (int)
encoder_layers_for_audio (int)
attention_dropout_for_audio (float)
activation_dropout_for_audio (float)
activation_function_for_audio (str)
num_mel_bins_for_audio (int)
max_source_positions_for_audio (int)
scale_embedding_for_audio (bool)
n_window_for_audio (int)
n_window_infer_for_audio (int)
conv_chunksize_for_audio (int)
downsample_hidden_size_for_audio (int)
output_dim_for_audio (int)
num_conv_layers_for_audio (int)
max_timescale_for_audio (float)
max_sample_len_for_audio (int)
- d_model_for_audio: int#
- encoder_attention_heads_for_audio: int#
- encoder_ffn_dim_for_audio: int#
- encoder_layers_for_audio: int#
- attention_dropout_for_audio: float#
- activation_dropout_for_audio: float#
- activation_function_for_audio: str#
- num_mel_bins_for_audio: int#
- max_source_positions_for_audio: int#
- scale_embedding_for_audio: bool#
- n_window_for_audio: int#
- n_window_infer_for_audio: int#
- conv_chunksize_for_audio: int#
- output_dim_for_audio: int#
- num_conv_layers_for_audio: int#
- max_timescale_for_audio: float#
- max_sample_len_for_audio: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Debug(*, rl=False)[source]#
Bases:
BaseModelConfiguration for debugging options.
- Parameters:
rl (bool)
- rl: bool#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.RLHardware(*, trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=True, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=-1, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1)[source]#
Bases:
BaseModelHardware settings specific to RL training.
- Parameters:
trainer_devices_fraction (float)
sampler_devices_fraction (float)
chips_per_vm (int)
use_pathways (bool)
num_trainer_slices (int)
num_samplers_slices (int)
rollout_data_parallelism (int)
rollout_tensor_parallelism (int)
rollout_expert_parallelism (int)
- trainer_devices_fraction: float#
- sampler_devices_fraction: float#
- chips_per_vm: int#
- use_pathways: bool#
- num_trainer_slices: int#
- num_samplers_slices: int#
- rollout_data_parallelism: int#
- rollout_tensor_parallelism: int#
- rollout_expert_parallelism: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.VLLM(*, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config=<factory>, vllm_hf_overrides=<factory>, vllm_hf_config_path='', use_standalone_converter=False, vllm_load_format='dummy', debug_converter=False, gcs_debug_path='')[source]#
Bases:
BaseModelvLLM-specific configuration for rollouts.
- Parameters:
kv_cache_buffer (int)
hbm_utilization_vllm (float)
swap_space_vllm_gb (int)
enable_dp_attention (bool)
enable_expert_parallel (bool)
async_scheduling (bool)
max_num_batched_tokens (int | None)
max_num_seqs (int | None)
stop_strings (list[str] | None)
vllm_additional_config (dict[str, Any])
vllm_hf_overrides (dict[str, Any])
vllm_hf_config_path (str)
use_standalone_converter (bool)
vllm_load_format (str)
debug_converter (bool)
gcs_debug_path (str)
- kv_cache_buffer: int#
- hbm_utilization_vllm: float#
- swap_space_vllm_gb: int#
- enable_dp_attention: bool#
- enable_expert_parallel: bool#
- async_scheduling: bool#
- max_num_batched_tokens: int | None#
- max_num_seqs: int | None#
- stop_strings: list[str] | None#
- vllm_additional_config: dict[str, Any]#
- vllm_hf_overrides: dict[str, Any]#
- vllm_hf_config_path: str#
- use_standalone_converter: bool#
- vllm_load_format: str#
- debug_converter: bool#
- gcs_debug_path: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.RL(*, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, reshard_chunk_size=None)[source]#
Bases:
BaseModelConfiguration for RL algorithms like Group Relative Policy Optimization (GRPO) among others.
- Parameters:
num_generations (int)
num_iterations (int)
grpo_beta (float)
grpo_epsilon (float)
loss_algo (Literal['grpo', 'gspo-token'])
use_agentic_rollout (bool)
max_concurrency (int)
off_policy_steps (int)
system_prompt (str)
degenerate_group_masking (bool)
epsilon_high (float | None)
reshard_chunk_size (int | None)
- num_generations: int#
- num_iterations: int#
- grpo_beta: float#
- grpo_epsilon: float#
- loss_algo: Literal['grpo', 'gspo-token']#
- use_agentic_rollout: bool#
- max_concurrency: int#
- off_policy_steps: int#
- system_prompt: str#
- degenerate_group_masking: bool#
- epsilon_high: float | None#
- reshard_chunk_size: int | None#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.RLDataset(*, batch_size=1, num_batches=4, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1)[source]#
Bases:
BaseModelDataset settings for RL training.
- Parameters:
batch_size (int)
num_batches (int)
num_test_batches (int)
test_batch_start_index (int)
train_fraction (float)
train_micro_batch_size (int)
rollout_micro_batch_size (int)
- batch_size: int#
- num_batches: int#
- num_test_batches: int#
- test_batch_start_index: int#
- train_fraction: float#
- train_micro_batch_size: int#
- rollout_micro_batch_size: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.RLEvaluation(*, eval_sampling_strategy='greedy', generation_configs=<factory>, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass')[source]#
Bases:
BaseModelSettings for RL evaluation.
- Parameters:
eval_sampling_strategy (str)
generation_configs (dict[str, Any])
num_eval_passes (int)
eval_corr_lst (bool)
eval_make_lst (bool)
eval_mode (Literal['pass', 'maj', 'pass_at_1'])
- eval_sampling_strategy: str#
- generation_configs: dict[str, Any]#
- num_eval_passes: int#
- eval_corr_lst: bool#
- eval_make_lst: bool#
- eval_mode: Literal['pass', 'maj', 'pass_at_1']#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Reward(*, reward_exact_answer=5.0, reward_exact_format_match=3.0, reward_white_space_format_match=1.5, reward_partial_format_match=0.5, reward_ratio_guess_to_answer_high=0.5, reward_ratio_guess_to_answer_low=0.25, penalty_incorrect_format=-0.5, penalty_incorrect_answer=-1.0, math_verify_timeout=300, math_verify_num_procs=None)[source]#
Bases:
BaseModelConfiguration for the reward/penalty model in RL.
- Parameters:
reward_exact_answer (float)
reward_exact_format_match (float)
reward_white_space_format_match (float)
reward_partial_format_match (float)
reward_ratio_guess_to_answer_high (float)
reward_ratio_guess_to_answer_low (float)
penalty_incorrect_format (float)
penalty_incorrect_answer (float)
math_verify_timeout (int)
math_verify_num_procs (int | None)
- reward_exact_answer: float#
- reward_exact_format_match: float#
- reward_white_space_format_match: float#
- reward_partial_format_match: float#
- reward_ratio_guess_to_answer_high: float#
- reward_ratio_guess_to_answer_low: float#
- penalty_incorrect_format: float#
- penalty_incorrect_answer: float#
- math_verify_timeout: int#
- math_verify_num_procs: int | None#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.SpecialTokens(*, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>')[source]#
Bases:
BaseModelSpecial tokens used for formatting prompts and responses in RL.
- Parameters:
reasoning_start_token (str)
reasoning_end_token (str)
solution_start_token (str)
solution_end_token (str)
- reasoning_start_token: str#
- reasoning_end_token: str#
- solution_start_token: str#
- solution_end_token: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.Engram(*, engram_layers=<factory>, engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=<factory>, engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0)[source]#
Bases:
BaseModelConfiguration for DeepSeek Engram (https://www.arxiv.org/pdf/2601.07372).
- Parameters:
engram_layers (list[int])
engram_num_heads (int)
engram_head_dim (int)
engram_vocab_bases (list[int])
engram_max_ngram_size (int)
engram_kernel_size (int)
engram_seed (int)
- engram_layers: list[int]#
- engram_num_heads: int#
- engram_head_dim: int#
- engram_vocab_bases: list[int]#
- engram_max_ngram_size: int#
- engram_kernel_size: int#
- engram_seed: int#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class maxtext.configs.types.DerivedValues(*, emb_dim=None, mlp_dim=None, moe_mlp_dim=None, num_decoder_layers=None, num_kv_heads=None, num_query_heads=None, num_diloco_replicas=None, ici_parallelism=None, dcn_parallelism=None, using_pipeline_parallelism=None, context_parallel_size=None, num_target_devices=None, global_batch_size_to_train_on=None, global_batch_size_to_eval_on=None, global_batch_size_to_load=None, global_batch_size_to_load_eval=None, micro_batch_size_to_train_on=None, micro_batch_size_to_eval_on=None, checkpoint_dir=None, convert_checkpoint_if_possible=False, metrics_dir=None, tensorboard_dir=None, managed_mldiagnostics_dir=None, rampup_end_step=None, tensors_on_device=None, tensors_to_offload=None, global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None)[source]#
Bases:
BaseModelHolds all fields that are derived from other config values for perfect legacy compatibility.
- Parameters:
emb_dim (None | int)
mlp_dim (None | int)
moe_mlp_dim (None | int)
num_decoder_layers (None | int)
num_kv_heads (None | int)
num_query_heads (None | int)
num_diloco_replicas (None | int)
ici_parallelism (None | list[int])
dcn_parallelism (None | list[int])
using_pipeline_parallelism (None | bool)
context_parallel_size (None | int)
num_target_devices (None | int)
global_batch_size_to_train_on (None | int)
global_batch_size_to_eval_on (None | int)
global_batch_size_to_load (None | int)
global_batch_size_to_load_eval (None | int)
micro_batch_size_to_train_on (None | int)
micro_batch_size_to_eval_on (None | int)
checkpoint_dir (None | str)
convert_checkpoint_if_possible (bool)
metrics_dir (None | str)
tensorboard_dir (None | str)
managed_mldiagnostics_dir (None | str)
rampup_end_step (None | int)
tensors_on_device (None | list[str])
tensors_to_offload (None | list[str])
global_batch_size_to_load_start (None | int)
global_batch_size_to_load_increment (None | int)
rampup_samples_per_increment_to_load (None | float)
- emb_dim: None | int#
- mlp_dim: None | int#
- moe_mlp_dim: None | int#
- num_decoder_layers: None | int#
- num_kv_heads: None | int#
- num_query_heads: None | int#
- num_diloco_replicas: None | int#
- ici_parallelism: None | list[int]#
- dcn_parallelism: None | list[int]#
- using_pipeline_parallelism: None | bool#
- context_parallel_size: None | int#
- num_target_devices: None | int#
- global_batch_size_to_train_on: None | int#
- global_batch_size_to_eval_on: None | int#
- global_batch_size_to_load: None | int#
- global_batch_size_to_load_eval: None | int#
- micro_batch_size_to_train_on: None | int#
- micro_batch_size_to_eval_on: None | int#
- checkpoint_dir: None | str#
- convert_checkpoint_if_possible: bool#
- metrics_dir: None | str#
- tensorboard_dir: None | str#
- managed_mldiagnostics_dir: None | str#
- rampup_end_step: None | int#
- tensors_on_device: None | list[str]#
- tensors_to_offload: None | list[str]#
- global_batch_size_to_load_start: None | int#
- global_batch_size_to_load_increment: None | int#
- rampup_samples_per_increment_to_load: None | float#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- maxtext.configs.types.get_individual_scales(scale)[source]#
Choose appropriate scales for individual dimensions based on global scale.
- Parameters:
scale (int)
- Return type:
tuple[int, int, int, int]
- class maxtext.configs.types.MaxTextConfig(*, emb_dim=None, mlp_dim=None, moe_mlp_dim=None, num_decoder_layers=None, num_kv_heads=None, num_query_heads=None, num_diloco_replicas=None, ici_parallelism=None, dcn_parallelism=None, using_pipeline_parallelism=None, context_parallel_size=None, num_target_devices=None, global_batch_size_to_train_on=None, global_batch_size_to_eval_on=None, global_batch_size_to_load=None, global_batch_size_to_load_eval=None, micro_batch_size_to_train_on=None, micro_batch_size_to_eval_on=None, checkpoint_dir=None, convert_checkpoint_if_possible=False, metrics_dir=None, tensorboard_dir=None, managed_mldiagnostics_dir=None, rampup_end_step=None, tensors_on_device=None, tensors_to_offload=None, global_batch_size_to_load_start=None, global_batch_size_to_load_increment=None, rampup_samples_per_increment_to_load=None, d_model_for_audio=256, encoder_attention_heads_for_audio=4, encoder_ffn_dim_for_audio=512, encoder_layers_for_audio=2, attention_dropout_for_audio=0.0, activation_dropout_for_audio=0.0, activation_function_for_audio='gelu', num_mel_bins_for_audio=128, max_source_positions_for_audio=1500, scale_embedding_for_audio=True, n_window_for_audio=50, n_window_infer_for_audio=800, conv_chunksize_for_audio=500, downsample_hidden_size_for_audio=256, output_dim_for_audio=512, num_conv_layers_for_audio=3, max_timescale_for_audio=10000.0, max_sample_len_for_audio=10000, projector_input_dim_for_vit=4096, projector_output_dim_for_vit=4096, pixel_shuffle_ratio_for_vit=0.5, projector_dropout_for_vit=0.0, hidden_size_for_vit=1408, intermediate_size_for_vit=5632, num_attention_heads_for_vit=16, num_channels_for_vit=3, tile_size_for_vit=336, patch_size_for_vit=14, conv_stride_for_vit=14, num_hidden_layers_for_vit=34, rope_theta_for_vit=10000, vision_output_dim_for_vit=4096, spatial_merge_size_for_vit=2, out_hidden_size_for_vit=512, temporal_patch_size_for_vit=2, num_position_embeddings_for_vit=1024, deepstack_visual_indexes_for_vit=[], vision_output_length=-1, use_multimodal=False, freeze_vision_encoder_params=True, freeze_audio_encoder_params=True, use_audio=False, image_size_for_vit=896, image_path='', image_placeholder='<|image|>', posemb_type_for_vit='learn', max_num_images_per_example=-1, video_path='', audio_path='', video_placeholder='<|video|>', audio_placeholder='<|audio|>', use_audio_in_video=False, use_mrope=False, mrope_section=[24, 20, 20], position_id_per_seconds=25, managed_mldiagnostics=False, managed_mldiagnostics_run_group='', enable_tensorboard=True, use_vertex_tensorboard=False, vertex_tensorboard_project='', vertex_tensorboard_region='', report_heartbeat_metric_for_gcp_monitoring=False, heartbeat_reporting_interval_in_seconds=5, report_performance_metric_for_gcp_monitoring=False, enable_goodput_recording=False, monitor_goodput=False, goodput_upload_interval_seconds=30, enable_pathways_goodput=False, monitor_step_time_deviation=True, step_deviation_interval_seconds=30, enable_gcp_goodput_metrics=True, enable_gcp_step_deviation_metrics=True, metrics_file=None, gcs_metrics=False, save_config_to_gcs=False, record_internal_nn_metrics=0, prometheus_port=0, enable_checkpoint_cloud_logger=False, enable_tunix_perf_metrics=False, collect_stack_trace=False, stack_trace_to_cloud=False, stack_trace_interval_seconds=600, dump_hlo=False, dump_step=-1, dump_hlo_local_dir='/tmp/xla_dump/', dump_hlo_delete_local_after=True, dump_hlo_gcs_dir='', dump_hlo_module_name='jit_train_step', dump_hlo_local_module_name='jit_train_step', dump_hlo_xla_flags='', dump_hlo_upload_all=False, dump_jaxpr=False, dump_jaxpr_local_dir='/tmp/jaxpr_dump/', dump_jaxpr_delete_local_after=True, dump_jaxpr_gcs_dir='', profiler=ProfilerType.NONE, upload_all_profiler_results=False, skip_first_n_steps_for_profiler=1, profiler_steps=5, profile_cleanly=True, profile_periodically_period=-1, hide_profiler_step_metric=False, enable_jax_profiler=False, jax_profiler_port=9999, enable_tpu_profiling_options=False, tpu_num_chips_to_profile_per_task=1, tpu_num_sparse_cores_to_trace=2, tpu_num_sparse_core_tiles_to_trace=1, xprof_tpu_power_trace_level=XProfTPUPowerTraceMode.POWER_TRACE_NONE, xprof_e2e_enable_fw_throttle_event=False, xprof_e2e_enable_fw_power_level_event=False, xprof_e2e_enable_fw_thermal_event=False, profile_power_events=False, constant_bound_config=[], jax_cache_dir='/home/docs/jax_cache', jax_distributed_initialization_timeout=300, jax_debug_log_modules='', skip_jax_distributed_system=False, enable_single_controller=False, subslice_shape='', max_checkify=False, compiled_trainstep_file='', compile_topology='', compile_topology_num_slices=-1, enable_prefix_caching=False, prefix_caching_hbm_byte=10000000000, prefix_caching_dram_byte=100000000000, inference_microbenchmark_prefill_lengths='64,128,256,512,1024', inference_microbenchmark_stages='prefill,generate', inference_microbenchmark_loop_iters=10, inference_microbenchmark_log_file_path='', inference_microbenchmark_num_samples=[1, 2, 3, 4, 5], inference_metadata_file='', inference_benchmark_test=False, inference_server='MaxtextInterleavedServer', prefill_slice='v5e-16', generate_slice='v5e-16', stack_prefill_result_cache=False, prefill_cache_axis_order='1,2,0,3', ar_cache_axis_order='1,2,0,3', compute_axis_order='0,1,2,3', reshape_q=False, decode_sampling_strategy=SamplingStrategy.GREEDY, decode_sampling_nucleus_p=-1.0, decode_sampling_top_k=0, decode_sampling_temperature=1.0, max_target_length=2048, max_prefill_predict_length=64, prompt='I love to', load_from_prefill_dir=False, prefill_cache_dir='', autoregressive_decode_assert='', model_call_mode='', use_chunked_prefill=False, prefill_chunk_size=256, enable_model_warmup=False, enable_llm_inference_pool=False, multi_sampling=False, return_log_prob=False, vocab_size=32000, tokenizer_path=None, tokenizer_type=TokenizerType.SENTENCEPIECE, use_chat_template=False, chat_template_path='', chat_template='', tokenize_train_data=True, tokenize_eval_data=True, add_bos=True, add_eos=True, use_truncation=True, num_vocab_tiling=1, olmo_index_path='', olmo_path_remap_from='', olmo_path_remap_to='', olmo_apply_ngram_filter=True, grain_train_files='', grain_eval_files='', grain_train_mixture_config_path='', grain_file_type='arrayrecord', grain_use_elastic_iterator=False, grain_worker_count=1, grain_per_worker_buffer_size=1, grain_worker_count_eval=1, grain_per_worker_buffer_size_eval=1, grain_ram_budget_mb=1024, grain_num_threads=16, grain_prefetch_buffer_size=500, grain_num_threads_eval=16, grain_prefetch_buffer_size_eval=500, grain_data_source_max_workers=16, grain_shuffle_buffer_size=100, hf_path='', hf_name=None, hf_data_dir=None, hf_train_files=None, hf_eval_split=None, hf_eval_files=None, hf_access_token=None, dataset_path='', dataset_name='c4/en:3.0.1', eval_dataset_name='c4/en:3.0.1', train_split='train', eval_split='validation', dataset_type=DatasetType.TFDS, per_device_batch_size=12, eval_per_device_batch_size=0.0, max_corpus_chars=10000000, train_data_columns=['text'], train_image_column='image', eval_data_columns=['text'], eval_image_column='image', packing=True, grain_packing_type='first_fit', max_segments_per_seq=-1, num_epoch=1, expansion_factor_real_data=-1.0, reuse_example_batch=0, generate_padding_batch_train=False, generate_padding_batch_eval=False, enable_rampup_batch_size=False, per_device_batch_size_start=4.0, per_device_batch_size_increment=2.0, global_rampup_samples=500, colocated_python_data_input=False, max_position_embeddings=163840, original_max_position_embeddings=4096, rope_factor=40, beta_fast=32, beta_slow=1, mscale=1.0, rope_interleave=True, rope_truncate=True, rope_attention_scaling=False, rope_type=RopeType.DEFAULT, rope_use_scale=True, rope_min_timescale=1, rope_max_timescale=10000, rope_linear_scaling_factor=1.0, local_rope_max_timescale=-1, global_rope_max_timescale=-1, global_rope_proportion=0.25, local_rope_proportion=1.0, use_iota_embed=False, use_untrainable_positional_embedding=False, trainable_position_size=-1, nope_layer_interval=-1, reasoning_start_token='<reasoning>', reasoning_end_token='</reasoning>', solution_start_token='<answer>', solution_end_token='</answer>', reward_exact_answer=5.0, reward_exact_format_match=3.0, reward_white_space_format_match=1.5, reward_partial_format_match=0.5, reward_ratio_guess_to_answer_high=0.5, reward_ratio_guess_to_answer_low=0.25, penalty_incorrect_format=-0.5, penalty_incorrect_answer=-1.0, math_verify_timeout=300, math_verify_num_procs=None, eval_sampling_strategy='greedy', generation_configs=<factory>, num_eval_passes=1, eval_corr_lst=False, eval_make_lst=False, eval_mode='pass', batch_size=1, num_batches=4, num_test_batches=5, test_batch_start_index=0, train_fraction=1.0, train_micro_batch_size=-1, rollout_micro_batch_size=-1, num_generations=2, num_iterations=1, grpo_beta=0.08, grpo_epsilon=0.2, loss_algo='grpo', use_agentic_rollout=False, max_concurrency=256, off_policy_steps=0, system_prompt='', degenerate_group_masking=True, epsilon_high=None, reshard_chunk_size=None, kv_cache_buffer=256, hbm_utilization_vllm=0.72, swap_space_vllm_gb=2, enable_dp_attention=False, enable_expert_parallel=False, async_scheduling=False, max_num_batched_tokens=None, max_num_seqs=None, stop_strings=None, vllm_additional_config=<factory>, vllm_hf_overrides=<factory>, vllm_hf_config_path='', use_standalone_converter=False, vllm_load_format='dummy', debug_converter=False, gcs_debug_path='', trainer_devices_fraction=0.5, sampler_devices_fraction=0.5, chips_per_vm=4, use_pathways=True, num_trainer_slices=-1, num_samplers_slices=-1, rollout_data_parallelism=-1, rollout_tensor_parallelism=-1, rollout_expert_parallelism=1, student_overrides=<factory>, teacher_overrides=<factory>, offline_data_dir=None, distill_alpha=0.5, distill_temperature=1.0, distill_beta=0.0, distill_feature_loss_type='cosine', distill_layer_indices=None, distill_alpha_end=None, distill_alpha_schedule='constant', distill_temperature_end=None, distill_temperature_schedule='constant', distill_beta_end=None, distill_beta_schedule='constant', learn_to_init_mode=False, lti_use_general_linear_map=False, distill_weights_copy_map=<factory>, distill_student_weights_share_map=<factory>, student_params_to_update=None, use_dpo=False, dpo_label_smoothing=0.0, dpo_beta=0.1, use_sft=False, sft_train_on_completion_only=False, formatting_func_path='', formatting_func_kwargs=<factory>, use_grpo=None, muon_beta=0.95, muon_weight_decay=0, muon_consistent_rms=None, adam_b1=0.9, adam_b2=0.95, adam_eps=1e-08, adam_eps_root=0.0, adam_weight_decay=0.1, adamw_mask=<factory>, mu_dtype='', opt_type=OptimizerType.ADAMW, skip_step_on_spikes=False, skip_step_interval=128, skip_step_scaling_factor=6.0, gradient_accumulation_steps=1, use_tunix_gradient_accumulation=False, gradient_clipping_threshold=1.0, learning_rate=3e-05, lr_schedule_type=LearningRateScheduleType.COSINE, learning_rate_final_fraction=0.1, wsd_decay_steps_fraction=0.1, wsd_decay_style=WsdDecayStyle.LINEAR, warmup_steps_fraction=0.1, learning_rate_schedule_steps=-1, trainable_parameters_mask=<factory>, enable_diloco=False, diloco_sync_period=36, diloco_outer_lr=0.3, diloco_outer_momentum=0.9, mhc_expansion_rate=1, sinkhorn_iterations=20, steps=150001, log_period=100, eval_interval=-1, eval_steps=-1, target_eval_loss=0.0, abort_on_nan_loss=True, abort_on_inf_loss=True, enable_dropout=True, dropout_rate=0.0, enable_data_shuffling=True, data_shuffle_seed=0, init_weights_seed=0, remat_policy='full', remat_policy_for_vit='minimal', decoder_layer_input=RematLocation.DEVICE, context=RematLocation.REMAT, mlpwi=RematLocation.REMAT, mlpwi_0=RematLocation.REMAT, mlpwi_1=RematLocation.REMAT, mlpwo=RematLocation.REMAT, moe_mlpwi_0=RematLocation.REMAT, moe_mlpwi_1=RematLocation.REMAT, moe_mlpwo=RematLocation.REMAT, query_proj=RematLocation.REMAT, key_proj=RematLocation.REMAT, value_proj=RematLocation.REMAT, query_wa_proj=RematLocation.REMAT, kv_wa_proj=RematLocation.REMAT, qkv_proj=RematLocation.REMAT, out_proj=RematLocation.REMAT, mla_q=RematLocation.REMAT, mla_kv=RematLocation.REMAT, attention_out=RematLocation.REMAT, engram=RematLocation.REMAT, optimizer_memory_host_offload=False, parameter_memory_host_offload=False, pipeline_fsdp_ag_per_repeat=False, num_layers_per_pipeline_stage=1, num_pipeline_repeats=-1, pipeline_parallel_layers=-1, num_pipeline_microbatches=-1, pipeline_delay_activation_forwarding=False, pipeline_fsdp_ag_once=False, scan_pipeline_iterations=True, scan_pipeline_repeats=True, scan_layers_per_stage=False, set_remat_policy_on_pipeline_iterations=True, set_remat_policy_on_layers_per_stage=False, ici_diloco_parallelism=1, ici_data_parallelism=1, ici_fsdp_parallelism=-1, ici_fsdp_transpose_parallelism=1, ici_sequence_parallelism=1, ici_context_parallelism=1, ici_context_autoregressive_parallelism=1, ici_tensor_parallelism=1, ici_tensor_transpose_parallelism=1, ici_tensor_sequence_parallelism=1, ici_autoregressive_parallelism=1, ici_pipeline_parallelism=1, ici_expert_parallelism=1, dcn_diloco_parallelism=1, dcn_data_parallelism=-1, dcn_fsdp_parallelism=1, dcn_fsdp_transpose_parallelism=1, dcn_sequence_parallelism=1, dcn_context_parallelism=1, dcn_context_autoregressive_parallelism=1, dcn_tensor_parallelism=1, dcn_tensor_transpose_parallelism=1, dcn_tensor_sequence_parallelism=1, dcn_pipeline_parallelism=1, dcn_expert_parallelism=1, dcn_autoregressive_parallelism=1, logical_axis_rules=[], data_sharding=[], context_sharding='context', input_data_sharding_logical_axes=['activation_embed_and_logits_batch', 'activation_norm_length'], sharding_tolerance=0.02, shard_optimizer_over_data=False, internal_compile=False, internal_compile_num_devices=-1, compile_xla_flags='', hardware='tpu', num_slices=-1, mesh_axes=['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'], shard_mode='auto', inhomogeneous_layer_cycle_interval=1, scan_layers=True, param_scan_axis=1, context_parallel_load_balance=True, context_parallel_strategy='all_gather', context_parallel_reorder_strategy=ReorderStrategy.AUTO, custom_mesh='', custom_mesh_and_rule=CustomRule.DEFAULT, allow_split_physical_axes=False, enable_nnx=False, optimize_mesh_for_tpu_v6e=False, shardy=True, pure_nnx_decoder=False, pure_nnx=False, remove_size_one_mesh_axis_from_type=True, gdn_conv_kernel_dim=4, gdn_key_head_dim=128, gdn_value_head_dim=128, gdn_num_key_heads=16, gdn_num_value_heads=32, gdn_chunk_size=64, use_qk_norm_in_gdn=True, partial_rotary_factor=1.0, first_num_dense_layers=0, shared_experts=0, routed_scaling_factor=1.0, routed_score_func='', routed_bias=False, routed_bias_update_rate=0.0, mlp_bias=False, n_routing_groups=-1, topk_routing_group=-1, use_batch_split_schedule=False, batch_split_factor=1, megablox=True, sparse_matmul=True, wi_tile_fwd_batch_seq=512, wi_tile_fwd_embed_dim=1024, wi_tile_fwd_mlp_dim=1024, wi_tile_dlhs_batch_seq=512, wi_tile_dlhs_embed_dim=1024, wi_tile_dlhs_mlp_dim=1024, wi_tile_drhs_batch_seq=512, wi_tile_drhs_embed_dim=1024, wi_tile_drhs_mlp_dim=1024, wo_tile_fwd_batch_seq=512, wo_tile_fwd_embed_dim=1024, wo_tile_fwd_mlp_dim=1024, wo_tile_dlhs_batch_seq=512, wo_tile_dlhs_embed_dim=1024, wo_tile_dlhs_mlp_dim=1024, wo_tile_drhs_batch_seq=512, wo_tile_drhs_embed_dim=1024, wo_tile_drhs_mlp_dim=1024, merge_gating_gmm=False, num_experts=1, num_experts_per_tok=1, capacity_factor=-1.0, ragged_buffer_factor=-1.0, moe_expert_input_dim=-1, base_moe_mlp_dim=-1, padded_base_moe_mlp_dim=None, load_balance_loss_weight=0.0, use_custom_sort_vjp=True, use_ring_of_experts=False, use_gather_mosaic_kernel=False, use_random_routing=False, interleave_moe_layer_step=1, moe_fsdp_use_two_stage_all_gather=False, shard_exp_on_fsdp=False, use_2d_fsdp_sharding=False, norm_topk_prob=False, float32_weight_sum=True, float32_gate_logits=False, prefuse_moe_weights=False, pagedattn_num_pages=64, pagedattn_tokens_per_page=32, pagedattn_pages_per_compute_block=4, pagedattn_max_pages_per_group=-1, pagedattn_head_dim_alignment=128, sa_block_q=512, sa_block_kv=512, sa_block_kv_compute=512, sa_block_q_dkv=512, sa_block_kv_dkv=512, sa_block_kv_dkv_compute=512, sa_block_q_dq=512, sa_block_kv_dq=512, sa_use_fused_bwd_kernel=False, sa_q_layout='HEAD_DIM_MINOR', sa_k_layout='HEAD_DIM_MINOR', sa_v_layout='HEAD_DIM_MINOR', use_max_logit_estimate=-1, cost_estimate_flops_fwd=-1, cost_estimate_flops_bwd=-1, dq_reduction_steps=0, use_splash_scheduler=False, use_qk_norm=False, temperature_tuning=False, use_indexer=False, indexer_head_dim=128, indexer_n_heads=64, indexer_topk=2048, indexer_sparse_training=False, indexer_loss_scaling_factor=0.0, moba=False, moba_chunk_size=1024, moba_topk=8, mla_naive_kvcache=True, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, attention='autoselected', attention_type='global', share_kv_projections=False, global_num_kv_heads=0, attention_sink=False, float32_qk_product=False, float32_logits=False, sliding_window_size=0, chunk_attn_window_size=0, attn_logits_soft_cap=None, use_post_attn_norm=False, use_post_ffw_norm=False, use_ragged_attention=False, use_tokamax_gmm=False, ragged_block_size=256, enable_padding_causal_mask=True, use_tokamax_splash=False, use_jax_splash=False, force_q_layout=False, use_qk_clip=False, qk_clip_threshold=100.0, logits_via_embedding=False, normalize_embedding_logits=True, logits_dot_in_fp32=False, cast_logits_to_fp32=True, final_logits_soft_cap=None, z_loss_multiplier=0.0, mtp_num_layers=0, mtp_loss_scaling_factor=0.1, mtp_eval_target_module=0, engram_layers=<factory>, engram_num_heads=8, engram_head_dim=1280, engram_vocab_bases=<factory>, engram_max_ngram_size=3, engram_kernel_size=4, engram_seed=0, decoder_block='llama2', global_parameter_scale=1, base_emb_dim=2048, base_num_query_heads=16, base_num_kv_heads=16, base_mlp_dim=7168, dense_init_scale=1.0, base_num_decoder_layers=16, head_dim=128, attention_output_dim=-1, global_head_dim=0, mlp_activations=['silu', 'linear'], mlp_activations_limit=-1.0, normalization_layer_epsilon=1e-05, fused_qkv=False, attention_bias=False, fused_mlp=False, qk_norm_with_scale=True, v_norm_with_scale=True, quantization=QuantizationType.NONE, replicate_quant_scale=False, quant_cfg_path='', quantize_kvcache=False, kv_quant_axis=KvQuantAxis.HEADS_AND_DKV, kv_quant_dtype='int8', quantization_local_shard_count=-1, use_qwix_quantization=False, use_manual_quantization=False, weight_quantization_calibration_method='absmax', act_quantization_calibration_method='absmax', bwd_quantization_calibration_method='absmax', weight_sparsity_n=None, weight_sparsity_m=None, weight_sparsity_update_step=10, weight_sparsity_start_step=50, dtype=DType.BFLOAT16, grad_dtype=DType.FLOAT32, weight_dtype=DType.FLOAT32, matmul_precision=MatmulPrecision.DEFAULT, activations_in_float32=False, dtype_mm='float32', elastic_enabled=False, elastic_timeout_seconds=300, elastic_max_retries=10, elastic_min_slice_count=-1, enable_multi_tier_checkpointing=False, local_checkpoint_directory='', local_checkpoint_period=0, multi_tier_checkpointing_backup_interval_minutes=0, mtc_data_parallelism=0, enable_emergency_checkpoint=False, use_replicator_service=False, replicator_backup_interval_minutes=0, checkpoint_storage_target_data_file_size_bytes=2147483648, checkpoint_storage_use_ocdbt=True, checkpoint_storage_use_zarr3=True, checkpoint_storage_concurrent_gb=96, load_parameters_path='', lora_input_adapters_path='', load_full_state_path='', enable_checkpointing=True, load_checkpoint_only_once=False, async_checkpointing=True, checkpoint_period=10000, max_num_checkpoints_to_keep=None, enable_single_replica_ckpt_restoring=False, checkpoint_todelete_subdir=None, checkpoint_todelete_full_path=None, force_unroll=False, checkpoint_is_quantized=False, save_quantized_params_path='', enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout='orbax', save_checkpoint_on_completion=True, enable_continuous_checkpointing=False, colocated_python_checkpointing=False, enable_autocheckpoint=False, base_config=None, run_name='', model_name='default', override_model_config=False, override_logical_axis_rules=False, log_config=True, debug_sharding=False, base_output_directory='', sharding_strategy=None, debug=<factory>, rl=<factory>)[source]#
Bases:
RunInfo,Checkpointing,OrbaxStorage,EmergencyCheckpointing,ElasticTraining,DataTypes,Quantization,ModelArchitecture,Engram,MTP,Logits,Attention,MlaAttention,MoBa,AttentionIndexer,Llama4Attention,SplashAttention,PagedAttention,MoEGeneral,MoEKernels,DeepSeekMoE,Qwen3Next,HardwareAndMesh,LayoutAndSharding,DcnParallelism,IciParallelism,PipelineParallelism,RematAndOffload,TrainingLoop,ManifoldConstrainedHyperConnections,DilocoParams,Optimizer,AdamW,Muon,FineTuning,Distillation,RLHardware,VLLM,RL,RLDataset,RLEvaluation,Reward,SpecialTokens,PositionalEmbedding,Rope,YarnRope,DatasetGeneral,TfdsDataset,HfDataset,GrainDataset,OlmoGrainDataset,Tokenizer,InferenceGeneral,Decoding,InferenceLayout,InferenceServer,InferenceBenchmark,PrefixCaching,AOT,DevelopmentAndDebugging,Profiling,HloDump,StackTrace,Metrics,Goodput,GcpMonitoring,Tensorboard,ManagedMLDiagnostics,MultimodalGeneral,VisionTower,VisionProjector,AudioEncoder,DerivedValuesThe main configuration object for MaxText.
This class aggregates all configuration options from modular BaseModel classes into a single, validated object. It is populated by the initialize function. Every field is explicitly defined to prevent misconfigurations (extra=’forbid’).
- Parameters:
emb_dim (None | int)
mlp_dim (None | int)
moe_mlp_dim (None | int)
num_decoder_layers (None | int)
num_kv_heads (None | int)
num_query_heads (None | int)
num_diloco_replicas (None | int)
ici_parallelism (None | list[int])
dcn_parallelism (None | list[int])
using_pipeline_parallelism (None | bool)
context_parallel_size (None | int)
num_target_devices (None | int)
global_batch_size_to_train_on (None | int)
global_batch_size_to_eval_on (None | int)
global_batch_size_to_load (None | int)
global_batch_size_to_load_eval (None | int)
micro_batch_size_to_train_on (None | int)
micro_batch_size_to_eval_on (None | int)
checkpoint_dir (None | str)
convert_checkpoint_if_possible (bool)
metrics_dir (None | str)
tensorboard_dir (None | str)
managed_mldiagnostics_dir (None | str)
rampup_end_step (None | int)
tensors_on_device (None | list[str])
tensors_to_offload (None | list[str])
global_batch_size_to_load_start (None | int)
global_batch_size_to_load_increment (None | int)
rampup_samples_per_increment_to_load (None | float)
d_model_for_audio (int)
encoder_attention_heads_for_audio (int)
encoder_ffn_dim_for_audio (int)
encoder_layers_for_audio (int)
attention_dropout_for_audio (float)
activation_dropout_for_audio (float)
activation_function_for_audio (str)
num_mel_bins_for_audio (int)
max_source_positions_for_audio (int)
scale_embedding_for_audio (bool)
n_window_for_audio (int)
n_window_infer_for_audio (int)
conv_chunksize_for_audio (int)
downsample_hidden_size_for_audio (int)
output_dim_for_audio (int)
num_conv_layers_for_audio (int)
max_timescale_for_audio (float)
max_sample_len_for_audio (int)
projector_input_dim_for_vit (int)
projector_output_dim_for_vit (int)
pixel_shuffle_ratio_for_vit (float)
projector_dropout_for_vit (float)
hidden_size_for_vit (int)
intermediate_size_for_vit (int)
num_attention_heads_for_vit (int)
num_channels_for_vit (int)
tile_size_for_vit (int)
patch_size_for_vit (int)
conv_stride_for_vit (int)
num_hidden_layers_for_vit (int)
rope_theta_for_vit (int)
vision_output_dim_for_vit (int)
spatial_merge_size_for_vit (int)
out_hidden_size_for_vit (int)
temporal_patch_size_for_vit (int)
num_position_embeddings_for_vit (int)
deepstack_visual_indexes_for_vit (list[int])
vision_output_length (int)
use_multimodal (bool)
freeze_vision_encoder_params (bool)
freeze_audio_encoder_params (bool)
use_audio (bool)
image_size_for_vit (int | list[int])
image_path (str)
image_placeholder (str)
posemb_type_for_vit (str)
max_num_images_per_example (int)
video_path (str)
audio_path (str)
video_placeholder (str)
audio_placeholder (str)
use_audio_in_video (bool)
use_mrope (bool)
mrope_section (list[int])
position_id_per_seconds (int)
managed_mldiagnostics (bool)
managed_mldiagnostics_run_group (str)
enable_tensorboard (bool)
use_vertex_tensorboard (bool)
vertex_tensorboard_project (str | None)
vertex_tensorboard_region (str | None)
report_heartbeat_metric_for_gcp_monitoring (bool)
heartbeat_reporting_interval_in_seconds (int)
report_performance_metric_for_gcp_monitoring (bool)
enable_goodput_recording (bool)
monitor_goodput (bool)
goodput_upload_interval_seconds (int)
enable_pathways_goodput (bool)
monitor_step_time_deviation (bool)
step_deviation_interval_seconds (int)
enable_gcp_goodput_metrics (bool)
enable_gcp_step_deviation_metrics (bool)
metrics_file (None | str)
gcs_metrics (bool)
save_config_to_gcs (bool)
record_internal_nn_metrics (int)
prometheus_port (int)
enable_checkpoint_cloud_logger (bool)
enable_tunix_perf_metrics (bool)
collect_stack_trace (bool)
stack_trace_to_cloud (bool)
stack_trace_interval_seconds (int)
dump_hlo (bool)
dump_step (int)
dump_hlo_local_dir (str)
dump_hlo_delete_local_after (bool)
dump_hlo_gcs_dir (str)
dump_hlo_module_name (str)
dump_hlo_local_module_name (str)
dump_hlo_xla_flags (str)
dump_hlo_upload_all (bool)
dump_jaxpr (bool)
dump_jaxpr_local_dir (str)
dump_jaxpr_delete_local_after (bool)
dump_jaxpr_gcs_dir (str)
profiler (ProfilerType)
upload_all_profiler_results (bool)
skip_first_n_steps_for_profiler (int)
profiler_steps (int)
profile_cleanly (bool)
profile_periodically_period (int)
hide_profiler_step_metric (bool)
enable_jax_profiler (bool)
jax_profiler_port (int)
enable_tpu_profiling_options (bool)
tpu_num_chips_to_profile_per_task (int)
tpu_num_sparse_cores_to_trace (int)
tpu_num_sparse_core_tiles_to_trace (int)
xprof_tpu_power_trace_level (XProfTPUPowerTraceMode)
xprof_e2e_enable_fw_throttle_event (bool)
xprof_e2e_enable_fw_power_level_event (bool)
xprof_e2e_enable_fw_thermal_event (bool)
profile_power_events (bool)
constant_bound_config (list)
jax_cache_dir (str | None)
jax_distributed_initialization_timeout (int)
jax_debug_log_modules (str)
skip_jax_distributed_system (bool)
enable_single_controller (bool)
subslice_shape (str)
max_checkify (bool)
compiled_trainstep_file (str)
compile_topology (str)
compile_topology_num_slices (int)
enable_prefix_caching (bool)
prefix_caching_hbm_byte (int)
prefix_caching_dram_byte (int)
inference_microbenchmark_prefill_lengths (str)
inference_microbenchmark_stages (str)
inference_microbenchmark_loop_iters (int)
inference_microbenchmark_log_file_path (str)
inference_microbenchmark_num_samples (list[int])
inference_metadata_file (str)
inference_benchmark_test (bool)
inference_server (str)
prefill_slice (str)
generate_slice (str)
stack_prefill_result_cache (bool)
prefill_cache_axis_order (str)
ar_cache_axis_order (str)
compute_axis_order (str)
reshape_q (bool)
decode_sampling_strategy (SamplingStrategy)
decode_sampling_nucleus_p (int | float)
decode_sampling_top_k (int)
decode_sampling_temperature (float)
max_target_length (int)
max_prefill_predict_length (int)
prompt (str)
load_from_prefill_dir (bool)
prefill_cache_dir (str)
autoregressive_decode_assert (str)
model_call_mode (str)
use_chunked_prefill (bool)
prefill_chunk_size (int)
enable_model_warmup (bool)
enable_llm_inference_pool (bool)
multi_sampling (bool)
return_log_prob (bool)
vocab_size (int)
tokenizer_path (None | str)
tokenizer_type (TokenizerType)
use_chat_template (bool)
chat_template_path (str)
chat_template (str)
tokenize_train_data (bool)
tokenize_eval_data (bool)
add_bos (bool)
add_eos (bool)
use_truncation (bool)
num_vocab_tiling (int)
olmo_index_path (str)
olmo_path_remap_from (str)
olmo_path_remap_to (str)
olmo_apply_ngram_filter (bool)
grain_train_files (str)
grain_eval_files (str)
grain_train_mixture_config_path (str)
grain_file_type (str)
grain_use_elastic_iterator (bool)
grain_worker_count (int)
grain_per_worker_buffer_size (int)
grain_worker_count_eval (int)
grain_per_worker_buffer_size_eval (int)
grain_ram_budget_mb (int)
grain_num_threads (int)
grain_prefetch_buffer_size (int)
grain_num_threads_eval (int)
grain_prefetch_buffer_size_eval (int)
grain_data_source_max_workers (int)
grain_shuffle_buffer_size (int)
hf_path (str)
hf_name (None | str)
hf_data_dir (None | str)
hf_train_files (None | str)
hf_eval_split (None | str)
hf_eval_files (None | str)
hf_access_token (None | str)
dataset_path (str)
dataset_name (str)
eval_dataset_name (str)
train_split (str)
eval_split (str)
dataset_type (DatasetType)
per_device_batch_size (int | float)
eval_per_device_batch_size (int | float)
max_corpus_chars (int)
train_data_columns (list[str])
train_image_column (str | list[str])
eval_data_columns (list[str])
eval_image_column (str | list[str])
packing (bool)
grain_packing_type (Literal['first_fit', 'best_fit', 'concat_then_split'])
max_segments_per_seq (int)
num_epoch (int)
expansion_factor_real_data (float)
reuse_example_batch (int)
generate_padding_batch_train (bool)
generate_padding_batch_eval (bool)
enable_rampup_batch_size (bool)
per_device_batch_size_start (float)
per_device_batch_size_increment (float)
global_rampup_samples (int)
colocated_python_data_input (bool)
max_position_embeddings (int)
original_max_position_embeddings (int)
rope_factor (int)
beta_fast (int)
beta_slow (int)
mscale (float)
rope_interleave (bool)
rope_truncate (bool)
rope_attention_scaling (bool)
rope_type (RopeType)
rope_use_scale (bool)
rope_min_timescale (int)
rope_max_timescale (int)
rope_linear_scaling_factor (float)
local_rope_max_timescale (int)
global_rope_max_timescale (int)
global_rope_proportion (float)
local_rope_proportion (float)
use_iota_embed (bool)
use_untrainable_positional_embedding (bool)
trainable_position_size (int)
nope_layer_interval (int)
reasoning_start_token (str)
reasoning_end_token (str)
solution_start_token (str)
solution_end_token (str)
reward_exact_answer (float)
reward_exact_format_match (float)
reward_white_space_format_match (float)
reward_partial_format_match (float)
reward_ratio_guess_to_answer_high (float)
reward_ratio_guess_to_answer_low (float)
penalty_incorrect_format (float)
penalty_incorrect_answer (float)
math_verify_timeout (int)
math_verify_num_procs (int | None)
eval_sampling_strategy (str)
generation_configs (dict[str, Any])
num_eval_passes (int)
eval_corr_lst (bool)
eval_make_lst (bool)
eval_mode (Literal['pass', 'maj', 'pass_at_1'])
batch_size (int)
num_batches (int)
num_test_batches (int)
test_batch_start_index (int)
train_fraction (float)
train_micro_batch_size (int)
rollout_micro_batch_size (int)
num_generations (int)
num_iterations (int)
grpo_beta (float)
grpo_epsilon (float)
loss_algo (Literal['grpo', 'gspo-token'])
use_agentic_rollout (bool)
max_concurrency (int)
off_policy_steps (int)
system_prompt (str)
degenerate_group_masking (bool)
epsilon_high (float | None)
reshard_chunk_size (int | None)
kv_cache_buffer (int)
hbm_utilization_vllm (float)
swap_space_vllm_gb (int)
enable_dp_attention (bool)
enable_expert_parallel (bool)
async_scheduling (bool)
max_num_batched_tokens (int | None)
max_num_seqs (int | None)
stop_strings (list[str] | None)
vllm_additional_config (dict[str, Any])
vllm_hf_overrides (dict[str, Any])
vllm_hf_config_path (str)
use_standalone_converter (bool)
vllm_load_format (str)
debug_converter (bool)
gcs_debug_path (str)
trainer_devices_fraction (float)
sampler_devices_fraction (float)
chips_per_vm (int)
use_pathways (bool)
num_trainer_slices (int)
num_samplers_slices (int)
rollout_data_parallelism (int)
rollout_tensor_parallelism (int)
rollout_expert_parallelism (int)
student_overrides (dict[str, Any])
teacher_overrides (dict[str, Any])
offline_data_dir (str | None)
distill_alpha (float)
distill_temperature (float)
distill_beta (float)
distill_feature_loss_type (Literal['cosine', 'l2'])
distill_layer_indices (None | list)
distill_alpha_end (float | None)
distill_alpha_schedule (Literal['constant', 'linear', 'cosine'])
distill_temperature_end (float | None)
distill_temperature_schedule (Literal['constant', 'linear', 'cosine'])
distill_beta_end (float | None)
distill_beta_schedule (Literal['constant', 'linear', 'cosine'])
learn_to_init_mode (bool)
lti_use_general_linear_map (bool)
distill_weights_copy_map (dict[str, Any])
distill_student_weights_share_map (dict[str, Any])
student_params_to_update (None | list)
use_dpo (bool)
dpo_label_smoothing (Annotated[float, Ge(ge=0.0), Le(le=1.0)])
dpo_beta (float)
use_sft (bool)
sft_train_on_completion_only (bool)
formatting_func_path (str)
formatting_func_kwargs (dict)
use_grpo (None | bool)
muon_beta (float)
muon_weight_decay (float)
muon_consistent_rms (float | None)
adam_b1 (float)
adam_b2 (float)
adam_eps (float)
adam_eps_root (float)
adam_weight_decay (float)
adamw_mask (list[str])
mu_dtype (str)
opt_type (OptimizerType)
skip_step_on_spikes (bool)
skip_step_interval (Annotated[int, Gt(gt=0)])
skip_step_scaling_factor (float)
gradient_accumulation_steps (Annotated[int, Gt(gt=0)])
use_tunix_gradient_accumulation (bool)
gradient_clipping_threshold (Annotated[float, Ge(ge=0)])
learning_rate (Annotated[float, Ge(ge=0)])
lr_schedule_type (LearningRateScheduleType)
learning_rate_final_fraction (float)
wsd_decay_steps_fraction (Annotated[float, Ge(ge=0.0), Le(le=1.0)])
wsd_decay_style (WsdDecayStyle)
warmup_steps_fraction (Annotated[float, Ge(ge=0.0), Le(le=1.0)])
learning_rate_schedule_steps (Annotated[int, Ge(ge=-1)])
trainable_parameters_mask (list[str])
enable_diloco (bool)
diloco_sync_period (int)
diloco_outer_lr (float)
diloco_outer_momentum (float)
mhc_expansion_rate (Annotated[int, Gt(gt=0)])
sinkhorn_iterations (Annotated[int, Gt(gt=0)])
steps (Annotated[int, Ge(ge=-1)])
log_period (int)
eval_interval (int)
eval_steps (int)
target_eval_loss (float)
abort_on_nan_loss (bool)
abort_on_inf_loss (bool)
enable_dropout (bool)
dropout_rate (Annotated[float, Ge(ge=0.0), Le(le=1.0)])
enable_data_shuffling (bool)
data_shuffle_seed (int)
init_weights_seed (int)
remat_policy (str)
remat_policy_for_vit (str)
decoder_layer_input (RematLocation)
context (RematLocation)
mlpwi (RematLocation)
mlpwi_0 (RematLocation)
mlpwi_1 (RematLocation)
mlpwo (RematLocation)
moe_mlpwi_0 (RematLocation)
moe_mlpwi_1 (RematLocation)
moe_mlpwo (RematLocation)
query_proj (RematLocation)
key_proj (RematLocation)
value_proj (RematLocation)
query_wa_proj (RematLocation)
kv_wa_proj (RematLocation)
qkv_proj (RematLocation)
out_proj (RematLocation)
mla_q (RematLocation)
mla_kv (RematLocation)
attention_out (RematLocation)
engram (RematLocation)
optimizer_memory_host_offload (bool)
parameter_memory_host_offload (bool)
pipeline_fsdp_ag_per_repeat (bool)
num_layers_per_pipeline_stage (int)
num_pipeline_repeats (int)
pipeline_parallel_layers (int)
num_pipeline_microbatches (int)
pipeline_delay_activation_forwarding (bool)
pipeline_fsdp_ag_once (bool)
scan_pipeline_iterations (bool)
scan_pipeline_repeats (bool)
scan_layers_per_stage (bool)
set_remat_policy_on_pipeline_iterations (bool)
set_remat_policy_on_layers_per_stage (bool)
ici_diloco_parallelism (int)
ici_data_parallelism (int)
ici_fsdp_parallelism (int)
ici_fsdp_transpose_parallelism (int)
ici_sequence_parallelism (int)
ici_context_parallelism (int)
ici_context_autoregressive_parallelism (int)
ici_tensor_parallelism (int)
ici_tensor_transpose_parallelism (int)
ici_tensor_sequence_parallelism (int)
ici_autoregressive_parallelism (int)
ici_pipeline_parallelism (int)
ici_expert_parallelism (int)
dcn_diloco_parallelism (int)
dcn_data_parallelism (int)
dcn_fsdp_parallelism (int)
dcn_fsdp_transpose_parallelism (int)
dcn_sequence_parallelism (int)
dcn_context_parallelism (int)
dcn_context_autoregressive_parallelism (int)
dcn_tensor_parallelism (int)
dcn_tensor_transpose_parallelism (int)
dcn_tensor_sequence_parallelism (int)
dcn_pipeline_parallelism (int)
dcn_expert_parallelism (int)
dcn_autoregressive_parallelism (int)
logical_axis_rules (Any)
data_sharding (Any)
context_sharding (str)
input_data_sharding_logical_axes (list[str])
sharding_tolerance (Annotated[float, Ge(ge=0.0), Le(le=1.0)])
shard_optimizer_over_data (bool)
internal_compile (bool)
internal_compile_num_devices (int)
compile_xla_flags (str)
hardware (Literal['tpu', 'gpu', 'gpu_multiprocess', 'cpu'])
num_slices (int)
mesh_axes (list[str])
shard_mode (ShardMode)
inhomogeneous_layer_cycle_interval (int)
scan_layers (bool)
param_scan_axis (int)
context_parallel_load_balance (bool)
context_parallel_strategy (str)
context_parallel_reorder_strategy (ReorderStrategy)
custom_mesh (str)
custom_mesh_and_rule (CustomRule)
allow_split_physical_axes (bool)
enable_nnx (bool)
optimize_mesh_for_tpu_v6e (bool)
shardy (bool)
pure_nnx_decoder (bool)
pure_nnx (bool)
remove_size_one_mesh_axis_from_type (bool)
gdn_conv_kernel_dim (int)
gdn_key_head_dim (int)
gdn_value_head_dim (int)
gdn_num_key_heads (int)
gdn_num_value_heads (int)
gdn_chunk_size (int)
use_qk_norm_in_gdn (bool)
partial_rotary_factor (float)
first_num_dense_layers (Annotated[int, Ge(ge=0)])
shared_experts (Annotated[int, Ge(ge=0)])
routed_scaling_factor (float)
routed_score_func (str)
routed_bias (bool)
routed_bias_update_rate (float)
mlp_bias (bool)
n_routing_groups (int)
topk_routing_group (int)
use_batch_split_schedule (bool)
batch_split_factor (int)
megablox (bool)
sparse_matmul (bool)
wi_tile_fwd_batch_seq (int)
wi_tile_fwd_embed_dim (int)
wi_tile_fwd_mlp_dim (int)
wi_tile_dlhs_batch_seq (int)
wi_tile_dlhs_embed_dim (int)
wi_tile_dlhs_mlp_dim (int)
wi_tile_drhs_batch_seq (int)
wi_tile_drhs_embed_dim (int)
wi_tile_drhs_mlp_dim (int)
wo_tile_fwd_batch_seq (int)
wo_tile_fwd_embed_dim (int)
wo_tile_fwd_mlp_dim (int)
wo_tile_dlhs_batch_seq (int)
wo_tile_dlhs_embed_dim (int)
wo_tile_dlhs_mlp_dim (int)
wo_tile_drhs_batch_seq (int)
wo_tile_drhs_embed_dim (int)
wo_tile_drhs_mlp_dim (int)
merge_gating_gmm (bool)
num_experts (Annotated[int, Gt(gt=0)])
num_experts_per_tok (Annotated[int, Gt(gt=0)])
capacity_factor (float)
ragged_buffer_factor (float)
moe_expert_input_dim (int)
base_moe_mlp_dim (int)
padded_base_moe_mlp_dim (int | None)
load_balance_loss_weight (Annotated[float, Ge(ge=0)])
use_custom_sort_vjp (bool)
use_ring_of_experts (bool)
use_gather_mosaic_kernel (bool)
use_random_routing (bool)
interleave_moe_layer_step (int)
moe_fsdp_use_two_stage_all_gather (bool)
shard_exp_on_fsdp (bool)
use_2d_fsdp_sharding (bool)
norm_topk_prob (bool)
float32_weight_sum (bool)
float32_gate_logits (bool)
prefuse_moe_weights (bool)
pagedattn_num_pages (int)
pagedattn_tokens_per_page (int)
pagedattn_pages_per_compute_block (int)
pagedattn_max_pages_per_group (int)
pagedattn_head_dim_alignment (int)
sa_block_q (int)
sa_block_kv (int)
sa_block_kv_compute (int)
sa_block_q_dkv (int)
sa_block_kv_dkv (int)
sa_block_kv_dkv_compute (int)
sa_block_q_dq (int)
sa_block_kv_dq (int)
sa_use_fused_bwd_kernel (bool)
sa_q_layout (str)
sa_k_layout (str)
sa_v_layout (str)
use_max_logit_estimate (int)
cost_estimate_flops_fwd (int)
cost_estimate_flops_bwd (int)
dq_reduction_steps (int)
use_splash_scheduler (bool)
use_qk_norm (bool)
temperature_tuning (bool)
use_indexer (bool)
indexer_head_dim (Annotated[int, Ge(ge=0)])
indexer_n_heads (Annotated[int, Ge(ge=0)])
indexer_topk (Annotated[int, Ge(ge=0)])
indexer_sparse_training (bool)
indexer_loss_scaling_factor (float)
moba (bool)
moba_chunk_size (int)
moba_topk (int)
mla_naive_kvcache (bool)
q_lora_rank (Annotated[int, Ge(ge=0)])
kv_lora_rank (Annotated[int, Ge(ge=0)])
qk_nope_head_dim (Annotated[int, Ge(ge=0)])
qk_rope_head_dim (Annotated[int, Ge(ge=0)])
v_head_dim (Annotated[int, Ge(ge=0)])
attention (str)
attention_type (Literal['global', 'local_sliding', 'chunk', 'mla', 'full'])
share_kv_projections (bool)
global_num_kv_heads (int)
attention_sink (bool)
float32_qk_product (bool)
float32_logits (bool)
sliding_window_size (Annotated[int, Ge(ge=0)])
chunk_attn_window_size (Annotated[int, Ge(ge=0)])
attn_logits_soft_cap (None | Annotated[float, Ge(ge=0)])
use_post_attn_norm (bool)
use_post_ffw_norm (bool)
use_ragged_attention (bool)
use_tokamax_gmm (bool)
ragged_block_size (int)
enable_padding_causal_mask (bool)
use_tokamax_splash (bool)
use_jax_splash (bool)
force_q_layout (bool)
use_qk_clip (bool)
qk_clip_threshold (float)
logits_via_embedding (bool)
normalize_embedding_logits (bool)
logits_dot_in_fp32 (bool)
cast_logits_to_fp32 (bool)
final_logits_soft_cap (None | Annotated[float, Ge(ge=0)])
z_loss_multiplier (float)
mtp_num_layers (Annotated[int, Ge(ge=0)])
mtp_loss_scaling_factor (Annotated[float, Ge(ge=0)])
mtp_eval_target_module (Annotated[int, Ge(ge=0)])
engram_layers (list[int])
engram_num_heads (int)
engram_head_dim (int)
engram_vocab_bases (list[int])
engram_max_ngram_size (int)
engram_kernel_size (int)
engram_seed (int)
decoder_block (DecoderBlockType)
global_parameter_scale (int)
base_emb_dim (int)
base_num_query_heads (int)
base_num_kv_heads (int)
base_mlp_dim (int)
dense_init_scale (float)
base_num_decoder_layers (int)
head_dim (int)
attention_output_dim (int)
global_head_dim (int)
mlp_activations (list[str])
mlp_activations_limit (float)
normalization_layer_epsilon (float)
fused_qkv (bool)
attention_bias (bool)
fused_mlp (bool)
qk_norm_with_scale (bool)
v_norm_with_scale (bool)
quantization (None | QuantizationType)
replicate_quant_scale (bool)
quant_cfg_path (str)
quantize_kvcache (bool)
kv_quant_axis (KvQuantAxis)
kv_quant_dtype (Literal['int8', 'int4'])
quantization_local_shard_count (int)
use_qwix_quantization (bool)
use_manual_quantization (bool)
weight_quantization_calibration_method (str)
act_quantization_calibration_method (str)
bwd_quantization_calibration_method (str)
weight_sparsity_n (int | None)
weight_sparsity_m (int | None)
weight_sparsity_update_step (int)
weight_sparsity_start_step (int)
dtype (DType)
grad_dtype (DType)
weight_dtype (DType)
matmul_precision (MatmulPrecision)
activations_in_float32 (bool)
dtype_mm (str)
elastic_enabled (bool)
elastic_timeout_seconds (int)
elastic_max_retries (int)
elastic_min_slice_count (int)
enable_multi_tier_checkpointing (bool)
local_checkpoint_directory (str)
local_checkpoint_period (Annotated[int, Ge(ge=0)])
multi_tier_checkpointing_backup_interval_minutes (Annotated[int, Ge(ge=0)])
mtc_data_parallelism (int)
enable_emergency_checkpoint (bool)
use_replicator_service (bool)
replicator_backup_interval_minutes (Annotated[int, Ge(ge=0)])
checkpoint_storage_target_data_file_size_bytes (int)
checkpoint_storage_use_ocdbt (bool)
checkpoint_storage_use_zarr3 (bool)
checkpoint_storage_concurrent_gb (int)
load_parameters_path (str)
lora_input_adapters_path (str)
load_full_state_path (str)
enable_checkpointing (bool)
load_checkpoint_only_once (bool)
async_checkpointing (bool)
checkpoint_period (int)
max_num_checkpoints_to_keep (int | None)
enable_single_replica_ckpt_restoring (bool)
checkpoint_todelete_subdir (str | None)
checkpoint_todelete_full_path (str | None)
force_unroll (bool)
checkpoint_is_quantized (bool)
save_quantized_params_path (str)
enable_orbax_v1 (bool)
checkpoint_conversion_fn (None | str)
source_checkpoint_layout (Literal['orbax', 'safetensors'])
save_checkpoint_on_completion (bool)
enable_continuous_checkpointing (bool)
colocated_python_checkpointing (bool)
enable_autocheckpoint (bool)
base_config (None | str)
run_name (str)
model_name (Literal['default', 'llama2-7b', 'llama2-13b', 'llama2-70b', 'llama3-8b', 'llama3.1-8b-Instruct', 'llama3-70b', 'llama3.1-70b-Instruct', 'llama3.1-8b', 'llama3.1-70b', 'llama3.1-405b', 'llama3.3-70b', 'mistral-7b', 'mixtral-8x7b', 'mixtral-8x22b', 'deepseek2-16b', 'deepseek2-236b', 'deepseek3-671b', 'deepseek3-671b-2dfsdp', 'deepseek3-671b-batchsplit', 'deepseek3-test', 'deepseek3-tiny', 'deepseek3.2-671b', 'deepseek-custom', 'kimi-k2-1t', 'gemma-7b', 'gemma-2b', 'gemma2-2b', 'gemma2-9b', 'gemma2-27b', 'gemma3-4b', 'gemma3-12b', 'gemma3-27b', 'gemma4-26b', 'gemma4-31b', 'qwen2.5-1.5b', 'qwen2.5-7b', 'qwen2.5-14b', 'qwen3-0.6b', 'qwen3-1.7b', 'qwen3-1.7b-base', 'qwen3-4b', 'qwen3-4b-base', 'qwen3-4b-thinking-2507', 'qwen3-8b', 'qwen3-8b-base', 'qwen3-14b', 'qwen3-14b-base', 'qwen3-32b', 'qwen3-235b-a22b', 'qwen3-30b-a3b', 'qwen3-30b-a3b-base', 'qwen3-480b-a35b', 'qwen3-next-80b-a3b', 'qwen3-omni-30b-a3b', 'qwen3-custom-30b-a3b', 'qwen3.5-397b-a17b', 'gpt3-175b', 'gpt3-22b', 'gpt3-6b', 'gpt3-52k', 'gpt-oss-20b', 'gpt-oss-120b', 'llama4-17b-16e', 'llama4-17b-128e', 'olmo3-7b', 'olmo3-7b-pt', 'olmo3-32b'])
override_model_config (bool)
override_logical_axis_rules (bool)
log_config (bool)
debug_sharding (bool)
base_output_directory (str)
sharding_strategy (None | Literal['experimental'])
debug (Debug)
rl (RL)
- model_config: ClassVar[ConfigDict] = {'extra': 'forbid', 'protected_namespaces': ()}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- classmethod load_model_specific_defaults(values)[source]#
This method is a no-op because pyconfig handles model-specific config loading.
- Parameters:
values (dict[str, Any])
- Return type:
dict[str, Any]