maxtext.layers.moe module#
MoE related Layers.
- maxtext.layers.moe.random_routing(rng_key, gate_logits, num_experts_per_tok)[source]#
Performs random routing of tokens to experts.
- Parameters:
rng_key – A JAX PRNGKey for randomness.
gate_logits – A JAX array of shape (batch_size, sequence_length, num_experts) representing the logits for each expert.
num_experts_per_tok – The number of experts to select for each token.
- Returns:
top_k_indices: JAX array of shape (batch_size, sequence_length,
- num_experts_per_tok)
representing the indices of the selected experts for each token.
top_k_weights: JAX array of shape (batch_size, sequence_length,
- num_experts_per_tok)
representing the weights for the selected experts.
- Return type:
A tuple containing
- maxtext.layers.moe.calculate_load_balance_updates(top_k_indices, num_experts, rate)[source]#
Computes a bias adjustment update based on expert load. Used in DeepSeek V3: https://arxiv.org/html/2412.19437v1. Implementation reference: https://arxiv.org/pdf/2408.15664.
- Parameters:
top_k_indices – Shape (batch, sequence, top_k).
num_experts – Total number of experts.
rate – The update rate.
- Returns:
The value to add to the expert bias. Shape (num_experts,).
- Return type:
update
- class maxtext.layers.moe.GateLogit(*args, **kwargs)[source]#
Bases:
ModuleA layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- class maxtext.layers.moe.RoutedMoE(*args, **kwargs)[source]#
Bases:
ModuleImplements a routed MoE block.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- should_update_load_balance()[source]#
Determines if loss-free load balancing updates should be applied.
- deepseek_scale_weights(weights)[source]#
Scales weights according to DeepSeek’s v3 reference implementation.
- expert_group_mask(gate_logits)[source]#
Returns a mask that selects only the top-k groups of experts.
Groups of experts are selected based on the sum of the top-2 expert scores for each group.
- Parameters:
gate_logits (Array) – Array of shape (batch, seq, num_experts).
- Returns:
Array of shape (batch, seq, num_experts) that is 1 for experts in the top-k groups and 0 elsewhere.
- Return type:
Array
- deepseek_routing(gate_logits, pre_bias_logits)[source]#
DeepSeek routing logit.
If the configuration does not specify routing groups (n_routing_groups is -1), we use a standard top-k routing mechanism. Otherwise, we force all selected experts to be from the a subset of the highest rated expert groups.
The selection process uses post_bias logits, while the return weights use pre_bias logits.
- Parameters:
gate_logits (Array) – Array of shape (batch, seq, num_experts).
pre_bias_logits (Array) – Array of shape (batch, seq,num_experts).
- Returns:
- (batch, seq, num_experts_per_tok) array of weight values for
each selected expert.
top_k_indices: (batch, seq, num_experts_per_tok) array of indices identifying the selected experts for each token.
- Return type:
top_k_weights
- permute(inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None)[source]#
Permute tokens to group by expert to fit gmm call.
- unpermute(intermediate, sorted_selected_experts, weights, batch_size, sequence_length, use_custom_sort_vjp=True)[source]#
Unpermute tokens to original order and combine weights.
- static local_permute(inputs, global_group_sizes, local_expert_size, shard_index, is_offset=False, global_sorted_experts=None, use_custom_sort_vjp=True)[source]#
Permutes tokens locally within an expert shard.
This function prepares the input tokens for processing by the experts located on the current shard. It groups the tokens by their assigned local expert index (0 to local_expert_size - 1).
- Parameters:
inputs – The input data (tokens) assigned to the experts on this shard. Shape [tokens, emb_dim].
global_group_sizes – The count of tokens assignments for each global expert across all the batch shards. Shape `[num_batch_shards, num_experts].
local_expert_size – The number of experts handled by the current shard.
shard_index – The index of the current expert shard (0 to num_expert_parallelism - 1).
is_offset – If True, assumes inputs are pre-sorted by global expert ID and selects the slice relevant to this shard’s assigned experts. If False, assumes that inputs corresponding to the shard’s experts start from the beginning of the tensor but need to be permuted by expert ID.
global_sorted_experts – Global expert IDs for the inputs used when is_offset is True. Shape [total_tokens_for_this_shard].
- Returns:
sorted_inputs: Input data permuted local expert ID. sorted_indices: Indices used to permute the inputs. local_group_size: Number of tokens assigned to each local expert on this
shard.
sorted_experts_ids: expert ID corresponding to each token of the permuted inputs.
- Return type:
A tuple containing
- static get_all_to_all_params(all_shards_group_sizes, shard_id, num_expert_parallelism, is_batch_sharded=True)[source]#
Generates input offsets, send sizes, output offsets, and receive sizes used for ragged_all_to_all.
- transform_bias(experts_index, *biases)[source]#
Selects bias values for a variable number of bias tensors based on chosen experts.
- static get_ragged_buffer_size(local_batch, ep_degree, global_experts, top_k, ragged_buffer_factor)[source]#
Calculates the token batch size of the ragged buffer. When explicitly setting ragged_buffer_factor>0, this is balanced_size * ragged_buffer_factor, which can drop tokens. Otherwise this will be worst case size to ensure no dropping.
- Inputs:
local_batch: local token batch (batch*seq blown up by top_k) shard on this device (e.g. inside shard_map) ep_degree: degree of expert parallelism, generally equal to ici_expert_parallelism global_experts: unsharded expert count, e.g. 256 for deepseek top_k: aka num_experts_per_tok, 8 for deepseek. ragged_buffer_factor: When set > 0, the buffer is balanced_size * ragged_buffer_factor.
The value 1.0 will be dropless only in the perfectly balanced case, else tokens will be dropped.
- Outputs:
The ragged buffer’s token batch size.
- sparse_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias)[source]#
Perform sparse matrix multiplication of inputs and Experts.
- generate_masks_subgroup(top_k_indices, softmax_probs)[source]#
Subgroup mask generation for inference only.
- get_einsum(rhs_mesh_axes=(), einsum_name=None)[source]#
Get the Einstein summation.
- Parameters:
rhs_mesh_axes (Tuple[str | None, ...])
einsum_name (str | None)
- maybe_all_gather_kernel_weight_in_expert_parallelism(kernel, kernel_axes)[source]#
All-gather kernel weight in expert parallelism if needed.
- Parameters:
kernel (Array)
kernel_axes (Tuple[str | None, ...])
- dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias)[source]#
Dense matrix multiplication.
- Return type:
tuple[Array, Array | None, Array | None]
- fused_moe_matmul(inputs, gate_logits, wo_kernel, w0_kernel=None, w1_kernel=None, fused_kernel=None)[source]#
Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only).
fused_moe_func handles routing, GMM, and weighted combination internally. It does not compute lb_loss or bias_updates (inference-only).
- Return type:
tuple[Array, None, None]
Bases:
ModuleImplements a block which combines shared and routed experts.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- maxtext.layers.moe.get_gate_logit(inputs_shape, out_features_shape, model_name, axis=-1, weight_dtype=<class 'jax.numpy.float32'>, dtype=<class 'jax.numpy.float32'>, kernel_init=<function nd_dense_init.<locals>.init_fn>, kernel_axes=(), use_bias=False, score_func='', quant=None, matmul_precision='default', name=None)[source]#
Creates a GateLogit Linen module.
- Parameters:
inputs_shape (tuple[int, ...])
out_features_shape (Iterable[int] | int)
model_name (str)
axis (Iterable[int] | int)
weight_dtype (dtype)
dtype (dtype)
kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array])
kernel_axes (Tuple[str | None, ...])
use_bias (bool)
score_func (str)
quant (AqtQuantization | None)
matmul_precision (str)
name (str | None)
- maxtext.layers.moe.get_routed_moe(config, num_experts, num_experts_per_tok, mesh, kernel_init, kernel_axes, intermediate_dim=2048, weight_dtype=<class 'jax.numpy.float32'>, dtype=<class 'jax.numpy.float32'>, quant=None, name=None)[source]#
Creates a RoutedMoE Linen module.
- Parameters:
config (Any)
num_experts (int)
num_experts_per_tok (int)
mesh (Mesh)
kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array])
kernel_axes (Tuple[str | None, ...])
intermediate_dim (int)
weight_dtype (dtype)
dtype (dtype)
quant (AqtQuantization | None)
name (str | None)
Creates a RoutedAndSharedMoE Linen module.
- Parameters:
config (Any)
mesh (Mesh)
kernel_init (Callable[[Array, Sequence[int], dtype, int | tuple[int, ...], int | tuple[int, ...]], Array])
kernel_axes (Tuple[str | None, ...])
weight_dtype (dtype)
dtype (dtype)
quant (AqtQuantization | None)
name (str | None)