# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma4-specific utilities for multimodal features."""
from dataclasses import dataclass
import numpy as np
from PIL import Image
from maxtext.multimodal import utils as mm_utils
# Constants for Gemma4-specific processing
# Using rectangular dimensions to yield exactly 2520 patches (280 tokens)
GEMMA4_IMAGE_HEIGHT = 672
GEMMA4_IMAGE_WIDTH = 960
GEMMA4_PATCH_SIZE = 16
GEMMA4_POOLING_KERNEL = 3
GEMMA4_IMAGE_PLACEHOLDER_IN_PROMPT = "<|image|>"
GEMMA4_BEGIN_IMAGE_TOKEN = 255999
GEMMA4_END_IMAGE_TOKEN = 258882
GEMMA4_NEW_LINE_TOKEN = 108
GEMMA4_TOKEN_PLACEHOLDER = 258880
# 2 extra tokens to pad around image: <begin_image>, <end_of_image>
GEMMA4_NUM_EXTRA_TOKENS_PER_MEDIA = 2
[docs]
@dataclass
class Gemma4PreprocessorOutput(mm_utils.PreprocessorOutput):
"""The output of Gemma4 image preprocessor."""
num_images: int = 0
pixel_values: None | np.ndarray = None
pixel_mask: None | np.ndarray = None
positions_xy: None | np.ndarray = None
[docs]
def preprocess_mm_data_gemma4(images):
"""Preprocesses multimodal data for Gemma4 models."""
# PIL resize expects (width, height)
target_size = (GEMMA4_IMAGE_WIDTH, GEMMA4_IMAGE_HEIGHT)
images_in, images_out = [], []
if isinstance(images, np.ndarray):
images_in.append(images)
else:
images_in.extend(images)
for img in images_in:
pil_img = Image.fromarray(img)
resized_pil_img = pil_img.resize(target_size, resample=Image.Resampling.BICUBIC)
# Gemma 4 expects inputs strictly in the [0, 1] range.
img = np.asarray(resized_pil_img, dtype=np.float32) / 255.0
images_out.append(img)
stacked_images = np.stack(images_out, axis=0).astype(np.float32)
return Gemma4PreprocessorOutput(
num_images=len(images_in),
pixel_values=stacked_images[:, np.newaxis, ...],
)
[docs]
def get_image_offsets_gemma4(processor_output: mm_utils.PreprocessorOutput | None):
"""Gets the increase in total token count after inserting image token placeholders."""
has_images = processor_output is not None and processor_output.pixel_values is not None
num_images = processor_output.pixel_values.shape[0] if has_images else 1
# Calculate soft tokens taking 3x3 pooling into account
num_patches = (GEMMA4_IMAGE_HEIGHT // GEMMA4_PATCH_SIZE) * (GEMMA4_IMAGE_WIDTH // GEMMA4_PATCH_SIZE)
num_soft_tokens = num_patches // (GEMMA4_POOLING_KERNEL**2)
return (num_soft_tokens + GEMMA4_NUM_EXTRA_TOKENS_PER_MEDIA - 1) * num_images
def _get_new_text_positions(
*,
offset_on: np.ndarray,
offset_by: int,
) -> np.ndarray:
offset = np.cumsum(offset_on, axis=-1) * offset_by
new_positions = np.arange(offset_on.shape[-1]) + offset
new_positions -= offset_by * offset_on
return new_positions
[docs]
def insert_sequence(
tokens: np.ndarray,
*,
at: int,
sequence: list[int],
max_num_images: int,
) -> np.ndarray:
"""Inserts a sequence of tokens into the given tokens array at the specified token position."""
original_dim = tokens.ndim
if original_dim == 1:
tokens = tokens[None, :]
batch_size, length = tokens.shape
mm_tokens_to_insert = np.array(sequence)
offset_by = len(mm_tokens_to_insert) - 1
length_with_mm = length + max_num_images * offset_by
# Create a boolean mask where the image trigger token `at` is present.
mm_start = tokens == at
new_tokens = np.zeros((batch_size, length_with_mm), dtype=np.int64)
new_text_pos = _get_new_text_positions(offset_on=mm_start, offset_by=offset_by)
np.put_along_axis(new_tokens, new_text_pos, tokens, axis=1)
batch_indices_to_zero, _ = np.where(mm_start)
new_pos_to_zero = new_text_pos[mm_start]
if batch_indices_to_zero.size > 0:
new_tokens[batch_indices_to_zero, new_pos_to_zero] = 0
batch_indices, seq_indices = np.nonzero(mm_start)
if batch_indices.size > 0:
intra_batch_img_idx = np.cumsum(mm_start, axis=1)[mm_start] - 1
final_img_start_pos = seq_indices + intra_batch_img_idx * offset_by
indices_to_insert = final_img_start_pos[:, None] + np.arange(len(mm_tokens_to_insert))
new_tokens[batch_indices[:, None], indices_to_insert] = mm_tokens_to_insert
if original_dim == 1:
new_tokens = np.squeeze(new_tokens)
return new_tokens
[docs]
def get_dummy_image_shape_for_init_gemma4(batch_size=1, num_image_per_sequence=1):
"""Returns the shape of the dummy image for Gemma4 model's initialization."""
image_shape = (
batch_size,
num_image_per_sequence,
GEMMA4_IMAGE_HEIGHT,
GEMMA4_IMAGE_WIDTH,
mm_utils.NUM_IMAGE_CHANNELS,
)
return image_shape