maxtext.models.deepseek module#

Transformer model definition.

class maxtext.models.deepseek.DeepSeekGenericLayer(*args, **kwargs)[source]#

Bases: Module

Generic DeepSeek layer with Multi-Head Latent Attention.

This is to be used as a base class for DeepSeek layers with dense/sparse MLPs. This class follows a pattern of separating module creation from execution.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

mlp_op(x, deterministic, *args, **kwargs)[source]#

Executes the MLP operation. To be implemented by subclasses.

with_logical_constraint(x)[source]#
dropout_op(x, deterministic)[source]#
pre_attention_norm_op(x)[source]#
post_attention_norm_op(x)[source]#
attention_op(x, decoder_segment_ids, decoder_positions, deterministic, previous_chunk=None, page_state=None, slot=None)[source]#

Executes the attention layer.

Parameters:
  • page_state (None | PageState)

  • slot (None | int)

property logical_axis_names#

Generate logical names for activations generally.

property mlp_logical_axis_names#

Generate logical names for activations in MLP.

post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache=None)[source]#

postprocessing.

self_attention_with_norm_op(inputs, decoder_segment_ids, decoder_positions, deterministic, previous_chunk=None, page_state=None, slot=None)[source]#

self-attention with normalization

Parameters:
  • page_state (None | PageState)

  • slot (None | int)

engram_op(x, decoder_input_tokens)[source]#
class maxtext.models.deepseek.DeepSeekDenseLayer(*args, **kwargs)[source]#

Bases: DeepSeekGenericLayer

DeepSeek-style dense layer with Multi-Head Latent Attention.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

mlp_op(x, deterministic)[source]#

Executes the MLP operation. To be implemented by subclasses.

class maxtext.models.deepseek.DeepSeekMoELayer(*args, **kwargs)[source]#

Bases: DeepSeekGenericLayer

DeepSeek-style MoE layer with Multi-Head Latent Attention.

Supports dropless and dropping base on configs. Uses a bias in routing instead of load balancing loss.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

mlp_op(x, deterministic, *args, **kwargs)[source]#

Executes the MLP operation. To be implemented by subclasses.