maxtext.layers.decoders module#
Module for decoder layers
- class maxtext.layers.decoders.DecoderLayer(config, mesh, model_mode, quant=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleTransformer decoder layer that attends to the encoder. This is the core, reusable building block for both the main model’s decoder stack and the auxiliary MTP layers.
- Parameters:
config (Any)
mesh (Mesh)
model_mode (str)
quant (None | AqtQuantization)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- config: Any#
- mesh: Mesh#
- model_mode: str#
- quant: None | AqtQuantization = None#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class maxtext.layers.decoders.SequentialBlockDecoderLayers(decoder_layer, num_decoder_layers, config, mesh, quant, model_mode, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleSequential unscanned series of decoder layers.
- Parameters:
decoder_layer (Any)
num_decoder_layers (int)
config (Any)
mesh (Mesh)
quant (AqtQuantization)
model_mode (str)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- decoder_layer: Any#
- num_decoder_layers: int#
- config: Any#
- mesh: Mesh#
- quant: AqtQuantization#
- model_mode: str#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- maxtext.layers.decoders.deepstack_process(hidden_states, bidirectional_mask, visual_embeds)[source]#
Process deepstack visual embeddings by adding them to hidden states at visual token positions.
- Parameters:
hidden_states – [batch, seq_len, hidden_dim] decoder hidden states
bidirectional_mask – [batch, seq_len] boolean mask marking visual token positions
visual_embeds – [batch, num_visual_tokens, hidden_dim] visual features from encoder layer
- Returns:
Updated hidden_states with visual features added at visual positions
- class maxtext.layers.decoders.Decoder(config, mesh, quant=None, model_mode='train', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleA stack of decoder layers as a part of an encoder-decoder architecture.
- Parameters:
config (Any)
mesh (Mesh)
quant (None | AqtQuantization)
model_mode (str)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- config: Any#
- mesh: Mesh#
- quant: None | AqtQuantization = None#
- model_mode: str = 'train'#
- minimal_policy(with_context=False, with_quantization=False)[source]#
Helper for creating minimal checkpoint policies.
- get_decoder_layers()[source]#
Retrieves a list of decoder layer classes based on the decoder_block config.
- Returns:
A list containing one or more nn.Module classes for the decoder.
- get_norm_layer(num_features)[source]#
get normalization layer (return type inherits from nn.Module)
- Parameters:
num_features (int)
- scan_decoder_layers(cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs)[source]#
scan decoder layers, calls flax.linen.transforms.scan
- apply_output_head(shared_embedding, y, deterministic, model_mode)[source]#
Applies final normalization and projects hidden states to logits.
- Parameters:
shared_embedding (Module | Module)
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#