# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quantization library."""
import functools
import json
import qwix.pallas as qpl
import re
from typing import Tuple, Sequence, Callable
from dataclasses import dataclass
from aqt.jax.v2 import config as aqt_config
from aqt.jax.v2 import aqt_tensor
from aqt.jax.v2.flax import aqt_flax
from aqt.jax.v2 import tiled_dot_general
from aqt.jax.v2 import calibration
import qwix
from qwix._src.core import dot_general_qt
from qwix._src.core import sparsity
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten_with_path, tree_unflatten
from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
import flax.linen as nn
from maxtext.common.common_types import DType, Config
from maxtext.inference.kvcache import KVQuant
# Params used to define mixed precision quantization configs
DEFAULT = "__default__" # default config
_W_BITS = "w_bits" # Number of bits used to represent weights
_A_BITS = "a_bits" # Number of bits used to represent activations
_W_SCALE = "w_scale" # Clipping scale for weights
_A_SCALE = "a_scale" # Clipping scale for activations
_TILE_SIZE = "tile_size" # Tile size for subchannel
[docs]
@dataclass
class Quantization:
"""Base class for quantization configurations"""
[docs]
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Placeholder for dot_general implementation in subclasses."""
[docs]
def einsum(self, dtype: DType = jnp.float32):
"""Placeholder for einsum implementation in subclasses."""
def _tiling_fn(lhs, rhs, dimension_numbers, tile_size):
"""apply tiling function"""
del lhs, rhs
(lhs_ca, rhs_ca), _ = dimension_numbers
ret = tiled_dot_general.Cfg(
lhs=tiled_dot_general.TensorTiling(contraction_axes=[], remaining_axes=[]),
rhs=tiled_dot_general.TensorTiling(contraction_axes=[], remaining_axes=[]),
)
for lhs_idx, rhs_idx in zip(lhs_ca, rhs_ca):
ret.lhs.contraction_axes.append(tiled_dot_general.AxisTiling(axis=lhs_idx, tile_size=tile_size, tile_count=None))
ret.rhs.contraction_axes.append(tiled_dot_general.AxisTiling(axis=rhs_idx, tile_size=tile_size, tile_count=None))
return ret
def _rhs_axis_metadata_wrapper(
x: jnp.ndarray,
tile_map,
no_sharding_axis: Sequence[int],
mesh_axes: Tuple[str, ...],
is_tiled: bool,
replicate_scale: bool = False,
):
"""right-hand-side axis metadata wrapper"""
if replicate_scale:
# Temporarily using the shape to identify the scale.
# TODO: remove the replication once the 2d sharding quantization
# works as expected.
if len(x.shape) == 1:
return nn.with_logical_partitioning((lambda: x), tuple(None for _ in mesh_axes))()
mesh_axes = list(mesh_axes)
if is_tiled:
# tile_map is a mapping between original rank and a list of new, tiled rank.
if len(mesh_axes) < len(tile_map):
mesh_axes = [None] * (len(tile_map) - len(mesh_axes)) + mesh_axes
new_mesh_axes = [None] * len(x.shape)
for orig_rank, new_rank in tile_map.items():
assert new_rank
assert len(new_rank) <= 2
new_mesh_axes[new_rank[-1]] = mesh_axes[orig_rank]
mesh_axes = new_mesh_axes
if mesh_axes is not None and len(mesh_axes) > 0:
for no_shard_idx in no_sharding_axis:
if no_shard_idx < len(mesh_axes):
mesh_axes[no_shard_idx] = None
return nn.with_logical_partitioning((lambda: x), mesh_axes)()
[docs]
@dataclass
class AqtQuantization:
"""Configures AQT quantization github.com/google/aqt."""
quant_dg: aqt_config.DotGeneral
quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN
replicate_scale: bool = False
def _get_mixed_precision_cfg(self):
"""get configuration for mixed precision"""
quant_dg = None
is_tiled = False
tiling_fn = None
# pylint: disable=protected-access
module_path = "/".join(nn.module._context.module_stack[-1].path)
tile_size = -1
for layer_name_re, layer_quant_dg in self.quant_dg.items():
if re.fullmatch(layer_name_re, module_path):
quant_dg, tile_size = layer_quant_dg
if quant_dg is None:
quant_dg, tile_size = self.quant_dg[DEFAULT]
if tile_size != -1:
is_tiled = True
tiling_fn = functools.partial(_tiling_fn, tile_size=tile_size)
return quant_dg, is_tiled, tiling_fn
def _get_rhs_axis_metadata_wrapper(
self, mesh_axes: Tuple[str, ...] = (), is_tiled: bool = False, replicate_scale: bool = False
):
if self.quant_mode == aqt_flax.QuantMode.CONVERT:
return None
return functools.partial(
_rhs_axis_metadata_wrapper, mesh_axes=mesh_axes, is_tiled=is_tiled, replicate_scale=replicate_scale
)
[docs]
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
if isinstance(self.quant_dg, dict):
quant_dg, is_tiled, tiling_fn = self._get_mixed_precision_cfg()
else:
quant_dg, is_tiled, tiling_fn = self.quant_dg, False, None
rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper(
mesh_axes, is_tiled, replicate_scale=self.replicate_scale
)
# module_path = "/".join(nn.module._context.module_stack[-1].path)
# print(f"quant_dg: {quant_dg}, is_tiled: {is_tiled}, module_path: {module_path}")
aqt_dg_cls = functools.partial(
aqt_flax.AqtDotGeneral,
quant_dg,
rhs_quant_mode=self.quant_mode,
lhs_freeze_mode=aqt_flax.FreezerMode.NONE,
rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE,
rhs_axis_metadata_wrapper=rhs_axis_metadata_wrapper,
use_legacy_freezer=False,
tiling_fn=tiling_fn,
)
return aqt_dg_cls
[docs]
def einsum(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns einsum configured with aqt params."""
if isinstance(self.quant_dg, dict):
quant_dg, is_tiled, tiling_fn = self._get_mixed_precision_cfg()
else:
quant_dg, is_tiled, tiling_fn = self.quant_dg, False, None
rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper(
mesh_axes, is_tiled, replicate_scale=self.replicate_scale
)
aqt_einsum = functools.partial(
aqt_flax.AqtEinsum(
cfg=quant_dg,
rhs_quant_mode=self.quant_mode,
lhs_freeze_mode=aqt_flax.FreezerMode.NONE,
rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE,
rhs_axis_metadata_wrapper=rhs_axis_metadata_wrapper,
use_legacy_freezer=False,
tiling_fn=tiling_fn,
)
)
return aqt_einsum
[docs]
@dataclass
class QwixQuantization:
"""Configures Qwix quantization github.com/google/qwix, for training only."""
quant_mode = "train" # needed by external call
act_calibration_method: str = "absmax"
weight_calibration_method: str = "absmax"
bwd_calibration_method: str = "absmax"
def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig:
"""Returns Qwix dot_general config for fp8_full quantization."""
return dot_general_qt.DotGeneralQtConfig(
lhs_qtype=jnp.float8_e4m3fn, # activation
rhs_qtype=jnp.float8_e4m3fn, # weight
dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient
drhs_grad_qtype=jnp.float8_e5m2, # weight gradient
lhs_calibration_method=self.act_calibration_method,
rhs_calibration_method=self.weight_calibration_method,
dlhs_grad_calibration_method=self.bwd_calibration_method,
drhs_grad_calibration_method=self.bwd_calibration_method,
tile_size=None,
)
[docs]
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns Qwix dot_general."""
return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config())
[docs]
def einsum(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns Qwix einsum."""
return QwixEinsum(config=self._get_fp8_full_qwix_config())
[docs]
class QwixDotGeneral(nn.Module):
"""A callable class for Qwix dot_general."""
config: dot_general_qt.DotGeneralQtConfig
@nn.compact
def __call__(
self,
lhs: jax.Array,
rhs: jax.Array,
dimension_numbers: jax.lax.DotDimensionNumbers,
precision: jax.lax.PrecisionLike = None,
preferred_element_type: jax.typing.DTypeLike | None = None,
*,
out_sharding=None,
) -> jax.Array:
return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config)
[docs]
class QwixEinsum(nn.Module):
"""A callable class for Qwix einsum."""
config: dot_general_qt.DotGeneralQtConfig
@nn.compact
def __call__(
self,
einsum_str: str,
*operands: jax.Array,
precision: jax.lax.PrecisionLike = None,
preferred_element_type: jax.typing.DTypeLike | None = None,
_dot_general: Callable[..., jax.Array] | None = None,
out_sharding=None,
) -> jax.Array:
def custom_dot_general(*args, **kwargs):
return dot_general_qt.dot_general_qt(*args[:3], self.config)
with jax.disable_jit():
return jnp.einsum(
einsum_str,
*operands,
precision=precision,
preferred_element_type=preferred_element_type,
_dot_general=custom_dot_general,
out_sharding=out_sharding,
)
[docs]
@dataclass
class Fp8Quantization(Quantization):
"""Configures Fp8 quantization for NVIDIA GPUs"""
quant_mode = "train"
[docs]
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.Fp8DirectDotGeneralOp
[docs]
def einsum(self, dtype: DType = jnp.float32):
return _Fp8EinsumWrapper(dtype=dtype)
class _Fp8EinsumWrapper(nn.Module):
"""Wrapper for nn.Fp8Einsum to handle computation dtype."""
dtype: DType
@nn.compact
def __call__(self, eqn, lhs, rhs, **kwargs):
# nn.Fp8Einsum determines compute dtype from rhs.
# We cast rhs to the desired computation dtype.
# nn.Fp8Einsum will then cast lhs to the same dtype.
rhs = rhs.astype(self.dtype)
return nn.Fp8Einsum(name="fp8_einsum")(eqn, lhs, rhs, **kwargs)
[docs]
class Fp8Einsum(nn.Module):
"""An fp8 einsum op."""
#: size of the amax history.
amax_history_length: int = 1024
#: e4m3 variants, e.g., e4m3fn, e4m3fnuz.
e4m3_dtype: DType = jnp.float8_e4m3fn
#: e5m2 variants, e.g., e5m2, e5m2fnuz.
e5m2_dtype: DType = jnp.float8_e5m2
#: computation dtype.
dtype: DType = jnp.float32
[docs]
def setup(self) -> None:
"""init with input_amax_history, kernel_amax_history, output_grad_amax_history,
input_scale, kernel_scale, output_grad_scale"""
scale_args = (
flax_initializers.ones_init(),
jax.random.PRNGKey(0),
(1,),
jnp.float32,
)
amax_history_args = (
flax_initializers.zeros_init(),
jax.random.PRNGKey(0),
(self.amax_history_length,),
jnp.float32,
)
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
self.input_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "input_amax_history", *amax_history_args)
self.kernel_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_amax_history", *amax_history_args)
self.output_grad_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_amax_history", *amax_history_args)
self.input_scale = self.variable(OVERWRITE_WITH_GRADIENT, "input_scale", *scale_args)
self.kernel_scale = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_scale", *scale_args)
self.output_grad_scale = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_scale", *scale_args)
def __call__(self, eqn, *args, **kwargs):
assert len(args) == 2
x = args[0]
k = args[1]
comp_dtype = self.dtype
k = jnp.asarray(k, comp_dtype)
x = jnp.asarray(x, comp_dtype)
x_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value)
k_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value)
y_qdq = jnp.einsum(eqn, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision)
y = fp8_ops.out_qdq(
comp_dtype,
self.e5m2_dtype,
y_qdq,
self.output_grad_scale.value,
self.output_grad_amax_history.value,
)
return y
[docs]
@dataclass
class NANOOFp8Quantization(Quantization):
"""Configures NANOO Fp8 quantization for AMD MI300/MI325 GPUs"""
quant_mode = "train"
[docs]
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.NANOOFp8DotGeneralOp
def _get_int8_quant_config(config):
drhs_bits = None
drhs_accumulator_dtype = None
drhs_local_aqt = None
if config.quantization_local_shard_count != 0:
drhs_bits = 8
drhs_accumulator_dtype = jnp.int32
drhs_local_aqt = aqt_config.LocalAqt(contraction_axis_shard_count=config.quantization_local_shard_count)
return aqt_config.config_v3(
fwd_bits=8,
dlhs_bits=8,
drhs_bits=drhs_bits,
rng_type="jax.uniform",
dlhs_local_aqt=None,
drhs_local_aqt=drhs_local_aqt,
fwd_accumulator_dtype=jnp.int32,
dlhs_accumulator_dtype=jnp.int32,
drhs_accumulator_dtype=drhs_accumulator_dtype,
)
[docs]
@dataclass(frozen=True)
class ConstantBoundConfig:
fwd_lhs_bound: float | None = None
fwd_rhs_bound: float | None = None
dlhs_lhs_bound: float | None = None
dlhs_rhs_bound: float | None = None
drhs_lhs_bound: float | None = None
drhs_rhs_bound: float | None = None
def _build_const_scale_config(
aqt_dg: aqt_config.DotGeneral,
cst_bound_config: ConstantBoundConfig,
) -> aqt_config.DotGeneral:
"""Build a constant scale config for AQT dot general.
Args:
aqt_dg: The AQT dot general config.
cst_bound_config: The constant bound config.
Returns:
The AQT dot general config with constant scale config.
"""
if cst_bound_config.fwd_lhs_bound is not None:
aqt_dg.fwd.dg_quantizer.lhs.calibration = functools.partial(
calibration.ConstantCalibration, bound=cst_bound_config.fwd_lhs_bound
)
if cst_bound_config.fwd_rhs_bound is not None:
aqt_dg.fwd.dg_quantizer.rhs.calibration = functools.partial(
calibration.ConstantCalibration, bound=cst_bound_config.fwd_rhs_bound
)
if cst_bound_config.dlhs_lhs_bound:
aqt_dg.dlhs.dg_quantizer.lhs.calibration = functools.partial(
calibration.ConstantCalibration, bound=cst_bound_config.dlhs_lhs_bound
)
if cst_bound_config.dlhs_rhs_bound is not None:
aqt_dg.dlhs.dg_quantizer.rhs.calibration = functools.partial(
calibration.ConstantCalibration, bound=cst_bound_config.dlhs_rhs_bound
)
if cst_bound_config.drhs_lhs_bound is not None:
aqt_dg.drhs.dg_quantizer.lhs.calibration = functools.partial(
calibration.ConstantCalibration, bound=cst_bound_config.drhs_lhs_bound
)
if cst_bound_config.drhs_rhs_bound is not None:
aqt_dg.drhs.dg_quantizer.rhs.calibration = functools.partial(
calibration.ConstantCalibration, bound=cst_bound_config.drhs_rhs_bound
)
return aqt_dg
[docs]
@dataclass
class PerTensorScales:
fwd_lhs: bool = False
fwd_rhs: bool = False
dlhs_lhs: bool = False
dlhs_rhs: bool = False
drhs_lhs: bool = False
drhs_rhs: bool = False
def _build_per_tensor_config(
aqt_dg: aqt_config.DotGeneral,
per_tensor_scales: PerTensorScales,
) -> aqt_config.DotGeneral:
"""Build a per tensor config for AQT dot general.
Args:
aqt_dg: The AQT dot general config.
per_tensor_scales: The per tensor scales config.
Returns:
The AQT dot general config with per tensor config.
"""
if per_tensor_scales.fwd_lhs:
aqt_dg.fwd.dg_quantizer.lhs.calib_shared_axes = "per_tensor"
if per_tensor_scales.fwd_rhs:
aqt_dg.fwd.dg_quantizer.rhs.calib_shared_axes = "per_tensor"
if per_tensor_scales.dlhs_lhs:
aqt_dg.dlhs.dg_quantizer.lhs.calib_shared_axes = "per_tensor"
if per_tensor_scales.dlhs_rhs:
aqt_dg.dlhs.dg_quantizer.rhs.calib_shared_axes = "per_tensor"
if per_tensor_scales.drhs_lhs:
aqt_dg.drhs.dg_quantizer.lhs.calib_shared_axes = "per_tensor"
if per_tensor_scales.drhs_rhs:
aqt_dg.drhs.dg_quantizer.rhs.calib_shared_axes = "per_tensor"
return aqt_dg
# fp8 training recipe of dynamic scaling with configurable constant_bound_config for static scaling option
def _get_aqt_fp8_default_config(config):
"""Get aqt for 8-bit floating point quantization configuration."""
aqt_dg = aqt_config.config_v4(
fwd_bits="e4m3",
dlhs_bits="e5m2",
drhs_bits="e5m2",
use_dummy_static_bound=False,
fwd_accumulator_dtype=jnp.bfloat16,
dlhs_accumulator_dtype=jnp.bfloat16,
drhs_accumulator_dtype=jnp.bfloat16,
dlhs_use_fwd_quant=False,
drhs_use_fwd_quant=False,
)
constant_bound_config = None
if len(config.constant_bound_config) == 6:
fwd_lhs_bound, fwd_rhs_bound, dlhs_lhs_bound, dlhs_rhs_bound, drhs_lhs_bound, drhs_rhs_bound = (
config.constant_bound_config
)
constant_bound_config = ConstantBoundConfig(
fwd_lhs_bound=fwd_lhs_bound,
fwd_rhs_bound=fwd_rhs_bound,
dlhs_lhs_bound=dlhs_lhs_bound,
dlhs_rhs_bound=dlhs_rhs_bound,
drhs_lhs_bound=drhs_lhs_bound,
drhs_rhs_bound=drhs_rhs_bound,
)
aqt_dg = _build_const_scale_config(aqt_dg, constant_bound_config)
aqt_config.set_stochastic_rounding(
aqt_dg,
vjp_lhs_stochastic_rounding=False,
vjp_rhs_stochastic_rounding=False,
implementation="jax.uniform",
)
per_tensor_scales = PerTensorScales(
fwd_lhs=True,
fwd_rhs=True,
dlhs_lhs=True,
dlhs_rhs=True,
drhs_lhs=True,
drhs_rhs=True,
)
return _build_per_tensor_config(aqt_dg, per_tensor_scales)
def _get_aqt_fp8_quant_config(config):
"""get aqt for 8-bit floating point quantization configuration"""
cfg = aqt_config.config_v4(fwd_bits="e4m3", dlhs_bits=None, drhs_bits=None, fwd_accumulator_dtype=jnp.bfloat16)
return cfg
def _dot_general_make(quant_cfg):
"""Create quantization configs for input matrices to a matmul"""
lhs_bits = quant_cfg[_A_BITS]
lhs_scale = quant_cfg[_A_SCALE]
rhs_bits = quant_cfg[_W_BITS]
rhs_scale = quant_cfg[_W_SCALE]
aqt_dg = aqt_config.dot_general_make(lhs_bits=lhs_bits, rhs_bits=rhs_bits)
if lhs_scale < 1.0:
aqt_dg.fwd.dg_quantizer.lhs.calibration = functools.partial(calibration.AbsMaxCalibration, scale=lhs_scale)
if rhs_scale < 1.0:
aqt_dg.fwd.dg_quantizer.rhs.calibration = functools.partial(calibration.AbsMaxCalibration, scale=rhs_scale)
return aqt_dg
def _get_default_mp_config(default=None):
default_config = {_W_BITS: None, _A_BITS: None, _W_SCALE: 1.0, _A_SCALE: 1.0, _TILE_SIZE: -1}
if default:
default_config.update(default)
return default_config
def _get_mixed_precision_quant_config(mixed_precision_config):
"""Set quantization params based on user configuration."""
ret_config = {}
default_mp_config = _get_default_mp_config(default=mixed_precision_config.get(DEFAULT, None))
for layer_name_re, layer_quantization_config in mixed_precision_config.items():
# Make a copy of default_mp_config to avoid updating original dict
quant_config = default_mp_config.copy()
# print(f"Mixed precision config: processing
# {layer_name_re} - {layer_quantization_config}, default config - {quant_config}")
if layer_name_re != DEFAULT:
for k in quant_config:
quant_config[k] = layer_quantization_config.get(k, default_mp_config[k])
ret_config[layer_name_re] = [_dot_general_make(quant_config), quant_config["tile_size"]]
return ret_config
def _get_quant_config(config):
"""Set quantization params based on user configuration."""
if not config.quantization or config.quantization == "":
return None
if config.quantization == "int8":
return _get_int8_quant_config(config)
if config.quantization == "intmp":
assert config.quant_cfg_path, "Must specify quant_cfg for mixed precision quantization"
with open(config.quant_cfg_path, "rt", encoding="utf8") as config_file:
mixed_precision_config = json.load(config_file)
return _get_mixed_precision_quant_config(mixed_precision_config)
if config.quantization == "fp8":
return "fp8"
if config.quantization == "nanoo_fp8":
return "nanoo_fp8"
if config.quantization == "aqt_fp8":
return _get_aqt_fp8_quant_config(config)
if config.quantization == "aqt_fp8_full":
return _get_aqt_fp8_default_config(config)
if config.quantization.startswith("te_"):
return config.quantization
raise ValueError(f"Invalid value configured for quantization {config.quantization}.")
[docs]
def in_convert_mode(quant):
return quant and (quant.quant_mode == aqt_flax.QuantMode.CONVERT)
[docs]
def in_serve_mode(quant):
return quant and (quant.quant_mode == aqt_flax.QuantMode.SERVE)
[docs]
def get_quant_mode(quant_mode_str: str = "train"):
"""Set quant mode."""
if quant_mode_str == "train":
return aqt_flax.QuantMode.TRAIN
elif quant_mode_str == "serve":
return aqt_flax.QuantMode.SERVE
elif quant_mode_str == "convert":
return aqt_flax.QuantMode.CONVERT
raise ValueError(f"Invalid quantization mode {quant_mode_str}.")
[docs]
def match_aqt_and_unquantized_param(aqt_params, params):
"""match aqt and unquantized params"""
aqt_param_flat, aqt_tree_def = jax.tree_util.tree_flatten_with_path(
aqt_params, is_leaf=lambda x: isinstance(x, aqt_tensor.QTensor)
)
param_tree_flat, _ = jax.tree_util.tree_flatten_with_path(params)
aqt_paths = []
# Original path of quantized AQT param path.
param_paths = []
for aqt_k, _ in aqt_param_flat:
index = None
for index, (k, _) in enumerate(param_tree_flat):
path_depth = len(k)
# every quantized parameter has AQT.. as the leaf node
# AqtDotGeneral and AqtEinsum replace leaf node.
# Therefore, leaf node should be ignored for path matching
# Note: Aqt only operates on kernels so don't pop bias parameters.
# Ref: https://github.com/AI-Hypercomputer/maxtext/compare/main...quantize_r1
if k[: path_depth - 1] == aqt_k[: path_depth - 1] and k[-1].key != "bias":
aqt_paths.append(aqt_k)
param_paths.append(k)
break
assert index is not None
# since the parameter is already added, we can delete it.
param_tree_flat.pop(index)
return jax.tree_util.tree_unflatten(aqt_tree_def, param_paths)
def _get_aqt_key_paths(aqt_vars, params):
"""Generate a list of paths which have aqt state"""
aqt_to_unquantized_key_path = match_aqt_and_unquantized_param(aqt_vars, params)
aqt_key_paths, _ = jax.tree_util.tree_flatten(aqt_to_unquantized_key_path, is_leaf=lambda x: isinstance(x, tuple))
return list(aqt_key_paths)
[docs]
def remove_quantized_params(params, aqt_vars):
"""Remove param values with aqt tensors to Null to optimize memory."""
quantized_param_paths = _get_aqt_key_paths(aqt_vars, params)
tree_flat, tree_struct = tree_flatten_with_path(params)
for i, (k, v) in enumerate(tree_flat):
if k in quantized_param_paths:
v = {}
tree_flat[i] = v
return tree_unflatten(tree_struct, tree_flat)
[docs]
class NvidaFp8Provider(qwix.QtProvider):
"""Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface."""
[docs]
def dot_general(self, *args, **kwargs):
# Here we only check if the rule is None or not.
rule, op_id = self._get_current_rule_and_op_id("dot_general")
if rule is None:
return jax.lax.dot_general(*args, **kwargs)
return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs)
[docs]
def einsum(self, *args, **kwargs):
rule, op_id = self._get_current_rule_and_op_id("einsum")
if rule is None:
return jnp.einsum(*args, **kwargs)
return nn.Fp8Einsum(name=op_id)(*args, **kwargs)
[docs]
class NANOOFp8Provider(qwix.QtProvider):
[docs]
def dot_general(self, *args, **kwargs):
# Here we only check if the rule is None or not.
rule, op_id = self._get_current_rule_and_op_id("dot_general")
if rule is None:
return jax.lax.dot_general(*args, **kwargs)
return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)
[docs]
def get_fp8_full_qwix_rule_w_sparsity(config: Config):
sparsity_rule = None
if config.weight_sparsity_n and config.weight_sparsity_m:
sparsity_rule = sparsity.SparsityRule(
weight_sparsity_n=config.weight_sparsity_n,
weight_sparsity_m=config.weight_sparsity_m,
weight_sparsity_update_step=config.weight_sparsity_update_step,
weight_sparsity_start_step=config.weight_sparsity_start_step,
)
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
weight_calibration_method=config.weight_quantization_calibration_method,
act_calibration_method=config.act_quantization_calibration_method,
bwd_calibration_method=config.bwd_quantization_calibration_method,
additional_qt_config={"sparsity_rule": sparsity_rule},
op_names=("dot_general", "gmm", "ragged_dot"),
),
]
[docs]
def get_quantization_rule(config: Config):
match config.quantization:
case "int4":
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.int4,
act_qtype=jnp.int4,
bwd_qtype=jnp.int4,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "int8":
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.int8,
act_qtype=jnp.int8,
bwd_qtype=jnp.int8,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "fp8":
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "fp8_full":
return get_fp8_full_qwix_rule_w_sparsity(config)
case "fp8_gpu":
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "fp8_nanoo":
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "":
return None
[docs]
def get_qt_provider(config):
"""Get quantization rules based on the config."""
match config.quantization:
case "int8":
return qwix.QtProvider(get_quantization_rule(config))
case "int4":
return qwix.QtProvider(get_quantization_rule(config))
case "fp8":
return qwix.QtProvider(get_quantization_rule(config))
case "fp8_full":
return qwix.QtProvider(get_quantization_rule(config))
case "fp8_gpu":
return NvidaFp8Provider(get_quantization_rule(config))
case "fp8_nanoo":
return NANOOFp8Provider(get_quantization_rule(config))
return None
[docs]
def maybe_quantize_model(model, config):
"""Quantize the model if quantization is enabled."""
# Batch split is not using Qwix's interception feature but manual plumbing
if config.use_qwix_quantization and not config.use_batch_split_schedule:
quantization_provider = get_qt_provider(config)
if quantization_provider:
model = qwix.quantize_model(model, quantization_provider)
return model
def _cast_reduced_from(arr, reduced_arr):
aval = jax.typeof(reduced_arr)
# In shard map
if aval.sharding.mesh.axis_types[0] == jax.sharding.AxisType.Manual:
for axis in aval.mat.reduced:
arr = jax.lax.pcast(arr, axis, to="reduced")
return arr
# Outside shard map
return jax.reshard(arr, aval.sharding)
def _make_scale_tensor(scale, arr):
scale_tensor = jnp.full_like(arr, scale, dtype=jnp.bfloat16)
return _cast_reduced_from(scale_tensor, arr)
def _get_max_min(target_dtype):
if target_dtype in (jnp.int4, jnp.int8):
return jnp.iinfo(target_dtype).max, jnp.iinfo(target_dtype).min
else:
return jnp.finfo(target_dtype).max.astype(jnp.bfloat16), jnp.finfo(target_dtype).min.astype(jnp.bfloat16)
[docs]
def manual_quantize(tensor, calibration_method, dtype=jnp.float8_e4m3fn):
"""Manually quantizes a tensor based on a fixed calibration method.
Args:
tensor: The tensor to quantize.
calibration_method: A string specifying the calibration method. Expected
format is "fixed,{scale},{max_val}".
Returns:
A qwix.QArray containing the quantized value and the scale.
Raises:
ValueError: If calibration_method is None or has an unexpected format.
"""
calib_method = calibration_method
if calib_method is None:
raise ValueError("calibration_method cannot be None for manual quantization")
if not calib_method.startswith("fixed"):
raise ValueError("Only static weight/activation quantization is supported, but got" f" {calib_method}")
parts = calib_method.split(",")
if len(parts) != 3:
raise ValueError(f"Unexpected format for weight calibration method: {calib_method}")
dtype_max, dtype_min = _get_max_min(dtype)
max_val = float(parts[2])
scale = max_val / dtype_max
scale = jnp.where(scale == 0, 1.0, scale)
# scale must be converted to a tensor because grad has reduced axes.
scale_tensor = _make_scale_tensor(scale, tensor)
min_bound = _make_scale_tensor(dtype_min, tensor)
max_bound = _make_scale_tensor(dtype_max, tensor)
q_tensor = jnp.clip(tensor / scale_tensor, min_bound, max_bound).astype(dtype)
# get scale for QArray
scale_shape = [1] * tensor.ndim
# It must stay fully replicated for the backward pass and Pallas.
scale_tensor_qpl = jnp.full(scale_shape, scale, dtype=tensor.dtype)
# wrap in QArray
return qpl.QArray(qvalue=q_tensor, scale=scale_tensor_qpl)