maxtext.kernels.megablox.ops module#
Grouped matrix multiplication operations with custom VJPs.
- maxtext.kernels.megablox.ops.gmm(lhs, rhs, group_sizes, preferred_element_type=<class 'jax.numpy.float32'>, tiling=(128, 128, 128, 128, 128, 128, 128, 128, 128), group_offset=None, existing_out=None, transpose_rhs=False, interpret=False, lhs_quantize_dtype=None, rhs_quantize_dtype=None, use_qwix_quantization=False, use_tokamax_backend=False, weight_gather_axes=None, qwix_rule=None, use_manual_quantization=False)[source]#
Grouped matrix multiplication operation.
- Parameters:
lhs (Array)
rhs (Array)
group_sizes (Array)
preferred_element_type (dtype)
tiling (tuple[int, int, int, int, int, int, int, int, int])
group_offset (Array | None)
existing_out (Array | None)
transpose_rhs (bool)
interpret (bool)
lhs_quantize_dtype (Literal[<class 'jax.numpy.int4'>, <class 'jax.numpy.int8'>] | None)
rhs_quantize_dtype (Literal[<class 'jax.numpy.int4'>, <class 'jax.numpy.int8'>] | None)
use_qwix_quantization (bool)
use_tokamax_backend (bool)
weight_gather_axes (List[Tuple[str, int]] | None)
qwix_rule (QtRule | None)
use_manual_quantization (bool)