maxtext.layers.multi_token_prediction module#
JAX implementation of the Multi Token Prediction https://arxiv.org/pdf/2412.19437
- class maxtext.layers.multi_token_prediction.mtp_losses(*args, **kwargs)[source]#
Bases:
VariableVariable type for storing MTP loss components -> ‘mtp_losses’ collection.
- class maxtext.layers.multi_token_prediction.mtp_acceptance(*args, **kwargs)[source]#
Bases:
VariableVariable type for storing MTP acceptance predictions -> ‘mtp_acceptance’ collection.
- maxtext.layers.multi_token_prediction.roll_and_mask(x, shift=-1)[source]#
Performs a leftward roll on sequence axis and masks invalid positions.
- Parameters:
x (Array) – Input array of shape [batch, seq_len, …].
shift (int) – Number of positions to shift left.
- Returns:
Rolled array with masked positions set to zero.
- Return type:
Array
- class maxtext.layers.multi_token_prediction.MultiTokenPredictionLayer(*args, **kwargs)[source]#
Bases:
ModuleMulti-Token Prediction layer: normalize, concatenate, project, and transform.
Implements: h_next = TransformerLayer(W_p(concat(RMSNorm(h_prev), RMSNorm(e_target))))
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- property embedding_norm#
- property projection_layer#
- property transformer_layer#
- class maxtext.layers.multi_token_prediction.MultiTokenPredictionBlock(*args, **kwargs)[source]#
Bases:
ModuleOrchestrates the MTP process by running a sequence of MTP layers.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.layers.multi_token_prediction.calculate_mtp_loss(intermediate_outputs, config)[source]#
Calculates Multi-Token Prediction loss from intermediate outputs.
- maxtext.layers.multi_token_prediction.calculate_mtp_acceptance_rate(intermediate_outputs, config)[source]#
Calculates MTP acceptance rate from intermediate outputs.
- maxtext.layers.multi_token_prediction.multi_token_prediction_block_as_linen(*, config, mesh, transformer_layer_module, decoder, rngs, name=None)[source]#
Initializes MultiTokenPredictionBlock as a Linen module.
- Parameters:
config (Any) – Configuration object containing model hyperparameters.
mesh (Mesh) – JAX Mesh for model parallelism.
transformer_layer_module (Type[DecoderLayer]) – The Transformer Decoder Layer class to use.
decoder (Module) – The decoder module that provides embedding and output head.
rngs (Rngs) – Random number generators for initialization.
name (str | None) – Optional name for the module.
- Returns:
An instance of MultiTokenPredictionBlock wrapped as a Linen module.
- Return type:
Module