maxtext.kernels.megablox.backend module#
Grouped matrix multiplication kernels for TPU written in Pallas.
- maxtext.kernels.megablox.backend.make_group_metadata(*, group_sizes, m, tm, start_group, num_nonzero_groups, visit_empty_groups=True)[source]#
Create the metadata needed for grouped matmul computation.
- Parameters:
group_sizes (Array) – A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
m (int) – The number of rows in lhs.
tm (int) – The m-dimension tile size being used.
start_group (Array) – The group in group sizes to start computing from. This is particularly useful for when rhs num_groups is sharded.
num_nonzero_groups (int) – Number of groups in group sizes to compute on. Useful in combination with group_offset.
visit_empty_groups (bool) – If True, do not squeeze tiles for empty groups out of the metadata. This is necessary for tgmm, where we at least need to zero the output for each group.
- Returns:
- group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32
dtype. group_offsets[i] indicates the row at which group [i] starts in the lhs matrix and group_offsets[i-1] = m.
- group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
jnp.int32 dtype. group_ids[i] indicates which group grid index ‘i’ will work on.
- m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index ‘i’ will work on.
num_tiles: The number of m-dimension tiles to execute.
- Return type:
tuple of
- maxtext.kernels.megablox.backend.gmm(lhs, rhs, group_sizes, preferred_element_type=<class 'jax.numpy.float32'>, tiling=(128, 128, 128), group_offset=None, existing_out=None, transpose_rhs=False, interpret=False)[source]#
Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group ‘i’.
- Parameters:
lhs (Array | QArray) – A 2d, jnp.ndarray with shape [m, k].
rhs (Array | QArray) – A 3d, jnp.ndarray with shape [num_groups, k, n].
group_sizes (Array) – A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
preferred_element_type (dtype) – jnp.dtype, the element type for the output matrix.
tiling (tuple[int, int, int] | Callable[[int, int, int], tuple[int, int, int] | None] | None) – 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset (Array | None) – The group in group sizes to start computing from. This is particularly useful for when rhs num_groups is sharded.
existing_out (Array | None) – Existing output to write to.
transpose_rhs (bool) – True if the rhs needs to be transposed.
interpret (bool) – Whether or not to run the kernel in interpret mode, helpful for testing and debugging.
- Returns:
A 2d, jnp.ndarray with shape [m, n].
- Return type:
Array
- maxtext.kernels.megablox.backend.tgmm(lhs, rhs, group_sizes, preferred_element_type=<class 'jax.numpy.float32'>, tiling=(128, 128, 128), group_offset=None, num_actual_groups=None, existing_out=None, interpret=False)[source]#
Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
- Parameters:
lhs (Array | QArray) – A 2d, jnp.ndarray with shape [k, m].
rhs (Array | QArray) – A 2d, jnp.ndarray with shape [m, n].
group_sizes (Array) – A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
preferred_element_type (dtype) – jnp.dtype, the element type for the output matrix.
tiling (tuple[int, int, int] | Callable[[int, int, int], tuple[int, int, int] | None] | None) – 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset (Array | None) – The group in group sizes to start computing from. This is particularly useful for when rhs num_groups is sharded.
num_actual_groups (int | None) – For when num_groups is sharded and we should only compute the groups that are local, starting from group_offset.
existing_out (Array | None) – Existing output to write to.
interpret (bool) – Whether or not to run the kernel in interpret mode, helpful for testing and debugging.
- Returns:
A 3d, jnp.ndarray with shape [num_groups, k, n].
- Return type:
Array