maxtext.multimodal.processor_llama4 module#

Llama4-specific utilities for multimodal features.

class maxtext.multimodal.processor_llama4.Llama4PreprocessorOutput(pixel_values=None, pixel_mask=None, aspect_ratios=None, num_images=0, audio_values=None, audio_mask=None)[source]#

Bases: PreprocessorOutput

Holds the output of Llama4 image preprocessor.

Parameters:
  • pixel_values (None | ndarray)

  • pixel_mask (None | ndarray)

  • aspect_ratios (None | ndarray)

  • num_images (int)

  • audio_values (None | ndarray)

  • audio_mask (None | ndarray)

Inherited from `mm_utils.PreprocessorOutput`.
num_images: int = 0#
pixel_values: None | ndarray = None#
pixel_mask: None | ndarray = None#
aspect_ratios: None | ndarray = None#
maxtext.multimodal.processor_llama4.get_factors(dividend)[source]#

Calculate all factors of a given number, i.e. a divisor that leaves no remainder. For example, if dividend=12, it will return {1, 2, 3, 4, 6, 12}. :param dividend: The number to find factors for. :type dividend: int

Returns:

A set containing all factors of the number.

Return type:

set

Parameters:

dividend (int)

maxtext.multimodal.processor_llama4.find_supported_resolutions(max_num_tiles=16, tile_size=336)[source]#

Find all possible resolutions for the image based on the number of chunks.

Parameters:
  • max_num_tiles (int)

  • tile_size (int)

Return type:

list[tuple[int, int]]

maxtext.multimodal.processor_llama4.get_best_resolution(img_height, image_width, possible_resolutions, resize_to_max_canvas=False)[source]#

Get the best resolution for the image based on the possible resolutions. :param img_height: The height of the image. :type img_height: int :param image_width: The width of the image. :type image_width: int :param possible_resolutions: A list of possible resolutions. :type possible_resolutions: list :param resize_to_max_canvas: Whether to resize to max canvas or not. :type resize_to_max_canvas: bool

Returns:

The best resolution for the image.

Return type:

tuple

Parameters:
  • img_height (int)

  • image_width (int)

  • possible_resolutions (list[tuple[int, int]])

  • resize_to_max_canvas (bool)

maxtext.multimodal.processor_llama4.pad_to_best_fit_jax(images, target_size, background_color=0)[source]#

Pads and/or crops an image or batch of images to a target size using JAX. If the image is larger than the target size, it’s cropped from the top-left. If smaller, it’s padded on the right and bottom.

Parameters:
  • images (np.ndarray) – The images to process. Expected shape (…, H, W, C).

  • target_size (tuple[int, int]) – The target (height, width).

  • background_color (int | tuple[int, ...] | None) – The color to use for padding. If int, it’s used for the first channel and subsequent channels are padded with 0. If tuple, its length must match the number of channels in the image. Defaults to 0.

Returns:

The processed images of shape (…, target_height, target_width, C).

Return type:

np.ndarray

maxtext.multimodal.processor_llama4.pad_to_max_tiles(images, max_num_tiles=20)[source]#

Pads the image tiles to the maximum number of tiles using JAX.

Parameters:
  • images (ndarray) – The input image tiles with shape (num_tiles, C, H, W).

  • max_num_tiles (int) – The maximum number of tiles to pad to.

Returns:

The padded image tiles with shape (max_num_tiles, C, H, W). The mask indicating valid tiles with shape (max_num_tiles,).

Return type:

tuple[ndarray, ndarray]

maxtext.multimodal.processor_llama4.split_to_tiles(images, num_tiles_height, num_tiles_width)[source]#

Splits an image tensor into tiles using JAX.

Parameters:
  • images (ndarray) – The input image tensor with shape (batch_size, num_channels, height, width).

  • num_tiles_height (int) – The number of tiles along the height dimension.

  • num_tiles_width (int) – The number of tiles along the width dimension.

Returns:

(batch_size * num_tiles_height * num_tiles_width, num_channels, height // num_tiles_height, width // num_tiles_width).

Return type:

The tiled image tensor with shape

maxtext.multimodal.processor_llama4.preprocess_mm_data_llama4(images)[source]#

Pre-process image for Llama4 model. Find best resolution and split into tiles with an additional global tile. Original implementation from image_processing_llama4.py: http://shortn/_VXLgQ1lmkz :param images: The np.array image [H, W, C] or images [N, H, W, C] to pre-process.

Returns:

Llama4PreprocessorOutput. The pre-processed image in np.array [N, NUM_TILES, C, TILE_SIZE, TILE_SIZE].

Example

image of (536, 640, 3), its best_resolution = (672, 672), image split into 4 tiles of (336, 336) Additional global tile of (336, 336) is added, and the final output image_tiles is (1, 5, 3, 336, 336).

maxtext.multimodal.processor_llama4.get_num_tokens_for_this_image(this_aspect_ratio, num_patches_per_chunk)[source]#

This function computes the length of the token sequence that would be generated by get_tokens_for_this_image, without explicit loops.

Parameters:
  • aspect_ratio – A tuple (ratio_h, ratio_w) representing the number of tiles along height and width.

  • num_patches_per_chunk – The number of patch tokens per image tile.

Returns:

The total number of tokens for the image representation.

maxtext.multimodal.processor_llama4.get_image_offsets_llama4(processor_output)[source]#

Get the increase in total token count after inserting image token placeholders

Parameters:

processor_output (PreprocessorOutput | None)

maxtext.multimodal.processor_llama4.reformat_prompt_llama4(prompt, image_placeholder, num_images)[source]#

Reformat prompt for Llama4 model.

maxtext.multimodal.processor_llama4.get_tokens_for_this_image(this_aspect_ratio, num_patches_per_chunk)[source]#

Constructs the token sequence for a single image in Llama4. This function generates a list of special tokens that represent an image, including its tiled structure (if applicable) and a global representation. The sequence includes: - A beginning-of-image token. - Patch tokens for each local tile, interspersed with tile separators

if the image is divided into multiple tiles (ratio_h * ratio_w > 1).

  • A fake image token placeholder for the global image representation.

  • Patch tokens associated with the global image representation.

  • An end-of-image token.

Parameters:
  • this_aspect_ratio – A tuple (ratio_h, ratio_w) representing the number of tiles along the height and width dimensions for the current image.

  • num_patches_per_chunk – The number of patch tokens to use for each image tile (both local and global).

Returns:

A list of integer token IDs representing the image.

Example

If this_aspect_ratio is [2, 2] and num_patches_per_chunk is 4, the output will be: [

LLAMA4_BEGIN_IMAGE_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_TILE_X_SEPARATOR_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_TILE_Y_SEPARATOR_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_TILE_X_SEPARATOR_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_TILE_Y_SEPARATOR_TOKEN, LLAMA4_FAKE_IMAGE_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_END_IMAGE_TOKEN

], total 27 tokens.

maxtext.multimodal.processor_llama4.add_extra_tokens_for_images_llama4(tokens, processor_output)[source]#

Add the extra image tokens to the text tokens for Llama4.

Parameters:

processor_output (PreprocessorOutput)

maxtext.multimodal.processor_llama4.get_dummy_image_shape_for_init_llama4(batch_size=1, num_image_per_sequence=1)[source]#

Return the shape of the dummy image for Llama4 model’s initialization.