maxtext.models.llama4 module#

Llama4 decoder layer definition.

class maxtext.models.llama4.Llama4UnfoldConvolution(*args, **kwargs)[source]#

Bases: Module

implementation of Llama4UnfoldConvolution for Llama4 Multi modal model.

This module extracts patches from input images and projects them to hidden dimension.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

maxtext.models.llama4.pixel_shuffle(input_tensor, shuffle_ratio)[source]#

Apply pixel shuffle operation to the input tensor.

Parameters:
  • input_tensor (Array)

  • shuffle_ratio (float)

Return type:

Array

class maxtext.models.llama4.Llama4VisionMLP(*args, **kwargs)[source]#

Bases: Module

MLP block for Llama4EncoderLayer.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

class maxtext.models.llama4.Llama4VisionMLP2(*args, **kwargs)[source]#

Bases: Module

MLP block for Llama4VisionPixelShuffleMLP.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

class maxtext.models.llama4.Llama4VisionPixelShuffleMLP(*args, **kwargs)[source]#

Bases: Module

Implementation of Llama4VisionPixelShuffleMLP for Llama4 Multi modal model.

This module applies pixel shuffle operation and MLP to encoded patches.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

class maxtext.models.llama4.Llama4MultiModalProjector(*args, **kwargs)[source]#

Bases: Module

Implementation of Llama4MultiModalProjector for Llama4 Multi modal model.

This module projects vision features to text hidden dimension.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

maxtext.models.llama4.llama4multimodalprojector_as_linen(config, mesh)[source]#
Parameters:
  • config (Any)

  • mesh (Mesh)

maxtext.models.llama4.determine_is_nope_layer(layer_id, nope_layer_interval)[source]#

Determines whether the given layer at layer_id should use RoPE or not (NoPE).

Parameters:
  • layer_id (int) – The index of the layer.

  • nope_layer_interval (int) – The interval at which layers should use NoPE.

Returns:

True if the layer should use NoPE, False otherwise.

Return type:

bool

maxtext.models.llama4.determine_is_moe_layer(layer_id, interleave_moe_layer_step)[source]#

Determines whether the given layer at layer_id is MoE layer.

This function implements a striding pattern. For example: - If moe_layer_stride is 1, all layers are MoE layers. - If moe_layer_stride is 2, layers with index 1, 3, 5, … are MoE layers.

Parameters:
  • layer_id (int) – The 0-based index of the layer being checked.

  • interleave_moe_layer_step (int) – The interval or stride for placing MoE layers.

Returns:

True if the layer is MoE layer, False otherwise.

Return type:

bool

class maxtext.models.llama4.Llama4DecoderLayer(*args, **kwargs)[source]#

Bases: Module

Transformer decoder layer for Llama4.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

property moe_block#
class maxtext.models.llama4.Llama4ScannableBlock(*args, **kwargs)[source]#

Bases: Module

A repeatable block given nope_layer_interval and interleave_moe_layer_step.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.llama4.Llama4VisionEncoderLayer(*args, **kwargs)[source]#

Bases: Module

Transformer encoder layer for Llama4 vision model.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.llama4.Llama4VisionEncoder(*args, **kwargs)[source]#

Bases: Module

Transformer encoder consisting of multiple Llama4VisionEncoderLayer layers.

This encoder is based on the PyTorch reference implementation and uses multiple encoder layers to process vision input.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

mesh#

Mesh, JAX device mesh (used for sharding)

class maxtext.models.llama4.Llama4VisionModel(*args, **kwargs)[source]#

Bases: Module

Llama4 vision model for processing image inputs.

This model extracts patches from input image tiles and processes them through Llama4VisionEncoder and other vision-specific layers.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

config#

Config containing model parameters

mesh#

Mesh, JAX device mesh (used for sharding)

maxtext.models.llama4.llama4visionmodel_as_linen(config, mesh)[source]#
Parameters:
  • config (Any)

  • mesh (Mesh)

Return type:

Module