Source code for maxtext.layers.quantizations

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Quantization library."""

import functools
import json
import qwix.pallas as qpl
import re
from typing import Tuple, Sequence, Callable
from dataclasses import dataclass

from aqt.jax.v2 import config as aqt_config
from aqt.jax.v2 import aqt_tensor
from aqt.jax.v2.flax import aqt_flax
from aqt.jax.v2 import tiled_dot_general
from aqt.jax.v2 import calibration

import qwix
from qwix._src.core import dot_general_qt
from qwix._src.core import sparsity

import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten_with_path, tree_unflatten

from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
import flax.linen as nn

from maxtext.common.common_types import DType, Config
from maxtext.inference.kvcache import KVQuant

# Params used to define mixed precision quantization configs
DEFAULT = "__default__"  # default config
_W_BITS = "w_bits"  # Number of bits used to represent weights
_A_BITS = "a_bits"  # Number of bits used to represent activations
_W_SCALE = "w_scale"  # Clipping scale for weights
_A_SCALE = "a_scale"  # Clipping scale for activations
_TILE_SIZE = "tile_size"  # Tile size for subchannel


[docs] @dataclass class Quantization: """Base class for quantization configurations"""
[docs] def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Placeholder for dot_general implementation in subclasses."""
[docs] def einsum(self, dtype: DType = jnp.float32): """Placeholder for einsum implementation in subclasses."""
def _tiling_fn(lhs, rhs, dimension_numbers, tile_size): """apply tiling function""" del lhs, rhs (lhs_ca, rhs_ca), _ = dimension_numbers ret = tiled_dot_general.Cfg( lhs=tiled_dot_general.TensorTiling(contraction_axes=[], remaining_axes=[]), rhs=tiled_dot_general.TensorTiling(contraction_axes=[], remaining_axes=[]), ) for lhs_idx, rhs_idx in zip(lhs_ca, rhs_ca): ret.lhs.contraction_axes.append(tiled_dot_general.AxisTiling(axis=lhs_idx, tile_size=tile_size, tile_count=None)) ret.rhs.contraction_axes.append(tiled_dot_general.AxisTiling(axis=rhs_idx, tile_size=tile_size, tile_count=None)) return ret def _rhs_axis_metadata_wrapper( x: jnp.ndarray, tile_map, no_sharding_axis: Sequence[int], mesh_axes: Tuple[str, ...], is_tiled: bool, replicate_scale: bool = False, ): """right-hand-side axis metadata wrapper""" if replicate_scale: # Temporarily using the shape to identify the scale. # TODO: remove the replication once the 2d sharding quantization # works as expected. if len(x.shape) == 1: return nn.with_logical_partitioning((lambda: x), tuple(None for _ in mesh_axes))() mesh_axes = list(mesh_axes) if is_tiled: # tile_map is a mapping between original rank and a list of new, tiled rank. if len(mesh_axes) < len(tile_map): mesh_axes = [None] * (len(tile_map) - len(mesh_axes)) + mesh_axes new_mesh_axes = [None] * len(x.shape) for orig_rank, new_rank in tile_map.items(): assert new_rank assert len(new_rank) <= 2 new_mesh_axes[new_rank[-1]] = mesh_axes[orig_rank] mesh_axes = new_mesh_axes if mesh_axes is not None and len(mesh_axes) > 0: for no_shard_idx in no_sharding_axis: if no_shard_idx < len(mesh_axes): mesh_axes[no_shard_idx] = None return nn.with_logical_partitioning((lambda: x), mesh_axes)()
[docs] @dataclass class AqtQuantization: """Configures AQT quantization github.com/google/aqt.""" quant_dg: aqt_config.DotGeneral quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN replicate_scale: bool = False def _get_mixed_precision_cfg(self): """get configuration for mixed precision""" quant_dg = None is_tiled = False tiling_fn = None # pylint: disable=protected-access module_path = "/".join(nn.module._context.module_stack[-1].path) tile_size = -1 for layer_name_re, layer_quant_dg in self.quant_dg.items(): if re.fullmatch(layer_name_re, module_path): quant_dg, tile_size = layer_quant_dg if quant_dg is None: quant_dg, tile_size = self.quant_dg[DEFAULT] if tile_size != -1: is_tiled = True tiling_fn = functools.partial(_tiling_fn, tile_size=tile_size) return quant_dg, is_tiled, tiling_fn def _get_rhs_axis_metadata_wrapper( self, mesh_axes: Tuple[str, ...] = (), is_tiled: bool = False, replicate_scale: bool = False ): if self.quant_mode == aqt_flax.QuantMode.CONVERT: return None return functools.partial( _rhs_axis_metadata_wrapper, mesh_axes=mesh_axes, is_tiled=is_tiled, replicate_scale=replicate_scale )
[docs] def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Returns dot_general configured with aqt params.""" if isinstance(self.quant_dg, dict): quant_dg, is_tiled, tiling_fn = self._get_mixed_precision_cfg() else: quant_dg, is_tiled, tiling_fn = self.quant_dg, False, None rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper( mesh_axes, is_tiled, replicate_scale=self.replicate_scale ) # module_path = "/".join(nn.module._context.module_stack[-1].path) # print(f"quant_dg: {quant_dg}, is_tiled: {is_tiled}, module_path: {module_path}") aqt_dg_cls = functools.partial( aqt_flax.AqtDotGeneral, quant_dg, rhs_quant_mode=self.quant_mode, lhs_freeze_mode=aqt_flax.FreezerMode.NONE, rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, rhs_axis_metadata_wrapper=rhs_axis_metadata_wrapper, use_legacy_freezer=False, tiling_fn=tiling_fn, ) return aqt_dg_cls
[docs] def einsum(self, mesh_axes: Tuple[str, ...] = ()): """Returns einsum configured with aqt params.""" if isinstance(self.quant_dg, dict): quant_dg, is_tiled, tiling_fn = self._get_mixed_precision_cfg() else: quant_dg, is_tiled, tiling_fn = self.quant_dg, False, None rhs_axis_metadata_wrapper = self._get_rhs_axis_metadata_wrapper( mesh_axes, is_tiled, replicate_scale=self.replicate_scale ) aqt_einsum = functools.partial( aqt_flax.AqtEinsum( cfg=quant_dg, rhs_quant_mode=self.quant_mode, lhs_freeze_mode=aqt_flax.FreezerMode.NONE, rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE, rhs_axis_metadata_wrapper=rhs_axis_metadata_wrapper, use_legacy_freezer=False, tiling_fn=tiling_fn, ) ) return aqt_einsum
[docs] @dataclass class QwixQuantization: """Configures Qwix quantization github.com/google/qwix, for training only.""" quant_mode = "train" # needed by external call act_calibration_method: str = "absmax" weight_calibration_method: str = "absmax" bwd_calibration_method: str = "absmax" def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig: """Returns Qwix dot_general config for fp8_full quantization.""" return dot_general_qt.DotGeneralQtConfig( lhs_qtype=jnp.float8_e4m3fn, # activation rhs_qtype=jnp.float8_e4m3fn, # weight dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient drhs_grad_qtype=jnp.float8_e5m2, # weight gradient lhs_calibration_method=self.act_calibration_method, rhs_calibration_method=self.weight_calibration_method, dlhs_grad_calibration_method=self.bwd_calibration_method, drhs_grad_calibration_method=self.bwd_calibration_method, tile_size=None, )
[docs] def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Returns Qwix dot_general.""" return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config())
[docs] def einsum(self, mesh_axes: Tuple[str, ...] = ()): """Returns Qwix einsum.""" return QwixEinsum(config=self._get_fp8_full_qwix_config())
[docs] class QwixDotGeneral(nn.Module): """A callable class for Qwix dot_general.""" config: dot_general_qt.DotGeneralQtConfig @nn.compact def __call__( self, lhs: jax.Array, rhs: jax.Array, dimension_numbers: jax.lax.DotDimensionNumbers, precision: jax.lax.PrecisionLike = None, preferred_element_type: jax.typing.DTypeLike | None = None, *, out_sharding=None, ) -> jax.Array: return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config)
[docs] class QwixEinsum(nn.Module): """A callable class for Qwix einsum.""" config: dot_general_qt.DotGeneralQtConfig @nn.compact def __call__( self, einsum_str: str, *operands: jax.Array, precision: jax.lax.PrecisionLike = None, preferred_element_type: jax.typing.DTypeLike | None = None, _dot_general: Callable[..., jax.Array] | None = None, out_sharding=None, ) -> jax.Array: def custom_dot_general(*args, **kwargs): return dot_general_qt.dot_general_qt(*args[:3], self.config) with jax.disable_jit(): return jnp.einsum( einsum_str, *operands, precision=precision, preferred_element_type=preferred_element_type, _dot_general=custom_dot_general, out_sharding=out_sharding, )
[docs] @dataclass class Fp8Quantization(Quantization): """Configures Fp8 quantization for NVIDIA GPUs""" quant_mode = "train"
[docs] def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Returns dot_general configured with aqt params.""" return nn.Fp8DirectDotGeneralOp
[docs] def einsum(self, dtype: DType = jnp.float32): return _Fp8EinsumWrapper(dtype=dtype)
class _Fp8EinsumWrapper(nn.Module): """Wrapper for nn.Fp8Einsum to handle computation dtype.""" dtype: DType @nn.compact def __call__(self, eqn, lhs, rhs, **kwargs): # nn.Fp8Einsum determines compute dtype from rhs. # We cast rhs to the desired computation dtype. # nn.Fp8Einsum will then cast lhs to the same dtype. rhs = rhs.astype(self.dtype) return nn.Fp8Einsum(name="fp8_einsum")(eqn, lhs, rhs, **kwargs)
[docs] class Fp8Einsum(nn.Module): """An fp8 einsum op.""" #: size of the amax history. amax_history_length: int = 1024 #: e4m3 variants, e.g., e4m3fn, e4m3fnuz. e4m3_dtype: DType = jnp.float8_e4m3fn #: e5m2 variants, e.g., e5m2, e5m2fnuz. e5m2_dtype: DType = jnp.float8_e5m2 #: computation dtype. dtype: DType = jnp.float32
[docs] def setup(self) -> None: """init with input_amax_history, kernel_amax_history, output_grad_amax_history, input_scale, kernel_scale, output_grad_scale""" scale_args = ( flax_initializers.ones_init(), jax.random.PRNGKey(0), (1,), jnp.float32, ) amax_history_args = ( flax_initializers.zeros_init(), jax.random.PRNGKey(0), (self.amax_history_length,), jnp.float32, ) OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" self.input_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "input_amax_history", *amax_history_args) self.kernel_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_amax_history", *amax_history_args) self.output_grad_amax_history = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_amax_history", *amax_history_args) self.input_scale = self.variable(OVERWRITE_WITH_GRADIENT, "input_scale", *scale_args) self.kernel_scale = self.variable(OVERWRITE_WITH_GRADIENT, "kernel_scale", *scale_args) self.output_grad_scale = self.variable(OVERWRITE_WITH_GRADIENT, "output_grad_scale", *scale_args)
def __call__(self, eqn, *args, **kwargs): assert len(args) == 2 x = args[0] k = args[1] comp_dtype = self.dtype k = jnp.asarray(k, comp_dtype) x = jnp.asarray(x, comp_dtype) x_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value) k_qdq = fp8_ops.in_qdq(comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value) y_qdq = jnp.einsum(eqn, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision) y = fp8_ops.out_qdq( comp_dtype, self.e5m2_dtype, y_qdq, self.output_grad_scale.value, self.output_grad_amax_history.value, ) return y
[docs] @dataclass class NANOOFp8Quantization(Quantization): """Configures NANOO Fp8 quantization for AMD MI300/MI325 GPUs""" quant_mode = "train"
[docs] def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Returns dot_general configured with aqt params.""" return nn.NANOOFp8DotGeneralOp
def _get_int8_quant_config(config): drhs_bits = None drhs_accumulator_dtype = None drhs_local_aqt = None if config.quantization_local_shard_count != 0: drhs_bits = 8 drhs_accumulator_dtype = jnp.int32 drhs_local_aqt = aqt_config.LocalAqt(contraction_axis_shard_count=config.quantization_local_shard_count) return aqt_config.config_v3( fwd_bits=8, dlhs_bits=8, drhs_bits=drhs_bits, rng_type="jax.uniform", dlhs_local_aqt=None, drhs_local_aqt=drhs_local_aqt, fwd_accumulator_dtype=jnp.int32, dlhs_accumulator_dtype=jnp.int32, drhs_accumulator_dtype=drhs_accumulator_dtype, )
[docs] @dataclass(frozen=True) class ConstantBoundConfig: fwd_lhs_bound: float | None = None fwd_rhs_bound: float | None = None dlhs_lhs_bound: float | None = None dlhs_rhs_bound: float | None = None drhs_lhs_bound: float | None = None drhs_rhs_bound: float | None = None
def _build_const_scale_config( aqt_dg: aqt_config.DotGeneral, cst_bound_config: ConstantBoundConfig, ) -> aqt_config.DotGeneral: """Build a constant scale config for AQT dot general. Args: aqt_dg: The AQT dot general config. cst_bound_config: The constant bound config. Returns: The AQT dot general config with constant scale config. """ if cst_bound_config.fwd_lhs_bound is not None: aqt_dg.fwd.dg_quantizer.lhs.calibration = functools.partial( calibration.ConstantCalibration, bound=cst_bound_config.fwd_lhs_bound ) if cst_bound_config.fwd_rhs_bound is not None: aqt_dg.fwd.dg_quantizer.rhs.calibration = functools.partial( calibration.ConstantCalibration, bound=cst_bound_config.fwd_rhs_bound ) if cst_bound_config.dlhs_lhs_bound: aqt_dg.dlhs.dg_quantizer.lhs.calibration = functools.partial( calibration.ConstantCalibration, bound=cst_bound_config.dlhs_lhs_bound ) if cst_bound_config.dlhs_rhs_bound is not None: aqt_dg.dlhs.dg_quantizer.rhs.calibration = functools.partial( calibration.ConstantCalibration, bound=cst_bound_config.dlhs_rhs_bound ) if cst_bound_config.drhs_lhs_bound is not None: aqt_dg.drhs.dg_quantizer.lhs.calibration = functools.partial( calibration.ConstantCalibration, bound=cst_bound_config.drhs_lhs_bound ) if cst_bound_config.drhs_rhs_bound is not None: aqt_dg.drhs.dg_quantizer.rhs.calibration = functools.partial( calibration.ConstantCalibration, bound=cst_bound_config.drhs_rhs_bound ) return aqt_dg
[docs] @dataclass class PerTensorScales: fwd_lhs: bool = False fwd_rhs: bool = False dlhs_lhs: bool = False dlhs_rhs: bool = False drhs_lhs: bool = False drhs_rhs: bool = False
def _build_per_tensor_config( aqt_dg: aqt_config.DotGeneral, per_tensor_scales: PerTensorScales, ) -> aqt_config.DotGeneral: """Build a per tensor config for AQT dot general. Args: aqt_dg: The AQT dot general config. per_tensor_scales: The per tensor scales config. Returns: The AQT dot general config with per tensor config. """ if per_tensor_scales.fwd_lhs: aqt_dg.fwd.dg_quantizer.lhs.calib_shared_axes = "per_tensor" if per_tensor_scales.fwd_rhs: aqt_dg.fwd.dg_quantizer.rhs.calib_shared_axes = "per_tensor" if per_tensor_scales.dlhs_lhs: aqt_dg.dlhs.dg_quantizer.lhs.calib_shared_axes = "per_tensor" if per_tensor_scales.dlhs_rhs: aqt_dg.dlhs.dg_quantizer.rhs.calib_shared_axes = "per_tensor" if per_tensor_scales.drhs_lhs: aqt_dg.drhs.dg_quantizer.lhs.calib_shared_axes = "per_tensor" if per_tensor_scales.drhs_rhs: aqt_dg.drhs.dg_quantizer.rhs.calib_shared_axes = "per_tensor" return aqt_dg # fp8 training recipe of dynamic scaling with configurable constant_bound_config for static scaling option def _get_aqt_fp8_default_config(config): """Get aqt for 8-bit floating point quantization configuration.""" aqt_dg = aqt_config.config_v4( fwd_bits="e4m3", dlhs_bits="e5m2", drhs_bits="e5m2", use_dummy_static_bound=False, fwd_accumulator_dtype=jnp.bfloat16, dlhs_accumulator_dtype=jnp.bfloat16, drhs_accumulator_dtype=jnp.bfloat16, dlhs_use_fwd_quant=False, drhs_use_fwd_quant=False, ) constant_bound_config = None if len(config.constant_bound_config) == 6: fwd_lhs_bound, fwd_rhs_bound, dlhs_lhs_bound, dlhs_rhs_bound, drhs_lhs_bound, drhs_rhs_bound = ( config.constant_bound_config ) constant_bound_config = ConstantBoundConfig( fwd_lhs_bound=fwd_lhs_bound, fwd_rhs_bound=fwd_rhs_bound, dlhs_lhs_bound=dlhs_lhs_bound, dlhs_rhs_bound=dlhs_rhs_bound, drhs_lhs_bound=drhs_lhs_bound, drhs_rhs_bound=drhs_rhs_bound, ) aqt_dg = _build_const_scale_config(aqt_dg, constant_bound_config) aqt_config.set_stochastic_rounding( aqt_dg, vjp_lhs_stochastic_rounding=False, vjp_rhs_stochastic_rounding=False, implementation="jax.uniform", ) per_tensor_scales = PerTensorScales( fwd_lhs=True, fwd_rhs=True, dlhs_lhs=True, dlhs_rhs=True, drhs_lhs=True, drhs_rhs=True, ) return _build_per_tensor_config(aqt_dg, per_tensor_scales) def _get_aqt_fp8_quant_config(config): """get aqt for 8-bit floating point quantization configuration""" cfg = aqt_config.config_v4(fwd_bits="e4m3", dlhs_bits=None, drhs_bits=None, fwd_accumulator_dtype=jnp.bfloat16) return cfg def _dot_general_make(quant_cfg): """Create quantization configs for input matrices to a matmul""" lhs_bits = quant_cfg[_A_BITS] lhs_scale = quant_cfg[_A_SCALE] rhs_bits = quant_cfg[_W_BITS] rhs_scale = quant_cfg[_W_SCALE] aqt_dg = aqt_config.dot_general_make(lhs_bits=lhs_bits, rhs_bits=rhs_bits) if lhs_scale < 1.0: aqt_dg.fwd.dg_quantizer.lhs.calibration = functools.partial(calibration.AbsMaxCalibration, scale=lhs_scale) if rhs_scale < 1.0: aqt_dg.fwd.dg_quantizer.rhs.calibration = functools.partial(calibration.AbsMaxCalibration, scale=rhs_scale) return aqt_dg def _get_default_mp_config(default=None): default_config = {_W_BITS: None, _A_BITS: None, _W_SCALE: 1.0, _A_SCALE: 1.0, _TILE_SIZE: -1} if default: default_config.update(default) return default_config def _get_mixed_precision_quant_config(mixed_precision_config): """Set quantization params based on user configuration.""" ret_config = {} default_mp_config = _get_default_mp_config(default=mixed_precision_config.get(DEFAULT, None)) for layer_name_re, layer_quantization_config in mixed_precision_config.items(): # Make a copy of default_mp_config to avoid updating original dict quant_config = default_mp_config.copy() # print(f"Mixed precision config: processing # {layer_name_re} - {layer_quantization_config}, default config - {quant_config}") if layer_name_re != DEFAULT: for k in quant_config: quant_config[k] = layer_quantization_config.get(k, default_mp_config[k]) ret_config[layer_name_re] = [_dot_general_make(quant_config), quant_config["tile_size"]] return ret_config def _get_quant_config(config): """Set quantization params based on user configuration.""" if not config.quantization or config.quantization == "": return None if config.quantization == "int8": return _get_int8_quant_config(config) if config.quantization == "intmp": assert config.quant_cfg_path, "Must specify quant_cfg for mixed precision quantization" with open(config.quant_cfg_path, "rt", encoding="utf8") as config_file: mixed_precision_config = json.load(config_file) return _get_mixed_precision_quant_config(mixed_precision_config) if config.quantization == "fp8": return "fp8" if config.quantization == "nanoo_fp8": return "nanoo_fp8" if config.quantization == "aqt_fp8": return _get_aqt_fp8_quant_config(config) if config.quantization == "aqt_fp8_full": return _get_aqt_fp8_default_config(config) if config.quantization.startswith("te_"): return config.quantization raise ValueError(f"Invalid value configured for quantization {config.quantization}.")
[docs] def in_convert_mode(quant): return quant and (quant.quant_mode == aqt_flax.QuantMode.CONVERT)
[docs] def in_serve_mode(quant): return quant and (quant.quant_mode == aqt_flax.QuantMode.SERVE)
[docs] def get_quant_mode(quant_mode_str: str = "train"): """Set quant mode.""" if quant_mode_str == "train": return aqt_flax.QuantMode.TRAIN elif quant_mode_str == "serve": return aqt_flax.QuantMode.SERVE elif quant_mode_str == "convert": return aqt_flax.QuantMode.CONVERT raise ValueError(f"Invalid quantization mode {quant_mode_str}.")
[docs] def configure_quantization(config: Config, quant_mode_str: str = "train"): """Configure quantization based on user config and quant mode.""" if config.use_batch_split_schedule and config.quantization: # The older version of batch-split that fully uses qwix quantization. if config.quantization == "fp8_full" and not config.use_manual_quantization: return QwixQuantization( weight_calibration_method=config.weight_quantization_calibration_method, act_calibration_method=config.act_quantization_calibration_method, bwd_calibration_method=config.bwd_quantization_calibration_method, ) # The pure JAX version of batch-split that uses manual quantization for dot general. return None if config.use_qwix_quantization: return None quant_cfg = _get_quant_config(config) if quant_cfg: if quant_cfg == "fp8": return Fp8Quantization() elif quant_cfg == "nanoo_fp8": return NANOOFp8Quantization() elif isinstance(quant_cfg, str) and quant_cfg.startswith("te_"): return TransformerEngineQuantization(config) quant_mode = get_quant_mode(quant_mode_str) replicate_scale = config.replicate_quant_scale if config.replicate_quant_scale else False return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode, replicate_scale=replicate_scale) return None
[docs] def match_aqt_and_unquantized_param(aqt_params, params): """match aqt and unquantized params""" aqt_param_flat, aqt_tree_def = jax.tree_util.tree_flatten_with_path( aqt_params, is_leaf=lambda x: isinstance(x, aqt_tensor.QTensor) ) param_tree_flat, _ = jax.tree_util.tree_flatten_with_path(params) aqt_paths = [] # Original path of quantized AQT param path. param_paths = [] for aqt_k, _ in aqt_param_flat: index = None for index, (k, _) in enumerate(param_tree_flat): path_depth = len(k) # every quantized parameter has AQT.. as the leaf node # AqtDotGeneral and AqtEinsum replace leaf node. # Therefore, leaf node should be ignored for path matching # Note: Aqt only operates on kernels so don't pop bias parameters. # Ref: https://github.com/AI-Hypercomputer/maxtext/compare/main...quantize_r1 if k[: path_depth - 1] == aqt_k[: path_depth - 1] and k[-1].key != "bias": aqt_paths.append(aqt_k) param_paths.append(k) break assert index is not None # since the parameter is already added, we can delete it. param_tree_flat.pop(index) return jax.tree_util.tree_unflatten(aqt_tree_def, param_paths)
def _get_aqt_key_paths(aqt_vars, params): """Generate a list of paths which have aqt state""" aqt_to_unquantized_key_path = match_aqt_and_unquantized_param(aqt_vars, params) aqt_key_paths, _ = jax.tree_util.tree_flatten(aqt_to_unquantized_key_path, is_leaf=lambda x: isinstance(x, tuple)) return list(aqt_key_paths)
[docs] def remove_quantized_params(params, aqt_vars): """Remove param values with aqt tensors to Null to optimize memory.""" quantized_param_paths = _get_aqt_key_paths(aqt_vars, params) tree_flat, tree_struct = tree_flatten_with_path(params) for i, (k, v) in enumerate(tree_flat): if k in quantized_param_paths: v = {} tree_flat[i] = v return tree_unflatten(tree_struct, tree_flat)
[docs] def configure_kv_quant(config): return None if not config.quantize_kvcache else KVQuant(config)
[docs] class NvidaFp8Provider(qwix.QtProvider): """Wraps nn.Fp8DirectDotGeneralOp with Qwix's provider interface."""
[docs] def dot_general(self, *args, **kwargs): # Here we only check if the rule is None or not. rule, op_id = self._get_current_rule_and_op_id("dot_general") if rule is None: return jax.lax.dot_general(*args, **kwargs) return nn.Fp8DirectDotGeneralOp(name=op_id)(*args, **kwargs)
[docs] def einsum(self, *args, **kwargs): rule, op_id = self._get_current_rule_and_op_id("einsum") if rule is None: return jnp.einsum(*args, **kwargs) return nn.Fp8Einsum(name=op_id)(*args, **kwargs)
[docs] class NANOOFp8Provider(qwix.QtProvider):
[docs] def dot_general(self, *args, **kwargs): # Here we only check if the rule is None or not. rule, op_id = self._get_current_rule_and_op_id("dot_general") if rule is None: return jax.lax.dot_general(*args, **kwargs) return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)
[docs] def get_fp8_full_qwix_rule_w_sparsity(config: Config): sparsity_rule = None if config.weight_sparsity_n and config.weight_sparsity_m: sparsity_rule = sparsity.SparsityRule( weight_sparsity_n=config.weight_sparsity_n, weight_sparsity_m=config.weight_sparsity_m, weight_sparsity_update_step=config.weight_sparsity_update_step, weight_sparsity_start_step=config.weight_sparsity_start_step, ) return [ qwix.QtRule( module_path="decoder/.*layers.*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, weight_calibration_method=config.weight_quantization_calibration_method, act_calibration_method=config.act_quantization_calibration_method, bwd_calibration_method=config.bwd_quantization_calibration_method, additional_qt_config={"sparsity_rule": sparsity_rule}, op_names=("dot_general", "gmm", "ragged_dot"), ), ]
[docs] def get_quantization_rule(config: Config): match config.quantization: case "int4": return [ qwix.QtRule( module_path="decoder/.*layers.*", weight_qtype=jnp.int4, act_qtype=jnp.int4, bwd_qtype=jnp.int4, bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] case "int8": return [ qwix.QtRule( module_path="decoder/.*layers.*", weight_qtype=jnp.int8, act_qtype=jnp.int8, bwd_qtype=jnp.int8, bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] case "fp8": return [ qwix.QtRule( module_path="decoder/.*layers.*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] case "fp8_full": return get_fp8_full_qwix_rule_w_sparsity(config) case "fp8_gpu": return [ qwix.QtRule( module_path="decoder/.*layers.*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] case "fp8_nanoo": return [ qwix.QtRule( module_path="decoder/.*layers.*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, op_names=("dot_general",), ) ] case "": return None
[docs] def get_qt_provider(config): """Get quantization rules based on the config.""" match config.quantization: case "int8": return qwix.QtProvider(get_quantization_rule(config)) case "int4": return qwix.QtProvider(get_quantization_rule(config)) case "fp8": return qwix.QtProvider(get_quantization_rule(config)) case "fp8_full": return qwix.QtProvider(get_quantization_rule(config)) case "fp8_gpu": return NvidaFp8Provider(get_quantization_rule(config)) case "fp8_nanoo": return NANOOFp8Provider(get_quantization_rule(config)) return None
[docs] def maybe_quantize_model(model, config): """Quantize the model if quantization is enabled.""" # Batch split is not using Qwix's interception feature but manual plumbing if config.use_qwix_quantization and not config.use_batch_split_schedule: quantization_provider = get_qt_provider(config) if quantization_provider: model = qwix.quantize_model(model, quantization_provider) return model
def _cast_reduced_from(arr, reduced_arr): aval = jax.typeof(reduced_arr) # In shard map if aval.sharding.mesh.axis_types[0] == jax.sharding.AxisType.Manual: for axis in aval.mat.reduced: arr = jax.lax.pcast(arr, axis, to="reduced") return arr # Outside shard map return jax.reshard(arr, aval.sharding) def _make_scale_tensor(scale, arr): scale_tensor = jnp.full_like(arr, scale, dtype=jnp.bfloat16) return _cast_reduced_from(scale_tensor, arr) def _get_max_min(target_dtype): if target_dtype in (jnp.int4, jnp.int8): return jnp.iinfo(target_dtype).max, jnp.iinfo(target_dtype).min else: return jnp.finfo(target_dtype).max.astype(jnp.bfloat16), jnp.finfo(target_dtype).min.astype(jnp.bfloat16)
[docs] def manual_quantize(tensor, calibration_method, dtype=jnp.float8_e4m3fn): """Manually quantizes a tensor based on a fixed calibration method. Args: tensor: The tensor to quantize. calibration_method: A string specifying the calibration method. Expected format is "fixed,{scale},{max_val}". Returns: A qwix.QArray containing the quantized value and the scale. Raises: ValueError: If calibration_method is None or has an unexpected format. """ calib_method = calibration_method if calib_method is None: raise ValueError("calibration_method cannot be None for manual quantization") if not calib_method.startswith("fixed"): raise ValueError("Only static weight/activation quantization is supported, but got" f" {calib_method}") parts = calib_method.split(",") if len(parts) != 3: raise ValueError(f"Unexpected format for weight calibration method: {calib_method}") dtype_max, dtype_min = _get_max_min(dtype) max_val = float(parts[2]) scale = max_val / dtype_max scale = jnp.where(scale == 0, 1.0, scale) # scale must be converted to a tensor because grad has reduced axes. scale_tensor = _make_scale_tensor(scale, tensor) min_bound = _make_scale_tensor(dtype_min, tensor) max_bound = _make_scale_tensor(dtype_max, tensor) q_tensor = jnp.clip(tensor / scale_tensor, min_bound, max_bound).astype(dtype) # get scale for QArray scale_shape = [1] * tensor.ndim # It must stay fully replicated for the backward pass and Pallas. scale_tensor_qpl = jnp.full(scale_shape, scale, dtype=tensor.dtype) # wrap in QArray return qpl.QArray(qvalue=q_tensor, scale=scale_tensor_qpl)
[docs] class TransformerEngineQuantization(Quantization): """Class for TransformerEngine quantization recipes.""" def __init__(self, config): """Initialize TransformerEngine quantization.""" self.quant_mode = "train" if not config.quantization.startswith("te_"): raise ValueError(f"Invalid TransformerEngine quantization config: {config.quantization}") self._recipe = TransformerEngineQuantization._get_recipe(config.quantization) def __hash__(self): return hash((self.quant_mode, self._recipe)) def __eq__(self, other): if not isinstance(other, TransformerEngineQuantization): return False return (self.quant_mode, self._recipe) == (other.quant_mode, other._recipe) @staticmethod def _get_recipe(recipe_name: str): """Get the TransformerEngine recipe based on the name.""" from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error RECIPES = { "te_fp8_delayedscaling": recipe.DelayedScaling, "te_fp8_currentscaling": recipe.Float8CurrentScaling, "te_mxfp8": recipe.MXFP8BlockScaling, "te_nvfp4": recipe.NVFP4BlockScaling, # pytype: disable=module-attr "te_nvfp4_no_rht": functools.partial(recipe.NVFP4BlockScaling, disable_rht=True), # pytype: disable=module-attr } if recipe_name not in RECIPES: raise ValueError(f"Invalid TransformerEngine recipe: {recipe_name}") return RECIPES[recipe_name]()
[docs] def get_block_size(self): """Get the block size for quantization for recipes that require blocks. If there is no block requirement for the current recipe, returns 1. """ from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error if isinstance(self._recipe, recipe.MXFP8BlockScaling): return 32 if isinstance(self._recipe, recipe.NVFP4BlockScaling): # pytype: disable=module-attr return 128 # TODO(set this to 16 when unfused RHT is supported) return 1
def _wrap(self, f, name=None): """Wraps the given function `f` to support TransformerEngine quantization. This method does a couple things: 1. Wraps the given function in a context that specifies MaxText's physical mesh axes to TransformerEngine. This ensures our collective operations in TransformerEngine are using the correct axes. 2. Wraps the given function in a Flax linen module. This module does not store any Flax parameters but can store Flax variables for quantizers if required by the recipe. 3. When the wrapper is called, it provides an additional argument to the given function `f`, 'generate_quantizer_set' as the first argument. 'generate_quantizer_set' is a function that can be called to generate a TransformerEngine/JAX quantizer set object used in TransformerEngine/JAX APIs. 'generate_quantizer_set' will generate quantizers based on the recipe of this TransformerEngineQuantizer object. Args: f: The function to wrap. The first argument must be 'generate_quantizer_set'. name: The name of this wrapped operation. If unspecified, will use `f.__name__`. Returns: A Flax linen module that wraps the given function. """ import transformer_engine.jax # pylint: disable=import-outside-toplevel # pytype: disable=import-error fp8_recipe = self._recipe class TEWrapper(transformer_engine.jax.flax.module.TransformerEngineBase): """Wrapper module for TransformerEngine quantization.""" def generate_quantizer_set(self, postfix: str = ""): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( # pytype: disable=wrong-keyword-args postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, quantization_checkpoint_name="quantization", fp8_recipe=fp8_recipe, ) @nn.compact def __call__(self, *args, **kwargs): return f(self.generate_quantizer_set, *args, **kwargs) TEWrapper.__name__ = f"TEWrapper_{name if name else f.__name__}" return TEWrapper
[docs] def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Placeholder for dot_general implementation in subclasses.""" import transformer_engine.jax # pylint: disable=import-outside-toplevel # pytype: disable=import-error def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): contracting_dims, batch_dims = dims assert batch_dims == ((), ()), "Batch dimensions must be empty for TransformerEngine dot." quantizer_set = generate_quantizer_set() return transformer_engine.jax.dense.dense( x, kernel, contracting_dims=contracting_dims, quantizer_set=quantizer_set, ) return self._wrap(te_dot_general, "dot_general")
[docs] def einsum(self, dtype: DType = jnp.float32): """Placeholder for einsum implementation in subclasses.""" # quant.einsum is only required for MoE or for inference with KVCache. raise ValueError("Einsum is not yet supported for TransformerEngine quantization.")