maxtext.models.gemma4_vision module#
Vision transformer implementation for Gemma4.
- maxtext.models.gemma4_vision.factorized_posemb(posemb, positions_xy, precision)[source]#
Computes factorized position embedding from (x, y) coordinates.
- Parameters:
posemb (Array) – The factorized position embedding parameters.
positions_xy (Array) – The (x, y) coordinates for each patch.
precision – The precision for the einsum operation.
- Returns:
The computed position embeddings.
- Return type:
Array
- maxtext.models.gemma4_vision.patchify(images, patch_size)[source]#
Patchifies images and returns patches and (x, y) coordinates.
- Parameters:
images (Array) – The input images of shape […, H, W, C].
patch_size (int) – The size of each square patch.
- Returns:
patches: The extracted patches of shape […, num_patches, patch_size * patch_size * C].
positions_xy: The (x, y) coordinates of the top-left corner of each patch, of shape […, num_patches, 2].
- Return type:
A tuple containing
- class maxtext.models.gemma4_vision.VisionEntry(*args, **kwargs)[source]#
Bases:
ModuleThe vision entry layer.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.models.gemma4_vision.apply_multidimensional_rope(inputs, positions, *, base_frequency, rotary_fraction=None, scale_factor=1.0)[source]#
Applies multidimensional RoPE. Based on Gemma 4 implementation.
- Parameters:
inputs (Array) – The input array to apply RoPE to.
positions (Array) – The positional information. Can be 1D or ND.
base_frequency (int) – The base frequency for the sinusoidal functions.
rotary_fraction (float | None) – The fraction of the hidden dimension to apply RoPE to. If None, applies to the full dimension.
scale_factor (float) – A scale factor applied to the sinusoidal arguments.
- Returns:
The input array with multidimensional RoPE applied.
- Return type:
Array
- maxtext.models.gemma4_vision.avg_pool_by_positions(x, *, positions_xy, length, precision)[source]#
Performs 2D spatial pooling according to patch positions.
- Parameters:
x (Array) – The input features of shape [B, L, D].
positions_xy (Array) – The (x, y) coordinates of each patch of shape [B, L, 2].
length (int) – The desired output sequence length after pooling.
precision – The precision for the einsum operation.
- Returns:
output: The pooled features of shape [B, length, D].
mask: A boolean mask indicating valid pooled positions.
- Return type:
A tuple containing
- class maxtext.models.gemma4_vision.VisionExit(*args, **kwargs)[source]#
Bases:
ModuleVision exit layer with scaling and optional spatial pooling.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.gemma4_vision.Gemma4VisionRotaryEmbedding(*args, **kwargs)[source]#
Bases:
ModuleRotary position embedding for Gemma 4 vision.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.gemma4_vision.Gemma4Attention(*args, **kwargs)[source]#
Bases:
AttentionGemma 4 specific Attention module.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.gemma4_vision.Gemma4EncoderBlock(*args, **kwargs)[source]#
Bases:
ModuleSingle transformer encoder block (MHSA + MLP).
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.models.gemma4_vision.Gemma4VisionEncoderLayer(*args, **kwargs)[source]#
Bases:
ModuleGemma 4 Vision Encoder Layer.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any