maxtext.models.gemma4_vision module#

Vision transformer implementation for Gemma4.

maxtext.models.gemma4_vision.factorized_posemb(posemb, positions_xy, precision)[source]#

Computes factorized position embedding from (x, y) coordinates.

Parameters:
  • posemb (Array) – The factorized position embedding parameters.

  • positions_xy (Array) – The (x, y) coordinates for each patch.

  • precision – The precision for the einsum operation.

Returns:

The computed position embeddings.

Return type:

Array

maxtext.models.gemma4_vision.patchify(images, patch_size)[source]#

Patchifies images and returns patches and (x, y) coordinates.

Parameters:
  • images (Array) – The input images of shape […, H, W, C].

  • patch_size (int) – The size of each square patch.

Returns:

  • patches: The extracted patches of shape […, num_patches, patch_size * patch_size * C].

  • positions_xy: The (x, y) coordinates of the top-left corner of each patch, of shape […, num_patches, 2].

Return type:

A tuple containing

class maxtext.models.gemma4_vision.VisionEntry(*args, **kwargs)[source]#

Bases: Module

The vision entry layer.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.models.gemma4_vision.apply_multidimensional_rope(inputs, positions, *, base_frequency, rotary_fraction=None, scale_factor=1.0)[source]#

Applies multidimensional RoPE. Based on Gemma 4 implementation.

Parameters:
  • inputs (Array) – The input array to apply RoPE to.

  • positions (Array) – The positional information. Can be 1D or ND.

  • base_frequency (int) – The base frequency for the sinusoidal functions.

  • rotary_fraction (float | None) – The fraction of the hidden dimension to apply RoPE to. If None, applies to the full dimension.

  • scale_factor (float) – A scale factor applied to the sinusoidal arguments.

Returns:

The input array with multidimensional RoPE applied.

Return type:

Array

maxtext.models.gemma4_vision.avg_pool_by_positions(x, *, positions_xy, length, precision)[source]#

Performs 2D spatial pooling according to patch positions.

Parameters:
  • x (Array) – The input features of shape [B, L, D].

  • positions_xy (Array) – The (x, y) coordinates of each patch of shape [B, L, 2].

  • length (int) – The desired output sequence length after pooling.

  • precision – The precision for the einsum operation.

Returns:

  • output: The pooled features of shape [B, length, D].

  • mask: A boolean mask indicating valid pooled positions.

Return type:

A tuple containing

class maxtext.models.gemma4_vision.VisionExit(*args, **kwargs)[source]#

Bases: Module

Vision exit layer with scaling and optional spatial pooling.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma4_vision.Gemma4VisionRotaryEmbedding(*args, **kwargs)[source]#

Bases: Module

Rotary position embedding for Gemma 4 vision.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma4_vision.Gemma4Attention(*args, **kwargs)[source]#

Bases: Attention

Gemma 4 specific Attention module.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

init_rotary_embedding()[source]#

Initializes the rotary position embedding module for Gemma 4 vision.

Return type:

Gemma4VisionRotaryEmbedding

class maxtext.models.gemma4_vision.Gemma4EncoderBlock(*args, **kwargs)[source]#

Bases: Module

Single transformer encoder block (MHSA + MLP).

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma4_vision.Gemma4VisionEncoderLayer(*args, **kwargs)[source]#

Bases: Module

Gemma 4 Vision Encoder Layer.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma4_vision.Gemma4VisionProjector(*args, **kwargs)[source]#

Bases: Module

A layer that projects image embeddings to the embedding space of the text encoder.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.models.gemma4_vision.gemma4_vision_encoder_as_linen(config, mesh)[source]#

Wraps the Gemma 4 Vision Encoder as a Linen module.

Parameters:
  • config (Any)

  • mesh (Mesh)

Return type:

Module