maxtext.models.deepseek_batchsplit module#
Alternative DeepSeek model definition with batch-split schedule.
The model logic and optimizations are very explicit in this implementation. Weights are explicitly pre-fetched and gathered in the forward pass and gradients are explicitly reduced and post-scattered in the backward pass. Optimization barriers are used to enforce ordering of both large blocks of operations (e.g. attention, dispatch, etc) and individual operations (e.g. AG+gather within dispatch). In order to control remat, residuals from the forward pass are explicitly stored and passed to the backward pass in a custom VJP over the entire layer scan. The backward pass comprises of remat/bwd functions for each forward pass function, with relevant residuals passed between them.
- maxtext.models.deepseek_batchsplit.scheduling_group(group_id)[source]#
- Return type:
AbstractContextManager[None]
- maxtext.models.deepseek_batchsplit.fetch_weights(params, dtype)[source]#
Fetches weights from params in the proper format for batch-split schedule.
- maxtext.models.deepseek_batchsplit.split(x, split_factor=2)[source]#
Splits the input into split_factor parts along the batch dimension.
- maxtext.models.deepseek_batchsplit.merge(x, split_factor=2)[source]#
Merges the input microbatches back into a single tensor.
- maxtext.models.deepseek_batchsplit.extract_layer_weights(all_weights, layer_idx, layer_axis)[source]#
Extracts the weights for given layer.
- maxtext.models.deepseek_batchsplit.insert_layer_ws_grad(all_ws_grad, ws_grad, layer_idx, layer_axis)[source]#
Inserts the weight gradients for given layer.
- maxtext.models.deepseek_batchsplit.gather_weights(weights, mesh)[source]#
all-gathers FSDP sharded weights.
- maxtext.models.deepseek_batchsplit.reduce_scatter_ws_grad(ws_grad, mesh)[source]#
reduce-scatters weight gradients to FSDP sharding.
- maxtext.models.deepseek_batchsplit.all_reduce_ws_grad_dcn(ws_grad, mesh)[source]#
all-reduces weight gradients across DCN axes.
- maxtext.models.deepseek_batchsplit.init_splash_kernel(config)[source]#
Initializes the Splash kernel.
- maxtext.models.deepseek_batchsplit.tpu_flash_attention(query, key, value, mesh, splash_kernel, activation_pspec)[source]#
TPU Flash Attention.
- maxtext.models.deepseek_batchsplit.tpu_flash_attention_bwd(attention_out_grad, query, key, value, attention_output, logsumexp, mesh, splash_kernel, activation_pspec)[source]#
TPU Flash Attention backward.
- maxtext.models.deepseek_batchsplit.scan_batch_split_layers(inputs, params, positions, *, mesh, cfg, num_layers)[source]#
Scans the layers with batch-split schedule.
- maxtext.models.deepseek_batchsplit.batch_split_schedule(inputs, weights, positions, *, mesh, cfg, splash_kernel, activation_pspec, pairwise_swap_and_negate_mask)[source]#
Applies the DeepSeek MoE layer with batch-split schedule.
- maxtext.models.deepseek_batchsplit.batch_split_schedule_bwd(residuals, outputs_grad, weights, positions, *, mesh, cfg, splash_kernel, activation_pspec, pairwise_swap_and_negate_mask)[source]#
Performs the backward pass for a single layer.
- maxtext.models.deepseek_batchsplit.staggered_call(fn, xs)[source]#
Calls a function in a staggered manner while accumulating residuals.
- maxtext.models.deepseek_batchsplit.mla_with_norms(inputs, weights, yarn_freqs, *, mesh, config, splash_kernel, normalization_layer_epsilon, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, num_query_heads, max_position_embeddings, original_max_position_embeddings, rope_factor, mscale, pairwise_swap_and_negate_mask, dtype, activation_pspec)[source]#
Performs MLA with pre-normalization.
- maxtext.models.deepseek_batchsplit.mla_with_norms_remat(residuals, weights, yarn_freqs, *, mesh, config, splash_kernel, normalization_layer_epsilon, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, num_query_heads, max_position_embeddings, original_max_position_embeddings, rope_factor, mscale, pairwise_swap_and_negate_mask, dtype, activation_pspec)[source]#
Performs remat for the mla_with_norms function.
- maxtext.models.deepseek_batchsplit.mla_with_norms_bwd(outputs_grad, bwds)[source]#
Performs the backward pass for the mla_with_norms function.
- maxtext.models.deepseek_batchsplit.mla(inputs, yarn_freqs, weights, *, epsilon, kv_lora_rank, kv_norm_epsilon, qk_nope_head_dim, qk_rope_head_dim, num_query_heads, max_position_embeddings, original_max_position_embeddings, rope_factor, mscale, config, splash_kernel, pairwise_swap_and_negate_mask, dtype, mesh, activation_pspec)[source]#
Performs MLA.
- maxtext.models.deepseek_batchsplit.mla_remat(residuals, yarn_freqs, weights, *, epsilon, kv_lora_rank, kv_norm_epsilon, qk_nope_head_dim, qk_rope_head_dim, num_query_heads, max_position_embeddings, original_max_position_embeddings, rope_factor, mscale, config, splash_kernel, pairwise_swap_and_negate_mask, dtype, mesh, activation_pspec)[source]#
Performs remat for the mla function.
- maxtext.models.deepseek_batchsplit.mla_bwd(out_grad, bwds)[source]#
Performs the backward pass for the mla function.
- maxtext.models.deepseek_batchsplit.query_projection(inputs_q, yarn_freqs, wq_a_weights, wq_b_weights, q_norm_scale_weights, *, epsilon, qk_nope_head_dim, qk_rope_head_dim, max_position_embeddings, original_max_position_embeddings, rope_factor, pairwise_swap_and_negate_mask, dtype, mscale, config, mesh, activation_pspec)[source]#
Performs query projection.
- maxtext.models.deepseek_batchsplit.kv_projection(inputs, yarn_freqs, wkv_a_weights, wkv_b_weights, kv_norm_scale_weights, *, kv_lora_rank, kv_norm_epsilon, pairwise_swap_and_negate_mask, dtype, qk_nope_head_dim, num_query_heads, config, mesh, activation_pspec)[source]#
Performs KV projection.
- maxtext.models.deepseek_batchsplit.get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads)[source]#
Gets key and value from compressed KV latent vector and key rope.
- maxtext.models.deepseek_batchsplit.rms_norm(x, scale, *, epsilon, dtype, out_sharding=None)[source]#
RMS normalization.
- maxtext.models.deepseek_batchsplit.initialize_yarn_mask(embedding_dims)[source]#
Initializes YaRN mask.
- maxtext.models.deepseek_batchsplit.initialize_yarn_freqs(positions, embedding_dims, rope_theta, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, mesh, activation_pspec)[source]#
Initializes YaRN frequencies.
- maxtext.models.deepseek_batchsplit.yarn(inputs, freqs, *, pairwise_swap_and_negate_mask, fprop_dtype)[source]#
Performs YaRN rotary embedding.
Computes the shared expert and routes the activations.
- maxtext.models.deepseek_batchsplit.expert_group_mask(gate_logits, *, n_routing_groups, topk_routing_group, top_k_in_group)[source]#
Computes expert group mask for node-limited routing.
- maxtext.models.deepseek_batchsplit.expert_indices_and_weights(gate_logits, pre_bias_logits, num_experts_per_tok, routed_scaling_factor, n_routing_groups, topk_routing_group, top_k_in_group)[source]#
Computes expert indices for each token and their corresponding weights.
- Parameters:
gate_logits (Array)
pre_bias_logits (Array)
num_experts_per_tok (int)
routed_scaling_factor (float)
n_routing_groups (int)
topk_routing_group (int)
top_k_in_group (int)
- Return type:
tuple[Array, Array]
- maxtext.models.deepseek_batchsplit.expert_selection(x, routing_kernel, routing_bias, *, num_experts, num_experts_per_tok, routed_scaling_factor, n_routing_groups, topk_routing_group, top_k_in_group)[source]#
Selects experts for each token and calculates group sizes for each expert.
- Parameters:
num_experts (int)
num_experts_per_tok (int)
routed_scaling_factor (float)
n_routing_groups (int)
topk_routing_group (int)
top_k_in_group (int)
- maxtext.models.deepseek_batchsplit.route(x, selected_experts, weights, group_sizes, *, expert_axis_name, use_gather_mosaic_kernel)[source]#
All-gather tokens and then perform local routing.
- maxtext.models.deepseek_batchsplit.unroute(x, selected_experts, *, expert_axis_name, use_gather_mosaic_kernel)[source]#
Undo route().
- maxtext.models.deepseek_batchsplit.route_impl_fwd(x, selected_experts, expert_axis_name, use_gather_mosaic_kernel)[source]#
Routes the activations and all-gathers across the expert axis.
- maxtext.models.deepseek_batchsplit.route_impl_bwd(expert_axis_name, use_gather_mosaic_kernel, res, grad)[source]#
- maxtext.models.deepseek_batchsplit.unroute_impl_fwd(x, selected_experts, expert_axis_name, use_gather_mosaic_kernel)[source]#
Unroutes the activations and reduce-scatters across the expert axis.
- maxtext.models.deepseek_batchsplit.unroute_impl_bwd(expert_axis_name, use_gather_mosaic_kernel, res, grad)[source]#
- maxtext.models.deepseek_batchsplit.gmm(inputs, kernel, group_sizes, preferred_element_type, config)[source]#
Performs a Grouped Matrix Multiplication (GMM).
This function can use either a quantized Megablox kernel or a standard jax.lax.ragged_dot for the GMM operation, based on the configuration.
- Parameters:
inputs – The left-hand side operand of the GMM.
kernel – The right-hand side operand (kernel) of the GMM.
group_sizes – An array indicating the size of each group.
preferred_element_type – The preferred element type for the computation.
config – Configuration object containing model settings, including use_qwix_quantization and merge_gating_gmm.
- Returns:
The result of the grouped matrix multiplication.
- maxtext.models.deepseek_batchsplit.compute_gating(x, w0, w1, group_sizes, *, dtype, config)[source]#
Computes the gating GMMs.
- maxtext.models.deepseek_batchsplit.compute_linear(layer_w0, layer_w1, wo, group_sizes, weights, *, dtype, config)[source]#
Combines the outputs of the gating GMMs and computes the final GMM.
- maxtext.models.deepseek_batchsplit.route_compute_unroute(xs, weights, *, num_experts, num_experts_per_tok, routed_scaling_factor, n_routing_groups, topk_routing_group, top_k_in_group, expert_axis_name, use_gather_mosaic_kernel, normalization_layer_epsilon, dtype, config)[source]#
Routes, processes, and unroutes activations.
- maxtext.models.deepseek_batchsplit.unroute_ubatch_shard_mapped(moe_inputs, routed_expert_out, shared_expert_out, selected_experts, *, expert_axis_name, use_gather_mosaic_kernel, target_length, mesh, activation_pspec)[source]#
Performs the unroute operation for a single microbatch in a shard map.
- maxtext.models.deepseek_batchsplit.unroute_ubatch_fn(moe_inputs, routed_expert_out, shared_expert_out, selected_experts, *, expert_axis_name, use_gather_mosaic_kernel, target_length)[source]#
Performs the unroute operation for a single microbatch.
- maxtext.models.deepseek_batchsplit.unroute_ubatch_remat_and_bwd_shard_mapped(selected_experts, outputs_grad, *, expert_axis_name, use_gather_mosaic_kernel, mesh, activation_pspec)[source]#
Performs remat and backward pass for unroute_ubatch in a shard map.
- maxtext.models.deepseek_batchsplit.unroute_ubatch_fn_remat(selected_experts, *, expert_axis_name, use_gather_mosaic_kernel)[source]#
- maxtext.models.deepseek_batchsplit.route_compute_unroute_bwd(residuals, outputs_grad, weights, *, num_experts, num_experts_per_tok, routed_scaling_factor, n_routing_groups, topk_routing_group, top_k_in_group, expert_axis_name, use_gather_mosaic_kernel, normalization_layer_epsilon, dtype, config)[source]#
Performs the backward pass for route_compute_unroute.
- maxtext.models.deepseek_batchsplit.moe(xs, weights, *, mesh, num_experts, num_experts_per_tok, routed_scaling_factor, n_routing_groups, topk_routing_group, top_k_in_group, expert_axis_name, use_gather_mosaic_kernel, config, normalization_layer_epsilon, dtype, activation_pspec)[source]#
Performs dropless MoE with tensor/expert parallelism.
- maxtext.models.deepseek_batchsplit.moe_bwd(residuals, outputs_grad, weights, *, mesh, num_experts, num_experts_per_tok, routed_scaling_factor, n_routing_groups, topk_routing_group, top_k_in_group, expert_axis_name, use_gather_mosaic_kernel, config, normalization_layer_epsilon, dtype, activation_pspec)[source]#
Performs the backward pass for the moe function.