maxtext.layers.decoders module#

Module for decoder layers

class maxtext.layers.decoders.DecoderLayer(config, mesh, model_mode, quant=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Transformer decoder layer that attends to the encoder. This is the core, reusable building block for both the main model’s decoder stack and the auxiliary MTP layers.

Parameters:
  • config (Any)

  • mesh (Mesh)

  • model_mode (str)

  • quant (None | AqtQuantization)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

config: Any#
mesh: Mesh#
model_mode: str#
quant: None | AqtQuantization = None#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class maxtext.layers.decoders.SequentialBlockDecoderLayers(decoder_layer, num_decoder_layers, config, mesh, quant, model_mode, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Sequential unscanned series of decoder layers.

Parameters:
  • decoder_layer (Any)

  • num_decoder_layers (int)

  • config (Any)

  • mesh (Mesh)

  • quant (AqtQuantization)

  • model_mode (str)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

decoder_layer: Any#
num_decoder_layers: int#
config: Any#
mesh: Mesh#
quant: AqtQuantization#
model_mode: str#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
maxtext.layers.decoders.deepstack_process(hidden_states, bidirectional_mask, visual_embeds)[source]#

Process deepstack visual embeddings by adding them to hidden states at visual token positions.

Parameters:
  • hidden_states – [batch, seq_len, hidden_dim] decoder hidden states

  • bidirectional_mask – [batch, seq_len] boolean mask marking visual token positions

  • visual_embeds – [batch, num_visual_tokens, hidden_dim] visual features from encoder layer

Returns:

Updated hidden_states with visual features added at visual positions

class maxtext.layers.decoders.Decoder(config, mesh, quant=None, model_mode='train', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

A stack of decoder layers as a part of an encoder-decoder architecture.

Parameters:
  • config (Any)

  • mesh (Mesh)

  • quant (None | AqtQuantization)

  • model_mode (str)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

config: Any#
mesh: Mesh#
quant: None | AqtQuantization = None#
model_mode: str = 'train'#
setup()[source]#

Initialize decoder layer.

minimal_policy(with_context=False, with_quantization=False)[source]#

Helper for creating minimal checkpoint policies.

get_remat_policy()[source]#

Get remat policy

get_decoder_layers()[source]#

Retrieves a list of decoder layer classes based on the decoder_block config.

Returns:

A list containing one or more nn.Module classes for the decoder.

set_remat_policy(block_layers, policy)[source]#

Set remat policy

get_norm_layer(num_features)[source]#

get normalization layer (return type inherits from nn.Module)

Parameters:

num_features (int)

scan_decoder_layers(cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs)[source]#

scan decoder layers, calls flax.linen.transforms.scan

get_pipeline_stage_module(decoder_blocks)[source]#

get pipeline stage module

apply_output_head(shared_embedding, y, deterministic, model_mode)[source]#

Applies final normalization and projects hidden states to logits.

Parameters:

shared_embedding (Module | Module)

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#