maxtext.models.gemma3 module#

Specialised layers for Gemma 3.

maxtext.models.gemma3.get_attention_type(layer_id)[source]#
maxtext.models.gemma3.get_query_pre_attn_scalar(config)[source]#

Returns the scalar to multiply the query by before attention.

Return type:

float

class maxtext.models.gemma3.Gemma3DecoderLayer(*args, **kwargs)[source]#

Bases: Module

Transformer decoder layer for Gemma3.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma3.Gemma3ScannableBlock(*args, **kwargs)[source]#

Bases: Module

A repeatable block of Gemma3 decoder layers.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma3.MlpBlockViT(*args, **kwargs)[source]#

Bases: Module

NNX version of Transformer MLP / feed-forward block.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma3.Encoder1DBlock(*args, **kwargs)[source]#

Bases: Module

Single transformer encoder block (MHSA + MLP).

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma3.Encoder(*args, **kwargs)[source]#

Bases: Module

Transformer Model Encoder for sequence to sequence translation.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma3.Einsum(*args, **kwargs)[source]#

Bases: Module

Einsum is a convenience module for parameterized tensor multiplication.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

class maxtext.models.gemma3.VisionEmbedder(*args, **kwargs)[source]#

Bases: Module

Projects image embeddings to the embedding space of the text encoder.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.models.gemma3.visionembedder_as_linen(config, mesh)[source]#

Creates a VisionEmbedder module.

Parameters:
  • config (Any)

  • mesh (Mesh)

class maxtext.models.gemma3.VisionExit(*args, **kwargs)[source]#

Bases: Module

The vision exit layer.

Possibly downsample the soft tokens to a required output length.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

output_length#

The embed will be spatially avg-pooled to this output length.

maxtext.models.gemma3.vision_exit_as_linen(x, output_length)[source]#

A wrapper to use VisionExit as a function.

Parameters:
  • x (Array)

  • output_length (int)

Return type:

Array

class maxtext.models.gemma3.Gemma3VisionEncoderLayer(*args, **kwargs)[source]#

Bases: Module

gemma 3 vision encoder layer

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.models.gemma3.gemma3visionencoder_as_linen(config, mesh)[source]#

Creates a Gemma3VisionEncoder module.

Parameters:
  • config (Any)

  • mesh (Mesh)