maxtext.layers.nnx_decoders module#

Module for decoder layers

class maxtext.layers.nnx_decoders.NNXDecoderLayer(*args, **kwargs)[source]#

Bases: Module

Transformer decoder layer converted to NNX

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.layers.nnx_decoders.deepstack_process(hidden_states, bidirectional_mask, visual_embeds)[source]#

Process deepstack visual embeddings by adding them to hidden states at visual token positions.

Parameters:
  • hidden_states – [batch, seq_len, hidden_dim] decoder hidden states

  • bidirectional_mask – [batch, seq_len] boolean mask marking visual token positions

  • visual_embeds – [batch, num_visual_tokens, hidden_dim] visual features from encoder layer

Returns:

Updated hidden_states with visual features added at visual positions

class maxtext.layers.nnx_decoders.NNXDecoder(*args, **kwargs)[source]#

Bases: Module

A stack of decoder layers as a part of an encoder-decoder architecture, using NNX.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

get_decoder_layers()[source]#

Retrieves decoder layer classes based on config using a dictionary lookup.

minimal_policy(with_context=False, with_quantization=False)[source]#

Helper for creating minimal checkpoint policies.

get_remat_policy()[source]#

Get remat policy for jax.checkpoint.

get_norm_layer(num_features, rngs)[source]#

get normalization layer (return type inherits from nn.Module)

Parameters:
  • num_features (int)

  • rngs (Rngs)

apply_output_head(shared_embedding, y, deterministic, model_mode)[source]#

Applies final normalization and projects hidden states to logits.

maxtext.layers.nnx_decoders.decoder_as_linen(config, mesh, rngs, model_mode, quant=None)[source]#

Creates a Decoder module

Parameters:
  • config (Any)

  • mesh (Mesh)

  • rngs (Rngs)

  • model_mode (str)

  • quant (None | AqtQuantization)