# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma3-specific utilities for multimodal features. """
from dataclasses import dataclass
import numpy as np
from PIL import Image
from maxtext.multimodal import utils as mm_utils
# Constants for Gemma3-specific processing
GEMMA_DEFAULT_IMAGE_SIZE = 896
GEMMA_IMAGE_MEAN = (127.5,) * 3
GEMMA_IMAGE_STD = (127.5,) * 3
GEMMA_IMAGE_PLACEHOLDER_IN_PROMPT = "<start_of_image>"
GEMMA_BEGIN_IMAGE_TOKEN = 255999
GEMMA_END_IMAGE_TOKEN = 256000
GEMMA_NEW_LINE_TOKEN = 108
GEMMA_TOKEN_PLACEHOLDER = 262144
# The number of GEMMA_TOKEN_PLACEHOLDER tokens per image in Gemma3
GEMMA_NUM_PLACEHOLDER_TOKENS_PER_IMAGE = 256
# +4 means 4 extra tokens to pad around image: \n\n, <start_of_image>, <end_of_image>, \n\n
# One MEDIA means one image or multiple images in one video, but now we only support one image
GEMMA_NUM_TOKENS_PER_MEDIA = GEMMA_NUM_PLACEHOLDER_TOKENS_PER_IMAGE + 4
[docs]
@dataclass
class Gemma3PreprocessorOutput(mm_utils.PreprocessorOutput):
"""Holds the output of Gemma3 image preprocessor.
Attributes:
Inherited from `mm_utils.PreprocessorOutput`.
"""
# Image attributes.
num_images: int = 0
pixel_values: None | np.ndarray = None
pixel_mask: None | np.ndarray = None
[docs]
def preprocess_mm_data_gemma3(images):
"""Preprocesses multimodal data for Gemma3 models."""
# Performs a bi-linear resize (with anti-aliasing) and normalizes the image.
target_size = (GEMMA_DEFAULT_IMAGE_SIZE, GEMMA_DEFAULT_IMAGE_SIZE)
images_in, images_out = [], []
if isinstance(images, np.ndarray):
images_in.append(images)
else:
images_in.extend(images)
for img in images_in:
pil_img = Image.fromarray(img)
resample_method = Image.Resampling.BILINEAR
# Use a higher quality downsampling filter to approximate antialias=True
if pil_img.size[0] > target_size[0] or pil_img.size[1] > target_size[1]:
resample_method = Image.Resampling.LANCZOS
resized_pil_img = pil_img.resize(target_size, resample=resample_method)
img = np.asarray(resized_pil_img, dtype=np.float32)
img = mm_utils.normalize_images(img, mean=GEMMA_IMAGE_MEAN, std=GEMMA_IMAGE_STD)
img = np.clip(img, -1, 1)
images_out.append(img)
processor_output = Gemma3PreprocessorOutput(
num_images=len(images_in),
pixel_values=np.stack(images_out, axis=0).astype(np.float32), # (N, H, W, C)
)
return processor_output
[docs]
def get_image_offsets_gemma3(processor_output: mm_utils.PreprocessorOutput | None):
"""Get the increase in total token count after inserting image token placeholders"""
has_images = processor_output is not None and processor_output.pixel_values is not None
num_images = processor_output.pixel_values.shape[0] if has_images else 1
return (
GEMMA_NUM_TOKENS_PER_MEDIA - 1
) * num_images # -1 because <start_of_image> is already present in the input tokens.
def _get_new_text_positions(
*,
offset_on: np.ndarray,
offset_by: int,
) -> np.ndarray:
"""Create the positions of the new tokens.
Input: `[x, x, x, offset_on, x, x, offset_on, x]`
Output: `[0, 1, 2, 3, 4+Offset, 5+Offset, 6+Offset, 7+Offset^2]`
Args:
offset_on: The token to offset on.
offset_by: The number of tokens to offset by.
Returns:
The new positions of the tokens.
"""
offset = np.cumsum(offset_on, axis=-1) * offset_by
new_positions = np.arange(offset_on.shape[-1]) + offset
# Do not shift the `<start_of_image>` token, it will be overwritten by the MM
# tokens.
new_positions -= offset_by * offset_on
return new_positions
[docs]
def insert_sequence(
tokens: np.ndarray,
*,
at: int,
sequence: list[int],
max_num_images: int,
) -> np.ndarray:
"""
Inserts a sequence of tokens at all occurrences of a specific token `at`.
This function is fully vectorized and operates on a batch of token sequences.
Args:
tokens: A 1D or 2D array of input tokens.
at: The token ID to find and replace with the sequence.
sequence: The list of new token IDs to insert.
max_num_images: The maximum number of times `at` can appear.
Returns:
The modified token array with the sequences inserted.
"""
# Ensure input is a 2D array (batch)
original_dim = tokens.ndim
if original_dim == 1:
tokens = tokens[None, :]
batch_size, length = tokens.shape
mm_tokens_to_insert = np.array(sequence)
# Net number of tokens added for each image placeholder.
# It's -1 because the original '<begin_image>' token is replaced.
offset_by = len(mm_tokens_to_insert) - 1
length_with_mm = length + max_num_images * offset_by
# Create a boolean mask where the image trigger token `at` is present.
mm_start = tokens == at
# 1. Create a new buffer for the final merged tokens.
# This buffer will hold the text tokens in their new, shifted positions.
new_tokens = np.zeros((batch_size, length_with_mm), dtype=np.int64)
# Calculate the new, shifted positions for all original text tokens.
new_text_pos = _get_new_text_positions(offset_on=mm_start, offset_by=offset_by)
# Place the original tokens into their new positions.
# `np.put_along_axis` is the NumPy equivalent of the JAX scatter operation.
np.put_along_axis(new_tokens, new_text_pos, tokens, axis=1)
# Zero out the placeholder for the `<begin_image>` token at its new position, which we will
# overwrite with the full image sequence next.
# We find where `mm_start` is True and use the corresponding new positions
# to index `new_tokens` and set those locations to 0.
batch_indices_to_zero, _ = np.where(mm_start)
new_pos_to_zero = new_text_pos[mm_start]
if batch_indices_to_zero.size > 0:
new_tokens[batch_indices_to_zero, new_pos_to_zero] = 0
# 2. Now, insert the actual image token sequences.
# Find the row and column indices of all image trigger tokens.
batch_indices, seq_indices = np.nonzero(mm_start)
if batch_indices.size > 0:
# Calculate the index of each image within its sequence (0th, 1st, etc.).
intra_batch_img_idx = np.cumsum(mm_start, axis=1)[mm_start] - 1
# Calculate the final start position for each new image sequence,
# accounting for shifts from previous images in the same row.
final_img_start_pos = seq_indices + intra_batch_img_idx * offset_by
# Create the full index grid for placing all new tokens.
# This uses broadcasting to add the start position of each image sequence
# to a range of offsets [0, 1, ..., N] for the tokens within the sequence.
indices_to_insert = final_img_start_pos[:, None] + np.arange(len(mm_tokens_to_insert))
# Use the calculated indices to place the new tokens.
# We use `batch_indices` to specify the row and `indices_to_insert` for columns.
new_tokens[batch_indices[:, None], indices_to_insert] = mm_tokens_to_insert
if original_dim == 1:
new_tokens = np.squeeze(new_tokens)
return new_tokens
[docs]
def get_dummy_image_shape_for_init_gemma3(batch_size=1, num_image_per_sequence=1):
"""Return the shape of the dummy image for Gemma3 model's initialization."""
image_shape = (
batch_size,
num_image_per_sequence,
GEMMA_DEFAULT_IMAGE_SIZE,
GEMMA_DEFAULT_IMAGE_SIZE,
mm_utils.NUM_IMAGE_CHANNELS,
)
return image_shape