maxtext.models.deepseek_batchsplit_fp8 module#
Alternative DeepSeek model definition with batch-split schedule.
- maxtext.models.deepseek_batchsplit_fp8.fetch_weights(params, dtype)[source]#
Fetches weights from params in the proper format for batch-split schedule.
- maxtext.models.deepseek_batchsplit_fp8.split(x, split_factor=2)[source]#
Splits the input into split_factor parts along the batch dimension.
- maxtext.models.deepseek_batchsplit_fp8.merge(x, split_factor=2)[source]#
Merges the input microbatches back into a single tensor.
- maxtext.models.deepseek_batchsplit_fp8.gather_weights(weights, mesh)[source]#
all-gathers FSDP sharded weights.
- maxtext.models.deepseek_batchsplit_fp8.scan_batch_split_layers(inputs, params, positions, segment_ids, *, model_mode, mesh, quant, cfg, policy)[source]#
Scans the layers with batch-split schedule.
- maxtext.models.deepseek_batchsplit_fp8.batch_split_schedule(inputs, weights, positions, segment_ids, *, model_mode, mesh, quant, cfg)[source]#
Applies the DeepSeek MoE layer with batch-split schedule.
- maxtext.models.deepseek_batchsplit_fp8.dot(x, y, quant=None, axes=1)[source]#
Computes the dot product of two arrays, optionally using quantization.
- maxtext.models.deepseek_batchsplit_fp8.mla_with_norms(inputs, weights, decoder_positions, decoder_segment_ids, *, mesh, model_mode, attn_op, normalization_layer_epsilon, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, rope_max_timescale, num_query_heads, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, mscale, dtype, quant)[source]#
Performs MLA with pre- and post-normalization.
- maxtext.models.deepseek_batchsplit_fp8.mla(inputs, positions, segment_ids, weights, *, model_mode, epsilon, kv_lora_rank, kv_norm_epsilon, qk_nope_head_dim, qk_rope_head_dim, num_query_heads, rope_theta, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, mscale, attention_op_fn, dtype, quant)[source]#
Performs MLA.
- maxtext.models.deepseek_batchsplit_fp8.query_projection(inputs_q, inputs_positions, wq_a_weights, wq_b_weights, q_norm_scale_weights, *, epsilon, qk_nope_head_dim, qk_rope_head_dim, rope_theta, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, dtype, mscale, quant)[source]#
Performs query projection.
- maxtext.models.deepseek_batchsplit_fp8.kv_projection(inputs, inputs_positions, wkv_a_weights, wkv_b_weights, kv_norm_scale_weights, *, kv_lora_rank, kv_norm_epsilon, qk_rope_head_dim, rope_theta, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, dtype, qk_nope_head_dim, num_query_heads, quant)[source]#
Performs KV projection.
- maxtext.models.deepseek_batchsplit_fp8.get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads, quant)[source]#
Gets key and value from compressed KV latent vector and key rope.
- maxtext.models.deepseek_batchsplit_fp8.rms_norm(x, scale, *, epsilon, dtype)[source]#
RMS normalization.
- maxtext.models.deepseek_batchsplit_fp8.yarn(inputs, positions, *, embedding_dims, rope_theta, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, fprop_dtype)[source]#
Performs YaRN rotary embedding.
- maxtext.models.deepseek_batchsplit_fp8.moe(inputs, weights, *, mesh, num_experts, num_experts_per_tok, routed_scaling_factor, expert_axis_name, use_gather_mosaic_kernel, config, quant)[source]#
Performs dropless MoE with tensor/expert parallelism.
- maxtext.models.deepseek_batchsplit_fp8.expert_indices_and_weights(gate_logits, pre_bias_logits, num_experts_per_tok, routed_scaling_factor)[source]#
Computes expert indices for each token and their corresponding weights.
- Parameters:
gate_logits (Array)
pre_bias_logits (Array)
num_experts_per_tok (int)
routed_scaling_factor (float)
- Return type:
tuple[Array, Array]
- maxtext.models.deepseek_batchsplit_fp8.expert_selection(x, routing_kernel, routing_bias, *, num_experts, num_experts_per_tok, routed_scaling_factor, quant)[source]#
Selects experts for each token and calculates group sizes for each expert.
- maxtext.models.deepseek_batchsplit_fp8.route(x, selected_experts, weights, group_sizes, *, expert_axis_name, use_gather_mosaic_kernel)[source]#
All-gather tokens and then perform local routing.
- maxtext.models.deepseek_batchsplit_fp8.unroute(x, selected_experts, *, expert_axis_name, use_gather_mosaic_kernel)[source]#
Undo route().
- maxtext.models.deepseek_batchsplit_fp8.compute(x, w0, w1, wo, group_sizes, weights, *, config, mesh)[source]#
Processes routed tokens through the MLP.
- maxtext.models.deepseek_batchsplit_fp8.route_compute_unroute(xs, weights, *, num_experts, num_experts_per_tok, routed_scaling_factor, expert_axis_name, use_gather_mosaic_kernel, config, mesh, quant)[source]#
Routes, processes, and unroutes activations.
- maxtext.models.deepseek_batchsplit_fp8.process_activations(xs, weights, *, mesh, num_experts, num_experts_per_tok, routed_scaling_factor, expert_axis_name, use_gather_mosaic_kernel, config, quant)[source]#
Processes activations, which are fully sharded on the batch axis, with tensor/expert sharded weights.