maxtext.layers.quantizations module

Contents

maxtext.layers.quantizations module#

Quantization library.

class maxtext.layers.quantizations.Quantization[source]#

Bases: object

Base class for quantization configurations

dot_general_cls(mesh_axes=())[source]#

Placeholder for dot_general implementation in subclasses.

Parameters:

mesh_axes (Tuple[str, ...])

einsum(dtype=<class 'jax.numpy.float32'>)[source]#

Placeholder for einsum implementation in subclasses.

Parameters:

dtype (dtype)

class maxtext.layers.quantizations.AqtQuantization(quant_dg, quant_mode=QuantMode.TRAIN, replicate_scale=False)[source]#

Bases: object

Configures AQT quantization github.com/google/aqt.

Parameters:
  • quant_dg (DotGeneral)

  • quant_mode (QuantMode)

  • replicate_scale (bool)

quant_dg: DotGeneral#
quant_mode: QuantMode = 1#
replicate_scale: bool = False#
dot_general_cls(mesh_axes=())[source]#

Returns dot_general configured with aqt params.

Parameters:

mesh_axes (Tuple[str, ...])

einsum(mesh_axes=())[source]#

Returns einsum configured with aqt params.

Parameters:

mesh_axes (Tuple[str, ...])

class maxtext.layers.quantizations.QwixQuantization(act_calibration_method='absmax', weight_calibration_method='absmax', bwd_calibration_method='absmax')[source]#

Bases: object

Configures Qwix quantization github.com/google/qwix, for training only.

Parameters:
  • act_calibration_method (str)

  • weight_calibration_method (str)

  • bwd_calibration_method (str)

quant_mode = 'train'#
act_calibration_method: str = 'absmax'#
weight_calibration_method: str = 'absmax'#
bwd_calibration_method: str = 'absmax'#
dot_general_cls(mesh_axes=())[source]#

Returns Qwix dot_general.

Parameters:

mesh_axes (Tuple[str, ...])

einsum(mesh_axes=())[source]#

Returns Qwix einsum.

Parameters:

mesh_axes (Tuple[str, ...])

