Source code for maxtext.models.deepseek

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer model definition."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

import functools
from typing import Optional

from flax import nnx
import jax
from jax.ad_checkpoint import checkpoint_name
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Config
from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL
from maxtext.inference import page_manager
from maxtext.layers import attention_mla
from maxtext.layers import initializers
from maxtext.layers import linears
from maxtext.layers import mhc
from maxtext.layers import moe
from maxtext.layers import nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.linears import Dropout
from maxtext.layers.engram import Engram
from maxtext.layers.engram import NgramHashMapping
from maxtext.layers.normalizations import RMSNorm
from maxtext.models import deepseek_batchsplit
from maxtext.models import deepseek_batchsplit_fp8
from maxtext.utils import max_utils
from maxtext.utils.sharding import create_sharding
from maxtext.utils.sharding import maybe_shard_with_logical

import transformers

# -----------------------------------------
# The Decoder Layer for DeepSeek v3
# -----------------------------------------


[docs] class DeepSeekGenericLayer(nnx.Module): """Generic DeepSeek layer with Multi-Head Latent Attention. This is to be used as a base class for DeepSeek layers with dense/sparse MLPs. This class follows a pattern of separating module creation from execution. """ def __init__( self, config: Config, model_mode: str, mesh: Mesh, rngs: nnx.Rngs, quant: Optional[quantizations.AqtQuantization] = None, layer_idx: int = -1, ) -> None: self.config = config self.model_mode = model_mode self.mesh = mesh self.quant = quant self.rngs = rngs self.is_mhc_enabled = config.mhc_expansion_rate > 1 self.layer_idx = layer_idx self.is_engram_enabled = config.engram_layers and layer_idx in config.engram_layers batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode) self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim) self.out_sharding = create_sharding(self.mesh, self.logical_axis_names, rules=self.config.logical_axis_rules) self.mlp_intermediate_sharding = create_sharding( self.mesh, self.mlp_logical_axis_names, rules=self.config.logical_axis_rules ) self.pre_self_attention_layer_norm = RMSNorm( num_features=self.dummy_inputs_shape[-1], dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, kernel_axes=("norm",), epsilon=self.config.normalization_layer_epsilon, rngs=rngs, ) self.post_self_attention_layer_norm = RMSNorm( num_features=self.dummy_inputs_shape[-1], dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, kernel_axes=("norm",), epsilon=self.config.normalization_layer_epsilon, rngs=rngs, ) if self.is_engram_enabled: self.engram_layer_norm = RMSNorm( num_features=self.dummy_inputs_shape[-1], dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, kernel_axes=("norm",), epsilon=self.config.normalization_layer_epsilon, rngs=rngs, ) tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path, token=config.hf_access_token) # TODO(ranran): Refactor NgramHashMapping to initialize once globally or at the model level. # Moving this to decoders.py currently causes JAX initialization errors. self.ngram_hash_mapping = NgramHashMapping( engram_vocab_bases=config.engram_vocab_bases, max_ngram_size=config.engram_max_ngram_size, engram_num_heads=config.engram_num_heads, layer_ids=config.engram_layers, tokenizer=tokenizer, pad_id=tokenizer.pad_token_id, seed=config.engram_seed, ) self.engram = Engram( config=config, mesh=mesh, vocab_sizes=self.ngram_hash_mapping.get_vocab_sizes(layer_idx), engram_num_heads=config.engram_num_heads, engram_head_dim=config.engram_head_dim, engram_max_ngram_size=config.engram_max_ngram_size, engram_kernel_size=config.engram_kernel_size, mhc_expansion_rate=config.mhc_expansion_rate, quant=quant, rngs=rngs, ) else: self.engram_layer_norm = None self.engram = None self.self_attention = attention_mla.MLA( config=self.config, num_query_heads=self.config.num_query_heads, num_kv_heads=self.config.num_kv_heads, head_dim=self.config.head_dim, max_target_length=self.config.max_target_length, max_prefill_predict_length=self.config.max_prefill_predict_length, attention_kernel=self.config.attention, attention_type=self.config.attention_type, inputs_q_shape=self.dummy_inputs_shape, inputs_kv_shape=self.dummy_inputs_shape, mesh=mesh, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, dropout_rate=self.config.dropout_rate, name="self_attention", quant=quant, kv_quant=quantizations.configure_kv_quant(config), q_lora_rank=self.config.q_lora_rank, kv_lora_rank=self.config.kv_lora_rank, qk_nope_head_dim=self.config.qk_nope_head_dim, qk_rope_head_dim=self.config.qk_rope_head_dim, v_head_dim=self.config.v_head_dim, max_position_embeddings=self.config.max_position_embeddings, original_max_position_embeddings=self.config.original_max_position_embeddings, mscale=self.config.mscale, rope_factor=self.config.rope_factor, model_mode=model_mode, rngs=rngs, attn_logits_soft_cap=self.config.attn_logits_soft_cap, ) self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) if self.is_mhc_enabled: self.mhc_attention = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs) self.mhc_mlp = mhc.ManifoldConstrainedHyperConnections(self.config, self.config.emb_dim, self.mesh, self.rngs)
[docs] def mlp_op(self, x, deterministic, *args, **kwargs): """Executes the MLP operation. To be implemented by subclasses.""" raise NotImplementedError()
[docs] def with_logical_constraint(self, x): return maybe_shard_with_logical( x, logical_axes=self.logical_axis_names, mesh=self.mesh, shard_mode=self.config.shard_mode, debug_sharding=self.config.debug_sharding, extra_stack_level=1, rules=self.config.logical_axis_rules, )
[docs] def dropout_op(self, x, deterministic): dropout = self.dropout(x, deterministic=deterministic) return self.with_logical_constraint(dropout)
[docs] def pre_attention_norm_op(self, x): pre_attention_norm = self.pre_self_attention_layer_norm(x) return self.with_logical_constraint(pre_attention_norm)
[docs] def post_attention_norm_op(self, x): post_attention_norm = self.post_self_attention_layer_norm(x) return self.with_logical_constraint(post_attention_norm)
[docs] def attention_op( self, x, decoder_segment_ids, decoder_positions, deterministic, previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, ): """Executes the attention layer.""" attention_result, _ = self.self_attention( x, x, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=self.model_mode, out_sharding=self.out_sharding, previous_chunk=previous_chunk, page_state=page_state, slot=slot, ) return self.with_logical_constraint(attention_result)
@property def logical_axis_names(self): """Generate logical names for activations generally.""" length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" axis_names = ["activation_batch", length_name, "activation_embed"] return axis_names @property def mlp_logical_axis_names(self): """Generate logical names for activations in MLP.""" length_name = "prefill_activation_norm_length" if self.model_mode == MODEL_MODE_PREFILL else "activation_norm_length" axis_names = ["activation_batch", length_name, "activation_mlp"] return axis_names
[docs] def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cache=None): """postprocessing.""" if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss) if self.config.routed_bias and self.config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: self.sow(nnx.Intermediate, "moe_bias_updates", moe_bias_updates) if self.config.record_internal_nn_metrics: self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) if self.config.scan_layers: return layer_output, None return layer_output, kv_cache
[docs] def self_attention_with_norm_op( self, inputs, decoder_segment_ids, decoder_positions, deterministic, previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, ): """self-attention with normalization""" if self.is_mhc_enabled: intermediate_inputs, _ = self.mhc_attention( self.pre_attention_norm_op, self.self_attention, x=inputs, mhc_type=HyperConnectionType.ATTENTION, decoder_segment_ids=decoder_segment_ids, inputs_positions=decoder_positions, deterministic=deterministic, model_mode=self.model_mode, out_sharding=self.out_sharding, previous_chunk=previous_chunk, page_state=page_state, slot=slot, ) else: lnx = self.pre_attention_norm_op(inputs) attention_lnx = self.attention_op( lnx, decoder_segment_ids, decoder_positions, deterministic, previous_chunk, page_state, slot, ) intermediate_inputs = inputs + attention_lnx # Normalization hidden_states = self.post_attention_norm_op(intermediate_inputs) return hidden_states, intermediate_inputs
[docs] def engram_op(self, x, decoder_input_tokens): normed_x = self.engram_layer_norm(x) hash_ids = self.ngram_hash_mapping(decoder_input_tokens)[self.layer_idx] return self.engram(normed_x, hash_ids)
[docs] class DeepSeekDenseLayer(DeepSeekGenericLayer): """DeepSeek-style dense layer with Multi-Head Latent Attention.""" def __init__( self, config: Config, model_mode: str, mesh: Mesh, rngs: nnx.Rngs, quant: Optional[quantizations.AqtQuantization] = None, layer_idx: int = -1, ) -> None: super().__init__(config, model_mode, mesh, rngs, quant, layer_idx) self.mlp = linears.MlpBlock( in_features=self.dummy_inputs_shape[-1], intermediate_dim=self.config.mlp_dim, activations=self.config.mlp_activations, intermediate_dropout_rate=self.config.dropout_rate, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, config=self.config, quant=quant, model_mode=model_mode, mesh=mesh, rngs=self.rngs, )
[docs] def mlp_op(self, x, deterministic): mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) return self.with_logical_constraint(mlp)
def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, kv_cache=None, attention_metadata=None, decoder_input_tokens=None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] x = self.with_logical_constraint(inputs) x = checkpoint_name(x, "decoder_layer_input") if self.is_engram_enabled: engram_output = self.engram_op(x, decoder_input_tokens) x = x + engram_output hidden_states, intermediate_inputs = self.self_attention_with_norm_op( x, decoder_segment_ids, decoder_positions, deterministic, previous_chunk, page_state, slot, ) if self.is_mhc_enabled: layer_output, _ = self.mhc_mlp( self.post_attention_norm_op, self.mlp, x=intermediate_inputs, mhc_type=HyperConnectionType.MLP_DENSE, deterministic=deterministic, ) else: mlp_lnx = self.mlp_op(hidden_states, deterministic) layer_output = mlp_lnx + intermediate_inputs layer_output = self.dropout_op(layer_output, deterministic=deterministic) return self.post_process(layer_output, None, None, kv_cache)
DeepSeekDenseLayerToLinen = nnx_wrappers.to_linen_class( DeepSeekDenseLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, )
[docs] class DeepSeekMoELayer(DeepSeekGenericLayer): """DeepSeek-style MoE layer with Multi-Head Latent Attention. Supports dropless and dropping base on configs. Uses a bias in routing instead of load balancing loss. """ def __init__( self, config: Config, model_mode: str, mesh: Mesh, rngs: nnx.Rngs, quant: Optional[quantizations.AqtQuantization] = None, layer_idx: int = -1, ) -> None: super().__init__(config, model_mode, mesh, rngs, quant, layer_idx) self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE( config=self.config, mesh=mesh, kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, quant=quant, rngs=self.rngs, ) def __call__( self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, kv_cache=None, attention_metadata=None, decoder_input_tokens=None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] # This code should only be traced during initialization when using # batch-split schedule. It is never run during model execution, since # `Decoder` directly calls `batch_split_schedule` during execution. # That is also why we can split/merge activations here as well as # in `Decoder`, since they will never be executed together. if self.config.use_batch_split_schedule: # The older version of batch-split that fully uses qwix quantization. if self.config.use_qwix_quantization and not self.config.use_manual_quantization: activation_pspec = jax.sharding.PartitionSpec( ("data", "fsdp", "fsdp_transpose", "expert", "context"), None, None, ) inputs = jax.shard_map( functools.partial( deepseek_batchsplit_fp8.split, split_factor=self.config.batch_split_factor, ), mesh=self.mesh, in_specs=activation_pspec, out_specs=[activation_pspec] * self.config.batch_split_factor, )(inputs) dpos = deepseek_batchsplit_fp8.split(decoder_positions, self.config.batch_split_factor) dseg = deepseek_batchsplit_fp8.split(decoder_segment_ids, self.config.batch_split_factor) weights = deepseek_batchsplit_fp8.fetch_weights(nnx.to_pure_dict(nnx.state(self, nnx.Param)), self.config.dtype) outputs = deepseek_batchsplit_fp8.batch_split_schedule( inputs, weights, dpos, dseg, model_mode=model_mode, mesh=self.mesh, quant=self.quant, cfg=self.config, ) outputs = jax.shard_map( functools.partial( deepseek_batchsplit_fp8.merge, split_factor=self.config.batch_split_factor, ), mesh=self.mesh, in_specs=([activation_pspec] * self.config.batch_split_factor,), out_specs=activation_pspec, )(outputs) return outputs, None # bf16 and fp8 code path for pure-JAX batch-split. # fp8 code path supports both manual quantization and qwix # quantization. input_sharding = jax.typeof(inputs).sharding activation_pspec = jax.sharding.PartitionSpec( ("data", "fsdp", "expert"), None, None, ) inputs = jax.reshard(inputs, jax.sharding.NamedSharding(self.mesh, activation_pspec)) yarn_freqs = deepseek_batchsplit.initialize_yarn_freqs( decoder_positions, embedding_dims=self.config.qk_rope_head_dim, rope_theta=self.config.rope_max_timescale, max_position_embeddings=self.config.max_position_embeddings, original_max_position_embeddings=self.config.original_max_position_embeddings, beta_fast=self.config.beta_fast, beta_slow=self.config.beta_slow, rope_factor=self.config.rope_factor, mesh=self.mesh, activation_pspec=activation_pspec, ) yarn_mask = deepseek_batchsplit.initialize_yarn_mask(self.config.qk_rope_head_dim) splash_kernel = deepseek_batchsplit.init_splash_kernel(self.config) inputs = jax.shard_map( functools.partial( deepseek_batchsplit.split, split_factor=self.config.batch_split_factor, ), mesh=self.mesh, in_specs=activation_pspec, out_specs=[activation_pspec] * self.config.batch_split_factor, )(inputs) yarn_freqs = deepseek_batchsplit.split(yarn_freqs, self.config.batch_split_factor) def extract_fn(x): if isinstance(x, nnx.variablelib.Variable): return maybe_shard_with_logical( x.value, x.sharding_names, self.mesh, shard_mode=self.config.shard_mode, rules=self.config.logical_axis_rules, ) return x weights = deepseek_batchsplit.fetch_weights( nnx.to_pure_dict(nnx.state(self, nnx.Param), extract_fn), self.config.dtype ) weights = deepseek_batchsplit.gather_weights(weights, self.mesh) outputs, _ = deepseek_batchsplit.batch_split_schedule( inputs, weights, yarn_freqs, mesh=self.mesh, cfg=self.config, splash_kernel=splash_kernel, activation_pspec=activation_pspec, pairwise_swap_and_negate_mask=yarn_mask, ) moe_inputs, routed_expert_out, shared_expert_out, selected_experts = outputs[1] outputs[1], _ = deepseek_batchsplit.unroute_ubatch_shard_mapped( moe_inputs, routed_expert_out, shared_expert_out, selected_experts, expert_axis_name="expert", use_gather_mosaic_kernel=False, target_length=self.config.max_target_length, mesh=self.mesh, activation_pspec=activation_pspec, ) outputs = jax.shard_map( functools.partial( deepseek_batchsplit.merge, split_factor=self.config.batch_split_factor, ), mesh=self.mesh, in_specs=([activation_pspec] * self.config.batch_split_factor,), out_specs=activation_pspec, )(outputs) outputs = jax.reshard(outputs, input_sharding) return outputs, None x = self.with_logical_constraint(inputs) x = checkpoint_name(x, "decoder_layer_input") if self.is_engram_enabled: engram_output = self.engram_op(x, decoder_input_tokens) x = x + engram_output hidden_states, intermediate_inputs = self.self_attention_with_norm_op( x, decoder_segment_ids, decoder_positions, deterministic, previous_chunk, page_state, slot, ) if self.is_mhc_enabled: layer_output, metadata = self.mhc_mlp( self.post_attention_norm_op, self.DeepSeekMoeBlock_0, x=intermediate_inputs, mhc_type=HyperConnectionType.MLP_MOE, ) load_balance_loss = metadata["load_balance_loss"] moe_bias_updates = metadata["moe_bias_updates"] else: mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp_op(hidden_states, deterministic) layer_output = mlp_lnx + intermediate_inputs layer_output = self.dropout_op(layer_output, deterministic=deterministic) return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache)
[docs] def mlp_op(self, x, deterministic, *args, **kwargs): mlp_lnx, load_balance_loss, moe_bias_updates = self.DeepSeekMoeBlock_0( x, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding ) return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates
DeepSeekMoELayerToLinen = nnx_wrappers.to_linen_class( DeepSeekMoELayer, base_metadata_fn=initializers.variable_to_logically_partitioned, )