Source code for maxtext.models.deepseek_batchsplit_fp8

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Alternative DeepSeek model definition with batch-split schedule."""

import dataclasses
import functools
import math
from typing import Any, Sequence

from flax import linen as nn
import jax
import jax.numpy as jnp
from maxtext.kernels import megablox, sort_activations
from maxtext.layers import attention_op
from maxtext.layers import moe as moe_lib
from maxtext.layers import quantizations
import qwix.pallas as qpl
import tokamax


@functools.partial(
    jax.custom_vjp,
    nondiff_argnums=(
        1,
        2,
        3,
    ),
)
def quantized_psum_scatter(x: jax.Array, axis_name: str, scatter_dimension: int, tiled: bool) -> jax.Array:
  """Forward: Standard BF16 Reduce-Scatter.

  Backward: Quantized FP8 All-Gather (DeepSeek optimization).

  Args:
    x: The input tensor.
    axis_name: The axis name for the psum_scatter/all_gather operation.
    scatter_dimension: The dimension along which to scatter.
    tiled: Whether the scatter/gather is tiled.

  Returns:
    The result of the reduce-scatter operation.
  """
  return _q_psum_scatter_fwd(x, axis_name, scatter_dimension, tiled)[0]


def _q_psum_scatter_fwd(x: jax.Array, axis_name: str, scatter_dimension: int, tiled: bool) -> tuple[jax.Array, None]:
  out = jax.lax.psum_scatter(x, axis_name=axis_name, scatter_dimension=scatter_dimension, tiled=tiled)
  return out, None


def _q_psum_scatter_bwd(
    axis_name: str,
    scatter_dimension: int,
    tiled: bool,
    res: Any,
    grads: jax.Array,
) -> tuple[jax.Array]:  # pylint: disable=g-one-element-tuple
  """Backward pass for quantized_psum_scatter.

  Performs a quantized All-Gather of the gradients.

  Args:
    axis_name: The axis name for the all_gather operation.
    scatter_dimension: The dimension along which the scatter occurred in the
      forward pass.
    tiled: Whether the gather is tiled.
    res: The residuals from the forward pass (_q_psum_scatter_fwd), containing
      (axis_name, scatter_dimension, tiled).
    grads: The gradients from the next layer, which are in BF16.

  Returns:
    The dequantized and all-gathered gradients.
  """
  del res
  # --- BACKWARD PASS (Dispatch) ---
  # 'grads' is the BF16 gradient arriving from the next layer.
  # We need to broadcast it back to all devices (All-Gather).

  grads_q = qpl.quantize(
      grads,
      jnp.float8_e5m2,
      channelwise_axes=[0],
  )

  gathered_qvals = jax.lax.all_gather(grads_q.qvalue, axis_name=axis_name, tiled=tiled, axis=scatter_dimension)

  return (qpl.dequantize(dataclasses.replace(grads_q, qvalue=gathered_qvals)),)


quantized_psum_scatter.defvjp(_q_psum_scatter_fwd, _q_psum_scatter_bwd)


[docs] def fetch_weights(params, dtype): """Fetches weights from params in the proper format for batch-split schedule.""" return jax.tree.map( # If x is a LogicallyPartitioned array, then x.value is the underlying # array. If not, use the array directly. lambda x: jnp.asarray(getattr(x, "value", x)[...], dtype), ( ( ( params["pre_self_attention_layer_norm"]["scale"], params["post_self_attention_layer_norm"]["scale"], ), ( params["self_attention"]["wq_a"]["kernel"], params["self_attention"]["wq_b"]["kernel"], params["self_attention"]["q_norm"]["scale"], params["self_attention"]["wkv_a"]["kernel"], params["self_attention"]["wkv_b"]["kernel"], params["self_attention"]["kv_norm"]["scale"], params["self_attention"]["out"]["kernel"], ), ), ( ( params["DeepSeekMoeBlock_0"]["MoeBlock_0"]["gate"]["kernel"], params["DeepSeekMoeBlock_0"]["MoeBlock_0"]["gate"]["bias"], ), ( params["DeepSeekMoeBlock_0"]["MoeBlock_0"]["wi_0"], params["DeepSeekMoeBlock_0"]["MoeBlock_0"]["wi_1"], params["DeepSeekMoeBlock_0"]["MoeBlock_0"]["wo"], ), ( params["DeepSeekMoeBlock_0"]["shared_experts"]["wi_0"]["kernel"], params["DeepSeekMoeBlock_0"]["shared_experts"]["wi_1"]["kernel"], params["DeepSeekMoeBlock_0"]["shared_experts"]["wo"]["kernel"], ), ), ), is_leaf=lambda x: not isinstance(x, Sequence), )
[docs] @jax.named_scope("deepseek_batchsplit_split") def split(x, split_factor=2): """Splits the input into `split_factor` parts along the batch dimension.""" if split_factor == 1: return [x] if x is None: return [None] * split_factor else: x = jnp.reshape(x, (-1, split_factor) + x.shape[1:]) return [x[:, i, ...] for i in range(split_factor)]
[docs] @jax.named_scope("deepseek_batchsplit_merge") def merge(x, split_factor=2): """Merges the input microbatches back into a single tensor.""" if split_factor == 1: return x[0] x = jnp.stack(x, axis=1) return jnp.reshape(x, (-1,) + x.shape[2:])
[docs] def gather_weights(weights, mesh): """all-gathers FSDP sharded weights.""" def fn(weights): ( (pre_attn_norm, post_attn_norm), (wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out), ), ( (gate, bias), (routed_wi_0, routed_wi_1, routed_wo), (shared_wi_0, shared_wi_1, shared_wo), ) = weights # All-gather across FSDP axis. Expert axis is used for FSDP in attention. wq_a = jax.lax.all_gather(wq_a, axis_name="expert", tiled=True, axis=1) wq_a = jax.lax.all_gather(wq_a, axis_name="fsdp", tiled=True) wq_b = jax.lax.all_gather(wq_b, axis_name="expert", tiled=True, axis=1) wq_b = jax.lax.all_gather(wq_b, axis_name="fsdp", tiled=True) wkv_a = jax.lax.all_gather(wkv_a, axis_name="expert", tiled=True, axis=1) wkv_a = jax.lax.all_gather(wkv_a, axis_name="fsdp", tiled=True) wkv_b = jax.lax.all_gather(wkv_b, axis_name="expert", tiled=True, axis=1) wkv_b = jax.lax.all_gather(wkv_b, axis_name="fsdp", tiled=True) out = jax.lax.all_gather(out, axis_name="expert", tiled=True) out = jax.lax.all_gather(out, axis_name="fsdp", tiled=True, axis=2) gate = jax.lax.all_gather(gate, axis_name="fsdp", tiled=True) routed_wi_0 = jax.lax.all_gather(routed_wi_0, axis_name="fsdp", tiled=True) routed_wi_1 = jax.lax.all_gather(routed_wi_1, axis_name="fsdp", tiled=True) routed_wo = jax.lax.all_gather(routed_wo, axis_name="fsdp", tiled=True) shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="expert", tiled=True, axis=1) shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="fsdp", tiled=True) shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="expert", tiled=True, axis=1) shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="fsdp", tiled=True) shared_wo = jax.lax.all_gather(shared_wo, axis_name="expert", tiled=True) shared_wo = jax.lax.all_gather(shared_wo, axis_name="fsdp", tiled=True, axis=1) return ( ( (pre_attn_norm, post_attn_norm), (wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out), ), ( (gate, bias), (routed_wi_0, routed_wi_1, routed_wo), (shared_wi_0, shared_wi_1, shared_wo), ), ) return jax.shard_map( fn, mesh=mesh, in_specs=( ( ( ( jax.sharding.PartitionSpec(None), jax.sharding.PartitionSpec(None), ), ( jax.sharding.PartitionSpec("fsdp", "expert"), jax.sharding.PartitionSpec("fsdp", "expert", None), jax.sharding.PartitionSpec(None), jax.sharding.PartitionSpec("fsdp", "expert"), jax.sharding.PartitionSpec("fsdp", "expert", None), jax.sharding.PartitionSpec(None), jax.sharding.PartitionSpec("expert", None, "fsdp"), ), ), ( ( jax.sharding.PartitionSpec("fsdp", None), jax.sharding.PartitionSpec(None), ), ( jax.sharding.PartitionSpec("fsdp", None, "expert"), jax.sharding.PartitionSpec("fsdp", None, "expert"), jax.sharding.PartitionSpec("fsdp", "expert", None), ), ( jax.sharding.PartitionSpec("fsdp", "expert"), jax.sharding.PartitionSpec("fsdp", "expert"), jax.sharding.PartitionSpec("expert", "fsdp"), ), ), ), ), out_specs=( ( ( jax.sharding.PartitionSpec(None), jax.sharding.PartitionSpec(None), ), ( jax.sharding.PartitionSpec(None, None), jax.sharding.PartitionSpec(None, None, None), jax.sharding.PartitionSpec(None), jax.sharding.PartitionSpec(None, None), jax.sharding.PartitionSpec(None, None, None), jax.sharding.PartitionSpec(None), jax.sharding.PartitionSpec(None, None, None), ), ), ( ( jax.sharding.PartitionSpec(None, None), jax.sharding.PartitionSpec(None), ), ( jax.sharding.PartitionSpec(None, None, "expert"), jax.sharding.PartitionSpec(None, None, "expert"), jax.sharding.PartitionSpec(None, "expert", None), ), ( jax.sharding.PartitionSpec(None, None), jax.sharding.PartitionSpec(None, None), jax.sharding.PartitionSpec(None, None), ), ), ), check_vma=False, )(weights)
[docs] def scan_batch_split_layers( inputs, params, positions, segment_ids, *, model_mode, mesh, quant, cfg, policy, ): """Scans the layers with batch-split schedule.""" def batch_split_scan_fn(inputs, weights, dpos, dseg): weights = gather_weights(weights, mesh) xs = batch_split_schedule( inputs, weights, dpos, dseg, model_mode=model_mode, mesh=mesh, quant=quant, cfg=cfg, ) return xs, None batch_split_scan_fn_checkpointed = jax.checkpoint( batch_split_scan_fn, # No need to prevent CSE inside scan. prevent_cse=False, policy=policy, ) weights = fetch_weights(params, cfg.dtype) # `jax.lax.scan` expects the leading dimension of weights to be the scan # dimension, but the weights are initialized/loaded with the param scan # axis as the scan dimension, so swap the axes. weights = jax.tree.map(lambda x: jnp.swapaxes(x, 0, cfg.param_scan_axis), weights) activation_pspec = jax.sharding.PartitionSpec( ("data", "fsdp", "fsdp_transpose", "expert", "context"), None, None, ) inputs = jax.shard_map( functools.partial(split, split_factor=cfg.batch_split_factor), mesh=mesh, in_specs=activation_pspec, out_specs=[activation_pspec] * cfg.batch_split_factor, )(inputs) dpos = split(positions, split_factor=cfg.batch_split_factor) dseg = split(segment_ids, split_factor=cfg.batch_split_factor) outputs, _ = jax.lax.scan( functools.partial(batch_split_scan_fn_checkpointed, dpos=dpos, dseg=dseg), inputs, weights, ) outputs = jax.shard_map( functools.partial(merge, split_factor=cfg.batch_split_factor), mesh=mesh, in_specs=([activation_pspec] * cfg.batch_split_factor,), out_specs=activation_pspec, )(outputs) return outputs
[docs] def batch_split_schedule( inputs, weights, positions, segment_ids, *, model_mode, mesh, quant, cfg, ): """Applies the DeepSeek MoE layer with batch-split schedule.""" xs = [with_data_parallel_constraint(x, mesh) for x in inputs] xs = jax.ad_checkpoint.checkpoint_name(xs, "decoder_layer_input") attn_op = attention_op.AttentionOp( config=cfg, mesh=mesh, attention_kernel=cfg.attention, max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, quant=quant, kv_quant=quantizations.configure_kv_quant(cfg), num_query_heads=cfg.num_query_heads, num_kv_heads=cfg.num_kv_heads, dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, attention_type=cfg.attention_type, ) norm_mla_ws, moe_ws = weights xs = mla_with_norms( xs, norm_mla_ws, positions, segment_ids, mesh=mesh, model_mode=model_mode, attn_op=attn_op, normalization_layer_epsilon=cfg.normalization_layer_epsilon, kv_lora_rank=cfg.kv_lora_rank, qk_nope_head_dim=cfg.qk_nope_head_dim, qk_rope_head_dim=cfg.qk_rope_head_dim, rope_max_timescale=cfg.rope_max_timescale, num_query_heads=cfg.num_query_heads, max_position_embeddings=cfg.max_position_embeddings, original_max_position_embeddings=cfg.original_max_position_embeddings, beta_fast=cfg.beta_fast, beta_slow=cfg.beta_slow, rope_factor=cfg.rope_factor, mscale=cfg.mscale, dtype=cfg.dtype, quant=quant, ) xs = moe( xs, moe_ws, mesh=mesh, num_experts=cfg.num_experts, num_experts_per_tok=cfg.num_experts_per_tok, routed_scaling_factor=cfg.routed_scaling_factor, expert_axis_name="expert", use_gather_mosaic_kernel=False, config=cfg, quant=quant, ) return xs
[docs] def staggered_call(fn, xs): for i, x in enumerate(xs): if i == len(xs) - 1: xs[i] = fn(x) else: xs[i], xs[i + 1] = jax.lax.optimization_barrier((fn(x), xs[i + 1])) return xs
[docs] def with_data_parallel_constraint(x, mesh): activation_pspec = jax.sharding.PartitionSpec( ("data", "fsdp", "fsdp_transpose", "expert", "context"), None, None, ) return jax.lax.with_sharding_constraint(x, jax.NamedSharding(mesh, activation_pspec))
[docs] def dot(x, y, quant=None, axes=1): """Computes the dot product of two arrays, optionally using quantization.""" if quant is not None: # Convert axes to jax.lax.dot_general dimension_numbers if isinstance(axes, int): x_contract = tuple(range(x.ndim - axes, x.ndim)) y_contract = tuple(range(axes)) else: x_contract, y_contract = axes dimension_numbers = ((x_contract, y_contract), ((), ())) # Instantiate and call qwix dot_general custom_dot = quant.dot_general_cls()() return custom_dot(lhs=x, rhs=y, dimension_numbers=dimension_numbers) # Unquantized return jnp.tensordot(x, y, axes=axes)
[docs] def mla_with_norms( inputs, weights, decoder_positions, decoder_segment_ids, *, mesh, model_mode, attn_op, normalization_layer_epsilon, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, rope_max_timescale, num_query_heads, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, mscale, dtype, quant, ): """Performs MLA with pre- and post-normalization.""" (pre_attn_scale, post_attn_scale), attn_ws = weights def fn(args): x, dseg, dpos = args y = rms_norm( x, pre_attn_scale, epsilon=normalization_layer_epsilon, dtype=dtype, ) out = x + with_data_parallel_constraint( mla( y, dpos, dseg, attn_ws, model_mode=model_mode, epsilon=normalization_layer_epsilon, kv_lora_rank=kv_lora_rank, kv_norm_epsilon=normalization_layer_epsilon, qk_nope_head_dim=qk_nope_head_dim, qk_rope_head_dim=qk_rope_head_dim, rope_theta=rope_max_timescale, num_query_heads=num_query_heads, max_position_embeddings=max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings, beta_fast=beta_fast, beta_slow=beta_slow, rope_factor=rope_factor, dtype=dtype, mscale=mscale, attention_op_fn=attn_op, quant=quant, ), mesh, ) return out, rms_norm( out, post_attn_scale, epsilon=normalization_layer_epsilon, dtype=dtype, ) return staggered_call(fn, list(zip(inputs, decoder_segment_ids, decoder_positions)))
[docs] def mla( inputs, positions, segment_ids, weights, *, model_mode, epsilon, kv_lora_rank, kv_norm_epsilon, qk_nope_head_dim, qk_rope_head_dim, num_query_heads, rope_theta, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, mscale, attention_op_fn, dtype, quant, ): """Performs MLA.""" ( wq_a_weights, wq_b_weights, q_norm_scale_weights, wkv_a_weights, wkv_b_weights, kv_norm_scale_weights, out_weights, ) = weights query = query_projection( inputs, positions, wq_a_weights, wq_b_weights, q_norm_scale_weights, epsilon=epsilon, qk_rope_head_dim=qk_rope_head_dim, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings, beta_fast=beta_fast, beta_slow=beta_slow, rope_factor=rope_factor, dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, mscale=mscale, quant=quant, ) query = jax.ad_checkpoint.checkpoint_name(query, "query_proj") key, value = kv_projection( inputs, positions, wkv_a_weights, wkv_b_weights, kv_norm_scale_weights, kv_lora_rank=kv_lora_rank, kv_norm_epsilon=kv_norm_epsilon, qk_rope_head_dim=qk_rope_head_dim, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings, beta_fast=beta_fast, beta_slow=beta_slow, rope_factor=rope_factor, dtype=dtype, qk_nope_head_dim=qk_nope_head_dim, num_query_heads=num_query_heads, quant=quant, ) key = jax.ad_checkpoint.checkpoint_name(key, "key_proj") value = jax.ad_checkpoint.checkpoint_name(value, "value_proj") out = attention_op_fn( query, key, value, segment_ids, model_mode, cached_values=[None, None], ) out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") out = dot(out, out_weights, quant=quant, axes=2) out = jax.ad_checkpoint.checkpoint_name(out, "out_proj") return out
[docs] def query_projection( inputs_q, inputs_positions, wq_a_weights, wq_b_weights, q_norm_scale_weights, *, epsilon, qk_nope_head_dim, qk_rope_head_dim, rope_theta, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, dtype, mscale, quant, ): """Performs query projection.""" # Set softmax scaling. qk_head_dim = qk_nope_head_dim + qk_rope_head_dim softmax_scale = qk_head_dim**-0.5 if max_position_embeddings > original_max_position_embeddings: m = 0.1 * mscale * math.log(rope_factor) + 1.0 softmax_scale = softmax_scale * m * m # LoRA path low_rank_q = dot(inputs_q, wq_a_weights, quant=quant) low_rank_q = rms_norm( low_rank_q, q_norm_scale_weights, epsilon=epsilon, dtype=dtype, ) low_rank_q = jax.ad_checkpoint.checkpoint_name(low_rank_q, "mla_q") q = dot(low_rank_q, wq_b_weights, quant=quant) # Split into non-positional and rotary parts. q_nope, q_pe = jnp.split(q, [qk_nope_head_dim], axis=-1) q_pe = yarn( q_pe, inputs_positions, embedding_dims=qk_rope_head_dim, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings, beta_fast=beta_fast, beta_slow=beta_slow, rope_factor=rope_factor, fprop_dtype=dtype, ) query = jnp.concatenate([q_nope, q_pe], axis=-1) * softmax_scale return query
[docs] def kv_projection( inputs, inputs_positions, wkv_a_weights, wkv_b_weights, kv_norm_scale_weights, *, kv_lora_rank, kv_norm_epsilon, qk_rope_head_dim, rope_theta, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, dtype, qk_nope_head_dim, num_query_heads, quant, ): """Performs KV projection.""" low_rank = dot(inputs, wkv_a_weights, quant=quant) low_rank_main, low_rank_rope = jnp.split(low_rank, [kv_lora_rank], axis=-1) low_rank_main = rms_norm( low_rank_main, kv_norm_scale_weights, epsilon=kv_norm_epsilon, dtype=dtype, ) low_rank_main = jax.ad_checkpoint.checkpoint_name(low_rank_main, "mla_kv") key_rope = jnp.expand_dims(low_rank_rope, axis=2) key_rope = yarn( key_rope, inputs_positions, embedding_dims=qk_rope_head_dim, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings, beta_fast=beta_fast, beta_slow=beta_slow, rope_factor=rope_factor, fprop_dtype=dtype, ) return get_key_value( low_rank_main, key_rope, wkv_b_weights, qk_nope_head_dim=qk_nope_head_dim, num_query_heads=num_query_heads, quant=quant, )
[docs] def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads, quant): """Gets key and value from compressed KV latent vector and key rope.""" kv_out = dot(low_rank_main, wkv_b_weights, quant=quant) # Split kv_out into key_nope and value parts. key_nope, value = jnp.split(kv_out, [qk_nope_head_dim], axis=-1) key_rope = jnp.broadcast_to( key_rope, ( key_nope.shape[0], key_nope.shape[1], num_query_heads, key_rope.shape[3], ), ) key = jnp.concatenate([key_nope, key_rope], axis=-1) return key, value
[docs] def rms_norm(x, scale, *, epsilon, dtype): """RMS normalization.""" x = jnp.asarray(x, jnp.float32) mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True) y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), dtype) return jnp.einsum("i...k,...k->i...k", y, scale)
[docs] def yarn( inputs, positions, *, embedding_dims, rope_theta, max_position_embeddings, original_max_position_embeddings, beta_fast, beta_slow, rope_factor, fprop_dtype, ): """Performs YaRN rotary embedding.""" # Initialize the swap and negate mask. indices = jnp.arange(embedding_dims) # [1, 0, 3, 2, 5, 4, ...] swap_indices = jnp.where(indices % 2 == 0, indices + 1, indices - 1) negation_mask = jnp.where(indices % 2 == 0, -1, 1) identity = jnp.eye(embedding_dims, dtype=jnp.int32) pairwise_swap_and_negate_mask = identity[swap_indices] * negation_mask # Calculate the frequencies. half_dim = embedding_dims // 2 # Compute base frequencies for each (even-indexed) dimension. # (Note: We use jnp.arange with float32 for precision.) freqs = 1.0 / (rope_theta ** (2.0 * jnp.arange(0, half_dim, dtype=jnp.float32) / embedding_dims)) low = ( embedding_dims * math.log(original_max_position_embeddings / (beta_fast * 2 * math.pi)) / (2 * math.log(rope_theta)) ) high = ( embedding_dims * math.log(original_max_position_embeddings / (beta_slow * 2 * math.pi)) / (2 * math.log(rope_theta)) ) low = max(math.floor(low), 0) high = min(math.ceil(high), embedding_dims - 1) diff = high - low if high > low else 0.001 linear_func = (jnp.arange(half_dim, dtype=jnp.float32) - low) / diff smooth = 1 - jnp.clip(linear_func, 0, 1) # The corrected frequency is a weighted mix of the scaled and base values. freqs = freqs / rope_factor * (1 - smooth) + freqs * smooth # Precompute frequencies for all positions by taking the outer product. t = jnp.arange(max_position_embeddings, dtype=jnp.float32) # shape [max_position_embeddings] # This gives a [max_position_embeddings, half_dim] tensor with rows as time steps. freqs = jnp.outer(t, freqs) # Lookup the precomputed frequencies using the position indices. # self.freqs has shape [max_position_embeddings, half_dim] so we use jnp.take along axis 0. # After indexing, shape becomes [B, S, half_dim]; we then add an axis for the heads. freqs = jnp.take(freqs, positions, axis=0) # shape: [B, S, half_dim] freqs = freqs[:, :, jnp.newaxis, :] # shape: [B, S, 1, half_dim] freqs = jnp.repeat(freqs, 2, axis=-1) # shape: [B, S, 1, embedding_dims] # inputs @ mask: [B, S, N, embedding_dims] @ [embedding_dims, embedding_dims] -> [B, S, N, embedding_dims] output = inputs * jnp.cos(freqs) + jnp.matmul(inputs, pairwise_swap_and_negate_mask) * jnp.sin(freqs) return output.astype(fprop_dtype)
[docs] def moe( inputs, weights, *, mesh, num_experts, num_experts_per_tok, routed_scaling_factor, expert_axis_name, use_gather_mosaic_kernel, config, quant, ): """Performs dropless MoE with tensor/expert parallelism.""" xs, ys = list(zip(*inputs)) ys = with_data_parallel_constraint( process_activations( ys, weights, mesh=mesh, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, routed_scaling_factor=routed_scaling_factor, expert_axis_name=expert_axis_name, use_gather_mosaic_kernel=use_gather_mosaic_kernel, config=config, quant=quant, ), mesh, ) return [x + y for x, y in zip(xs, ys)]
[docs] def expert_indices_and_weights( gate_logits: jax.Array, pre_bias_logits: jax.Array, num_experts_per_tok: int, routed_scaling_factor: float, ) -> tuple[jax.Array, jax.Array]: """Computes expert indices for each token and their corresponding weights.""" _, indices = jax.lax.top_k( gate_logits, k=num_experts_per_tok, ) weights = jnp.take_along_axis(pre_bias_logits, indices, axis=-1) weights = routed_scaling_factor * (weights / weights.sum(-1, keepdims=True)) return indices, weights
[docs] def expert_selection( x, routing_kernel, routing_bias, *, num_experts, num_experts_per_tok, routed_scaling_factor, quant, ): """Selects experts for each token and calculates group sizes for each expert.""" pre_bias_logits = jax.nn.sigmoid(dot(x, routing_kernel, quant=quant)) logits = pre_bias_logits + routing_bias selected_experts, weights = expert_indices_and_weights( logits, pre_bias_logits, num_experts_per_tok=num_experts_per_tok, routed_scaling_factor=routed_scaling_factor, ) group_sizes = jnp.bincount(jnp.ravel(selected_experts), length=num_experts) return selected_experts, weights, group_sizes
[docs] def route( x, selected_experts, weights, group_sizes, *, expert_axis_name, use_gather_mosaic_kernel, ): """All-gather tokens and then perform local routing.""" # Communicate local results across the expert axis. x = jax.lax.all_gather(x, axis_name=expert_axis_name, tiled=True) weights = jax.lax.all_gather(weights, axis_name=expert_axis_name, tiled=True) selected_experts = jax.lax.all_gather(selected_experts, axis_name=expert_axis_name, tiled=True) group_sizes = jax.lax.psum(group_sizes, axis_name=expert_axis_name) # Sort the gathered tokens and weights. weights = jnp.ravel(weights)[jnp.argsort(jnp.ravel(selected_experts))] x = sort_activations.route( x, selected_experts, use_gather_mosaic_kernel=use_gather_mosaic_kernel, ) return x, selected_experts, weights, group_sizes
[docs] def unroute( x, selected_experts, *, expert_axis_name, use_gather_mosaic_kernel, ): """Undo `route()`.""" # Unsort the output. x = sort_activations.unroute( x, selected_experts, use_gather_mosaic_kernel=use_gather_mosaic_kernel, ) # Sum across expert shards. return jax.lax.psum_scatter(x, expert_axis_name, scatter_dimension=0, tiled=True)
[docs] def compute(x, w0, w1, wo, group_sizes, weights, *, config, mesh): """Processes routed tokens through the MLP.""" def gmm( inputs, kernel, tiling, group_sizes, preferred_element_type, weight_gather_axes, ): if config.use_qwix_quantization: output = megablox.gmm( lhs=inputs, rhs=kernel, group_sizes=group_sizes, preferred_element_type=preferred_element_type, tiling=tiling, use_qwix_quantization=config.use_qwix_quantization, use_tokamax_backend=config.use_tokamax_gmm, weight_gather_axes=weight_gather_axes, qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config)[0], ) else: output = tokamax.ragged_dot( lhs=inputs, rhs=kernel, group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), precision=jax.lax.Precision.DEFAULT, preferred_element_type=preferred_element_type, implementation="mosaic", ) return output gmm_fn = functools.partial(gmm, group_sizes=group_sizes, preferred_element_type=config.dtype) wi_gather_axes = [] wo_gather_axes = [] wi_tile_size = ( config.wi_tile_fwd_batch_seq, # m (LHS batch) config.wi_tile_fwd_embed_dim, # k (contracting) config.wi_tile_fwd_mlp_dim, # n (RHS batch) config.wi_tile_dlhs_batch_seq, # m (LHS batch) config.wi_tile_dlhs_mlp_dim, # k (contracting) config.wi_tile_dlhs_embed_dim, # n (RHS batch) config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is the RHS batch dim ) wo_tile_size = ( config.wo_tile_fwd_batch_seq, # m (LHS batch) config.wo_tile_fwd_mlp_dim, # k (contracting) config.wo_tile_fwd_embed_dim, # n (RHS batch) config.wo_tile_dlhs_batch_seq, # m (LHS batch) config.wo_tile_dlhs_embed_dim, # k (contracting) config.wo_tile_dlhs_mlp_dim, # n (RHS) config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim ) if config.use_qwix_quantization: gating_pspec, linear_pspec = moe_lib.get_batchsplit_init_kernel_axes() w0_pspec = nn.logical_to_mesh_axes(gating_pspec) wo_pspec = nn.logical_to_mesh_axes(linear_pspec) ignored_axes = ("expert", "tensor", "tensor_transpose") def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): if pspec_dim_axes is None: return [] axes = (pspec_dim_axes,) if isinstance(pspec_dim_axes, str) else pspec_dim_axes active = [] for ax in axes: if ax and ax not in ignored_axes and mesh.shape.get(ax, 1) > 1: active.append((ax, tensor_dim_index)) return active wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[0], 0)) wi_gather_axes.extend(get_active_sharding_axes(w0_pspec[2], 2)) wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0)) wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1)) if config.merge_gating_gmm: w01 = jnp.concatenate([w0, w1], axis=-1) layer_w01 = gmm_fn( x, w01, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, ) layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1) else: layer_w0 = gmm_fn( x, w0, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, ) layer_w1 = gmm_fn( x, w1, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes, ) layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0") layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1") intermediate_layer = jax.nn.silu(layer_w0) * layer_w1 intermediate_layer *= weights[:, None] layer_wo = gmm_fn( intermediate_layer, wo, tiling=wo_tile_size, weight_gather_axes=wo_gather_axes, ) return layer_wo
[docs] def route_compute_unroute( xs, weights, *, num_experts, num_experts_per_tok, routed_scaling_factor, expert_axis_name, use_gather_mosaic_kernel, config, mesh, quant, ): """Routes, processes, and unroutes activations.""" orig_shape = xs[0].shape ( (gate_kernel, gate_bias), (routed_w0, routed_w1, routed_wo), (shared_w0, shared_w1, shared_wo), ) = weights def route_fn(inputs): # Shared expert. y = dot( jax.nn.silu(dot(inputs, shared_w0, quant=quant)) * dot(inputs, shared_w1, quant=quant), shared_wo, quant=quant ) inputs = jnp.reshape(inputs, (-1, inputs.shape[-1])) selected_experts, weights, group_sizes = expert_selection( inputs, gate_kernel, gate_bias, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, routed_scaling_factor=routed_scaling_factor, quant=quant, ) x, selected_experts, weights, group_sizes = route( inputs, selected_experts, weights, group_sizes, expert_axis_name=expert_axis_name, use_gather_mosaic_kernel=use_gather_mosaic_kernel, ) return x, y, selected_experts, weights, group_sizes def compute_fn(inputs): x, y, selected_experts, weights, group_sizes = inputs x = compute( x, routed_w0, routed_w1, routed_wo, group_sizes, weights, config=config, mesh=mesh, ) return x, y, selected_experts def unroute_fn(inputs): x, y, selected_experts = inputs x = unroute( x, selected_experts, expert_axis_name=expert_axis_name, use_gather_mosaic_kernel=use_gather_mosaic_kernel, ) return jnp.reshape(x, orig_shape) + y xs = staggered_call(route_fn, xs) xs = staggered_call(compute_fn, xs) xs = staggered_call(unroute_fn, xs) return xs
[docs] def process_activations( xs, weights, *, mesh, num_experts, num_experts_per_tok, routed_scaling_factor, expert_axis_name, use_gather_mosaic_kernel, config, quant, ): """Processes activations, which are fully sharded on the batch axis, with tensor/expert sharded weights.""" activation_pspec = jax.sharding.PartitionSpec( ("data", "fsdp", "fsdp_transpose", "expert", "context"), None, None, ) if config.use_qwix_quantization: gating_pspec, linear_pspec = moe_lib.get_batchsplit_init_kernel_axes() gating_pspec = nn.logical_to_mesh_axes(gating_pspec) linear_pspec = nn.logical_to_mesh_axes(linear_pspec) else: gating_pspec = jax.sharding.PartitionSpec(None, None, expert_axis_name) linear_pspec = jax.sharding.PartitionSpec(None, expert_axis_name, None) return jax.shard_map( functools.partial( route_compute_unroute, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok, routed_scaling_factor=routed_scaling_factor, expert_axis_name=expert_axis_name, use_gather_mosaic_kernel=use_gather_mosaic_kernel, config=config, mesh=mesh, quant=quant, ), mesh=mesh, in_specs=( [activation_pspec] * len(xs), ( ( jax.sharding.PartitionSpec(None, None), jax.sharding.PartitionSpec(None), ), ( gating_pspec, gating_pspec, linear_pspec, ), ( jax.sharding.PartitionSpec(None, None), jax.sharding.PartitionSpec(None, None), jax.sharding.PartitionSpec(None, None), ), ), ), out_specs=activation_pspec, check_vma=False, )([x.astype(config.dtype) for x in xs], weights)