# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MoE related Layers."""
import enum
import functools
import math
import random
from typing import Iterable, Optional, Tuple, Union
from aqt.jax.v2 import aqt_tensor as aqt
from flax import nnx
import jax
from jax import ad_checkpoint as adc
from jax.experimental import xla_metadata
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from maxtext.common import common_types as ctypes
from maxtext.common.common_types import ShardMode
from maxtext.kernels import megablox as mblx
from maxtext.layers import attentions, linears, nnx_wrappers, quantizations
from maxtext.layers.initializers import NdInitializer, default_bias_init, nd_dense_init, variable_to_logically_partitioned
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils.sharding import create_sharding, maybe_shard_with_logical, maybe_shard_with_pspec
from maxtext.utils.sharding import logical_to_mesh_axes
import numpy as np
import qwix
from qwix.contrib.sparsity import sparsity_module
import qwix.pallas as qpl
import tokamax
set_xla_metadata = xla_metadata.set_xla_metadata
DISPATCH = "dispatch"
COMBINE = "combine"
def _sort_activations(
inputs: jax.Array,
sort_indices: jax.Array,
use_custom_vjp: bool,
) -> jax.Array:
"""Sort activations by `sort_indices`.
If `use_custom_vjp=True`, then we use a custom backward pass that
reverses the sort order. Specifically, this unsort operation is simply a sort
with `jnp.argsort(sort_indices)` as the sort indices. This is only needed in
the case where the compiler generates a less efficient backward pass op.
Note that `use_custom_vjp=True` assumes that `sort_indices` is a permutation
of `jnp.arange(inputs.shape[0])`.
Args:
inputs: `(tokens, ...)`-shaped array of input activations to sort.
sort_indices: `(tokens,)`-shaped array containing the sort order.
use_custom_vjp: Whether to use the explicit backward pass.
Returns:
`(tokens, ...)`-shaped array of input activations sorted by `sort_indices`.
"""
assert inputs.shape[0] == sort_indices.shape[0]
with jax.named_scope("sort_activations"):
if use_custom_vjp:
return _sort_activations_custom(inputs, sort_indices)
return inputs[sort_indices, ...]
@jax.custom_vjp
def _sort_activations_custom(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array:
"""Sort functions with custom vjp."""
return inputs[sort_indices, ...]
def _sort_activations_custom_fwd(inputs: jax.Array, sort_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Forward pass of the custom vjp for `_sort_activations()`."""
return _sort_activations_custom(inputs, sort_indices), sort_indices
def _sort_activations_custom_bwd(residuals: jax.Array, grads: jax.Array) -> tuple[jax.Array, None]:
"""Backward pass of the custom vjp for `_sort_activations()`."""
sort_indices = residuals
return _sort_activations_custom(grads, jnp.argsort(sort_indices)), None
_sort_activations_custom.defvjp(_sort_activations_custom_fwd, _sort_activations_custom_bwd)
[docs]
def get_batchsplit_init_kernel_axes():
return (
("embed_moe", None, "expert_only"),
("embed_moe", "expert_only", None),
)
[docs]
def random_routing(rng_key, gate_logits, num_experts_per_tok):
"""Performs random routing of tokens to experts.
Args:
rng_key: A JAX PRNGKey for randomness.
gate_logits: A JAX array of shape (batch_size, sequence_length, num_experts)
representing the logits for each expert.
num_experts_per_tok: The number of experts to select for each token.
Returns:
A tuple containing:
- top_k_indices: JAX array of shape (batch_size, sequence_length,
num_experts_per_tok)
representing the indices of the selected experts for each
token.
- top_k_weights: JAX array of shape (batch_size, sequence_length,
num_experts_per_tok)
representing the weights for the selected experts.
"""
bs, seq_len, num_experts = gate_logits.shape
selected_num = bs * seq_len * num_experts_per_tok
# Directly generate random integers in the range [0, num_experts)
top_k_indices = jax.random.randint(
rng_key,
shape=(selected_num,),
minval=0,
maxval=num_experts,
dtype=jnp.int32,
)
top_k_indices = top_k_indices.reshape(bs, seq_len, num_experts_per_tok)
top_k_weights = jnp.take_along_axis(gate_logits, top_k_indices, axis=-1)
return top_k_weights, top_k_indices
[docs]
def calculate_load_balance_updates(top_k_indices, num_experts, rate):
"""
Computes a bias adjustment update based on expert load.
Used in DeepSeek V3: https://arxiv.org/html/2412.19437v1.
Implementation reference: https://arxiv.org/pdf/2408.15664.
Args:
top_k_indices: Shape (batch, sequence, top_k).
num_experts: Total number of experts.
rate: The update rate.
Returns:
update: The value to add to the expert bias. Shape (num_experts,).
"""
flat_indices = top_k_indices.ravel()
expert_counts = jnp.bincount(flat_indices, length=num_experts)
total_tokens = flat_indices.size
average_load = total_tokens / num_experts
direction = jnp.sign(average_load - expert_counts)
output = direction * rate
return output
[docs]
class GateLogit(nnx.Module):
"""A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing."""
def __init__(
self,
in_features_shape: Union[Iterable[int], int],
out_features_shape: Union[Iterable[int], int],
model_name: str,
mesh: Mesh,
rngs: nnx.Rngs,
axis: Union[Iterable[int], int] = -1,
weight_dtype: ctypes.DType = jnp.float32,
dtype: ctypes.DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes: Tuple[Optional[str], ...] = (),
use_bias: bool = False,
score_func: str = "",
quant: Optional[quantizations.AqtQuantization] = None,
shard_mode: ShardMode = ShardMode.AUTO,
matmul_precision: str = "default",
):
"""Initializes the GateLogit module.
Attributes:
in_features_shape: The shape of the input features.
out_features_shape: The shape of the output features, typically the number of experts.
model_name: The name of the model.
rngs: An `nnx.Rngs` object used for initializing parameters.
axis: The axis or axes over transformation is applied.
weight_dtype: The data type of the kernel weights.
dtype: The data type for the computation.
kernel_init: The initializer function for the kernel weight matrix.
kernel_axes: A tuple of logical axis names for partitioning the kernel.
use_bias: Whether to add learnable bias in gate logit scores. When enabled,
this bias aids expert load balancing (like in DeepSeek V3), and is not
part of the loss calculation.
score_func: Scoring function for output normalization before applying bias.
quant: The quantization configuration. If None, no quantization is applied.
matmul_precision: The precision level for the matrix multiplication.
"""
self.in_features_shape = linears.canonicalize_tuple(in_features_shape)
self.out_features_shape = linears.canonicalize_tuple(out_features_shape)
self.model_name = model_name
self.mesh = mesh
self.axis = linears.canonicalize_tuple(axis)
self.weight_dtype = weight_dtype
self.dtype = dtype
self.kernel_init = kernel_init
self.kernel_axes = kernel_axes
self.use_bias = use_bias
self.score_func = score_func
self.quant = quant
self.shard_mode = shard_mode
self.matmul_precision = matmul_precision
# Parameter initialization
kernel_shape = self.in_features_shape + self.out_features_shape
kernel_in_axis = np.arange(len(self.axis))
kernel_out_axis = np.arange(len(self.axis), len(self.axis) + len(self.out_features_shape))
if not quantizations.in_serve_mode(self.quant):
self.kernel = nnx.Param(
self.kernel_init(
rngs.params(),
kernel_shape,
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
out_sharding=self.kernel_axes,
)
if self.use_bias:
bias_axes = self.kernel_axes[-len(self.out_features_shape) :]
bias_shape = kernel_shape[-len(self.out_features_shape) :]
self.bias = nnx.Param(
default_bias_init(rngs.params(), bias_shape, self.weight_dtype),
out_sharding=bias_axes,
)
else:
self.bias = None
if quant:
dot_general_cls = quant.dot_general_cls(mesh_axes=kernel_axes)
dot_general_linen = dot_general_cls()
quant_dot_general = nnx_wrappers.ToNNX(dot_general_linen, rngs=rngs)
self._quant_dot_general_name = f"{type(dot_general_linen).__name__}_0"
setattr(self, self._quant_dot_general_name, quant_dot_general)
dummy_inputs = jnp.zeros((1, *self.in_features_shape), dtype=self.dtype)
self(dummy_inputs, _initializing=True)
else:
self._quant_dot_general_name = None
@property
def quant_dot_general(self) -> nnx_wrappers.ToNNX | None:
if self._quant_dot_general_name is None:
return None
return getattr(self, self._quant_dot_general_name)
def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.Array, Optional[jax.Array]]:
inputs = jnp.asarray(inputs, self.dtype)
norm_axis = linears.normalize_axes(self.axis, inputs.ndim)
if quantizations.in_serve_mode(self.quant):
kernel_shape = self.in_features_shape + self.out_features_shape
kernel = jnp.zeros(kernel_shape, dtype=self.dtype)
else:
kernel = self.kernel[...]
kernel = jnp.asarray(kernel, self.dtype)
contract_ind = tuple(range(0, len(norm_axis)))
output_sharding = (
create_sharding(self.mesh, ("activation_batch", "activation_length", None))
if self.shard_mode == ShardMode.EXPLICIT
else None
)
output = linears._compute_dot_general_nnx(
inputs,
kernel,
norm_axis,
contract_ind,
self.matmul_precision,
self.quant_dot_general,
_initializing,
out_sharding=output_sharding,
)
pre_bias_logits = None
if self.score_func:
output = linears._convert_to_activation_function(self.score_func)(output)
if self.model_name.startswith("deepseek3"):
pre_bias_logits = output
if self.use_bias:
bias = jnp.asarray(self.bias[...], self.dtype)
output += bias
return output, pre_bias_logits
[docs]
class RoutedMoE(nnx.Module):
"""Implements a routed MoE block."""
def __init__(
self,
config: ctypes.Config,
num_experts: int,
num_experts_per_tok: int,
mesh: jax.sharding.Mesh,
kernel_init: attentions.NdInitializer,
kernel_axes: Tuple[Optional[str], ...],
rngs: nnx.Rngs,
intermediate_dim: int = 2048,
weight_dtype: ctypes.DType = jnp.float32,
dtype: ctypes.DType = jnp.float32,
quant: Optional[quantizations.AqtQuantization] = None,
):
"""Initializes the RoutedMoE module.
Attributes:
config: The main config setting.
num_experts: Number of experts.
num_experts_per_tok: Number of experts for each token.
mesh: Mesh, device mesh.
kernel_init: The initializer function for the kernel weight matrix.
kernel_axes: A tuple of logical axis names for partitioning the kernel.
rngs: An `nnx.Rngs` object used for initializing parameters.
intermediate_dim: Intermediate dimension of MoE.
weight_dtype: The data type of the kernel weights.
dtype: The data type for the computation.
quant: The quantization configuration. If None, no quantization is applied.
"""
self.config = config
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.mesh = mesh
self.kernel_init = kernel_init
self.kernel_axes = kernel_axes
self.intermediate_dim = intermediate_dim
self.weight_dtype = weight_dtype
self.dtype = dtype
self.quant = quant
self.rngs = rngs
self.moe_expert_input_dim = (
self.config.emb_dim if self.config.moe_expert_input_dim <= 0 else self.config.moe_expert_input_dim
)
if self.config.shard_exp_on_fsdp:
# special sharding for dsv3
self.wi_kernel_axes = ("embed_moe", None, "mlp_moe")
self.wo_kernel_axes = ("embed_moe", "mlp_moe", None)
elif self.config.use_2d_fsdp_sharding:
self.wi_kernel_axes = ("embed_moe", "mlp_moe", None)
self.wo_kernel_axes = ("embed_moe", "mlp_moe", None)
elif self.config.use_batch_split_schedule:
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
else:
self.wi_kernel_axes = ("exp", "embed_moe", "mlp_moe")
self.wo_kernel_axes = ("exp", "mlp_moe", "embed_moe")
if self.config.attention == "vllm_rpa":
# vLLM uses 'model' as the tensor parallelism axis name
self._tensor_parallelism_name = ("model", "attn_dp")
else:
self._tensor_parallelism_name = "tensor"
if self.config.attention == "vllm_rpa" and self.config.enable_dp_attention:
self._expert_parallelism_name = "attn_dp_expert"
elif self.config.custom_mesh_and_rule == ctypes.CustomRule.CP_AS_EP:
# when custom mesh and rule is cp-as-ep, context axis is same with expert in MoE component
self._expert_parallelism_name = ("context", "expert")
else:
self._expert_parallelism_name = "expert"
self.gate = GateLogit(
in_features_shape=self.moe_expert_input_dim,
out_features_shape=self.num_experts,
mesh=self.mesh,
model_name=self.config.model_name,
dtype=jnp.float32 if self.config.float32_gate_logits else self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
kernel_init=self.kernel_init,
kernel_axes=self.kernel_axes,
use_bias=self.config.routed_bias,
# tpu-inference applies the score function in the fused_moe_gmm kernel,
# so we don't apply it here to avoid redundant computation.
# See https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/common/fused_moe_gmm.py#L58.
score_func="" if self.config.attention == "vllm_rpa" else self.config.routed_score_func,
matmul_precision=self.config.matmul_precision,
shard_mode=config.shard_mode,
rngs=self.rngs,
)
rule = qpl.get_current_rule("gmm")
sparsity_rule = None
if rule is not None:
if not isinstance(rule, qwix.QtRule):
raise ValueError("Expect a QtRule for quantized training.")
if rule.additional_qt_config and "sparsity_rule" in rule.additional_qt_config:
q_s_rule = rule.additional_qt_config["sparsity_rule"]
if q_s_rule and q_s_rule.weight_sparsity_n and q_s_rule.weight_sparsity_m:
sparsity_rule = q_s_rule
if sparsity_rule is not None:
self.wi_0_sparsity_module = sparsity_module.SparsityModule(
shape=(self.num_experts, self.config.emb_dim, self.intermediate_dim),
sharding_axes=self.wi_kernel_axes,
sparsity_rule=sparsity_rule,
)
self.wi_1_sparsity_module = sparsity_module.SparsityModule(
shape=(self.num_experts, self.config.emb_dim, self.intermediate_dim),
sharding_axes=self.wi_kernel_axes,
sparsity_rule=sparsity_rule,
)
self.wo_sparsity_module = sparsity_module.SparsityModule(
shape=(self.num_experts, self.intermediate_dim, self.config.emb_dim),
sharding_axes=self.wo_kernel_axes,
sparsity_rule=sparsity_rule,
)
else:
self.wi_0_sparsity_module = None
self.wi_1_sparsity_module = None
self.wo_sparsity_module = None
# pylint: disable=protected-access
self.activation_fn = linears._convert_to_activation_function(self.config.mlp_activations[0])
kernel_in_axis = np.arange(1)
kernel_out_axis = np.arange(1, 2)
if quantizations.in_serve_mode(self.quant):
# During aqt convert state we delete kernel weight from params to save
# memory. Instead they are retrieved from the tensors stored in the 'aqt'
# collection.
self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim))
elif self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa":
# Pad model dimension in Fused MoE weight kernels for GMM_v2 execution.
moe_intermediate_dim = (
self.config.padded_base_moe_mlp_dim
if self.config.padded_base_moe_mlp_dim is not None
else self.intermediate_dim
)
self.wi = nnx.Param(
self.kernel_init(
self.rngs.params(),
(num_experts, self.moe_expert_input_dim, moe_intermediate_dim * 2),
weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
out_sharding=self.wi_kernel_axes,
)
self.wo = nnx.Param(
self.kernel_init(
self.rngs.params(),
(self.num_experts, self.intermediate_dim, self.moe_expert_input_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
out_sharding=self.wo_kernel_axes,
)
else:
# Pad model dimension in Unfused MoE weight kernels for GMM_v2 execution.
moe_intermediate_dim = (
self.config.padded_base_moe_mlp_dim
if self.config.padded_base_moe_mlp_dim is not None
else self.intermediate_dim
)
self.wi_0 = nnx.Param(
self.kernel_init(
self.rngs.params(),
(num_experts, self.moe_expert_input_dim, moe_intermediate_dim),
weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
out_sharding=self.wi_kernel_axes,
)
self.wi_1 = nnx.Param(
self.kernel_init(
self.rngs.params(),
(num_experts, self.moe_expert_input_dim, moe_intermediate_dim),
weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
out_sharding=self.wi_kernel_axes,
)
self.wo = nnx.Param(
self.kernel_init(
self.rngs.params(),
(self.num_experts, self.intermediate_dim, self.moe_expert_input_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
out_sharding=self.wo_kernel_axes,
)
if self.config.mlp_bias:
wi_bias_axes = ("exp", "activation_mlp")
wo_bias_axes = ("exp", "activation_embed")
wi_bias_shape = (self.num_experts, self.intermediate_dim)
wo_bias_shape = (self.num_experts, self.moe_expert_input_dim)
self.wi_0_bias = nnx.Param(
default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype),
out_sharding=wi_bias_axes,
)
self.wi_1_bias = nnx.Param(
default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype),
out_sharding=wi_bias_axes,
)
self.wo_bias = nnx.Param(
default_bias_init(self.rngs.params(), wo_bias_shape, self.weight_dtype),
out_sharding=wo_bias_axes,
)
else:
self.wi_0_bias = None
self.wi_1_bias = None
self.wo_bias = None
if self.config.decoder_block == ctypes.DecoderBlockType.GEMMA4:
self.per_expert_scale = nnx.Param(
jnp.ones((self.num_experts,), dtype=self.weight_dtype),
out_sharding=("exp",),
)
else:
self.per_expert_scale = None
def _maybe_shard_with_logical(self, inputs, logical_name):
return maybe_shard_with_logical(
inputs,
logical_name,
mesh=self.mesh,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)
def _logical_to_mesh_axes(self, logical_name):
logical_rules = None if self.config.using_pipeline_parallelism else self.config.logical_axis_rules
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=logical_rules)
def _maybe_shard_with_pspec(self, inputs, pspec: jax.sharding.PartitionSpec | None):
return maybe_shard_with_pspec(
inputs,
pspec,
mesh=self.mesh,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)
[docs]
def get_expert_parallelism_size(self):
# When expert parallelism has more than one physical axes, take product of their shapes
if isinstance(self._expert_parallelism_name, tuple):
return math.prod(self.mesh.shape.get(name, 1) for name in self._expert_parallelism_name)
return self.mesh.shape.get(self._expert_parallelism_name, 1)
[docs]
def get_tensor_parallelism_size(self):
if isinstance(self._tensor_parallelism_name, tuple):
size = 1
for axis in self._tensor_parallelism_name:
size *= self.mesh.shape.get(axis, 1)
return size
return self.mesh.shape.get(self._tensor_parallelism_name, 1)
[docs]
def get_tensor_transpose_parallelism_size(self):
return self.mesh.shape.get("tensor_transpose", 1)
[docs]
def get_context_autoregressive_parallelism_size(self):
return self.mesh.shape.get("context_autoregressive", 1)
[docs]
def should_update_load_balance(self):
"""Determines if loss-free load balancing updates should be applied."""
return self.config.routed_bias and self.config.routed_bias_update_rate > 0.0
[docs]
def get_topk(self, gate_logits, pre_bias_logits, rngs=None):
"""get topk."""
# shape of top_k_weights & top_k_indices:
# (batch, sequence, num_experts_per_tok).
if self.config.use_random_routing:
if rngs is None:
raise ValueError("The random key cannot be None for random routing.")
# Reuse the 'params' RNG stream to ensure random routing
rng = rngs.params()
top_k_weights, top_k_indices = random_routing(rng, gate_logits, self.num_experts_per_tok)
return top_k_weights, top_k_indices
if self.config.model_name.startswith("deepseek3"):
top_k_weights, top_k_indices = self.deepseek_routing(gate_logits, pre_bias_logits)
elif self.config.decoder_block == ctypes.DecoderBlockType.GEMMA4:
router_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1)
_, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
top_k_weights = jnp.take_along_axis(router_probs, top_k_indices, axis=-1).astype(self.dtype)
else:
top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK:
top_k_weights = self.deepseek_scale_weights(top_k_weights)
elif self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4):
top_k_weights = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype)
# Normalization of router weights (e.g. used by Qwen3, Gemma4).
if self.config.norm_topk_prob:
top_k_weights /= top_k_weights.sum(axis=-1, keepdims=True)
return top_k_weights, top_k_indices
[docs]
def deepseek_scale_weights(self, weights):
"""Scales weights according to DeepSeek's v3 reference implementation."""
# https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L592-L594.
if self.config.routed_score_func == "sigmoid":
weights /= weights.sum(-1, keepdims=True)
weights *= self.config.routed_scaling_factor
return weights
[docs]
def expert_group_mask(self, gate_logits: jax.Array) -> jax.Array:
"""Returns a mask that selects only the top-k groups of experts.
Groups of experts are selected based on the sum of the top-2 expert scores
for each group.
Args:
gate_logits: Array of shape `(batch, seq, num_experts)`.
Returns:
Array of shape `(batch, seq, num_experts)` that is 1 for experts in the
top-k groups and 0 elsewhere.
"""
# Find top groups based on each group's top-2 expert scores, where
# `scores_grouped.shape =
# (batch * seq, n_routing_groups, experts_per_group)`.
scores_grouped = jnp.reshape(
gate_logits,
gate_logits.shape[:-1] + (self.config.n_routing_groups, -1),
)
top2_in_group_vals, _ = jax.lax.top_k(scores_grouped, k=2)
group_scores = jnp.sum(jnp.astype(top2_in_group_vals, jnp.float32), axis=-1)
_, group_idx = jax.lax.top_k(group_scores, k=self.config.topk_routing_group)
# Mask selected groups so that only those experts are considered.
group_mask = jax.nn.one_hot(group_idx, num_classes=self.config.n_routing_groups, dtype=jnp.float32)
group_mask = jnp.sum(group_mask, axis=-2)
# Apply masks and get top-k indices.
score_mask_expanded = jnp.broadcast_to(
group_mask[..., None],
group_mask.shape + (self.num_experts // self.config.n_routing_groups,),
)
return jnp.reshape(
score_mask_expanded,
score_mask_expanded.shape[:-2] + (self.num_experts,),
)
[docs]
def deepseek_routing(self, gate_logits: jax.Array, pre_bias_logits: jax.Array) -> tuple[jax.Array, jax.Array]:
"""DeepSeek routing logit.
If the configuration does not specify routing groups (`n_routing_groups` is
-1), we use a standard top-k routing mechanism. Otherwise, we force all
selected experts to be from the a subset of the highest rated expert groups.
The selection process uses post_bias logits, while the return weights use
pre_bias logits.
Args:
gate_logits: Array of shape `(batch, seq, num_experts)`.
pre_bias_logits: Array of shape `(batch, seq,num_experts)`.
Returns:
- top_k_weights: `(batch, seq, num_experts_per_tok)` array of weight values for
each selected expert.
- top_k_indices: `(batch, seq, num_experts_per_tok)` array of indices
identifying the selected experts for each token.
"""
expert_mask = 1 if self.config.n_routing_groups == -1 else self.expert_group_mask(gate_logits)
_, top_k_indices = jax.lax.top_k(
jnp.where(expert_mask > 0, gate_logits, -jnp.inf),
k=self.num_experts_per_tok,
)
top_k_weights = jnp.take_along_axis(pre_bias_logits, top_k_indices, axis=-1)
return top_k_weights, top_k_indices
[docs]
def apply_ffn_activation(self, layer_w0, layer_w1):
"""Applies FFN activation function."""
with jax.named_scope("ffn_act"):
if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS:
layer_w0 = jnp.clip(layer_w0, min=None, max=self.config.mlp_activations_limit)
layer_w1 = jnp.clip(layer_w1, min=-self.config.mlp_activations_limit, max=self.config.mlp_activations_limit)
layer_act = self.activation_fn(layer_w0 * 1.702)
glu = jnp.multiply(layer_w0, layer_act)
intermediate_layer = jnp.multiply(glu, (layer_w1 + 1))
else:
layer_act = self.activation_fn(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
return intermediate_layer.astype(self.dtype)
[docs]
def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None):
"""Permute tokens to group by expert to fit gmm call."""
# reshape inputs (batch, sequence, emb) to (batch * sequence, emb)
inputs_shape = inputs.shape
bsz_times_seq_len = inputs_shape[0] * inputs_shape[1]
inputs_2d = jnp.reshape(inputs, (bsz_times_seq_len, inputs_shape[2]))
weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs)
lb_loss = None
if self.config.load_balance_loss_weight > 0.0:
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
lb_loss = self.load_balance_loss(selected_experts, softmax_probs)
if self.should_update_load_balance():
bias_updates = calculate_load_balance_updates(
selected_experts, self.config.num_experts, self.config.routed_bias_update_rate
)
else:
bias_updates = None
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
# weights will be of shape (batch_size, seq_len, num_experts_per_tok)
router_scores = jax.nn.sigmoid(weights.astype(jnp.float32)) # weights are top_k_weights here
# Squeeze router_scores to (batch_size * seq_len, num_experts_per_tok)
inputs_2d = inputs_2d * router_scores.reshape(bsz_times_seq_len, -1)
flatten_selected_experts = jnp.ravel(selected_experts)
if roll_to_expert_id is not None:
flatten_selected_experts = (flatten_selected_experts - roll_to_expert_id) % self.num_experts
sorted_selected_experts = jnp.argsort(flatten_selected_experts)
# sort inputs for number of selected experts
replicated_inputs_2d = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0)
sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype(
self.dtype
)
group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts)
# Return the experts for each sorted input.
expert_indices = jnp.arange(self.num_experts)
sorted_experts = jnp.repeat(
expert_indices,
repeats=group_size,
total_repeat_length=flatten_selected_experts.shape[0],
)
return (
sorted_inputs,
sorted_selected_experts,
weights,
group_size,
sorted_experts,
lb_loss,
bias_updates,
)
[docs]
def unpermute(
self,
intermediate,
sorted_selected_experts,
weights,
batch_size,
sequence_length,
use_custom_sort_vjp=True,
):
"""Unpermute tokens to original order and combine weights."""
unsort_intermediate = _sort_activations(
intermediate,
jnp.argsort(sorted_selected_experts),
use_custom_sort_vjp,
)
reshaped_weights = jnp.reshape(weights, (-1, self.num_experts_per_tok))
reshaped_intermediate = jnp.reshape(
unsort_intermediate,
(reshaped_weights.shape[0], self.num_experts_per_tok, -1),
)
with jax.named_scope("weight_sum"):
matmul_precision = jax.lax.Precision(self.config.matmul_precision)
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
# For Llama4, combine using weights of 1 for selected experts
reshaped_weights = jnp.ones_like(reshaped_weights)
if self.config.float32_weight_sum:
reshaped_intermediate = reshaped_intermediate.astype(jnp.float32)
reshaped_weights = reshaped_weights.astype(jnp.float32)
output = jnp.einsum(
"BKE,BK -> BE",
reshaped_intermediate,
reshaped_weights,
precision=matmul_precision,
)
return output.reshape(batch_size, sequence_length, -1).astype(self.dtype)
[docs]
@staticmethod
def local_permute(
inputs,
global_group_sizes,
local_expert_size,
shard_index,
is_offset=False,
global_sorted_experts=None,
use_custom_sort_vjp=True,
):
"""Permutes tokens locally within an expert shard.
This function prepares the input tokens for processing by the experts
located
on the current shard. It groups the tokens by their assigned local expert
index (0 to local_expert_size - 1).
Args:
inputs: The input data (tokens) assigned to the experts on this shard.
Shape `[tokens, emb_dim]`.
global_group_sizes: The count of tokens assignments for each global expert
across all the batch shards. Shape `[num_batch_shards, num_experts].
local_expert_size: The number of experts handled by the current shard.
shard_index: The index of the current expert shard (0 to
num_expert_parallelism - 1).
is_offset: If True, assumes `inputs` are pre-sorted by global expert ID
and selects the slice relevant to this shard's assigned experts. If
False, assumes that `inputs` corresponding to the shard's experts start
from the beginning of the tensor but need to be permuted by expert ID.
global_sorted_experts: Global expert IDs for the `inputs` used when
`is_offset` is True. Shape `[total_tokens_for_this_shard]`.
Returns:
A tuple containing:
sorted_inputs: Input data permuted local expert ID.
sorted_indices: Indices used to permute the inputs.
local_group_size: Number of tokens assigned to each local expert on this
shard.
sorted_experts_ids: expert ID corresponding to each token of the permuted
inputs.
"""
# Slice the count of local expert IDs in each batch shard.
# all_shard_local_sizes.shape: [expert_shard, local_expert_size]
all_shard_local_sizes = jax.lax.dynamic_slice_in_dim(
global_group_sizes,
shard_index * local_expert_size,
local_expert_size,
axis=1,
)
local_sizes = all_shard_local_sizes.reshape(-1)
# Total count of the local expert IDs is the sum of the counts across all
# batch shards, since all batch shards will send their contributions to the
# current expert shard.
local_group_size = jnp.sum(all_shard_local_sizes, axis=0)
# In this case, the data that needs to be processed by the local shard
# does not start from row 0 but actually starts at
# (jnp.concatenate((jnp.array([0]),
# jnp.cumsum(local_group_sizes[:-1]))[shard_id]).
# This happens if batches (`inputs`) are replicated across expert shards and
# pre-sorted by global Expert ID (via permute()).
if is_offset:
divided_assignments = jnp.floor_divide(global_sorted_experts, local_expert_size)
expert_indices = jnp.where(
divided_assignments == shard_index,
jnp.mod(global_sorted_experts, local_expert_size),
local_expert_size,
)
# In this case the `input` data has been received from the batch shards and
# needs to be reorganized in order of local Expert IDs.
else:
base_indices = jnp.mod(jnp.arange(local_sizes.shape[0]), local_expert_size)
expert_indices = jnp.repeat(base_indices, local_sizes, total_repeat_length=inputs.shape[0])
sorted_indices = jnp.argsort(expert_indices)
sorted_inputs = _sort_activations(inputs, sorted_indices, use_custom_sort_vjp)
sorted_experts_ids = expert_indices[sorted_indices]
return (
sorted_inputs,
sorted_indices,
local_group_size,
sorted_experts_ids,
)
[docs]
@staticmethod
def get_all_to_all_params(
all_shards_group_sizes,
shard_id,
num_expert_parallelism,
is_batch_sharded=True,
):
"""Generates input offsets, send sizes, output offsets, and receive sizes used for ragged_all_to_all."""
class TransformStrategy(enum.Enum):
INPUT_OFFSET = enum.auto()
SEND_SIZE = enum.auto()
OUTPUT_OFFSET = enum.auto()
RECV_SIZE = enum.auto()
def transform_array(input_array, shard_id, strategy, is_batch_sharded):
"""Transforms the input array based on the specified strategy."""
# Prepares it for the usage with `ragged_all_to_all` API. The
# transformation determines how data is sent and received between shards.
if is_batch_sharded:
if strategy == TransformStrategy.INPUT_OFFSET:
# Index of input array for the send
local_array = input_array[shard_id]
return jnp.concatenate((jnp.array([0]), jnp.cumsum(local_array)[:-1]))
elif strategy == TransformStrategy.SEND_SIZE:
# Size of input array for the send
return input_array[shard_id]
elif strategy == TransformStrategy.OUTPUT_OFFSET:
# Received index in the target output
zero_row = jnp.zeros((1,) + input_array.shape[1:], dtype=input_array.dtype)
array_with_zeros = jnp.concatenate((zero_row, input_array), axis=0)
cumulated_array = jnp.cumsum(array_with_zeros, axis=0, dtype=input_array.dtype)
return cumulated_array[shard_id]
elif strategy == TransformStrategy.RECV_SIZE:
# Received size in the target output
return input_array[:, shard_id]
else:
raise ValueError(f"Unknown transform array strategy: {strategy}")
# If the batch is unsharded then we send the same data slice to all other
# shards. We also assume each shard will have the local processed inputs
# sorted to start from index 0. Finally, len(input_array.shape) == 1 since
# there is only one batch shard.
else:
if strategy == TransformStrategy.INPUT_OFFSET:
# The data on each shard always starts at 0.
return jnp.zeros(num_expert_parallelism, dtype=input_array.dtype)
elif strategy == TransformStrategy.SEND_SIZE:
# The send amount is always the amount of data the current expert
# shard needs to process.
return jnp.repeat(input_array[shard_id], num_expert_parallelism)
elif strategy == TransformStrategy.OUTPUT_OFFSET:
# The offset in each shard will just be the start of the group which
# that shard is responsible for.
output_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(input_array[:-1])))[shard_id]
return jnp.repeat(output_offset, num_expert_parallelism)
# The amount that each shard receives from all other shards is
# equivalent to the group sizes (aka input_array).
elif strategy == TransformStrategy.RECV_SIZE:
# Received size in the target output
return input_array
else:
raise ValueError(f"Unknown transform array strategy: {strategy}")
input_offsets = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.INPUT_OFFSET,
is_batch_sharded,
)
send_sizes = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.SEND_SIZE,
is_batch_sharded,
)
output_offsets = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.OUTPUT_OFFSET,
is_batch_sharded,
)
recv_sizes = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.RECV_SIZE,
is_batch_sharded,
)
return input_offsets, send_sizes, output_offsets, recv_sizes
[docs]
@staticmethod
def get_ragged_buffer_size(local_batch, ep_degree, global_experts, top_k, ragged_buffer_factor):
"""Calculates the token batch size of the ragged buffer.
When explicitly setting ragged_buffer_factor>0, this is balanced_size * ragged_buffer_factor, which can drop tokens.
Otherwise this will be worst case size to ensure no dropping.
Inputs:
local_batch: local token batch (batch*seq blown up by top_k) shard on this device (e.g. inside shard_map)
ep_degree: degree of expert parallelism, generally equal to ici_expert_parallelism
global_experts: unsharded expert count, e.g. 256 for deepseek
top_k: aka num_experts_per_tok, 8 for deepseek.
ragged_buffer_factor: When set > 0, the buffer is balanced_size * ragged_buffer_factor.
The value 1.0 will be dropless only in the perfectly balanced case, else tokens will be dropped.
Outputs:
The ragged buffer's token batch size.
"""
balanced_size = local_batch
if ragged_buffer_factor > 0.0:
# This will drop tokens if the true distribution exceeds this buffer.
return int(balanced_size * ragged_buffer_factor)
else:
# Worst case
# Either determined by degree of EP, or can be less when num_local_exp is smaller than top_k:
# Example: If we have 4 EP shards, top_k=8, and experts=256 (deepseek), then worst case is
# all tokens in our EP replica get routed to a single shard, e.g. rank 0 - thus is |EP|=4x larger than perfectly
# balanced. However if we use EP=128, then there are only 256/128 = 2 local experts, and thus at most in an EP
# replica group only the 2 experts of top_k=8 can be chosen, so at most 1/4 of all tokens goes to the most
# popular shard. Thus the imbalance factor goes like |EP|/(top_k/local_exp) = 128/4 = 32.
# In general for local_experts < top_k (e.g. |EP|>32), the balance will go as
# EP * local_experts / top_k = EP * (global_exp/EP) / top_k = global_exp / top_k.
# This is constant as a function of the model - e.g. for deepseek the imbalance is never worse than
# 256 exp / 8 top_k = 32. In practice the imbalance should be much less and potentially can use
# ragged_buffer_factor set to >1 e.g. 3.0, and likely have no dropping (not guaranteed)
worst_case_factor = min(ep_degree, global_experts / top_k)
return int(balanced_size * worst_case_factor)
[docs]
def sparse_matmul(
self,
inputs,
gate_logits,
pre_bias_logits,
w0_kernel,
w1_kernel,
wo_kernel,
w0_bias,
w1_bias,
wo_bias,
):
"""Perform sparse matrix multiplication of inputs and Experts."""
def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, padding_amount):
"""Execute jax.lax.ragged_dot, with potential quantization"""
m, k, n = inputs.shape[0], inputs.shape[1], kernel.shape[2]
tiling = (
min(tiling[0], m),
min(tiling[1], k),
min(tiling[2], n),
)
rhs_inputs = kernel
if isinstance(kernel, aqt.QTensor):
if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1:
raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.")
rhs_inputs = kernel.qvalue
if self.config.use_qwix_quantization:
# Use full contraction for QWIX quantization to allow quantization
# fusion (max reduce over contracting dimension).
tiling = (tiling[0], k, tiling[2])
is_tpu = self.mesh.devices.flat[0] == "tpu"
# TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync
mosaic_group_id = f"{random.randint(0, 1000000000)}" if is_tpu else "0"
with set_xla_metadata(
ragged_dot_tiling=",".join([str(t) for t in tiling]),
mosaic_fusion_group=mosaic_group_id,
):
output = jax.lax.ragged_dot(
lhs=inputs,
rhs=rhs_inputs,
group_sizes=group_sizes,
preferred_element_type=self.dtype,
)
if isinstance(kernel, aqt.QTensor):
# Multiply outputs by the kernely scale
scales = jnp.take(kernel.scale[0].squeeze(), indices=expert_assignments, axis=0)
if padding_amount > 0:
scales = jax.lax.pad(
scales,
jnp.array(0.0, dtype=scales.dtype),
[(0, padding_amount, 0), (0, 0, 0)],
)
output *= scales
return output
def get_tokamax_group_sizes(group_sizes, inputs, kernel):
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
if self.config.use_qwix_quantization or (
self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat
):
return group_sizes
elif self.config.attention == "vllm_rpa":
return group_sizes
else:
return tokamax.RaggedDotGroupSizes(
group_sizes,
(inputs.shape[0] // kernel.shape[0],) * kernel.shape[0],
)
def get_quantization_dtypes():
lhs_quantize_dtype, rhs_quantize_dtype = None, None
if self.quant is not None:
quant_dg = self.quant.quant_dg
lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype()
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
return lhs_quantize_dtype, rhs_quantize_dtype
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
if inputs.shape[0] != expert_assignments.shape[0]:
raise ValueError("The number of input tokens must match the number of expert assignments!")
tokamax_group_sizes = get_tokamax_group_sizes(group_sizes, inputs, kernel)
orig_inputs_shape = inputs.shape # save shape of inputs before potentially padding.
inputs, padding_amount = max_utils.maybe_pad(inputs, self.config.wi_tile_fwd_batch_seq)
inputs = inputs.astype(self.dtype)
kernel = kernel.astype(self.dtype)
lhs_quantize_dtype, rhs_quantize_dtype = get_quantization_dtypes()
# We support three implementations for gmm - tokamax, older forked kernel, or jax.lax.ragged_dot
# For quantized tokamax we call a forked version that supports our quantization recipes.
if self.config.use_tokamax_gmm:
if self.config.quantization: # tokamax (quantized)
output = mblx.gmm(
lhs=inputs,
rhs=kernel,
group_sizes=group_sizes,
preferred_element_type=self.dtype,
tiling=tiling,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
use_qwix_quantization=self.config.use_qwix_quantization,
use_tokamax_backend=self.config.use_tokamax_gmm,
weight_gather_axes=weight_gather_axes,
)
else: # tokamax (unquantized)
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax_group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation="mosaic",
)
elif self.config.megablox: # Older forked megablox
output = mblx.gmm(
lhs=inputs,
rhs=kernel,
group_sizes=group_sizes,
preferred_element_type=self.dtype,
tiling=tiling,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
use_qwix_quantization=self.config.use_qwix_quantization,
use_tokamax_backend=self.config.use_tokamax_gmm,
weight_gather_axes=weight_gather_axes,
)
else: # jax.lax.ragged_dot
output = jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, padding_amount)
if padding_amount > 0:
output = output[: orig_inputs_shape[0]]
return output
def is_batch_sharded_by_ep(input_activation):
# The batch is sharded by expert, except during inference decoding (where batch size == 1).
# In the decoding case, the expert axis is instead replicated along the tensor's batch dimension.
return input_activation.shape[0] > 1
def explicitly_weight_ag(shard_exp_on_fsdp):
if shard_exp_on_fsdp:
quantization_rule = qpl.get_current_rule("gmm")
if quantization_rule and quantization_rule.weight_calibration_method.startswith("fixed"):
return True
return False
def maybe_aqt_partition(w0_kernel, w0_pspec, w1_kernel, w1_pspec, wo_kernel, wo_pspec):
if isinstance(w0_kernel, aqt.QTensor):
w0_pspec = aqt.partition_spec(w0_pspec, (1,), w0_kernel.dtype, use_bias=False)
if isinstance(w1_kernel, aqt.QTensor):
w1_pspec = aqt.partition_spec(w1_pspec, (1,), w1_kernel.dtype, use_bias=False)
if isinstance(wo_kernel, aqt.QTensor):
wo_pspec = aqt.partition_spec(wo_pspec, (1,), wo_kernel.dtype, use_bias=False)
return w0_pspec, w1_pspec, wo_pspec
def get_routed_moe_shardings(is_batch_sharded_by_expert):
if is_batch_sharded_by_expert:
batch_logical_axis = "activation_batch"
else:
batch_logical_axis = "decode_batch_moe"
if self.get_tensor_transpose_parallelism_size() > 1:
input_partition_pspec = self._logical_to_mesh_axes(
(batch_logical_axis, "activation_norm_length", "activation_embed")
)
w0_bias_pspec = self._logical_to_mesh_axes(("exp", None))
w1_bias_pspec = self._logical_to_mesh_axes(("exp", None))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed"))
else:
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed"))
gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
if self.config.model_name.startswith("deepseek3"):
pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
else:
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits_pspec = None
# w0, w1, wo needs to be un sharded on fsdp / fsdp_transpose axis, so use
# mlp_no_fsdp axis
if self.config.shard_exp_on_fsdp:
quantization_rule = qpl.get_current_rule("gmm")
if quantization_rule and quantization_rule.weight_calibration_method.startswith("fixed"):
# special sharding when using static scaling for weights in quantization with shard_exp_on_fsdp
w0_pspec = self._logical_to_mesh_axes(self.wi_kernel_axes)
w1_pspec = self._logical_to_mesh_axes(self.wi_kernel_axes)
wo_pspec = self._logical_to_mesh_axes(self.wo_kernel_axes)
else:
# special sharding for dsv3 to remove overhead between gmm/AG
w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", None, "mlp_no_fsdp"))
wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
elif self.config.use_2d_fsdp_sharding:
w0_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
w1_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
wo_pspec = self._logical_to_mesh_axes(("embed_tensor_transpose", "mlp_no_fsdp", None))
else:
# These are the main shardings used by default - they use funky rules to AG over FSDP.
w0_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
w1_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
wo_pspec = self._logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
return (
batch_logical_axis,
input_partition_pspec,
gate_logits_pspec,
pre_bias_logits_pspec,
w0_pspec,
w1_pspec,
wo_pspec,
w0_bias_pspec,
w1_bias_pspec,
wo_bias_pspec,
)
is_batch_sharded_by_expert = is_batch_sharded_by_ep(inputs)
weight_gather = explicitly_weight_ag(self.config.shard_exp_on_fsdp)
(
batch_logical_axis,
input_partition_pspec,
gate_logits_pspec,
pre_bias_logits_pspec,
w0_pspec,
w1_pspec,
wo_pspec,
w0_bias_pspec,
w1_bias_pspec,
wo_bias_pspec,
) = get_routed_moe_shardings(is_batch_sharded_by_expert)
w0_pspec, w1_pspec, wo_pspec = maybe_aqt_partition(w0_kernel, w0_pspec, w1_kernel, w1_pspec, wo_kernel, wo_pspec)
@functools.partial(
jax.shard_map,
mesh=self.mesh,
in_specs=(
input_partition_pspec,
gate_logits_pspec,
pre_bias_logits_pspec,
w0_pspec,
w1_pspec,
wo_pspec,
w0_bias_pspec,
w1_bias_pspec,
wo_bias_pspec,
P(), # Replicate the input key
),
out_specs=(
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")),
P(), # Handle None or replicate the output
P(), # Handle None or replicate the output
),
check_vma=False,
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
batch_size, sequence_length, _ = x.shape
num_expert_parallelism = self.get_expert_parallelism_size()
if num_expert_parallelism > 1:
expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name)
else:
expert_shard_id = 0
num_expert_parallelism = self.get_expert_parallelism_size()
if self.config.use_ring_of_experts:
# The ring-of-experts strategy first duplicates the inputs to all
# expert shards, and then routes within each shard.
# Duplicate inputs to all expert shards.
x, logits, pre_bias_logits = tuple(
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True)
for z in (x, logits, pre_bias_logits)
)
# "Route" tokens within each shard.
num_experts_per_shard = self.config.num_experts // num_expert_parallelism
x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute(
x,
logits,
pre_bias_logits,
self.config.use_custom_sort_vjp,
roll_to_expert_id=num_experts_per_shard * expert_shard_id,
rngs=rngs,
)
# Filter down to the group sizes that apply to only the experts in the
# current shard.
group_sizes = group_sizes[:num_experts_per_shard]
mask = jnp.arange(x.shape[0]) < jnp.sum(group_sizes)
x = jnp.where(mask[:, None], x, 0)
else:
x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute(
x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs
)
if num_expert_parallelism > 1:
batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data"
# get group sizes for all shards
local_expert_size = self.config.num_experts // num_expert_parallelism
reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1)
global_group_sizes = group_sizes
if is_batch_sharded_by_expert:
all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis)
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
all_shards_group_sizes,
expert_shard_id,
num_expert_parallelism,
)
buffer_size = self.get_ragged_buffer_size(
jnp.shape(x)[0],
num_expert_parallelism,
self.config.num_experts,
self.config.num_experts_per_tok,
self.config.ragged_buffer_factor,
)
output_shape = jax.lax.empty((buffer_size, self.moe_expert_input_dim), dtype=x.dtype)
x = jax.lax.ragged_all_to_all(
x,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name)
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
x,
global_group_sizes,
local_expert_size,
shard_index=expert_shard_id,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
)
else:
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
x,
global_group_sizes[None, :],
local_expert_size,
shard_index=expert_shard_id,
is_offset=True,
global_sorted_experts=selected_experts,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
)
if self.config.mlp_bias:
w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias)
def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
if pspec_dim_axes is None:
return []
axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes
active = []
for ax in axes:
if ax and self.mesh.shape.get(ax, 1) > 1:
active.append((ax, tensor_dim_index))
return active
wi_gather_axes = []
wo_gather_axes = []
if weight_gather:
# wi [Experts, In, Hidden] -> Gather Exp(0) and Hidden(2)
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[0], 0))
wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[2], 2))
# wo [Experts, Hidden, Out] -> Gather Exp(0) and Hidden(1)
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0))
wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1))
gmm_fn = functools.partial(
gmm,
group_sizes=group_sizes,
expert_assignments=selected_experts,
)
wi_tile_size = (
self.config.wi_tile_fwd_batch_seq, # m (LHS batch)
self.config.wi_tile_fwd_embed_dim, # k (contracting)
self.config.wi_tile_fwd_mlp_dim, # n (RHS batch)
self.config.wi_tile_dlhs_batch_seq, # m (LHS batch)
self.config.wi_tile_dlhs_mlp_dim, # k (contracting)
self.config.wi_tile_dlhs_embed_dim, # n (RHS batch)
self.config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
self.config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim
self.config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is RHS batch dim
)
wo_tile_size = (
self.config.wo_tile_fwd_batch_seq, # m (LHS batch)
self.config.wo_tile_fwd_mlp_dim, # k (contracting)
self.config.wo_tile_fwd_embed_dim, # n (RHS batch)
self.config.wo_tile_dlhs_batch_seq, # m (LHS batch)
self.config.wo_tile_dlhs_embed_dim, # k (contracting)
self.config.wo_tile_dlhs_mlp_dim, # n (RHS)
self.config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
self.config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim
self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
)
layer_w0 = gmm_fn(
x,
w0,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
layer_w1 = gmm_fn(
x,
w1,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
intermediate_output = gmm_fn(
intermediate_layer,
wo,
tiling=wo_tile_size,
weight_gather_axes=wo_gather_axes,
)
if self.get_tensor_parallelism_size() > 1:
intermediate_output = jax.lax.psum_scatter(
intermediate_output, self._tensor_parallelism_name, scatter_dimension=1, tiled=True
)
if self.config.mlp_bias:
intermediate_output = intermediate_output + wo_bias
intermediate_output = adc.checkpoint_name(intermediate_output, "moe_mlpwo")
if self.config.use_ring_of_experts:
# Set the outputs of tokens which were not processed to 0.
mask = jnp.arange(intermediate_output.shape[0]) < jnp.sum(group_sizes)
intermediate_output = jnp.where(mask[:, None], intermediate_output, 0)
# Unsort and deduplicate the outputs locally.
output = self.unpermute(
intermediate_output,
sorted_selected_experts,
weights,
batch_size=batch_size,
sequence_length=sequence_length,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
)
# Sum up the partial outputs across the expert shards.
output = jnp.reshape(
output, (-1, sequence_length, self.moe_expert_input_dim // self.get_tensor_parallelism_size())
)
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
else:
if num_expert_parallelism > 1:
original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok
if sorted_selected_experts.shape[0] != original_inputs_first_dim:
raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!")
output_shape = jax.lax.empty(
(
original_inputs_first_dim,
self.moe_expert_input_dim // self.get_tensor_parallelism_size(),
),
dtype=intermediate_output.dtype,
)
if is_batch_sharded_by_expert:
# locally unpermute back to the original order
local_output = _sort_activations(
intermediate_output,
jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable
self.config.use_custom_sort_vjp,
)
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable
expert_shard_id,
num_expert_parallelism,
)
intermediate_output = jax.lax.ragged_all_to_all(
local_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)
else:
# If bach is replicated across EP shards then each shard should send
# 0..local_shard_size data to the other shards and receive the
# local_shard data from all of the other shards using
# ragged_all_to_all.
input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params(
reshaped_group_sizes, # pylint: disable=undefined-variable
expert_shard_id,
num_expert_parallelism,
is_batch_sharded=False,
)
intermediate_output = jax.lax.ragged_all_to_all(
intermediate_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)
output = self.unpermute(
intermediate_output,
sorted_selected_experts,
weights,
batch_size=batch_size,
sequence_length=sequence_length,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
)
return output, lb_loss, bias_updates
if self.config.moe_fsdp_use_two_stage_all_gather:
# Unshard on fsdp axis
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp"))
# Unshard on fsdp_transpose axis
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp", "embed_tensor_transpose"))
# Make sure XLA does not optimize by combining above All-Gather to unshard
# on FSDP axis and the subsequent unshard on fsdp_transpose axis
w0_kernel = jax.lax.optimization_barrier(w0_kernel)
w1_kernel = jax.lax.optimization_barrier(w1_kernel)
wo_kernel = jax.lax.optimization_barrier(wo_kernel)
# Unshard on both fsdp and fsdp_transpose transpose
w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp_no_fsdp"))
w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp_no_fsdp"))
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose"))
if self.get_tensor_transpose_parallelism_size() > 1:
input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed")
else:
input_axes = (batch_logical_axis, "activation_norm_length", None)
gate_logits_axes = (batch_logical_axis, "activation_norm_length", None)
if self.config.model_name.startswith("deepseek3"):
pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None)
else:
pre_bias_logits_axes = None
inputs = self._maybe_shard_with_logical(inputs, input_axes)
gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes)
pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes)
w0_kernel = self._maybe_shard_with_pspec(w0_kernel, w0_pspec)
w1_kernel = self._maybe_shard_with_pspec(w1_kernel, w1_pspec)
wo_kernel = self._maybe_shard_with_pspec(wo_kernel, wo_pspec)
if w0_bias is not None:
w0_bias = self._maybe_shard_with_pspec(w0_bias, w0_bias_pspec)
if w1_bias is not None:
w1_bias = self._maybe_shard_with_pspec(w1_bias, w1_bias_pspec)
if wo_bias is not None:
wo_bias = self._maybe_shard_with_pspec(wo_bias, wo_bias_pspec)
return wrapper(
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs
)
[docs]
def reshape_and_update_weights(self, weights, indices):
"""reshape and update weights."""
# input of weights and indices: (batch_size, seq_len, num_experts_per_tok)
# output of updated weights: (batch_size, seq_len, num_experts)
update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype)
index_update = (
self._maybe_shard_with_logical(jnp.arange(weights.shape[0])[:, None, None], ("activation_batch", None, None)),
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length", None)),
indices,
)
weight_sharding = (
create_sharding(self.mesh, ("activation_batch", "activation_length", None))
if self.config.shard_mode == ShardMode.EXPLICIT
else None
)
update_weights = update_weights.at[index_update].set(weights, out_sharding=weight_sharding)
return update_weights
[docs]
def get_context_partition_and_sub_seq(self, seq_len):
cp = self.get_context_autoregressive_parallelism_size()
if seq_len % cp != 0:
cp = 1
sub_seq = seq_len // cp
return cp, sub_seq
[docs]
def generate_masks_subgroup(self, top_k_indices, softmax_probs):
"""Subgroup mask generation for inference only."""
# calculate
# expert_capacity = (tokens_per_batch / num_experts) * capacity_factor
batch_size, seq_len, _ = top_k_indices.shape
cp, sub_seq = self.get_context_partition_and_sub_seq(seq_len)
# Break sequence into subsequences (groups) of tokens, and route only within
# each group.
top_k_indices = jnp.reshape(top_k_indices, (batch_size, cp, sub_seq, top_k_indices.shape[2]))
tokens_per_batch = sub_seq * self.num_experts_per_tok
# this is to avoid expert_capacity_per_batch = 0
expert_capacity_per_batch = int(
max(
math.ceil(tokens_per_batch / self.num_experts) * self.config.capacity_factor,
self.config.capacity_factor,
)
)
max_logging.log("Applying potential token dropping with a batch expert_capacity of" f" {expert_capacity_per_batch}")
# calculate expert mask and drop tokens if needed
# shape of output expert mask: (batch, sequence, num_experts_per_tok)
#
# A small example:
# give num_experts=4 & num_experts_per_tok=2, and two tokens are routed to
# expert [0, 1] & [1, 3],
# then expert_mask becomes
# [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 1, 0, 0],[0, 0, 0, 1]]]],
# after cumsum, expert_token_count becomes
# [[[[1, 0, 0, 0],[1, 1, 0, 0]], [[1, 2, 0, 0],[1, 2, 0, 1]]]],
# if we set expert_capacity=1,
# trunc_expert_mask becomes
# [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]],
# so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of
# updated_expert_mask is [[[1, 1],[0, 1]]].
expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32)
expert_mask_fused = jnp.reshape(
expert_mask,
(batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts),
)
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None, None))
expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2)
expert_token_count = jnp.reshape(
expert_token_count_fused,
((batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts)),
)
expert_token_count = self._maybe_shard_with_logical(
expert_token_count,
("activation_batch", "activation_norm_length", None, None, None),
)
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3)
# reshape & update weights
softmax_probs = jnp.reshape(
softmax_probs,
((batch_size, cp, sub_seq, self.num_experts)),
)
softmax_probs *= combined_expert_mask
# calculate token position in expert capacity dimension
expert_token_position_fused = expert_mask_fused * expert_token_count_fused
expert_token_position = jnp.reshape(
expert_token_position_fused,
(batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts),
)
combined_expert_token_position = jnp.sum(expert_token_position, axis=3) * combined_expert_mask
expert_token_position_in_capacity = jax.nn.one_hot(
combined_expert_token_position,
num_classes=expert_capacity_per_batch + 1,
dtype=jnp.int32,
)
# shape of combine_mask is
# (batch_size, seq_len, num_experts, expert_capacity_per_batch + 1),
# and cut 0-dimension which is always 0
combine_mask = softmax_probs[..., None] * expert_token_position_in_capacity
combine_mask = combine_mask[..., 1:]
dispatch_mask = combine_mask.astype(bool)
# ici_context_parallelism
dispatch_mask = jnp.reshape(
dispatch_mask,
(batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch),
)
combine_mask = jnp.reshape(
combine_mask,
(batch_size, cp, sub_seq, self.num_experts, expert_capacity_per_batch),
)
return dispatch_mask, combine_mask
[docs]
def generate_masks(self, top_k_indices, softmax_probs):
"""Generate masks."""
# calculate
# expert_capacity = (tokens_per_batch / num_experts) * capacity_factor
batch_size, seq_len, _ = top_k_indices.shape
tokens_per_batch = seq_len * self.num_experts_per_tok
# this is to avoid expert_capacity_per_batch = 0
expert_capacity_per_batch = int(
max(
math.ceil(tokens_per_batch / self.num_experts) * self.config.capacity_factor,
self.config.capacity_factor,
)
)
max_logging.log("Applying potential token dropping with a batch expert_capacity of" f" {expert_capacity_per_batch}")
# calculate expert mask and drop tokens if needed
# shape of output expert mask: (batch, sequence, num_experts_per_tok)
#
# A small example:
# give num_experts=4 & num_experts_per_tok=2, and two tokens are routed to
# expert [0, 1] & [1, 3],
# then expert_mask becomes
# [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 1, 0, 0],[0, 0, 0, 1]]]],
# after cumsum, expert_token_count becomes
# [[[[1, 0, 0, 0],[1, 1, 0, 0]], [[1, 2, 0, 0],[1, 2, 0, 1]]]],
# if we set expert_capacity=1,
# trunc_expert_mask becomes
# [[[[1, 0, 0, 0],[0, 1, 0, 0]], [[0, 0, 0, 0],[0, 0, 0, 1]]]],
# so the 2nd token for expert #1 ([0, 1] & [1, 3]) is dropped, output of
# updated_expert_mask is [[[1, 1],[0, 1]]].
expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32)
expert_mask_fused = jnp.reshape(
expert_mask,
(batch_size, seq_len * self.num_experts_per_tok, self.num_experts),
)
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch_moe", None, None))
expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1)
expert_token_count = jnp.reshape(
expert_token_count_fused,
((batch_size, seq_len, self.num_experts_per_tok, self.num_experts)),
)
expert_token_count = self._maybe_shard_with_logical(
expert_token_count,
("activation_batch", "activation_norm_length", None, None),
)
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2)
softmax_probs *= combined_expert_mask
# calculate token position in expert capacity dimension
expert_token_position_fused = expert_mask_fused * expert_token_count_fused
expert_token_position = jnp.reshape(
expert_token_position_fused,
(batch_size, seq_len, self.num_experts_per_tok, self.num_experts),
)
combined_expert_token_position = jnp.sum(expert_token_position, axis=2) * combined_expert_mask
expert_token_position_in_capacity = jax.nn.one_hot(
combined_expert_token_position,
num_classes=expert_capacity_per_batch + 1,
dtype=jnp.int32,
)
# shape of combine_mask is
# (batch_size, seq_len, num_experts, expert_capacity_per_batch + 1),
# and cut 0-dimension which is always 0
combine_mask = softmax_probs[..., None] * expert_token_position_in_capacity
combine_mask = combine_mask[..., 1:]
dispatch_mask = combine_mask.astype(bool)
return dispatch_mask, combine_mask
# See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details.
[docs]
def load_balance_loss(self, top_k_indices, logits) -> jax.Array:
"""Compute the load balance loss."""
expert_mask = jax.nn.one_hot(top_k_indices, num_classes=self.num_experts, dtype=jnp.int32)
summed_expert_mask = jnp.sum(expert_mask, axis=2)
# Get fraction of tokens dispatched to each expert
density = jnp.mean(summed_expert_mask, axis=1)
# get fraction of probability allocated to each expert
density_prob = jnp.mean(logits, axis=1)
loss = jnp.mean(density * density_prob) * (self.num_experts**2) * self.config.load_balance_loss_weight
return loss
[docs]
def get_einsum(
self,
rhs_mesh_axes: Tuple[Optional[str], ...] = (),
einsum_name: str | None = None,
):
"""Get the Einstein summation."""
# the check is to prevent aqteinsum as einsum op for dispatch and combine
# einsums in ase when capacity_factor > 0
# this is necessary to load pre-quantized weights in case of inference
if self.config.model_call_mode == "inference" and einsum_name in (
DISPATCH,
COMBINE,
):
return jnp.einsum
if self.quant:
def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument
# simply skip kwargs, since aqt einsum doesn't support any kwargs
# like precision
is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization)
kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype}
return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error
einsum_op = aqt_einsum
else:
einsum_op = jnp.einsum
return einsum_op
[docs]
def maybe_all_gather_kernel_weight_in_expert_parallelism(
self, kernel: jax.Array, kernel_axes: Tuple[Optional[str], ...]
):
"""All-gather kernel weight in expert parallelism if needed."""
if self.get_expert_parallelism_size() > 1:
# This will trigger all-gather using weight_dtype
# relax it unless really necessary in expert parallelism only
# Otherwise compiler will handle communication automatically
# esp. with int8 quantization, kernel will be all-gathered in int8 instead
# of weight_dtype
kernel = self._maybe_shard_with_logical(kernel, kernel_axes)
return kernel
[docs]
def dense_matmul(
self,
inputs,
gate_logits,
pre_bias_logits,
w0_kernel,
w1_kernel,
wo_kernel,
w0_bias,
w1_bias,
wo_bias,
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
"""Dense matrix multiplication."""
# gate_logits: batch, length, expert
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch_moe", "activation_length_moe", None))
if self.config.model_name.startswith("deepseek3"):
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits = self._maybe_shard_with_logical(
pre_bias_logits, ("activation_batch_moe", "activation_length_moe", None)
)
top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs)
is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4
if is_llama4_decoder_layer:
router_scores = jax.nn.sigmoid(top_k_weights.astype(jnp.float32)).astype(self.dtype)
inputs = inputs * router_scores
else:
weights = self.reshape_and_update_weights(top_k_weights, top_k_indices)
matmul_precision = jax.lax.Precision(self.config.matmul_precision)
# Calculate load balance loss
if self.config.model_call_mode != "inference":
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
lb_loss = (
self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None
)
else:
lb_loss = None
# Calculate routed bias updates (loss-free)
if self.should_update_load_balance():
bias_updates = calculate_load_balance_updates(
top_k_indices, self.config.num_experts, self.config.routed_bias_update_rate
)
else:
bias_updates = None
batch_size = inputs.shape[0]
seq_len = inputs.shape[1]
cp, sub_seq = self.get_context_partition_and_sub_seq(seq_len)
if self.config.capacity_factor > 0:
# token dropping if needed
if self.config.model_call_mode != "inference":
# TODO(b/425930949): remove this pylint by refactoring the logic here.
dispatch_mask, combine_mask = self.generate_masks(
top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment
)
mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None)
dispatch_axis = (
"activation_exp",
"activation_batch_moe",
None,
"activation_embed_moe",
)
mlp_axis = (
"activation_exp",
"activation_batch_moe",
None,
"activation_mlp",
)
dispatch_eimsum = "BSM,BSEC -> EBCM"
mlp_up_einsum = "EBCM,EMH -> EBCH"
mlp_down_einsum = "EBCH,EHM -> EBCM"
output_einsum = "EBCM,BSEC -> BSM"
else:
# TODO(b/425930507): Try replacing `softmax_probs` with padded weights
# and verify with decode acc tests.
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs)
if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1:
mask_axes = (
"activation_norm_length_moe",
"activation_batch_moe",
None,
None,
None,
)
input_axis = (
"activation_norm_length_moe",
"activation_batch_moe",
None,
"activation_embed_moe",
)
dispatch_axis = (
"activation_exp",
"activation_batch_moe",
None,
None,
"activation_embed_moe",
)
mlp_axis = (
"activation_exp",
"activation_batch_moe",
None,
None,
"activation_mlp",
)
else:
mask_axes = (
"activation_batch_moe",
"activation_norm_length_moe",
None,
None,
None,
)
input_axis = (
"activation_batch_moe",
"activation_norm_length_moe",
None,
"activation_embed_moe",
)
dispatch_axis = (
"activation_exp",
"activation_batch_moe",
None,
None,
"activation_embed_moe",
)
mlp_axis = (
"activation_exp",
"activation_batch_moe",
None,
None,
"activation_mlp",
)
dispatch_eimsum = "BNSM,BNSEC -> EBNCM"
mlp_up_einsum = "EBNCM,EMH -> EBNCH"
mlp_down_einsum = "EBNCH,EHM -> EBNCM"
output_einsum = "EBNCM,BNSEC -> BNSM"
inputs = jnp.reshape(inputs, (batch_size, cp, sub_seq, inputs.shape[2]))
inputs = self._maybe_shard_with_logical(inputs, input_axis)
dispatch_mask = self._maybe_shard_with_logical(dispatch_mask, mask_axes)
combine_mask = self._maybe_shard_with_logical(combine_mask, mask_axes)
with jax.named_scope("dispatch"):
# only cp during prefill
dispatch = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=DISPATCH)(
dispatch_eimsum, inputs, dispatch_mask, precision=matmul_precision
)
if cp > 1:
dispatch = self._maybe_shard_with_logical(
dispatch,
(
None,
"activation_batch_moe",
"activation_norm_length_moe",
None,
"activation_embed_moe",
),
)
dispatch = self._maybe_shard_with_logical(
dispatch,
dispatch_axis,
)
with jax.named_scope("wi_0"):
w0_kernel_axes = ("exp", None, "mlp")
w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes)
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)(
mlp_up_einsum, dispatch, w0_kernel, precision=matmul_precision
)
if self.config.mlp_bias:
w0_bias = w0_bias[:, None, None, :]
layer_w0 = layer_w0 + w0_bias
if self.config.activations_in_float32:
layer_w0 = layer_w0.astype(jnp.float32)
layer_w0 = self._maybe_shard_with_logical(
layer_w0,
mlp_axis,
)
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
with jax.named_scope("wi_1"):
w1_kernel_axes = ("exp", None, "mlp")
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)(
mlp_up_einsum, dispatch, w1_kernel, precision=matmul_precision
)
if self.config.mlp_bias:
w1_bias = w1_bias[:, None, None, :]
layer_w1 = layer_w1 + w1_bias
if self.config.activations_in_float32:
layer_w1 = layer_w1.astype(jnp.float32)
layer_w1 = self._maybe_shard_with_logical(
layer_w1,
mlp_axis,
)
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
with jax.named_scope("wo"):
wo_kernel_axes = ("exp", "mlp", None)
wo_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(wo_kernel, wo_kernel_axes)
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)(
mlp_down_einsum,
layer_multiply,
wo_kernel,
precision=matmul_precision,
)
if self.config.mlp_bias:
wo_bias = wo_bias[:, None, None, :]
intermediate_layer = intermediate_layer + wo_bias
if self.config.activations_in_float32:
intermediate_layer = intermediate_layer.astype(jnp.float32)
if self.config.model_call_mode != "inference":
intermediate_layer = self._maybe_shard_with_logical(
intermediate_layer,
(
"activation_exp",
"activation_batch_moe",
None,
"activation_embed_moe",
),
)
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
with jax.named_scope("combine"):
# Matmul & element wise operation
output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)(
output_einsum,
intermediate_layer,
combine_mask,
precision=matmul_precision,
)
if output.ndim == 4:
output = jnp.reshape(
output,
(
output.shape[0],
output.shape[1] * output.shape[2],
output.shape[3],
),
)
return output, lb_loss, bias_updates
else:
inputs = self._maybe_shard_with_logical(
inputs, ("activation_batch_moe", "activation_norm_length_moe", "activation_embed_moe")
)
with jax.named_scope("wi_0"):
layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
"BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision
)
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias[None, None, :, :]
if self.config.activations_in_float32:
layer_w0 = layer_w0.astype(jnp.float32)
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
with jax.named_scope("wi_1"):
layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
"BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision
)
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias[None, None, :, :]
if self.config.activations_in_float32:
layer_w1 = layer_w1.astype(jnp.float32)
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
with jax.named_scope("wo"):
intermediate_layer = self.get_einsum(rhs_mesh_axes=self.wo_kernel_axes)(
"BSEH,EHM -> BSEM",
layer_multiply,
wo_kernel,
precision=matmul_precision,
)
if self.config.mlp_bias:
intermediate_layer = intermediate_layer + wo_bias[None, None, :, :]
if self.config.activations_in_float32:
intermediate_layer = intermediate_layer.astype(jnp.float32)
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
with jax.named_scope("weight_sum"):
if is_llama4_decoder_layer:
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)
if self.config.float32_weight_sum:
intermediate_layer = intermediate_layer.astype(jnp.float32)
weights = weights.astype(jnp.float32)
# cast to f32 for sum up in einsum op
output = jnp.einsum(
"BSEM,BSE -> BSM",
intermediate_layer,
weights,
precision=matmul_precision,
).astype(self.dtype)
return output, lb_loss, bias_updates
[docs]
def fused_moe_matmul(
self,
inputs,
gate_logits,
wo_kernel,
w0_kernel=None,
w1_kernel=None,
fused_kernel=None,
) -> tuple[jax.Array, None, None]:
"""Fused MoE via tpu_inference fused_moe_func (vllm_rpa path only).
fused_moe_func handles routing, GMM, and weighted combination internally.
It does not compute lb_loss or bias_updates (inference-only).
"""
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
from tpu_inference.layers.common.fused_moe_gmm import fused_moe_func
except ImportError as e:
raise ImportError("fused_moe_matmul requires the tpu-inference package.") from e
# Reshape 3D [B, S, D] -> 2D [T, D] (fused_moe_func expects 2D input)
batch_size, seq_len, emb_dim = inputs.shape
hidden_states = jnp.reshape(inputs, (batch_size * seq_len, emb_dim))
gating_output = jnp.reshape(gate_logits, (batch_size * seq_len, self.num_experts))
# Concatenate gate and up projections: [E, D, H] + [E, D, H] -> [E, D, 2H]
# fused_moe_func splits this internally: gate=w1[..., :H], up=w1[..., H:]
if fused_kernel is None:
fused_kernel = jnp.concatenate([w0_kernel, w1_kernel], axis=-1)
# Use expert parallelism if the expert axis has size > 1
use_ep = self.get_expert_parallelism_size() > 1
# Map MaxText config fields to fused_moe_func args
activation = self.config.mlp_activations[0] # e.g. "silu"
scoring_fn = self.config.routed_score_func if self.config.routed_score_func else "softmax"
# Check if the model architecture intrinsically renormalizes weights
renormalize = self.config.norm_topk_prob or (
self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4)
)
output_2d = fused_moe_func(
hidden_states=hidden_states,
w1=fused_kernel,
w2=wo_kernel,
w1_scale=None,
w2_scale=None,
w1_bias=None,
w2_bias=None,
gating_output=gating_output,
topk=self.num_experts_per_tok,
renormalize=renormalize,
mesh=self.mesh,
use_ep=use_ep,
activation=activation,
scoring_fn=scoring_fn,
sc_kernel_threshold=16777216,
sc_kernel_col_chunk_size=1024,
)
# Reshape output 2D [T, D] -> 3D [B, S, D]
output = jnp.reshape(output_2d, (batch_size, seq_len, emb_dim))
return output, None, None
[docs]
def retrieve_quantized_weight(
self,
inputs,
gate_logits,
pre_bias_logits,
w0_kernel,
w1_kernel,
wo_kernel,
w0_bias,
w1_bias,
wo_bias,
) -> tuple[aqt.QTensor, aqt.QTensor, aqt.QTensor]:
"""Retrieve quantized weights."""
# This is called only during tracing. This is to invoke creation of
# quantized tensor inside AqtEinsum. After jit, this will become no-op and
# will not affect performance.
_ = self.dense_matmul(
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
)
w0_kernel = self.variables["aqt"]["AqtEinsum_0"]["AqtDotGeneral_0"]["qrhs"]["frozen"]
w1_kernel = self.variables["aqt"]["AqtEinsum_1"]["AqtDotGeneral_0"]["qrhs"]["frozen"]
wo_kernel = self.variables["aqt"]["AqtEinsum_2"]["AqtDotGeneral_0"]["qrhs"]["frozen"]
w0_kernel = max_utils.unbox_logicallypartioned(w0_kernel)
w1_kernel = max_utils.unbox_logicallypartioned(w1_kernel)
wo_kernel = max_utils.unbox_logicallypartioned(wo_kernel)
return w0_kernel, w1_kernel, wo_kernel
def __call__(
self, inputs: jax.Array, gate_inputs: jax.Array | None = None, out_sharding: NamedSharding | None = None
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
cfg = self.config
inputs = inputs.astype(cfg.dtype)
gate_dtype = jnp.float32 if cfg.float32_gate_logits else cfg.dtype
routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype)
gate_logits, pre_bias_logits = self.gate(routing_inputs)
wo_kernel = jnp.asarray(self.wo[...], self.dtype)
fused_kernel = None
w0_kernel = None
w1_kernel = None
if cfg.prefuse_moe_weights and cfg.attention == "vllm_rpa":
fused_kernel = jnp.asarray(self.wi[...], self.dtype)
else:
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
if self.per_expert_scale is not None:
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
if self.wi_0_sparsity_module is not None:
_, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel)
_, w1_kernel = self.wi_1_sparsity_module(jnp.zeros_like(w1_kernel), w1_kernel)
_, wo_kernel = self.wo_sparsity_module(jnp.zeros_like(wo_kernel), wo_kernel)
if cfg.mlp_bias:
w0_bias = jnp.asarray(self.wi_0_bias[...], self.dtype)
w1_bias = jnp.asarray(self.wi_1_bias[...], self.dtype)
wo_bias = jnp.asarray(self.wo_bias[...], self.dtype)
else:
w0_bias, w1_bias, wo_bias = None, None, None
# vllm_rpa codepath uses fused_moe_func from tpu_inference for optimized inference.
if cfg.attention == "vllm_rpa":
output, lb_loss, bias_updates = self.fused_moe_matmul(
inputs, gate_logits, wo_kernel, w0_kernel=w0_kernel, w1_kernel=w1_kernel, fused_kernel=fused_kernel
)
elif cfg.sparse_matmul:
if quantizations.in_serve_mode(self.quant):
w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight(
inputs,
gate_logits,
pre_bias_logits,
w0_kernel,
w1_kernel,
wo_kernel,
w0_bias,
w1_bias,
wo_bias,
)
output, lb_loss, bias_updates = self.sparse_matmul(
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
)
else:
output, lb_loss, bias_updates = self.dense_matmul(
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
)
return output, lb_loss, bias_updates
[docs]
class RoutedAndSharedMoE(nnx.Module):
"""Implements a block which combines shared and routed experts."""
def __init__(
self,
config: ctypes.Config,
mesh: jax.sharding.Mesh,
kernel_init: NdInitializer,
kernel_axes: Tuple[Optional[str], ...],
rngs: nnx.Rngs,
weight_dtype: ctypes.DType = jnp.float32,
dtype: ctypes.DType = jnp.float32,
quant: Optional[quantizations.AqtQuantization] = None,
):
"""Initializes the RoutedAndSharedMoE module.
Attributes:
config: The main config setting.
mesh: Mesh, device mesh.
kernel_init: The initializer function for the kernel weight matrix.
kernel_axes: A tuple of logical axis names for partitioning the kernel.
rngs: An `nnx.Rngs` object used for initializing parameters.
weight_dtype: The data type of the kernel weights.
dtype: The data type for the computation.
quant: The quantization configuration. If None, no quantization is applied.
"""
self.config = config
self.mesh = mesh
self.kernel_init = kernel_init
self.kernel_axes = kernel_axes
self.weight_dtype = weight_dtype
self.dtype = dtype
self.quant = quant
self.rngs = rngs
self.moe_expert_input_dim = (
self.config.emb_dim if self.config.moe_expert_input_dim <= 0 else self.config.moe_expert_input_dim
)
# NOTE: the name MoeBlock_0 is to ensure reverse compatibility with
# existing checkpoints for routed experts.
self.MoeBlock_0 = RoutedMoE(
config=self.config,
num_experts=self.config.num_experts,
num_experts_per_tok=self.config.num_experts_per_tok,
mesh=self.mesh,
kernel_init=self.kernel_init,
kernel_axes=("embed_moe", None),
intermediate_dim=self.config.moe_mlp_dim,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
quant=self.quant,
rngs=self.rngs,
)
shared_expert_mlp_dim = (
self.config.mlp_dim if self.config.decoder_block == ctypes.DecoderBlockType.GEMMA4 else self.config.moe_mlp_dim
)
self.shared_experts = linears.MlpBlock(
mesh=self.mesh,
in_features=self.moe_expert_input_dim,
intermediate_dim=self.config.shared_experts * shared_expert_mlp_dim,
activations=self.config.mlp_activations,
kernel_init=self.kernel_init,
intermediate_dropout_rate=self.config.dropout_rate,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
config=self.config,
quant=self.quant,
rngs=self.rngs,
)
@property
def routed_moe(self):
return self.MoeBlock_0
def __call__(
self,
inputs: jax.Array,
original_inputs: jax.Array | None = None,
gate_inputs: jax.Array | None = None,
intermediate_sharding: NamedSharding | None = None,
out_sharding: NamedSharding | None = None,
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
routed_experts, load_balance_loss, moe_bias_updates = self.routed_moe(
inputs, gate_inputs=gate_inputs, out_sharding=out_sharding
)
shared_experts = self.shared_experts(inputs, intermediate_sharding=intermediate_sharding, out_sharding=out_sharding)
return routed_experts + shared_experts, load_balance_loss, moe_bias_updates
[docs]
def get_gate_logit(
inputs_shape: tuple[int, ...],
out_features_shape: Union[Iterable[int], int],
model_name: str,
axis: Union[Iterable[int], int] = -1,
weight_dtype: ctypes.DType = jnp.float32,
dtype: ctypes.DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes: Tuple[Optional[str], ...] = (),
use_bias: bool = False,
score_func: str = "",
quant: Optional[quantizations.AqtQuantization] = None,
matmul_precision: str = "default",
name: Optional[str] = None,
):
"""Creates a GateLogit Linen module."""
axis = linears.canonicalize_tuple(axis)
in_features_shape = tuple(inputs_shape[ax] for ax in linears.normalize_axes(axis, len(inputs_shape)))
module = nnx_wrappers.to_linen(
GateLogit,
in_features_shape=in_features_shape,
out_features_shape=out_features_shape,
model_name=model_name,
axis=axis,
weight_dtype=weight_dtype,
dtype=dtype,
kernel_init=kernel_init,
kernel_axes=kernel_axes,
use_bias=use_bias,
score_func=score_func,
quant=quant,
matmul_precision=matmul_precision,
name=name,
metadata_fn=variable_to_logically_partitioned,
abstract_init=False,
)
return module
[docs]
def get_routed_moe(
config: ctypes.Config,
num_experts: int,
num_experts_per_tok: int,
mesh: jax.sharding.Mesh,
kernel_init: NdInitializer,
kernel_axes: Tuple[Optional[str], ...],
intermediate_dim: int = 2048,
weight_dtype: ctypes.DType = jnp.float32,
dtype: ctypes.DType = jnp.float32,
quant: Optional[quantizations.AqtQuantization] = None,
name: Optional[str] = None,
):
"""Creates a RoutedMoE Linen module."""
module = nnx_wrappers.to_linen(
RoutedMoE,
config=config,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
mesh=mesh,
kernel_init=kernel_init,
kernel_axes=kernel_axes,
intermediate_dim=intermediate_dim,
weight_dtype=weight_dtype,
dtype=dtype,
quant=quant,
name=name,
metadata_fn=variable_to_logically_partitioned,
abstract_init=False,
)
return module
[docs]
def get_routed_and_shared_moe(
config: ctypes.Config,
mesh: jax.sharding.Mesh,
kernel_init: NdInitializer,
kernel_axes: Tuple[Optional[str], ...],
weight_dtype: ctypes.DType = jnp.float32,
dtype: ctypes.DType = jnp.float32,
quant: Optional[quantizations.AqtQuantization] = None,
name: Optional[str] = None,
):
"""Creates a RoutedAndSharedMoE Linen module."""
module = nnx_wrappers.to_linen(
RoutedAndSharedMoE,
config=config,
mesh=mesh,
kernel_init=kernel_init,
kernel_axes=kernel_axes,
weight_dtype=weight_dtype,
dtype=dtype,
quant=quant,
name=name,
metadata_fn=variable_to_logically_partitioned,
abstract_init=False,
)
return module