maxtext package#
MaxText is a high performance, highly scalable, open-source LLM written in pure Python/Jax and targeting Google Cloud TPUs and GPUs for training and inference. MaxText achieves high MFUs and scales from single host to very large clusters while staying simple and “optimization-free” thanks to the power of Jax and the XLA compiler.
Subpackages#
- maxtext.checkpoint_conversion package
- maxtext.configs package
- Submodules
- maxtext.configs.pyconfig module
- maxtext.configs.pyconfig_deprecated module
yaml_key_to_env_key()string_to_bool()validate_compute_axis_order()validate_shard_mode()validate_kv_quant_axis()validate_attention_kernel()validate_attention_type()validate_moba_attention()validate_attention_window_params()validate_profiler_type()validate_periodic_profiler()validate_model_call_mode()validate_prefill_and_target_lengths()validate_rope_type()validate_expert_shard_attention_option()validate_vocab_tiling()validate_rampup_batch_size()validate_context_parallel_strategy_ring()validate_keys()validate_tokenizer()validate_constant_bound()validate_quantization_methods()validate_tokamax_usage()validate_data_input()validate_llama4_config()validate_model_name()validate_multimodal_model_name()validate_no_keys_overwritten_twice()validate_and_assign_remat_tensors()resolve_config_path()create_parallelisms_list()set_mu_dtype()validate_and_set_hlo_dump_defaults()validate_multiple_slices()set_and_validate_pipeline_config()validate_deepseek_moe()validate_mlp_dim()validate_gpt_oss_moe()validate_sparse_matmul_parallelism()validate_ring_of_experts_parallelism()validate_shard_expert_on_fsdp()validate_ragged_dot()validate_optimizer_sharding_over_data()create_new_logical_axis_rules()update_model_keys()validate_and_update_keys()get_individual_scales()calculate_global_batch_sizes()calculate_rampup_samples_and_steps()get_num_target_devices()get_quantization_local_shard_count()get_context_parallel_size()using_pipeline_parallelism()using_tensor_parallelism()using_sequence_parallelism()using_expert_parallelism()using_fsdp_and_transpose_parallelism()HyperParametersinitialize()
- maxtext.configs.types module
XProfTPUPowerTraceModeDTypeMatmulPrecisionQuantizationTypeKvQuantAxisRematPolicyRematLocationOptimizerTypeLearningRateScheduleTypeWsdDecayStyleRopeTypeTokenizerTypeDatasetTypeSamplingStrategyProfilerTypeRunInfoCheckpointingOrbaxStorageEmergencyCheckpointingDataTypesQuantizationModelArchitectureMTPLogitsAttentionMoBaMlaAttentionAttentionIndexerLlama4AttentionSplashAttentionPagedAttentionMoEGeneralMoEKernelsDeepSeekMoEQwen3NextHardwareAndMeshLayoutAndShardingDcnParallelismIciParallelismPipelineParallelismRematAndOffloadTokenizerDatasetGeneralTfdsDatasetHfDatasetGrainDatasetOlmoGrainDatasetFineTuningDistillationTrainingLoopManifoldConstrainedHyperConnectionsDilocoParamsOptimizerAdamWMuonPositionalEmbeddingRopeYarnRopeInferenceGeneralDecodingInferenceLayoutInferenceServerInferenceBenchmarkPrefixCachingAOTDevelopmentAndDebuggingProfilingHloDumpStackTraceMetricsManagedMLDiagnosticsGoodputElasticTrainingGcpMonitoringTensorboardMultimodalGeneralVisionTowerVisionProjectorAudioEncoderDebugRLHardwareVLLMRLRLDatasetRLEvaluationRewardSpecialTokensEngramDerivedValuesget_individual_scales()MaxTextConfig
- Submodules
- maxtext.experimental package
- maxtext.input_pipeline package
- Subpackages
- Submodules
- maxtext.input_pipeline.data_processing_utils module
- maxtext.input_pipeline.distillation_data_processing module
- maxtext.input_pipeline.grain_data_processing module
- maxtext.input_pipeline.grain_tokenizer module
- maxtext.input_pipeline.hf_data_processing module
- maxtext.input_pipeline.input_pipeline_interface module
- maxtext.input_pipeline.input_pipeline_utils module
normalize_features()get_tokenizer()truncate_to_max_allowable_length()shift_data_by_truncation()add_segmentation_and_position()TokenizeOp()reformat_prompt()reformat_response()merge_image_columns()pre_process_image_sft()prepare_text_for_image_fusion()combine_columns()is_conversational()extract_token_ids()verify_chat_template_generation_prompt_logic()apply_chat_template()tokenization()SFTPromptMaskingSFTPromptMaskingVisionHFNormalizeFeaturesHFDataSourceGCSTFRecordIterDatasetmake_tfrecord_iter_dataset()ParseFeaturesNormalizeFeaturesKeepFeaturesRekeyReformatPackingPadOrTrimToMaxLengthExtractImagesAndMasksFoldImagesIntoBatchshift_right()shift_left()shift_and_refine()ShiftDataComputeQwen3OmniPositions
- maxtext.input_pipeline.instruction_data_processing module
- maxtext.input_pipeline.multihost_dataloading module
- maxtext.input_pipeline.olmo_data module
- maxtext.input_pipeline.olmo_data_grain module
- maxtext.input_pipeline.olmo_grain_data_processing module
- maxtext.input_pipeline.synthetic_data_processing module
- maxtext.input_pipeline.tfds_data_processing module
- maxtext.input_pipeline.tfds_data_processing_c4_mlperf module
- maxtext.input_pipeline.tokenizer module
- maxtext.integration package
- maxtext.kernels package
- maxtext.layers package
- Submodules
- maxtext.layers.attention_mla module
- maxtext.layers.attention_op module
- maxtext.layers.attentions module
- maxtext.layers.decoders module
- maxtext.layers.embeddings module
embed_as_linen()Embedattend_on_embedding()rotary_embedding_as_linen()RotaryEmbeddingllama_rotary_embedding_as_linen()partial_rotary_embedding_as_linen()PartialRotaryEmbeddingGemma4PartialRotaryEmbeddingLLaMARotaryEmbeddingyarn_rotary_embedding_as_linen()YarnRotaryEmbeddingpositional_embedding_as_linen()PositionalEmbeddingllama_vision_rotary_embedding_as_linen()LlamaVisionRotaryEmbeddingQwen3OmniMoeVisionRotaryEmbeddingqwen3omnimoe_vision_pos_embed_interpolate_as_linen()Qwen3OmniMoeVisionPosEmbedInterpolateQwen3OmniMoeThinkerTextRotaryEmbeddingqwen3_omni_mrope_embedding_as_linen()
- maxtext.layers.encoders module
- maxtext.layers.engram module
- maxtext.layers.initializers module
- maxtext.layers.learn_to_init_layer module
- maxtext.layers.linears module
- maxtext.layers.mhc module
- maxtext.layers.moe module
- maxtext.layers.multi_token_prediction module
- maxtext.layers.nnx_decoders module
- maxtext.layers.nnx_wrappers module
- maxtext.layers.normalizations module
- maxtext.layers.pipeline module
- maxtext.layers.pipeline_deprecated module
- maxtext.layers.quantizations module
QuantizationAqtQuantizationQwixQuantizationQwixDotGeneralQwixEinsumFp8QuantizationFp8EinsumNANOOFp8QuantizationConstantBoundConfigPerTensorScalesin_convert_mode()in_serve_mode()get_quant_mode()configure_quantization()match_aqt_and_unquantized_param()remove_quantized_params()configure_kv_quant()NvidaFp8ProviderNANOOFp8Providerget_fp8_full_qwix_rule_w_sparsity()get_quantization_rule()get_qt_provider()maybe_quantize_model()manual_quantize()TransformerEngineQuantization
- maxtext.layers.train_state_nnx module
- Submodules
- maxtext.models package
- Submodules
- maxtext.models.deepseek module
- maxtext.models.deepseek_batchsplit module
scheduling_group()fetch_weights()split()merge()extract_layer_weights()insert_layer_ws_grad()gather_weights()reduce_scatter_ws_grad()all_reduce_ws_grad_dcn()init_splash_kernel()tpu_flash_attention()tpu_flash_attention_bwd()scan_batch_split_layers()batch_split_schedule()batch_split_schedule_bwd()staggered_call()dot()mla_with_norms()mla_with_norms_remat()mla_with_norms_bwd()mla()mla_remat()mla_bwd()query_projection()kv_projection()get_key_value()rms_norm()initialize_yarn_mask()initialize_yarn_freqs()yarn()shared_expert_and_route()shared_expert()expert_group_mask()expert_indices_and_weights()expert_selection()route()unroute()route_impl_fwd()route_impl_bwd()unroute_impl_fwd()unroute_impl_bwd()gmm()compute_gating()compute_linear()route_compute_unroute()unroute_ubatch_shard_mapped()unroute_ubatch_fn()unroute_ubatch_remat_and_bwd_shard_mapped()unroute_ubatch_fn_remat()unroute_ubatch_fn_bwd()sum_grads()route_compute_unroute_bwd()moe()moe_bwd()
- maxtext.models.deepseek_batchsplit_fp8 module
fetch_weights()split()merge()gather_weights()scan_batch_split_layers()batch_split_schedule()staggered_call()with_data_parallel_constraint()dot()mla_with_norms()mla()query_projection()kv_projection()get_key_value()rms_norm()yarn()moe()expert_indices_and_weights()expert_selection()route()unroute()compute()route_compute_unroute()process_activations()
- maxtext.models.gemma module
- maxtext.models.gemma2 module
- maxtext.models.gemma3 module
- maxtext.models.gemma4 module
- maxtext.models.gemma4_vision module
- maxtext.models.gpt3 module
- maxtext.models.gpt_oss module
- maxtext.models.llama2 module
- maxtext.models.llama4 module
Llama4UnfoldConvolutionpixel_shuffle()Llama4VisionMLPLlama4VisionMLP2Llama4VisionPixelShuffleMLPLlama4MultiModalProjectorllama4multimodalprojector_as_linen()determine_is_nope_layer()determine_is_moe_layer()Llama4DecoderLayerLlama4ScannableBlockLlama4VisionEncoderLayerLlama4VisionEncoderLlama4VisionModelllama4visionmodel_as_linen()
- maxtext.models.mistral module
- maxtext.models.mixtral module
- maxtext.models.models module
- maxtext.models.olmo3 module
- maxtext.models.qwen2 module
- maxtext.models.qwen3 module
naive_jax_chunk_gated_delta_rule()jax_chunk_gated_delta_rule()Qwen3NextGatedDeltaNetQwen3NextFullAttentionQwen3NextSparseMoeBlockQwen3NextScannableBlockQwen3NextDecoderLayerAttentionWithNormQwen3DecoderLayerQwen3MoeDecoderLayerQwen3OmniMoeVisionPatchMergerQwen3OmniMoeVisionMLPQwen3OmniMoeVisionPatchEmbedQwen3OmniMoeVisionAttentionQwen3OmniMoeVisionBlockQwen3OmniMoeVisionEncoderQwen3OmniMoeVisionProjectorqwen3omni_visionencoder_as_linen()qwen3omni_visionprojector_as_linen()Qwen3OmniAudioEncoderLayerQwen3OmniAudioEncoderQwen3OmniAudioProjectorqwen3omni_audioencoder_as_linen()qwen3omni_audioprojector_as_linen()
- maxtext.models.qwen3_5 module
- maxtext.models.qwen3_custom module
- maxtext.models.simple_layer module
- Submodules
- maxtext.multimodal package
- Submodules
- maxtext.multimodal.processor module
- maxtext.multimodal.processor_gemma3 module
- maxtext.multimodal.processor_gemma4 module
- maxtext.multimodal.processor_llama4 module
Llama4PreprocessorOutputget_factors()find_supported_resolutions()get_best_resolution()pad_to_best_fit_jax()pad_to_max_tiles()split_to_tiles()preprocess_mm_data_llama4()get_num_tokens_for_this_image()get_image_offsets_llama4()reformat_prompt_llama4()get_tokens_for_this_image()add_extra_tokens_for_images_llama4()get_dummy_image_shape_for_init_llama4()
- maxtext.multimodal.processor_qwen3_omni module
Qwen3OmniPreprocessorOutputsmart_resize()pre_process_qwen3_image()calculate_video_frame_range()smart_nframes()preprocess_video()pre_process_audio_qwen3_omni()preprocess_mm_data_qwen3_omni()add_extra_tokens_for_qwen3_omni()get_dummy_image_shape_for_init_qwen3_omni()get_dummy_audio_shape_for_init_qwen3_omni()get_llm_pos_ids_for_vision()get_chunked_index()get_rope_index()reformat_prompt_qwen3_omni()get_mm_offsets_qwen3_omni()
- maxtext.multimodal.utils module
- Submodules