maxtext.kernels.megablox.common module#

Common utilities for GMM kernels.

maxtext.kernels.megablox.common.is_tpu()[source]#
Return type:

bool

maxtext.kernels.megablox.common.tpu_kind()[source]#

Query identification string for the currently attached TPU.

Return type:

str

maxtext.kernels.megablox.common.tpu_generation()[source]#

Generation number of the currently attached TPU.

Return type:

int

maxtext.kernels.megablox.common.supports_bfloat16_matmul()[source]#

Does the currently attached CPU support bfloat16 inputs?

Return type:

bool

maxtext.kernels.megablox.common.assert_is_supported_dtype(dtype)[source]#
Parameters:

dtype (dtype)

Return type:

None

maxtext.kernels.megablox.common.select_input_dtype(lhs, rhs)[source]#

A type to which both input should be adapted to before dot product.

Parameters:
  • lhs (Array)

  • rhs (Array)

Return type:

dtype