class maxtext.layers.quantizations.QwixDotGeneral(config, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

A callable class for Qwix dot_general.

Parameters:
  • config (DotGeneralQtConfig)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

config: DotGeneralQtConfig#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class maxtext.layers.quantizations.QwixEinsum(config, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

A callable class for Qwix einsum.

Parameters:
  • config (DotGeneralQtConfig)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

config: DotGeneralQtConfig#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class maxtext.layers.quantizations.Fp8Quantization[source]#

Bases: Quantization

Configures Fp8 quantization for NVIDIA GPUs

quant_mode = 'train'#
dot_general_cls(mesh_axes=())[source]#

Returns dot_general configured with aqt params.

Parameters:

mesh_axes (Tuple[str, ...])

einsum(dtype=<class 'jax.numpy.float32'>)[source]#

Placeholder for einsum implementation in subclasses.

Parameters:

dtype (dtype)

class maxtext.layers.quantizations.Fp8Einsum(amax_history_length=1024, e4m3_dtype=<class 'jax.numpy.float8_e4m3fn'>, e5m2_dtype=<class 'jax.numpy.float8_e5m2'>, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

An fp8 einsum op.

Parameters:
  • amax_history_length (int)

  • e4m3_dtype (dtype)

  • e5m2_dtype (dtype)

  • dtype (dtype)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

amax_history_length: int = 1024#

size of the amax history.

e4m3_dtype#

e4m3 variants, e.g., e4m3fn, e4m3fnuz.

alias of float8_e4m3fn

e5m2_dtype#

e5m2 variants, e.g., e5m2, e5m2fnuz.

alias of float8_e5m2

dtype#

computation dtype.

alias of float32

setup()[source]#

init with input_amax_history, kernel_amax_history, output_grad_amax_history, input_scale, kernel_scale, output_grad_scale

Return type:

None

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class maxtext.layers.quantizations.NANOOFp8Quantization[source]#

Bases: Quantization

Configures NANOO Fp8 quantization for AMD MI300/MI325 GPUs

quant_mode = 'train'#
dot_general_cls(mesh_axes=())[source]#

Returns dot_general configured with aqt params.

Parameters:

mesh_axes (Tuple[str, ...])

class maxtext.layers.quantizations.ConstantBoundConfig(fwd_lhs_bound: float | None = None, fwd_rhs_bound: float | None = None, dlhs_lhs_bound: float | None = None, dlhs_rhs_bound: float | None = None, drhs_lhs_bound: float | None = None, drhs_rhs_bound: float | None = None)[source]#

Bases: object

Parameters:
  • fwd_lhs_bound (float | None)

  • fwd_rhs_bound (float | None)

  • dlhs_lhs_bound (float | None)

  • dlhs_rhs_bound (float | None)

  • drhs_lhs_bound (float | None)

  • drhs_rhs_bound (float | None)

fwd_lhs_bound: float | None = None#
fwd_rhs_bound: float | None = None#
dlhs_lhs_bound: float | None = None#
dlhs_rhs_bound: float | None = None#
drhs_lhs_bound: float | None = None#
drhs_rhs_bound: float | None = None#
class maxtext.layers.quantizations.PerTensorScales(fwd_lhs: bool = False, fwd_rhs: bool = False, dlhs_lhs: bool = False, dlhs_rhs: bool = False, drhs_lhs: bool = False, drhs_rhs: bool = False)[source]#

Bases: object

Parameters:
  • fwd_lhs (bool)

  • fwd_rhs (bool)

  • dlhs_lhs (bool)

  • dlhs_rhs (bool)

  • drhs_lhs (bool)

  • drhs_rhs (bool)

fwd_lhs: bool = False#
fwd_rhs: bool = False#
dlhs_lhs: bool = False#
dlhs_rhs: bool = False#
drhs_lhs: bool = False#
drhs_rhs: bool = False#
maxtext.layers.quantizations.in_convert_mode(quant)[source]#
maxtext.layers.quantizations.in_serve_mode(quant)[source]#
maxtext.layers.quantizations.get_quant_mode(quant_mode_str='train')[source]#

Set quant mode.

Parameters:

quant_mode_str (str)

maxtext.layers.quantizations.configure_quantization(config, quant_mode_str='train')[source]#

Configure quantization based on user config and quant mode.

Parameters:
  • config (Any)

  • quant_mode_str (str)

maxtext.layers.quantizations.match_aqt_and_unquantized_param(aqt_params, params)[source]#

match aqt and unquantized params

maxtext.layers.quantizations.remove_quantized_params(params, aqt_vars)[source]#

Remove param values with aqt tensors to Null to optimize memory.

maxtext.layers.quantizations.configure_kv_quant(config)[source]#
class maxtext.layers.quantizations.NvidaFp8Provider(rules, *, disable_jit=False)[source]#

Bases: QtProvider

Wraps nn.Fp8DirectDotGeneralOp with Qwix’s provider interface.

Parameters:
  • rules (Sequence[QuantizationRule])

  • disable_jit (bool)

dot_general(*args, **kwargs)[source]#

QT dot_general.

einsum(*args, **kwargs)[source]#

QT einsum.

class maxtext.layers.quantizations.NANOOFp8Provider(rules, *, disable_jit=False)[source]#

Bases: QtProvider

Parameters:
  • rules (Sequence[QuantizationRule])

  • disable_jit (bool)

dot_general(*args, **kwargs)[source]#

QT dot_general.

maxtext.layers.quantizations.get_fp8_full_qwix_rule_w_sparsity(config)[source]#
Parameters:

config (Any)

maxtext.layers.quantizations.get_quantization_rule(config)[source]#
Parameters:

config (Any)

maxtext.layers.quantizations.get_qt_provider(config)[source]#

Get quantization rules based on the config.

maxtext.layers.quantizations.maybe_quantize_model(model, config)[source]#

Quantize the model if quantization is enabled.

maxtext.layers.quantizations.manual_quantize(tensor, calibration_method, dtype=<class 'jax.numpy.float8_e4m3fn'>)[source]#

Manually quantizes a tensor based on a fixed calibration method.

Parameters:
  • tensor – The tensor to quantize.

  • calibration_method – A string specifying the calibration method. Expected format is “fixed,{scale},{max_val}”.

Returns:

A qwix.QArray containing the quantized value and the scale.

Raises:

ValueError – If calibration_method is None or has an unexpected format.

class maxtext.layers.quantizations.TransformerEngineQuantization(config)[source]#

Bases: Quantization

Class for TransformerEngine quantization recipes.

get_block_size()[source]#

Get the block size for quantization for recipes that require blocks.

If there is no block requirement for the current recipe, returns 1.

dot_general_cls(mesh_axes=())[source]#

Placeholder for dot_general implementation in subclasses.

Parameters:

mesh_axes (Tuple[str, ...])

einsum(dtype=<class 'jax.numpy.float32'>)[source]#

Placeholder for einsum implementation in subclasses.

Parameters:

dtype (dtype)