maxtext.models.llama4 module#
Llama4 decoder layer definition.
- class maxtext.models.llama4.Llama4UnfoldConvolution(*args, **kwargs)[source]#
Bases:
Moduleimplementation of Llama4UnfoldConvolution for Llama4 Multi modal model.
This module extracts patches from input images and projects them to hidden dimension.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- maxtext.models.llama4.pixel_shuffle(input_tensor, shuffle_ratio)[source]#
Apply pixel shuffle operation to the input tensor.
- Parameters:
input_tensor (Array)
shuffle_ratio (float)
- Return type:
Array
- class maxtext.models.llama4.Llama4VisionMLP(*args, **kwargs)[source]#
Bases:
ModuleMLP block for Llama4EncoderLayer.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- class maxtext.models.llama4.Llama4VisionMLP2(*args, **kwargs)[source]#
Bases:
ModuleMLP block for Llama4VisionPixelShuffleMLP.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- class maxtext.models.llama4.Llama4VisionPixelShuffleMLP(*args, **kwargs)[source]#
Bases:
ModuleImplementation of Llama4VisionPixelShuffleMLP for Llama4 Multi modal model.
This module applies pixel shuffle operation and MLP to encoded patches.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- class maxtext.models.llama4.Llama4MultiModalProjector(*args, **kwargs)[source]#
Bases:
ModuleImplementation of Llama4MultiModalProjector for Llama4 Multi modal model.
This module projects vision features to text hidden dimension.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- maxtext.models.llama4.llama4multimodalprojector_as_linen(config, mesh)[source]#
- Parameters:
config (Any)
mesh (Mesh)
- maxtext.models.llama4.determine_is_nope_layer(layer_id, nope_layer_interval)[source]#
Determines whether the given layer at layer_id should use RoPE or not (NoPE).
- Parameters:
layer_id (int) – The index of the layer.
nope_layer_interval (int) – The interval at which layers should use NoPE.
- Returns:
True if the layer should use NoPE, False otherwise.
- Return type:
bool
- maxtext.models.llama4.determine_is_moe_layer(layer_id, interleave_moe_layer_step)[source]#
Determines whether the given layer at layer_id is MoE layer.
This function implements a striding pattern. For example: - If moe_layer_stride is 1, all layers are MoE layers. - If moe_layer_stride is 2, layers with index 1, 3, 5, … are MoE layers.
- Parameters:
layer_id (int) – The 0-based index of the layer being checked.
interleave_moe_layer_step (int) – The interval or stride for placing MoE layers.
- Returns:
True if the layer is MoE layer, False otherwise.
- Return type:
bool
- class maxtext.models.llama4.Llama4DecoderLayer(*args, **kwargs)[source]#
Bases:
ModuleTransformer decoder layer for Llama4.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- property moe_block#
- class maxtext.models.llama4.Llama4ScannableBlock(*args, **kwargs)[source]#
Bases:
ModuleA repeatable block given nope_layer_interval and interleave_moe_layer_step.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.llama4.Llama4VisionEncoderLayer(*args, **kwargs)[source]#
Bases:
ModuleTransformer encoder layer for Llama4 vision model.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.llama4.Llama4VisionEncoder(*args, **kwargs)[source]#
Bases:
ModuleTransformer encoder consisting of multiple Llama4VisionEncoderLayer layers.
This encoder is based on the PyTorch reference implementation and uses multiple encoder layers to process vision input.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- mesh#
Mesh, JAX device mesh (used for sharding)
- class maxtext.models.llama4.Llama4VisionModel(*args, **kwargs)[source]#
Bases:
ModuleLlama4 vision model for processing image inputs.
This model extracts patches from input image tiles and processes them through Llama4VisionEncoder and other vision-specific layers.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- config#
Config containing model parameters
- mesh#
Mesh, JAX device mesh (used for sharding)