maxtext.layers package#
Submodules#
- maxtext.layers.attention_mla module
- maxtext.layers.attention_op module
validate_compute_axis_order()apply_mask_to_logits()validate_gpu_flash_attention()ChunkedCausalMaskattention_op_as_linen()AttentionOpAttentionOp.check_attention_inputs()AttentionOp.generate_attention_mask()AttentionOp.calculate_moba_gate_logic()AttentionOp.generate_moba_mask_single_item()AttentionOp.apply_attention()AttentionOp.gpu_ragged_attention()AttentionOp.tpu_ragged_attention()AttentionOp.tpu_flash_attention()AttentionOp.cudnn_flash_attention()AttentionOp.cudnn_jax_flash_attention()AttentionOp.compute_local_attention()AttentionOp.is_partition_in_decode()AttentionOp.apply_attention_dot()AttentionOp.qk_product()AttentionOp.wv_product()AttentionOp.reverse_transepose()AttentionOp.normalize_cudnn_attention()AttentionOp.normalize_attention()
LoadBalancedCausalMask
- maxtext.layers.attentions module
L2Norml2_norm_as_linen()attention_as_linen()AttentionAttention.configAttention.num_query_headsAttention.num_kv_headsAttention.head_dimAttention.max_target_lengthAttention.meshAttention.attention_kernelAttention.inputs_q_shapeAttention.inputs_kv_shapeAttention.dtypeAttention.weight_dtypeAttention.max_prefill_predict_lengthAttention.dropout_rateAttention.kernel_initAttention.float32_qk_productAttention.float32_logitsAttention.quantAttention.kv_quantAttention.attention_typeAttention.attn_logits_soft_capAttention.init_query_w()Attention.query_projection()Attention.init_kv_w()Attention.kv_projection()Attention.init_qkv_w()Attention.qkv_projection()Attention.out_head_dimAttention.init_out_w()Attention.out_projection()Attention.convert_dense_general_inputs_shape()Attention.init_rotary_embedding()Attention.apply_rotary_embedding()Attention.init_kv_caches()Attention.update_kv_caches()Attention.forward_serve_vllm()
- maxtext.layers.decoders module
DecoderLayerSequentialBlockDecoderLayersSequentialBlockDecoderLayers.decoder_layerSequentialBlockDecoderLayers.num_decoder_layersSequentialBlockDecoderLayers.configSequentialBlockDecoderLayers.meshSequentialBlockDecoderLayers.quantSequentialBlockDecoderLayers.model_modeSequentialBlockDecoderLayers.nameSequentialBlockDecoderLayers.parentSequentialBlockDecoderLayers.scope
deepstack_process()DecoderDecoder.configDecoder.meshDecoder.quantDecoder.model_modeDecoder.setup()Decoder.minimal_policy()Decoder.get_remat_policy()Decoder.get_decoder_layers()Decoder.set_remat_policy()Decoder.get_norm_layer()Decoder.scan_decoder_layers()Decoder.get_pipeline_stage_module()Decoder.apply_output_head()Decoder.nameDecoder.parentDecoder.scope
- maxtext.layers.embeddings module
embed_as_linen()Embedattend_on_embedding()rotary_embedding_as_linen()RotaryEmbeddingllama_rotary_embedding_as_linen()partial_rotary_embedding_as_linen()PartialRotaryEmbeddingGemma4PartialRotaryEmbeddingLLaMARotaryEmbeddingyarn_rotary_embedding_as_linen()YarnRotaryEmbeddingYarnRotaryEmbedding.embedding_dimsYarnRotaryEmbedding.max_position_embeddingsYarnRotaryEmbedding.original_max_position_embeddingsYarnRotaryEmbedding.beta_fastYarnRotaryEmbedding.beta_slowYarnRotaryEmbedding.rope_thetaYarnRotaryEmbedding.rope_factorYarnRotaryEmbedding.cast_as_fprop_dtypeYarnRotaryEmbedding.fprop_dtypeYarnRotaryEmbedding.rope_interleaveYarnRotaryEmbedding.rope_truncateYarnRotaryEmbedding.rope_attention_scalingYarnRotaryEmbedding.rngsYarnRotaryEmbedding.freqs_cis
positional_embedding_as_linen()PositionalEmbeddingllama_vision_rotary_embedding_as_linen()LlamaVisionRotaryEmbeddingLlamaVisionRotaryEmbedding.image_sizeLlamaVisionRotaryEmbedding.patch_sizeLlamaVisionRotaryEmbedding.hidden_sizeLlamaVisionRotaryEmbedding.num_attention_headsLlamaVisionRotaryEmbedding.rope_thetaLlamaVisionRotaryEmbedding.cast_as_fprop_dtypeLlamaVisionRotaryEmbedding.fprop_dtypeLlamaVisionRotaryEmbedding.rngsLlamaVisionRotaryEmbedding.freqs_cis
Qwen3OmniMoeVisionRotaryEmbeddingQwen3OmniMoeVisionRotaryEmbedding.hidden_sizeQwen3OmniMoeVisionRotaryEmbedding.num_attention_headsQwen3OmniMoeVisionRotaryEmbedding.spatial_merge_sizeQwen3OmniMoeVisionRotaryEmbedding.rope_thetaQwen3OmniMoeVisionRotaryEmbedding.cast_as_fprop_dtypeQwen3OmniMoeVisionRotaryEmbedding.fprop_dtypeQwen3OmniMoeVisionRotaryEmbedding.rngsQwen3OmniMoeVisionRotaryEmbedding.compute_cos_sin()
qwen3omnimoe_vision_pos_embed_interpolate_as_linen()Qwen3OmniMoeVisionPosEmbedInterpolateQwen3OmniMoeVisionPosEmbedInterpolate.num_position_embeddingsQwen3OmniMoeVisionPosEmbedInterpolate.hidden_sizeQwen3OmniMoeVisionPosEmbedInterpolate.spatial_merge_sizeQwen3OmniMoeVisionPosEmbedInterpolate.dtypeQwen3OmniMoeVisionPosEmbedInterpolate.cast_as_fprop_dtypeQwen3OmniMoeVisionPosEmbedInterpolate.fprop_dtypeQwen3OmniMoeVisionPosEmbedInterpolate.rngs
Qwen3OmniMoeThinkerTextRotaryEmbeddingqwen3_omni_mrope_embedding_as_linen()
- maxtext.layers.encoders module
- maxtext.layers.engram module
- maxtext.layers.initializers module
- maxtext.layers.learn_to_init_layer module
- maxtext.layers.linears module
- maxtext.layers.mhc module
- maxtext.layers.moe module
get_batchsplit_init_kernel_axes()random_routing()calculate_load_balance_updates()GateLogitRoutedMoERoutedMoE.get_expert_parallelism_size()RoutedMoE.get_tensor_parallelism_size()RoutedMoE.get_tensor_transpose_parallelism_size()RoutedMoE.get_context_autoregressive_parallelism_size()RoutedMoE.should_update_load_balance()RoutedMoE.get_topk()RoutedMoE.deepseek_scale_weights()RoutedMoE.expert_group_mask()RoutedMoE.deepseek_routing()RoutedMoE.apply_ffn_activation()RoutedMoE.permute()RoutedMoE.unpermute()RoutedMoE.local_permute()RoutedMoE.get_all_to_all_params()RoutedMoE.transform_bias()RoutedMoE.get_ragged_buffer_size()RoutedMoE.sparse_matmul()RoutedMoE.reshape_and_update_weights()RoutedMoE.get_context_partition_and_sub_seq()RoutedMoE.generate_masks_subgroup()RoutedMoE.generate_masks()RoutedMoE.load_balance_loss()RoutedMoE.get_einsum()RoutedMoE.maybe_all_gather_kernel_weight_in_expert_parallelism()RoutedMoE.dense_matmul()RoutedMoE.fused_moe_matmul()RoutedMoE.retrieve_quantized_weight()
RoutedAndSharedMoEget_gate_logit()get_routed_moe()get_routed_and_shared_moe()
- maxtext.layers.multi_token_prediction module
- maxtext.layers.nnx_decoders module
- maxtext.layers.nnx_wrappers module
- maxtext.layers.normalizations module
- maxtext.layers.pipeline module
PipelineBasePipelineBase.configPipelineBase.layersPipelineBase.meshPipelineBase.remat_policyPipelineBase.setup()PipelineBase.need_circ_storage()PipelineBase.iterations_to_complete_first_microbatch_one_repeat()PipelineBase.iterations_to_complete_first_microbatch()PipelineBase.get_iteration_inputs()PipelineBase.get_microbatch_and_repeat_ids()PipelineBase.get_pipeline_remat_policy()PipelineBase.get_weight_sharding()PipelineBase.get_vmap_func_for_init()PipelineBase.get_main_vmap_func_for_iterations()PipelineBase.namePipelineBase.parentPipelineBase.scope
PipelinePipeline.init_states()Pipeline.shard_dim_by_stages()Pipeline.vmap_parallel_gather()Pipeline.vmap_gather()Pipeline.get_new_loop_state()Pipeline.permute_output_micro_per_stage_dim()Pipeline.get_current_stage_weights()Pipeline.get_current_repeat_from_stages()Pipeline.run_one_iteration()Pipeline.get_logical_spec_repeats_removed()Pipeline.all_gather_over_fsdp()Pipeline.namePipeline.parentPipeline.scope
CircularPipelineCircularPipeline.init_states()CircularPipeline.gather_weights_across_stages_vmap()CircularPipeline.gather_microbatch_inputs_vmap()CircularPipeline.advance_circular_buffers()CircularPipeline.realign_output_microbatches()CircularPipeline.fetch_active_stage_weights()CircularPipeline.get_current_weights_from_bsw()CircularPipeline.from_all_variables_to_repeat_weights()CircularPipeline.from_repeat_weights_to_bsw()CircularPipeline.weight_prefetching()CircularPipeline.run_one_iteration()CircularPipeline.nameCircularPipeline.parentCircularPipeline.scope
create_pipeline()
- maxtext.layers.pipeline_deprecated module
PipelinePipeline.configPipeline.layersPipeline.meshPipeline.remat_policyPipeline.setup()Pipeline.need_circ_storage()Pipeline.iterations_to_complete_first_microbatch_one_repeat()Pipeline.iterations_to_complete_first_microbatch()Pipeline.init_states()Pipeline.get_iteration_inputs()Pipeline.shard_dim_by_stages()Pipeline.get_microbatch_and_repeat_ids()Pipeline.vmap_parallel_gather()Pipeline.vmap_gather()Pipeline.get_new_loop_state()Pipeline.permute_output_micro_per_stage_dim()Pipeline.get_current_stage_weights()Pipeline.get_current_repeat_from_stages()Pipeline.get_vmap_func_for_init()Pipeline.get_main_vmap_func_for_iterations()Pipeline.run_one_iteration()Pipeline.get_pipeline_remat_policy()Pipeline.get_weight_sharding()Pipeline.get_logical_spec_repeats_removed()Pipeline.all_gather_over_fsdp()Pipeline.namePipeline.parentPipeline.scope
- maxtext.layers.quantizations module
QuantizationAqtQuantizationQwixQuantizationQwixDotGeneralQwixEinsumFp8QuantizationFp8EinsumNANOOFp8QuantizationConstantBoundConfigPerTensorScalesin_convert_mode()in_serve_mode()get_quant_mode()configure_quantization()match_aqt_and_unquantized_param()remove_quantized_params()configure_kv_quant()NvidaFp8ProviderNANOOFp8Providerget_fp8_full_qwix_rule_w_sparsity()get_quantization_rule()get_qt_provider()maybe_quantize_model()manual_quantize()TransformerEngineQuantization
- maxtext.layers.train_state_nnx module