maxtext.kernels.gather_reduce_pallas module

maxtext.kernels.gather_reduce_pallas module#

SparseCore gather-reduce kernel implementation using Pallas.

This module contains a Pallas kernel implementation for performing a gather-reduce operation on TPU SparseCore. It groups rows of an operand based on provided indices, sums them up, and scatters the results.

maxtext.kernels.gather_reduce_pallas.sc_gather_reduce(op, idx, topk_weights=None, *, reduce_group_size, single_sc=False, col_chunk_size=3584, row_chunk_size=512, topk_wgt_zero_nan=False)[source]#

Performs a gather-reduce operation on SparseCore.

This kernel groups rows of the operand op based on idx, sums them up, and scatters the results. The gather and add operations are performed in fp32, and the results are written back in bf16.

Equivalent JAX code:

gathered = op[idx, :]
if topk_weights is not None:
  flat_weights = topk_weights.flatten()
  gathered = gathered * flat_weights[:, None].astype(jnp.float32)
gathered = jnp.reshape(gathered, (-1, reduce_group_size, op.shape[1]))
output = jnp.sum(gathered.astype(jnp.float32), axis=1).astype(jnp.bfloat16)
Parameters:
  • op (Array) – The operand matrix [B, K] in f32 or bf16 to gather from and reduce.

  • idx (Array) – The indices [M,] in int32 guiding the gather.

  • topk_weights (Array | None) – Optional weights [M // 128, 128] in bf16 to apply to the gathered rows before reduction.

  • reduce_group_size (int) – The number of gathered rows to sum per output row.

  • single_sc (bool) – Whether to use a single SparseCore.

  • col_chunk_size (int) – The size of column chunks to process.

  • row_chunk_size (int) – The size of row chunks for internal processing. Must be 2 * reduce_group_size.

  • topk_wgt_zero_nan (bool) – If True, treat zero topk_weights as indicators of NaN during multiplication, resulting in zero output.

Returns:

The reduced result as a bf16 matrix [M / reduce_group_size, K].

Return type:

Array