maxtext.layers.quantizations module#
Quantization library.
- class maxtext.layers.quantizations.Quantization[source]#
Bases:
objectBase class for quantization configurations
- class maxtext.layers.quantizations.AqtQuantization(quant_dg, quant_mode=QuantMode.TRAIN, replicate_scale=False)[source]#
Bases:
objectConfigures AQT quantization github.com/google/aqt.
- Parameters:
quant_dg (DotGeneral)
quant_mode (QuantMode)
replicate_scale (bool)
- quant_dg: DotGeneral#
- quant_mode: QuantMode = 1#
- replicate_scale: bool = False#
- class maxtext.layers.quantizations.QwixQuantization(act_calibration_method='absmax', weight_calibration_method='absmax', bwd_calibration_method='absmax')[source]#
Bases:
objectConfigures Qwix quantization github.com/google/qwix, for training only.
- Parameters:
act_calibration_method (str)
weight_calibration_method (str)
bwd_calibration_method (str)
- quant_mode = 'train'#
- act_calibration_method: str = 'absmax'#
- weight_calibration_method: str = 'absmax'#
- bwd_calibration_method: str = 'absmax'#
- class maxtext.layers.quantizations.QwixDotGeneral(config, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleA callable class for Qwix dot_general.
- Parameters:
config (DotGeneralQtConfig)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- config: DotGeneralQtConfig#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class maxtext.layers.quantizations.QwixEinsum(config, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleA callable class for Qwix einsum.
- Parameters:
config (DotGeneralQtConfig)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- config: DotGeneralQtConfig#
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class maxtext.layers.quantizations.Fp8Quantization[source]#
Bases:
QuantizationConfigures Fp8 quantization for NVIDIA GPUs
- quant_mode = 'train'#
- class maxtext.layers.quantizations.Fp8Einsum(amax_history_length=1024, e4m3_dtype=<class 'jax.numpy.float8_e4m3fn'>, e5m2_dtype=<class 'jax.numpy.float8_e5m2'>, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleAn fp8 einsum op.
- Parameters:
amax_history_length (int)
e4m3_dtype (dtype)
e5m2_dtype (dtype)
dtype (dtype)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- amax_history_length: int = 1024#
size of the amax history.
- e4m3_dtype#
e4m3 variants, e.g., e4m3fn, e4m3fnuz.
alias of
float8_e4m3fn
- e5m2_dtype#
e5m2 variants, e.g., e5m2, e5m2fnuz.
alias of
float8_e5m2
- dtype#
computation dtype.
alias of
float32
- setup()[source]#
init with input_amax_history, kernel_amax_history, output_grad_amax_history, input_scale, kernel_scale, output_grad_scale
- Return type:
None
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class maxtext.layers.quantizations.NANOOFp8Quantization[source]#
Bases:
QuantizationConfigures NANOO Fp8 quantization for AMD MI300/MI325 GPUs
- quant_mode = 'train'#
- class maxtext.layers.quantizations.ConstantBoundConfig(fwd_lhs_bound: float | None = None, fwd_rhs_bound: float | None = None, dlhs_lhs_bound: float | None = None, dlhs_rhs_bound: float | None = None, drhs_lhs_bound: float | None = None, drhs_rhs_bound: float | None = None)[source]#
Bases:
object- Parameters:
fwd_lhs_bound (float | None)
fwd_rhs_bound (float | None)
dlhs_lhs_bound (float | None)
dlhs_rhs_bound (float | None)
drhs_lhs_bound (float | None)
drhs_rhs_bound (float | None)
- fwd_lhs_bound: float | None = None#
- fwd_rhs_bound: float | None = None#
- dlhs_lhs_bound: float | None = None#
- dlhs_rhs_bound: float | None = None#
- drhs_lhs_bound: float | None = None#
- drhs_rhs_bound: float | None = None#
- class maxtext.layers.quantizations.PerTensorScales(fwd_lhs: bool = False, fwd_rhs: bool = False, dlhs_lhs: bool = False, dlhs_rhs: bool = False, drhs_lhs: bool = False, drhs_rhs: bool = False)[source]#
Bases:
object- Parameters:
fwd_lhs (bool)
fwd_rhs (bool)
dlhs_lhs (bool)
dlhs_rhs (bool)
drhs_lhs (bool)
drhs_rhs (bool)
- fwd_lhs: bool = False#
- fwd_rhs: bool = False#
- dlhs_lhs: bool = False#
- dlhs_rhs: bool = False#
- drhs_lhs: bool = False#
- drhs_rhs: bool = False#
- maxtext.layers.quantizations.get_quant_mode(quant_mode_str='train')[source]#
Set quant mode.
- Parameters:
quant_mode_str (str)
- maxtext.layers.quantizations.configure_quantization(config, quant_mode_str='train')[source]#
Configure quantization based on user config and quant mode.
- Parameters:
config (Any)
quant_mode_str (str)
- maxtext.layers.quantizations.match_aqt_and_unquantized_param(aqt_params, params)[source]#
match aqt and unquantized params
- maxtext.layers.quantizations.remove_quantized_params(params, aqt_vars)[source]#
Remove param values with aqt tensors to Null to optimize memory.
- class maxtext.layers.quantizations.NvidaFp8Provider(rules, *, disable_jit=False)[source]#
Bases:
QtProviderWraps nn.Fp8DirectDotGeneralOp with Qwix’s provider interface.
- Parameters:
rules (Sequence[QuantizationRule])
disable_jit (bool)
- class maxtext.layers.quantizations.NANOOFp8Provider(rules, *, disable_jit=False)[source]#
Bases:
QtProvider- Parameters:
rules (Sequence[QuantizationRule])
disable_jit (bool)
- maxtext.layers.quantizations.get_fp8_full_qwix_rule_w_sparsity(config)[source]#
- Parameters:
config (Any)
- maxtext.layers.quantizations.get_qt_provider(config)[source]#
Get quantization rules based on the config.
- maxtext.layers.quantizations.maybe_quantize_model(model, config)[source]#
Quantize the model if quantization is enabled.
- maxtext.layers.quantizations.manual_quantize(tensor, calibration_method, dtype=<class 'jax.numpy.float8_e4m3fn'>)[source]#
Manually quantizes a tensor based on a fixed calibration method.
- Parameters:
tensor – The tensor to quantize.
calibration_method – A string specifying the calibration method. Expected format is “fixed,{scale},{max_val}”.
- Returns:
A qwix.QArray containing the quantized value and the scale.
- Raises:
ValueError – If calibration_method is None or has an unexpected format.
- class maxtext.layers.quantizations.TransformerEngineQuantization(config)[source]#
Bases:
QuantizationClass for TransformerEngine quantization recipes.
- get_block_size()[source]#
Get the block size for quantization for recipes that require blocks.
If there is no block requirement for the current recipe, returns 1.