maxtext.kernels.gather_reduce_sc module#
SparseCore gather-reduce kernel implementation.
This module contains a kernel implementation for performing a gather-reduce operation on TPU SparseCore. It groups rows of an operand based on provided indices, sums them up, and scatters the results.
- class maxtext.kernels.gather_reduce_sc.VectorTypeHelper(element_type_fn)[source]#
Bases:
objectHelper to create VectorType with a specific element type.
- maxtext.kernels.gather_reduce_sc.sc_gather_reduce(op, idx, topk_weights=None, *, reduce_group_size, single_sc=False, col_chunk_size=3584, row_chunk_size=16, loop_unroll_factor_1=2, loop_unroll_factor_2=2, loop_unroll_factor_3=8, loop_parallel_access_1=True, loop_parallel_access_2=False, loop_parallel_access_3=False, topk_wgt_zero_nan=False)[source]#
Performs a gather-reduce operation on SparseCore.
This kernel groups rows of the operand op based on idx, sums them up, and scatters the results. The gather and add operations are performed in fp32, and the results are written back in bf16.
Equivalent jax numpy code: ```
gathered = op[idx, :] if topk_wgt_local is not None:
flat_weights = topk_wgt_local.flatten() gathered = gathered * flat_weights[:, None].astype(acc_dtype)
gathered = jnp.reshape(gathered, (-1, reduce_group_size, op.shape[1])) output = jnp.sum(gathered.astype(acc_dtype), axis=1).astype(jnp.bfloat16) ```
- Parameters:
op (Array) – The operand matrix in fp32 [B, K] to reduce.
idx (Array) – The indices in int32[M,] guiding the reduction and scatter.
topk_weights (Array | None) – Optional weights to apply to the gathered operands.
reduce_group_size (int) – The size of the groups to reduce.
single_sc (bool) – Whether to use a single SparseCore.
col_chunk_size (int) – The size of column chunks to process.
row_chunk_size (int) – The size of row chunks for internal processing.
loop_unroll_factor_1 (int) – Unroll factor for the main loop over column chunks.
loop_unroll_factor_2 (int) – Unroll factor for the loop over row chunks in offset calculation.
loop_unroll_factor_3 (int) – Unroll factor for the inner loop within offset calculation.
loop_parallel_access_1 (bool) – Enables parallel access for the main column chunk loop.
loop_parallel_access_2 (bool) – Enables parallel access for the row chunk loop in offset calculation.
loop_parallel_access_3 (bool) – Enables parallel access for the inner loop within offset calculation.
topk_wgt_zero_nan (bool) – If true, treat zero topk_weights as indicators of NaN during multiplication, resulting in zero output.
- Returns:
The result of operation, bf16 matrix [M/reduce_group_size, K].
- Return type:
Array