maxtext.layers.nnx_decoders module#
Module for decoder layers
- class maxtext.layers.nnx_decoders.NNXDecoderLayer(*args, **kwargs)[source]#
Bases:
ModuleTransformer decoder layer converted to NNX
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.layers.nnx_decoders.deepstack_process(hidden_states, bidirectional_mask, visual_embeds)[source]#
Process deepstack visual embeddings by adding them to hidden states at visual token positions.
- Parameters:
hidden_states – [batch, seq_len, hidden_dim] decoder hidden states
bidirectional_mask – [batch, seq_len] boolean mask marking visual token positions
visual_embeds – [batch, num_visual_tokens, hidden_dim] visual features from encoder layer
- Returns:
Updated hidden_states with visual features added at visual positions
- class maxtext.layers.nnx_decoders.NNXDecoder(*args, **kwargs)[source]#
Bases:
ModuleA stack of decoder layers as a part of an encoder-decoder architecture, using NNX.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- get_decoder_layers()[source]#
Retrieves decoder layer classes based on config using a dictionary lookup.
- minimal_policy(with_context=False, with_quantization=False)[source]#
Helper for creating minimal checkpoint policies.
- maxtext.layers.nnx_decoders.decoder_as_linen(config, mesh, rngs, model_mode, quant=None)[source]#
Creates a Decoder module
- Parameters:
config (Any)
mesh (Mesh)
rngs (Rngs)
model_mode (str)
quant (None | AqtQuantization)