Source code for maxtext.layers.multi_token_prediction

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX implementation of the Multi Token Prediction https://arxiv.org/pdf/2412.19437 """

from typing import Type

from flax import linen as nn
from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Config, MODEL_MODE_TRAIN
from maxtext.layers.nnx_decoders import NNXDecoderLayer
from maxtext.utils.globals import EPS
from maxtext.layers.decoders import DecoderLayer
from maxtext.layers.initializers import variable_to_logically_partitioned
from maxtext.layers.linears import DenseGeneral
from maxtext.layers.normalizations import RMSNorm
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import sharding


# Custom Variable types for MTP intermediate outputs
# These will be automatically converted to Linen mutable collections by ToLinen wrapper
# The class names become collection names directly (no case conversion)
[docs] class mtp_losses(nnx.Variable): # pylint: disable=invalid-name """Variable type for storing MTP loss components -> 'mtp_losses' collection."""
[docs] class mtp_acceptance(nnx.Variable): # pylint: disable=invalid-name """Variable type for storing MTP acceptance predictions -> 'mtp_acceptance' collection."""
[docs] def roll_and_mask(x: jnp.ndarray, shift: int = -1) -> jnp.ndarray: """Performs a leftward roll on sequence axis and masks invalid positions. Args: x: Input array of shape [batch, seq_len, ...]. shift: Number of positions to shift left. Returns: Rolled array with masked positions set to zero. """ if shift == 0: return x return jnp.roll(x, shift, axis=1).at[:, shift:, ...].set(0)
[docs] class MultiTokenPredictionLayer(nnx.Module): """Multi-Token Prediction layer: normalize, concatenate, project, and transform. Implements: h_next = TransformerLayer(W_p(concat(RMSNorm(h_prev), RMSNorm(e_target)))) """ def __init__( self, config: Config, mesh: Mesh, layer_number: int, transformer_layer_module: Type[NNXDecoderLayer], *, rngs: nnx.Rngs, ): self.config = config self.mesh = mesh self.layer_number = layer_number self.transformer_layer_module = transformer_layer_module self.rngs = rngs cfg = self.config self.embedding_norm = RMSNorm( num_features=cfg.emb_dim, epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, kernel_axes=("norm",), rngs=rngs, ) self.hidden_state_norm = RMSNorm( num_features=cfg.emb_dim, epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, kernel_axes=("norm",), rngs=rngs, ) self.projection_layer = DenseGeneral( in_features_shape=2 * cfg.emb_dim, out_features_shape=cfg.emb_dim, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, use_bias=False, kernel_axes=("concat_embed", "embed"), rngs=rngs, ) # Use MODEL_MODE_TRAIN for initialization; runtime model_mode is passed dynamically. self.transformer_layer = transformer_layer_module( config=cfg, mesh=mesh, model_mode=MODEL_MODE_TRAIN, rngs=rngs, ) @property def embedding_norm(self): return getattr(self, f"mtp_{self.layer_number}_embedding_norm") @embedding_norm.setter def embedding_norm(self, module): setattr(self, f"mtp_{self.layer_number}_embedding_norm", module) @property def hidden_state_norm(self): return getattr(self, f"mtp_{self.layer_number}_hidden_state_norm") @hidden_state_norm.setter def hidden_state_norm(self, module): setattr(self, f"mtp_{self.layer_number}_hidden_state_norm", module) @property def projection_layer(self): return getattr(self, f"mtp_{self.layer_number}_projection") @projection_layer.setter def projection_layer(self, module): setattr(self, f"mtp_{self.layer_number}_projection", module) @property def transformer_layer(self): return getattr(self, f"mtp_{self.layer_number}_transformer_layer") @transformer_layer.setter def transformer_layer(self, module): setattr(self, f"mtp_{self.layer_number}_transformer_layer", module) def __call__( self, prev_hidden_state: jnp.ndarray, target_token_embedding: jnp.ndarray, *, position_ids: jnp.ndarray, decoder_segment_ids: None | jnp.ndarray, deterministic: bool, model_mode: str = MODEL_MODE_TRAIN, ) -> jnp.ndarray: """Applies MTP combination, projection, and transformer processing. Args: prev_hidden_state: Shape [batch, seq_len, hidden_size]. target_token_embedding: Embedding for token t+k. Shape [batch, seq_len, embed_dim]. position_ids: Shape [batch, seq_len]. decoder_segment_ids: Shape [batch, seq_len] or None. deterministic: Whether to disable dropout. model_mode: Operational mode (train, eval, decode). Returns: Processed hidden state. Shape [batch, seq_len, hidden_size]. """ target_token_embedding = sharding.maybe_shard_with_logical( target_token_embedding, ("activation_batch", "activation_length", "activation_embed"), self.mesh, self.config.shard_mode, self.config.logical_axis_rules, ) embedding_norm = self.embedding_norm(target_token_embedding) hidden_state_norm = self.hidden_state_norm(prev_hidden_state) concatenated_features = jnp.concatenate([embedding_norm, hidden_state_norm], axis=-1) projected_features = self.projection_layer(concatenated_features) output = self.transformer_layer( inputs=projected_features, decoder_segment_ids=decoder_segment_ids, decoder_positions=position_ids, deterministic=deterministic, model_mode=model_mode, ) return output[0] if isinstance(output, tuple) else output
[docs] class MultiTokenPredictionBlock(nnx.Module): """Orchestrates the MTP process by running a sequence of MTP layers.""" def __init__( self, config: Config, mesh: Mesh, transformer_layer_module: Type[NNXDecoderLayer], decoder: nnx.Module, rngs: nnx.Rngs, ): self.config = config self.mesh = mesh self.transformer_layer_module = transformer_layer_module self.decoder = decoder self.rngs = rngs if rngs is not None else nnx.Rngs(0) # 1-indexed to match paper convention. for k in range(1, config.mtp_num_layers + 1): layer = MultiTokenPredictionLayer( config=config, mesh=mesh, layer_number=k, transformer_layer_module=transformer_layer_module, rngs=rngs.fork(), ) setattr(self, f"mtp_layer_{k}", layer) def __call__( self, shared_embedding, main_hidden_state, input_ids, target_ids, target_mask, *, position_ids, decoder_segment_ids, model_mode, deterministic, ) -> dict: cfg = self.config mtp_hidden_state = main_hidden_state # Rolling variables move prediction window one token to the right per iteration. rolled_input_ids = input_ids rolled_target_ids = target_ids rolled_target_mask = target_mask rolled_position_id = position_ids mtp_losses_list = [] mtp_weights_list = [] mtp_preds_list = [] mtp_masks_list = [] for k in range(1, cfg.mtp_num_layers + 1): rolled_input_ids = roll_and_mask(rolled_input_ids) rolled_target_ids = roll_and_mask(rolled_target_ids) rolled_target_mask = roll_and_mask(rolled_target_mask) rolled_position_id = roll_and_mask(rolled_position_id) target_token_embedding = self.decoder._apply_embedding( shared_embedding, rolled_input_ids, rolled_position_id, deterministic, model_mode=self.decoder.model_mode, ) mtp_layer = getattr(self, f"mtp_layer_{k}") mtp_hidden_state = mtp_layer( prev_hidden_state=mtp_hidden_state, target_token_embedding=target_token_embedding, position_ids=position_ids, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=self.decoder.model_mode, ) mtp_logits = self.decoder.apply_output_head(shared_embedding, mtp_hidden_state, deterministic, model_mode) mtp_xent, _ = max_utils.cross_entropy_with_logits( mtp_logits, jax.nn.one_hot(rolled_target_ids, cfg.vocab_size), 0.0 ) mtp_xent_masked = mtp_xent * rolled_target_mask if model_mode == MODEL_MODE_TRAIN: mtp_losses_list.append(jnp.sum(mtp_xent_masked)) mtp_weights_list.append(jnp.sum(rolled_target_mask).astype(jnp.float32)) if cfg.mtp_eval_target_module == k: # Float32 to avoid gradient errors; converted back to int32 in acceptance calculation. mtp_preds_list.append(jnp.argmax(mtp_logits, axis=-1).astype(jnp.float32)) mtp_masks_list.append(rolled_target_mask) if mtp_losses_list: # Not part of checkpoints, don't declare in __init__ self.losses = mtp_losses(jnp.stack(mtp_losses_list)) self.weights = mtp_losses(jnp.stack(mtp_weights_list)) if mtp_preds_list: # Not part of checkpoints, don't declare in __init__ self.mtp_preds = mtp_acceptance(jnp.stack(mtp_preds_list)) self.mtp_mask = mtp_acceptance(jnp.stack(mtp_masks_list)) return {}
[docs] def calculate_mtp_loss(intermediate_outputs, config): """Calculates Multi-Token Prediction loss from intermediate outputs.""" mtp_losses_data = maxtext_utils.get_nested_value( intermediate_outputs, ("mtp_losses", "mtp_block", "losses"), default=None ) mtp_weights_data = maxtext_utils.get_nested_value( intermediate_outputs, ("mtp_losses", "mtp_block", "weights"), default=None ) if mtp_losses_data is None: return 0.0 # Handle both tuple (Linen sow) and array (NNX Variable) formats. if isinstance(mtp_losses_data, (tuple, list)): if not mtp_losses_data: return 0.0 mtp_losses_array = jnp.array(mtp_losses_data) mtp_weights_array = jnp.array(mtp_weights_data) else: if mtp_losses_data.size == 0: return 0.0 mtp_losses_array = mtp_losses_data mtp_weights_array = mtp_weights_data avg_mtp_loss = jnp.sum(mtp_losses_array) / (jnp.sum(mtp_weights_array) + EPS) return avg_mtp_loss * config.mtp_loss_scaling_factor
[docs] def calculate_mtp_acceptance_rate(intermediate_outputs, config): """Calculates MTP acceptance rate from intermediate outputs.""" sown_data = maxtext_utils.get_nested_value(intermediate_outputs, ("mtp_acceptance", "mtp_block"), {}) # Handle both tuple (Linen sow) and array (NNX Variable) formats. mtp_preds_raw = maxtext_utils.get_nested_value(sown_data, ("mtp_preds",), None) valid_mask_raw = maxtext_utils.get_nested_value(sown_data, ("mtp_mask",), None) mtp_preds = mtp_preds_raw[0] if isinstance(mtp_preds_raw, (tuple, list)) and mtp_preds_raw else mtp_preds_raw valid_mask = valid_mask_raw[0] if isinstance(valid_mask_raw, (tuple, list)) and valid_mask_raw else valid_mask_raw # Only populated during eval for the target MTP module. if mtp_preds is None or valid_mask is None: return 0.0 mtp_preds = mtp_preds.astype(jnp.int32) main_model_preds = jnp.argmax(intermediate_outputs["logits"], axis=-1) # Align main model predictions with MTP head target by rolling k steps. rolled_main_preds = main_model_preds for _ in range(config.mtp_eval_target_module): rolled_main_preds = roll_and_mask(rolled_main_preds) correct_predictions = jnp.sum((mtp_preds == rolled_main_preds) * valid_mask) total_valid_tokens = jnp.sum(valid_mask) return (correct_predictions / (total_valid_tokens + EPS)) * 100
[docs] def multi_token_prediction_block_as_linen( *, config: Config, mesh: Mesh, transformer_layer_module: Type[DecoderLayer], decoder: nnx.Module, rngs: nnx.Rngs, name: str | None = None, ) -> nn.Module: """Initializes MultiTokenPredictionBlock as a Linen module. Args: config: Configuration object containing model hyperparameters. mesh: JAX Mesh for model parallelism. transformer_layer_module: The Transformer Decoder Layer class to use. decoder: The decoder module that provides embedding and output head. rngs: Random number generators for initialization. name: Optional name for the module. Returns: An instance of MultiTokenPredictionBlock wrapped as a Linen module. """ return nnx.bridge.to_linen( MultiTokenPredictionBlock, config=config, mesh=mesh, transformer_layer_module=transformer_layer_module, decoder=decoder, rngs=rngs, metadata_fn=variable_to_logically_partitioned, name=name, )