maxtext.layers.multi_token_prediction module#

JAX implementation of the Multi Token Prediction https://arxiv.org/pdf/2412.19437

class maxtext.layers.multi_token_prediction.mtp_losses(*args, **kwargs)[source]#

Bases: Variable

Variable type for storing MTP loss components -> ‘mtp_losses’ collection.

class maxtext.layers.multi_token_prediction.mtp_acceptance(*args, **kwargs)[source]#

Bases: Variable

Variable type for storing MTP acceptance predictions -> ‘mtp_acceptance’ collection.

maxtext.layers.multi_token_prediction.roll_and_mask(x, shift=-1)[source]#

Performs a leftward roll on sequence axis and masks invalid positions.

Parameters:
  • x (Array) – Input array of shape [batch, seq_len, …].

  • shift (int) – Number of positions to shift left.

Returns:

Rolled array with masked positions set to zero.

Return type:

Array

class maxtext.layers.multi_token_prediction.MultiTokenPredictionLayer(*args, **kwargs)[source]#

Bases: Module

Multi-Token Prediction layer: normalize, concatenate, project, and transform.

Implements: h_next = TransformerLayer(W_p(concat(RMSNorm(h_prev), RMSNorm(e_target))))

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

property embedding_norm#
property hidden_state_norm#
property projection_layer#
property transformer_layer#
class maxtext.layers.multi_token_prediction.MultiTokenPredictionBlock(*args, **kwargs)[source]#

Bases: Module

Orchestrates the MTP process by running a sequence of MTP layers.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

maxtext.layers.multi_token_prediction.calculate_mtp_loss(intermediate_outputs, config)[source]#

Calculates Multi-Token Prediction loss from intermediate outputs.

maxtext.layers.multi_token_prediction.calculate_mtp_acceptance_rate(intermediate_outputs, config)[source]#

Calculates MTP acceptance rate from intermediate outputs.

maxtext.layers.multi_token_prediction.multi_token_prediction_block_as_linen(*, config, mesh, transformer_layer_module, decoder, rngs, name=None)[source]#

Initializes MultiTokenPredictionBlock as a Linen module.

Parameters:
  • config (Any) – Configuration object containing model hyperparameters.

  • mesh (Mesh) – JAX Mesh for model parallelism.

  • transformer_layer_module (Type[DecoderLayer]) – The Transformer Decoder Layer class to use.

  • decoder (Module) – The decoder module that provides embedding and output head.

  • rngs (Rngs) – Random number generators for initialization.

  • name (str | None) – Optional name for the module.

Returns:

An instance of MultiTokenPredictionBlock wrapped as a Linen module.

Return type:

Module