# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grouped matrix multiplication operations with custom VJPs."""
# pylint: disable=too-many-positional-arguments
import dataclasses
import functools
from typing import List, Literal, Tuple
import jax
import jax.numpy as jnp
from maxtext.kernels.megablox import backend
from maxtext.layers import quantizations
import qwix
import qwix.pallas as qpl
import tokamax
DRHS_RAGGED_DOT_DIM_NUMS = jax.lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([0], [0]), ([], [])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[],
)
[docs]
def gmm(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
group_sizes: jnp.ndarray,
preferred_element_type: jnp.dtype = jnp.float32,
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
128,
128,
128,
128,
128,
128,
128,
128,
128,
),
group_offset: jnp.ndarray | None = None,
existing_out: jnp.ndarray | None = None,
transpose_rhs: bool = False,
interpret: bool = False,
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
use_qwix_quantization: bool = False,
use_tokamax_backend: bool = False,
weight_gather_axes: List[Tuple[str, int]] | None = None,
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
qwix_rule: qwix.QtRule | None = None,
use_manual_quantization: bool = False,
):
"""Grouped matrix multiplication operation."""
quantization_rule = None
if use_qwix_quantization:
# get_current_rule has to be called outside of the _gmm_fwd function.
quantization_rule = qwix_rule if qwix_rule else qpl.get_current_rule("gmm")
if quantization_rule and not isinstance(quantization_rule, qwix.QtRule):
raise ValueError("Expect a QtRule for quantized training.")
else:
# Handcraft a rule that matches the AQT's behavior.
if lhs_quantize_dtype or rhs_quantize_dtype:
quantization_rule = qwix.QtRule(
weight_qtype=rhs_quantize_dtype,
weight_calibration_method="absmax",
act_qtype=lhs_quantize_dtype,
act_calibration_method="absmax",
)
gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11, 12))
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
return gmm_fwd_bwd(
lhs,
rhs,
group_sizes,
preferred_element_type,
tiling,
group_offset,
existing_out,
transpose_rhs,
interpret,
quantization_rule,
use_tokamax_backend,
weight_gather_axes,
use_manual_quantization,
)
def _gmm_fwd(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
group_sizes: jnp.ndarray,
preferred_element_type: jnp.dtype = jnp.float32,
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
128,
128,
128,
128,
128,
128,
128,
128,
128,
),
group_offset: jnp.ndarray | None = None,
existing_out: jnp.ndarray | None = None,
transpose_rhs: bool = False,
interpret: bool = False,
quantization_rule: qwix.QtRule | None = None,
use_tokamax_backend: bool = False,
weight_gather_axes: List[Tuple[str, int]] | None = None,
use_manual_quantization: bool = False,
) -> tuple[
jnp.ndarray,
tuple[
jnp.ndarray | qpl.QArray,
jnp.ndarray | qpl.QArray,
jnp.ndarray,
jnp.ndarray | None,
],
]:
"""Forward function for GMM VJP."""
if quantization_rule:
if quantization_rule.act_qtype and not isinstance(lhs, qpl.QArray):
lhs = qpl.quantize(
lhs,
quantization_rule.act_qtype,
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
calibration_method=quantization_rule.act_calibration_method,
)
if quantization_rule.weight_qtype and not isinstance(rhs, qpl.QArray):
if not use_manual_quantization:
rhs = qpl.quantize(
rhs,
quantization_rule.weight_qtype,
# If only considering the fwd pass, we could also enable channelwise
# axes for the group axis, i.e., [0, 1 or 2]. However, this makes the
# bwd pass unable to reuse the scale easily.
channelwise_axes=([] if quantization_rule.disable_channelwise_axes else ([1] if transpose_rhs else [2])),
calibration_method=quantization_rule.weight_calibration_method,
)
else:
rhs = quantizations.manual_quantize(
rhs,
quantization_rule.weight_calibration_method,
quantization_rule.weight_qtype,
)
# QAG is only supported for following conditions
if use_tokamax_backend:
if quantization_rule and quantization_rule.bwd_qtype:
if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray):
if weight_gather_axes:
for axis_name, axis_idx in weight_gather_axes:
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, axis_name, axis=axis_idx, tiled=True)
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
# Handle transpose_rhs manually as ragged_dot assumes (G, K, N)
if transpose_rhs:
rhs = rhs.swapaxes(1, 2)
out = tokamax.ragged_dot(
lhs=lhs,
rhs=rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
implementation="mosaic",
manual_axis_type=jax.sharding.ManualAxisType(
varying=frozenset(["data", "fsdp", "expert"])
) if use_manual_quantization else None,
)
else:
out = backend.gmm(
lhs,
rhs,
group_sizes,
preferred_element_type,
tiling[:3],
group_offset,
existing_out,
transpose_rhs=transpose_rhs,
interpret=interpret,
)
return out, (lhs, rhs, group_sizes, group_offset)
def _gmm_bwd(
lhs_dtype: jax.typing.DTypeLike,
rhs_dtype: jax.typing.DTypeLike,
preferred_element_type: jnp.dtype,
tiling: tuple[int, int, int, int, int, int, int, int, int],
transpose_rhs: bool,
interpret: bool,
quantization_rule: qwix.QtRule | None,
use_tokamax_backend: bool,
weight_gather_axes: List[Tuple[str, int]] | None,
use_manual_quantization: bool,
residual: tuple[
jnp.ndarray | qpl.QArray,
jnp.ndarray | qpl.QArray,
jnp.ndarray,
jnp.ndarray | None,
],
grad: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]:
"""Backward function for throughput GMM VJP."""
del preferred_element_type
lhs, rhs, group_sizes, group_offset = residual
num_actual_groups = rhs.shape[0]
# Jargon used here:
# - lhs: input activation in forward pass, possibly quantized.
# - rhs: weight in forward pass, possibly quantized.
# - dout (or grad): the incoming gradient in the backward pass.
# - dlhs: gradient of the lhs in the backward pass, what we want to compute.
# - drhs: gradient of the rhs in the backward pass, what we want to compute.
# - dlhs_dout: the incoming gradient used to calculate dlhs.
# - drhs_dout: the incoming gradient used to calculate drhs.
# dlhs_dout and drhs_dout can be different when quantization is enabled.
dlhs_dout = grad
drhs_dout = grad
if isinstance(rhs, qpl.QArray): # qvalue: [g, k, n] scale: [1, 1, n]
# Apply rhs.scale to dlhs_dout to avoid dequantizing or requantizing rhs.
# We cannot apply the scale to dlhs because axis n will disappear there.
dlhs_dout *= rhs.scale.astype(grad.dtype).reshape(1, -1) # [1, n]
rhs = rhs.qvalue
if isinstance(lhs, qpl.QArray): # qvalue: [m, k] scale: [m, 1]
# Apply lhs.scale to drhs_dout, as axis m will disappear in drhs.
drhs_dout *= lhs.scale.astype(grad.dtype)
lhs = lhs.qvalue
if quantization_rule and quantization_rule.bwd_qtype:
# Enable backward pass quantization
dlhs_dout = qpl.quantize(
dlhs_dout,
quantization_rule.bwd_qtype,
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [0],
calibration_method=quantization_rule.bwd_calibration_method,
)
drhs_dout = qpl.quantize(
drhs_dout,
quantization_rule.bwd_qtype,
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [1],
calibration_method=quantization_rule.bwd_calibration_method,
)
if use_tokamax_backend:
# Handle transpose_rhs manually
dlhs_rhs = rhs
if not transpose_rhs:
dlhs_rhs = dlhs_rhs.swapaxes(1, 2)
dlhs = tokamax.ragged_dot(
lhs=dlhs_dout,
rhs=dlhs_rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=lhs_dtype,
group_offset=group_offset,
implementation="mosaic",
manual_axis_type=jax.sharding.ManualAxisType(
varying=frozenset(["data", "fsdp", "expert"])
) if use_manual_quantization else None,
)
drhs = tokamax.ragged_dot_general(
lhs=lhs,
rhs=drhs_dout,
group_sizes=group_sizes,
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=rhs_dtype,
group_offset=group_offset,
implementation="mosaic",
manual_axis_type=jax.sharding.ManualAxisType(
varying=frozenset(["expert"]),
unreduced=frozenset(["data", "fsdp"])
) if use_manual_quantization else None,
)
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
# Scatter back in reverse order of gather
for axis_name, axis_idx in reversed(weight_gather_axes):
drhs = jax.lax.psum_scatter(drhs, axis_name, scatter_dimension=axis_idx, tiled=True)
else:
dlhs = backend.gmm(
dlhs_dout,
rhs,
group_sizes,
lhs_dtype,
tiling[3:6],
group_offset,
transpose_rhs=not transpose_rhs,
interpret=interpret,
)
drhs = backend.tgmm(
lhs.swapaxes(0, 1),
drhs_dout,
group_sizes,
rhs_dtype,
tiling[-3:],
group_offset,
num_actual_groups,
interpret=interpret,
)
# NOTE: If the rhs transposition is fused into the forward pass we need to
# return the transpose of the rhs gradient that we calculated above.
#
# TODO(tgale, enriqueps, apaske): Fuse this transposition into the tgmm.
drhs = drhs.swapaxes(1, 2) if transpose_rhs else drhs
return dlhs, drhs, None, None, grad