maxtext.kernels.megablox.ops module

Contents

maxtext.kernels.megablox.ops module#

Grouped matrix multiplication operations with custom VJPs.

maxtext.kernels.megablox.ops.gmm(lhs, rhs, group_sizes, preferred_element_type=<class 'jax.numpy.float32'>, tiling=(128, 128, 128, 128, 128, 128, 128, 128, 128), group_offset=None, existing_out=None, transpose_rhs=False, interpret=False, lhs_quantize_dtype=None, rhs_quantize_dtype=None, use_qwix_quantization=False, use_tokamax_backend=False, weight_gather_axes=None, qwix_rule=None, use_manual_quantization=False)[source]#

Grouped matrix multiplication operation.

Parameters:
  • lhs (Array)

  • rhs (Array)

  • group_sizes (Array)

  • preferred_element_type (dtype)

  • tiling (tuple[int, int, int, int, int, int, int, int, int])

  • group_offset (Array | None)

  • existing_out (Array | None)

  • transpose_rhs (bool)

  • interpret (bool)

  • lhs_quantize_dtype (Literal[<class 'jax.numpy.int4'>, <class 'jax.numpy.int8'>] | None)

  • rhs_quantize_dtype (Literal[<class 'jax.numpy.int4'>, <class 'jax.numpy.int8'>] | None)

  • use_qwix_quantization (bool)

  • use_tokamax_backend (bool)

  • weight_gather_axes (List[Tuple[str, int]] | None)

  • qwix_rule (QtRule | None)

  • use_manual_quantization (bool)