Source code for maxtext.kernels.gather_reduce_sc

# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""SparseCore gather-reduce kernel implementation.

This module contains a kernel implementation for performing a gather-reduce
operation on TPU SparseCore. It groups rows of an operand based on provided
indices, sums them up, and scatters the results.
"""

import array
import functools
from typing import Any

import jax
from jax import core
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
import jax.experimental.pallas.tpu as pltpu
from jax.interpreters import mlir
import jax.numpy as jnp
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import func
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import scf
from jaxlib.mlir.dialects import vector


[docs] class VectorTypeHelper: """Helper to create VectorType with a specific element type.""" def __init__(self, element_type_fn): self.element_type_fn = element_type_fn def __getitem__(self, shape): if isinstance(shape, int): shape = [shape] return ir.VectorType.get(shape, self.element_type_fn())
_I32 = VectorTypeHelper(functools.partial(ir.IntegerType.get_signless, 32)) _F32 = VectorTypeHelper(ir.F32Type.get) _BF16 = VectorTypeHelper(ir.BF16Type.get)
[docs] @jax.jit( static_argnames=[ "reduce_group_size", "single_sc", "col_chunk_size", "loop_unroll_factor_1", "loop_unroll_factor_2", "loop_unroll_factor_3", "loop_parallel_access_1", "loop_parallel_access_2", "loop_parallel_access_3", "topk_wgt_zero_nan", ], ) def sc_gather_reduce( op: jax.Array, idx: jax.Array, topk_weights: jax.Array | None = None, *, reduce_group_size: int, single_sc: bool = False, col_chunk_size: int = int(3.5 * 1024), row_chunk_size: int = 16, # writing back 2 rows given reduce size of 8 loop_unroll_factor_1: int = 2, loop_unroll_factor_2: int = 2, loop_unroll_factor_3: int = 8, loop_parallel_access_1: bool = True, loop_parallel_access_2: bool = False, loop_parallel_access_3: bool = False, topk_wgt_zero_nan: bool = False, ) -> jax.Array: """Performs a gather-reduce operation on SparseCore. This kernel groups rows of the operand `op` based on `idx`, sums them up, and scatters the results. The gather and add operations are performed in fp32, and the results are written back in bf16. Equivalent jax numpy code: ``` gathered = op[idx, :] if topk_wgt_local is not None: flat_weights = topk_wgt_local.flatten() gathered = gathered * flat_weights[:, None].astype(acc_dtype) gathered = jnp.reshape(gathered, (-1, reduce_group_size, op.shape[1])) output = jnp.sum(gathered.astype(acc_dtype), axis=1).astype(jnp.bfloat16) ``` Args: op: The operand matrix in fp32 [B, K] to reduce. idx: The indices in int32[M,] guiding the reduction and scatter. topk_weights: Optional weights to apply to the gathered operands. reduce_group_size: The size of the groups to reduce. single_sc: Whether to use a single SparseCore. col_chunk_size: The size of column chunks to process. row_chunk_size: The size of row chunks for internal processing. loop_unroll_factor_1: Unroll factor for the main loop over column chunks. loop_unroll_factor_2: Unroll factor for the loop over row chunks in offset calculation. loop_unroll_factor_3: Unroll factor for the inner loop within offset calculation. loop_parallel_access_1: Enables parallel access for the main column chunk loop. loop_parallel_access_2: Enables parallel access for the row chunk loop in offset calculation. loop_parallel_access_3: Enables parallel access for the inner loop within offset calculation. topk_wgt_zero_nan: If true, treat zero topk_weights as indicators of NaN during multiplication, resulting in zero output. Returns: The result of operation, bf16 matrix [M/reduce_group_size, K]. """ assert op.dtype in ( jnp.float32, jnp.bfloat16, ), f"op.dtype must be f32 or bf16, but got {op.dtype}" is_bf16 = op.dtype == jnp.bfloat16 assert op.shape[0] % reduce_group_size == 0, ( "op.shape[0] must be divisible by reduce_group_size, but got" f" {op.shape[0]} and {reduce_group_size}" ) assert row_chunk_size / reduce_group_size == 2, ( # writing back in bf16 (2 rows at once) f"row_chunk_size must be 2 * reduce_group_size, but got {row_chunk_size=}" f" and {reduce_group_size=}" ) if topk_weights is not None: assert topk_weights.ndim == 2 assert topk_weights.shape[0] * 128 == idx.shape[0] assert topk_weights.shape[1] == 128 tpu_info = pltpu.get_tpu_info() if tpu_info.sparse_core is None: raise ValueError("SparseCore is not available on this TPU version.") used_sc_cores = 1 if single_sc else 2 num_sc_per_core = tpu_info.sparse_core.num_subcores num_sc = num_sc_per_core * used_sc_cores vreg_size = tpu_info.sparse_core.num_lanes with mlir.make_ir_context() as ctx, ir.Location.unknown(): # The TPU dialect is required for its TPU_DimensionSemantics tpu.register_dialect(ctx) bf16 = ir.BF16Type.get() i32 = ir.IntegerType.get_signless(32) f32 = ir.F32Type.get() index = ir.IndexType.get() memoryspace_tilespmem = ir.Attribute.parse("#tpu.memory_space<vmem>") memoryspace_hbm = ir.Attribute.parse("#tpu.memory_space<hbm>") memory_space_semaphore = ir.Attribute.parse("#tpu.memory_space<semaphore_mem>") dma_semaphore_type = ir.Type.parse("!tpu.dma_semaphore") # A mask of all True values to enable all sublanes in vector operations. enable_all_sublanes_mask = ir.DenseBoolArrayAttr.get([True] * vreg_size) assert row_chunk_size == vreg_size def _kernel_impl( current_sc_core, current_local_core, idx_ref, op_ref, weights_ref, out_ref, func_op, ): constants: dict[tuple[Any, Any], Any] = {} def const_lut(val, ty=None): ty = index if ty is None else ty if (val, ty) not in constants: with ir.InsertionPoint.at_block_begin(func_op.entry_block): constants[(val, ty)] = arith.constant(ty, ir.IntegerAttr.get(ty, val)) return constants[(val, ty)] def fill_load_offset_tile(offset_tile_local, idx_tile_local, col_pos): """Fills the offset tile for indirect DMA gather. This function calculates the HBM offsets from which to gather rows based on the indices in idx_tile_local, for a given column chunk. The offsets are calculated to correctly index into the operand `op` in HBM, considering the memory layout and the current column chunk being processed. The calculated offsets are stored in offset_tile_local, which is later used by tpu.enqueue_indirect_dma. Args: offset_tile_local: The destination memref in TileSpMem to store calculated offsets. idx_tile_local: A memref in TileSpMem containing a chunk of indices of rows to gather from `op`. col_pos: The index of the current column chunk being processed. Returns: The offset_tile_local memref filled with offsets for DMA gather. """ idx_loaded = tpu.load( _I32[row_chunk_size], idx_tile_local, [const_lut(0)], enable_all_sublanes_mask, ) parity = arith.remui( idx_loaded, vector.broadcast(_I32[row_chunk_size], const_lut(2, i32)), ) if is_bf16: mul_mem_layout = 4 else: mul_mem_layout = 8 iota = arith.constant( _I32[row_chunk_size], ir.DenseIntElementsAttr.get( array.array("i", [i * mul_mem_layout for i in range(row_chunk_size)]), type=_I32[row_chunk_size], ), ) loop_over_idx = scf.ForOp( lower_bound=const_lut(0), upper_bound=const_lut(row_chunk_size + 1), step=const_lut(1), ) loop_over_idx.attributes["sc.loop_unroll_factor"] = ir.IntegerAttr.get(i32, loop_unroll_factor_2) if loop_parallel_access_2: loop_over_idx.attributes["sc.parallel_access"] = ir.UnitAttr.get() with ir.InsertionPoint(loop_over_idx.body): i = loop_over_idx.induction_variable base_idx = vector.extract( idx_loaded, dynamic_position=[i], static_position=ir.DenseI64ArrayAttr.get([-9223372036854775808]), ) # align base_idx with memory layout base_idx_div = arith.divui(base_idx, const_lut(8, i32)) if is_bf16: base_idx_rem = arith.divui(arith.remui(base_idx, const_lut(8, i32)), const_lut(2, i32)) else: base_idx_rem = arith.remui(base_idx, const_lut(8, i32)) base_idx = arith.addi( arith.muli( # NOTYPO base_idx_div, const_lut(op.shape[1] // 128 * mul_mem_layout, i32), ), base_idx_rem, ) # consider col chunk base_idx = arith.addi( base_idx, arith.muli( # NOTYPO arith.index_cast(i32, col_pos), const_lut(col_chunk_size // 128 * mul_mem_layout, i32), ), ) start_vec = arith.addi(vector.broadcast(_I32[row_chunk_size], base_idx), iota) loop_j = scf.ForOp( const_lut(0), const_lut((col_chunk_size // 128) // vreg_size), const_lut(1), iter_args=[start_vec], ) loop_j.attributes["sc.loop_unroll_factor"] = ir.IntegerAttr.get(i32, loop_unroll_factor_3) if loop_parallel_access_3: loop_j.attributes["sc.parallel_access"] = ir.UnitAttr.get() with ir.InsertionPoint(loop_j.body): vec = loop_j.inner_iter_args[0] idx_local = arith.addi( arith.muli(i, const_lut((col_chunk_size // 128))), # NOTYPO arith.muli(loop_j.induction_variable, const_lut(vreg_size)), # NOTYPO ) tpu.store( vec, offset_tile_local, [idx_local], enable_all_sublanes_mask, ) next_vec = arith.addi( vec, vector.broadcast( _I32[row_chunk_size], const_lut(mul_mem_layout * vreg_size, i32), ), ) scf.YieldOp([next_vec]) rem_vreg = col_chunk_size / 128 % vreg_size / vreg_size last_full_vec = tpu.load( _I32[vreg_size], offset_tile_local, [ arith.subi( arith.muli(const_lut((col_chunk_size // 128)), i), # NOTYPO const_lut(int(vreg_size * (1 + rem_vreg))), ) ], enable_all_sublanes_mask, ) add_vec = vector.broadcast( _I32[vreg_size], const_lut(int(mul_mem_layout * rem_vreg * vreg_size), i32), ) # 4 comes from bf16 last_full_vec = arith.addi(last_full_vec, add_vec) tpu.store( last_full_vec, offset_tile_local, [ arith.subi( arith.muli(const_lut((col_chunk_size // 128)), i), # NOTYPO const_lut(int(vreg_size)), ) ], enable_all_sublanes_mask, ) scf.YieldOp([]) return offset_tile_local, parity def load_weights(lin_idx, dst_tile, sflag): """Loads weights from HBM to TileSpMem. This function calculates offsets for reading weights from HBM based on a linear index `lin_idx`. It handles a specific packed memory layout of weights where weights for two rows are packed together (bf16). It calculates parity to determine which part of the packed data to use later in `perform_add`. It uses indirect DMA to load weights into dst_tile. Args: lin_idx: Linear index used to calculate weight offsets in HBM. dst_tile: The destination memref in TileSpMem to load weights into. sflag: The semaphore to use for the DMA operation. Returns: Parity bit (0 or 1) indicating which set of weights to use from the loaded packed data. """ lin_idx = arith.divui(lin_idx, const_lut(16)) # reading in increments of 16 # parity # (i >> 3) & 1 parity = arith.andi(arith.shrui(lin_idx, const_lut(3)), const_lut(1)) # address # (i & 7) | ((i >> 4) << 3) lin_idx = arith.ori( arith.andi(lin_idx, const_lut(7)), arith.shli(arith.shrui(lin_idx, const_lut(4)), const_lut(3)), ) # lin_idx is index type, cast to i32 lin_idx_i32 = arith.index_cast(i32, lin_idx) lin_idx_base = arith.muli(lin_idx_i32, const_lut(16, i32)) # NOTYPO lin_idx_base_vec = vector.broadcast(_I32[row_chunk_size], lin_idx_base) iota = arith.constant( _I32[row_chunk_size], ir.DenseIntElementsAttr.get( array.array("i", list(range(16))), type=_I32[row_chunk_size], ), ) offsets = arith.addi(lin_idx_base_vec, iota) tpu.enqueue_indirect_dma( source=weights_ref, target=dst_tile, offsets=offsets, semaphore=sflag, ) tpu.wait_indirect_dma(semaphore=sflag, src=weights_ref, dst=dst_tile) return parity def perform_add( scratch_local, scratch_out_local, idx_parity, weights_local=None, parity=None, ): """Performs reduction (summation) of rows in scratchpad. This function reduces `row_chunk_size` rows stored in scratch_local into 2 rows by summing rows in groups of `reduce_group_size`. If weights_local is provided, rows are multiplied by weights before summation. The result is packed into bf16 format and stored in scratch_out_local. With row_chunk_size=16 and reduce_group_size=8, this reduces 16 rows to 2 rows (rows 0-7 summed to one row, rows 8-15 summed to another). Args: scratch_local: Memref in TileSpMem containing rows gathered from `op`. scratch_out_local: Memref in TileSpMem to store the 2 reduced rows in bf16 format. idx_parity: Parity information for each index in the gathered rows. weights_local: Optional memref in TileSpMem containing weights to apply before reduction. parity: Optional parity bit indicating which weights to use from packed weight data in weights_local. Returns: The scratch_out_local memref containing reduced rows. """ weights_vecs = None if weights_local is not None: f32_zero = arith.constant(f32, ir.FloatAttr.get(f32, 0.0)) zero_vec_f32 = vector.broadcast(_F32[vreg_size], f32_zero) # Load 32 weights (64B) -> 16 weights for R0 + 16 weights for R1 from # packed layout. We check parity to decide which one to take. We need # to extract the even or odd elements weights_local_2x16 = memref.reinterpret_cast( ir.MemRefType.get( (2, row_chunk_size), bf16, ir.Attribute.parse("#tpu.tiled<,[" + str(row_chunk_size) + ",1]>"), memory_space=memoryspace_tilespmem, ), weights_local, [], [], [], static_offsets=[0], static_sizes=[2, row_chunk_size], static_strides=[row_chunk_size, 1], ) raw_weights = tpu.load( _BF16[2, row_chunk_size], weights_local_2x16, [const_lut(0), const_lut(0)], enable_all_sublanes_mask, ) # Reminder Mosaic SC compiler: flips [d0, d1] to [d1, d0] # part=1 corresponds to sc_tpu.unpackf result 1 (Even indices) weights_evens = tpu.unpack_subelements( _F32[16], raw_weights, 1, ir.Attribute.parse("#tpu.pack_format<interleaved>"), ) # part=0 corresponds to sc_tpu.unpackf result 0 (Odd indices) weights_odds = tpu.unpack_subelements( _F32[16], raw_weights, 0, ir.Attribute.parse("#tpu.pack_format<interleaved>"), ) is_odd = arith.cmpi( arith.CmpIPredicate.eq, parity, const_lut(0), ) raw_weights_f32 = arith.select(is_odd, weights_odds, weights_evens) weights_vecs = [ vector.broadcast( _F32[vreg_size], vector.extract( raw_weights_f32, dynamic_position=[], static_position=ir.DenseI64ArrayAttr.get([i]), ), ) for i in range(16) ] # reinterpret cast to correct shape to read from it if not is_bf16: new_scratch_shape = (row_chunk_size, col_chunk_size) new_scratch_layout = ir.Attribute.parse("#tpu.tiled<,[" + str(col_chunk_size) + ",1]>") new_scratch_ref_ty = ir.MemRefType.get( new_scratch_shape, f32, new_scratch_layout, memory_space=memoryspace_tilespmem, ) scratch_local = memref.reinterpret_cast( new_scratch_ref_ty, scratch_local, [], [], [], static_offsets=[0], static_sizes=new_scratch_shape, static_strides=[col_chunk_size, 1], ) loop_pack_scratch = scf.ForOp(const_lut(0), const_lut(col_chunk_size), const_lut(vreg_size)) loop_pack_scratch.attributes["sc.loop_unroll_factor"] = ir.IntegerAttr.get(i32, loop_unroll_factor_1) if loop_parallel_access_1: loop_pack_scratch.attributes["sc.parallel_access"] = ir.UnitAttr.get() with ir.InsertionPoint(loop_pack_scratch.body): col_offset = loop_pack_scratch.induction_variable # map col_offset to idx and col inside 128 if is_bf16: row_add = arith.divui(col_offset, const_lut(128)) col_pos = arith.remui(col_offset, const_lut(128)) else: row_add = None col_pos = None def get_row_val(row_idx): if is_bf16: row_idx_l = row_idx * (col_chunk_size // 128) vec_bf16_2x16 = tpu.load( _BF16[2, 16], scratch_local, [ arith.addi(const_lut(row_idx_l), row_add), arith.muli(col_pos, const_lut(2)), # NOTYPO ], enable_all_sublanes_mask, ) vec_f32_evens = tpu.unpack_subelements( _F32[16], vec_bf16_2x16, 1, ir.Attribute.parse("#tpu.pack_format<interleaved>"), ) vec_f32_odds = tpu.unpack_subelements( _F32[16], vec_bf16_2x16, 0, ir.Attribute.parse("#tpu.pack_format<interleaved>"), ) parity_of_row = vector.extract( idx_parity, dynamic_position=[], static_position=ir.DenseI64ArrayAttr.get([row_idx]), ) is_odd_scalar = arith.cmpi(arith.CmpIPredicate.eq, parity_of_row, const_lut(0, i32)) return arith.select(is_odd_scalar, vec_f32_odds, vec_f32_evens) else: return tpu.load( _F32[vreg_size], scratch_local, [const_lut(row_idx), col_offset], enable_all_sublanes_mask, ) row0 = get_row_val(0) if weights_local is not None: row0 = arith.mulf(row0, weights_vecs[0]) if topk_wgt_zero_nan: row0 = arith.select( arith.cmpf(arith.CmpFPredicate.OEQ, weights_vecs[0], zero_vec_f32), zero_vec_f32, row0, ) row8 = get_row_val(8) if weights_local is not None: row8 = arith.mulf(row8, weights_vecs[8]) if topk_wgt_zero_nan: row8 = arith.select( arith.cmpf(arith.CmpFPredicate.OEQ, weights_vecs[8], zero_vec_f32), zero_vec_f32, row8, ) for sum_idx in range(7): tmp_row0 = get_row_val(sum_idx + 1) if weights_local is not None: tmp_row0 = arith.mulf(tmp_row0, weights_vecs[sum_idx + 1]) if topk_wgt_zero_nan: tmp_row0 = arith.select( arith.cmpf( arith.CmpFPredicate.OEQ, weights_vecs[sum_idx + 1], zero_vec_f32, ), zero_vec_f32, tmp_row0, ) row0 = arith.addf(row0, tmp_row0) tmp_row8 = get_row_val(8 + sum_idx + 1) if weights_local is not None: tmp_row8 = arith.mulf(tmp_row8, weights_vecs[8 + sum_idx + 1]) if topk_wgt_zero_nan: tmp_row8 = arith.select( arith.cmpf( arith.CmpFPredicate.OEQ, weights_vecs[8 + sum_idx + 1], zero_vec_f32, ), zero_vec_f32, tmp_row8, ) row8 = arith.addf(row8, tmp_row8) packed = tpu.pack_subelements( _BF16[2, vreg_size], [row0, row8], [0, 1], ir.Attribute.parse("#tpu.pack_format<interleaved>"), ) tpu.store( packed, scratch_out_local, [const_lut(0), arith.muli(col_offset, const_lut(2))], # NOTYPO enable_all_sublanes_mask, ) scf.YieldOp([]) # undo reshape cast on scratch_local if not is_bf16: mem_num = 128 new_scratch_shape = ( row_chunk_size * col_chunk_size // mem_num, mem_num, ) new_scratch_layout = ir.Attribute.parse("#tpu.tiled<,[" + str(mem_num) + ",1]>") new_scratch_ref_ty = ir.MemRefType.get( new_scratch_shape, f32, new_scratch_layout, memory_space=memoryspace_tilespmem, ) scratch_local = memref.reinterpret_cast( new_scratch_ref_ty, scratch_local, [], [], [], static_offsets=[0], static_sizes=new_scratch_shape, static_strides=[mem_num, 1], ) return scratch_out_local def fill_out_offset_tile(offset_tile_out_local, col_pos, row_pos=None): """Fills the offset tile for indirect DMA scatter for outputs (bf16). This function calculates the HBM offsets to scatter the reduced rows to. The offsets are calculated to correctly index into the output tensor in HBM, based on the row index of the reduction group `row_pos` and the current column chunk `col_pos`. The calculated offsets are stored in offset_tile_out_local, which is later used by tpu.enqueue_indirect_dma to scatter results from TileSpMem to HBM. Args: offset_tile_out_local: The destination memref in TileSpMem to store calculated offsets. col_pos: The index of the current column chunk being processed. row_pos: The starting row index for the reduction group in the output tensor. If None, computes offsets for prologue. Returns: The offset_tile_out_local memref filled with offsets for DMA scatter. """ if row_pos is not None: rest_row_pos = arith.remui(row_pos, const_lut(8)) rest_row_pos = arith.divui(rest_row_pos, const_lut(2)) tmp = arith.divui(row_pos, const_lut(8)) tmp = arith.muli(tmp, const_lut(4 * op.shape[1] // 128)) # NOTYPO row_vec_offset = arith.addi(tmp, rest_row_pos) row_vec_offset = arith.index_cast(i32, row_vec_offset) row_vec_offset = vector.broadcast(_I32[row_chunk_size], row_vec_offset) iota = arith.constant( _I32[row_chunk_size], ir.DenseIntElementsAttr.get( array.array("i", [i * 4 for i in range(row_chunk_size)]), # 4 is key here for bf16 type=_I32[row_chunk_size], ), ) iota = arith.addi( iota, vector.broadcast( _I32[row_chunk_size], arith.muli( # NOTYPO arith.index_cast(i32, col_pos), const_lut(col_chunk_size // 128 * 4, i32), ), ), ) loop_over_idx = scf.ForOp( lower_bound=const_lut(0), upper_bound=const_lut(row_chunk_size // reduce_group_size), step=const_lut(2), ) loop_over_idx.attributes["sc.loop_unroll_factor"] = ir.IntegerAttr.get(i32, loop_unroll_factor_2) if loop_parallel_access_2: loop_over_idx.attributes["sc.parallel_access"] = ir.UnitAttr.get() with ir.InsertionPoint(loop_over_idx.body): loop_i = loop_over_idx.induction_variable start_vec = arith.addi( vector.broadcast(_I32[row_chunk_size], arith.index_cast(i32, loop_i)), iota, ) loop_j = scf.ForOp( const_lut(0), const_lut((col_chunk_size // 128) // vreg_size), const_lut(1), iter_args=[start_vec], ) loop_j.attributes["sc.loop_unroll_factor"] = ir.IntegerAttr.get(i32, loop_unroll_factor_3) if loop_parallel_access_3: loop_j.attributes["sc.parallel_access"] = ir.UnitAttr.get() with ir.InsertionPoint(loop_j.body): vec = loop_j.inner_iter_args[0] idx_local = arith.addi( arith.muli(loop_i, const_lut((op.shape[1] // 128))), # NOTYPO arith.muli(loop_j.induction_variable, const_lut(vreg_size)), # NOTYPO ) if row_pos is not None: local_vec = arith.addi(vec, row_vec_offset) else: local_vec = vec tpu.store( local_vec, offset_tile_out_local, [idx_local], enable_all_sublanes_mask, ) add_vec = vector.broadcast(_I32[vreg_size], const_lut(4 * vreg_size, i32)) next_vec = arith.addi(vec, add_vec) scf.YieldOp([next_vec]) scf.YieldOp([]) rem_vreg = col_chunk_size / 128 % vreg_size / vreg_size last_full_vec = tpu.load( _I32[vreg_size], offset_tile_out_local, [const_lut(offset_sizes_out - int(vreg_size * (1 + rem_vreg)))], enable_all_sublanes_mask, ) add_vec = vector.broadcast(_I32[vreg_size], const_lut(int(4 * rem_vreg * vreg_size), i32)) # 4 comes from bf16 last_full_vec = arith.addi(last_full_vec, add_vec) tpu.store( last_full_vec, offset_tile_out_local, [const_lut(offset_sizes_out - vreg_size)], enable_all_sublanes_mask, ) return offset_tile_out_local offset_sizes_out = ( (row_chunk_size // reduce_group_size) * (col_chunk_size // 128) // 2 ) # two because doing two rows at once offset_tile_out_0 = memref.alloca( ir.MemRefType.get( shape=(offset_sizes_out,), element_type=i32, memory_space=memoryspace_tilespmem, ), [], [], ) offset_tile_out_1 = memref.alloca( ir.MemRefType.get( shape=(offset_sizes_out,), element_type=i32, memory_space=memoryspace_tilespmem, ), [], [], ) if is_bf16: scratch_0 = memref.alloca( ir.MemRefType.get( shape=(row_chunk_size * 2, col_chunk_size), element_type=bf16, memory_space=memoryspace_tilespmem, ), [], [], ) scratch_1 = memref.alloca( ir.MemRefType.get( shape=(row_chunk_size * 2, col_chunk_size), element_type=bf16, memory_space=memoryspace_tilespmem, ), [], [], ) mem_num = 128 * 2 new_scratch_shape = ( row_chunk_size * col_chunk_size * 2 // mem_num, mem_num, ) new_scratch_layout = ir.Attribute.parse("#tpu.tiled<,[" + str(mem_num) + ",1]>") new_scratch_ref_ty = ir.MemRefType.get( new_scratch_shape, bf16, new_scratch_layout, memory_space=memoryspace_tilespmem, ) else: scratch_0 = memref.alloca( ir.MemRefType.get( shape=(row_chunk_size, col_chunk_size), element_type=f32, memory_space=memoryspace_tilespmem, ), [], [], ) scratch_1 = memref.alloca( ir.MemRefType.get( shape=(row_chunk_size, col_chunk_size), element_type=f32, memory_space=memoryspace_tilespmem, ), [], [], ) mem_num = 128 new_scratch_shape = ( row_chunk_size * col_chunk_size // mem_num, mem_num, ) new_scratch_layout = ir.Attribute.parse("#tpu.tiled<,[" + str(mem_num) + ",1]>") new_scratch_ref_ty = ir.MemRefType.get( new_scratch_shape, f32, new_scratch_layout, memory_space=memoryspace_tilespmem, ) scratch_0 = memref.reinterpret_cast( new_scratch_ref_ty, scratch_0, [], [], [], static_offsets=[0], static_sizes=new_scratch_shape, static_strides=[mem_num, 1], ) scratch_1 = memref.reinterpret_cast( new_scratch_ref_ty, scratch_1, [], [], [], static_offsets=[0], static_sizes=new_scratch_shape, static_strides=[mem_num, 1], ) scratch_out_0 = memref.alloca( ir.MemRefType.get( shape=(2, col_chunk_size), element_type=bf16, memory_space=memoryspace_tilespmem, ), [], [], ) scratch_out_1 = memref.alloca( ir.MemRefType.get( shape=(2, col_chunk_size), element_type=bf16, memory_space=memoryspace_tilespmem, ), [], [], ) mem_num = 128 * 2 new_scratch_shape = (2 * col_chunk_size // mem_num, mem_num) new_scratch_layout = ir.Attribute.parse("#tpu.tiled<,[" + str(mem_num) + ",1]>") new_scratch_ref_ty = ir.MemRefType.get( new_scratch_shape, bf16, new_scratch_layout, memory_space=memoryspace_tilespmem, ) scratch_out_0 = memref.reinterpret_cast( new_scratch_ref_ty, scratch_out_0, [], [], [], static_offsets=[0], static_sizes=new_scratch_shape, static_strides=[mem_num, 1], ) scratch_out_1 = memref.reinterpret_cast( new_scratch_ref_ty, scratch_out_1, [], [], [], static_offsets=[0], static_sizes=new_scratch_shape, static_strides=[mem_num, 1], ) new_input_shape = ( idx.shape[0] // reduce_group_size * op.shape[1] // mem_num, mem_num, ) new_input_ref_ty = ir.MemRefType.get( new_input_shape, bf16, layout=ir.Attribute.parse(f"#tpu.tiled<,[{mem_num}, 1]>"), memory_space=memoryspace_hbm, ) out_ref = tpu.reinterpret_cast( new_input_ref_ty, out_ref, ) if topk_weights is not None: new_weights_shape = (topk_weights.size // 2, 2) new_weights_ref_ty = ir.MemRefType.get( new_weights_shape, bf16, layout=ir.Attribute.parse("#tpu.tiled<,[2, 1]>"), memory_space=memoryspace_hbm ) weights_ref = tpu.reinterpret_cast(new_weights_ref_ty, weights_ref) sflag_0 = tpu.sem_alloc(ir.MemRefType.get((), dma_semaphore_type, memory_space=memory_space_semaphore)) sflag_1 = tpu.sem_alloc(ir.MemRefType.get((), dma_semaphore_type, memory_space=memory_space_semaphore)) sflag_out_0 = tpu.sem_alloc(ir.MemRefType.get((), dma_semaphore_type, memory_space=memory_space_semaphore)) sflag_out_1 = tpu.sem_alloc(ir.MemRefType.get((), dma_semaphore_type, memory_space=memory_space_semaphore)) idx_tile_0 = memref.alloca( ir.MemRefType.get( shape=(row_chunk_size,), element_type=i32, memory_space=memoryspace_tilespmem, ), [], [], ) idx_tile_1 = memref.alloca( ir.MemRefType.get( shape=(row_chunk_size,), element_type=i32, memory_space=memoryspace_tilespmem, ), [], [], ) weights_tile_0 = None weights_tile_1 = None sflag_weights_0 = None sflag_weights_1 = None if topk_weights is not None: weights_tile_0 = memref.alloca( ir.MemRefType.get( shape=(row_chunk_size, 2), # matches Offsets shape (16) element_type=bf16, memory_space=memoryspace_tilespmem, ), [], [], ) weights_tile_1 = memref.alloca( ir.MemRefType.get( shape=(row_chunk_size, 2), # matches Offsets shape (16) element_type=bf16, memory_space=memoryspace_tilespmem, ), [], [], ) tiled_weights_ref_ty = ir.MemRefType.get( (row_chunk_size, 2), bf16, layout=ir.Attribute.parse("#tpu.tiled<,[2, 1]>"), memory_space=memoryspace_tilespmem, ) weights_tile_0 = memref.reinterpret_cast( tiled_weights_ref_ty, weights_tile_0, [], [], [], static_offsets=[0], static_sizes=[row_chunk_size, 2], static_strides=[2, 1], ) weights_tile_1 = memref.reinterpret_cast( tiled_weights_ref_ty, weights_tile_1, [], [], [], static_offsets=[0], static_sizes=[row_chunk_size, 2], static_strides=[2, 1], ) sflag_weights_0 = tpu.sem_alloc(ir.MemRefType.get((), dma_semaphore_type, memory_space=memory_space_semaphore)) sflag_weights_1 = tpu.sem_alloc(ir.MemRefType.get((), dma_semaphore_type, memory_space=memory_space_semaphore)) offset_sizes = row_chunk_size * (col_chunk_size // 128) offset_tile_0 = memref.alloca( ir.MemRefType.get( shape=(offset_sizes,), element_type=i32, memory_space=memoryspace_tilespmem, ), [], [], ) offset_tile_1 = memref.alloca( ir.MemRefType.get( shape=(offset_sizes,), element_type=i32, memory_space=memoryspace_tilespmem, ), [], [], ) global_chunk_to_process = arith.addi( arith.muli(current_sc_core, const_lut(num_sc_per_core, i32)), # NOTYPO current_local_core, ) global_row_idx_start = arith.index_cast( index, arith.muli( # NOTYPO global_chunk_to_process, const_lut(idx.shape[0] // num_sc, i32), ), ) loop_col_chunk = scf.ForOp( const_lut(0), const_lut(op.shape[1] // col_chunk_size), const_lut(1), ) with ir.InsertionPoint(loop_col_chunk.body): col_chunk_ij = loop_col_chunk.induction_variable offset_tile_out_0 = fill_out_offset_tile(offset_tile_out_0, col_chunk_ij) # ############# # # prologue # ############# # setup stream for #0 loop_row_chunk_idx = const_lut(0) base_idx_val = arith.index_cast(i32, arith.addi(global_row_idx_start, loop_row_chunk_idx)) base_idx_val = tpu.assume_multiple(base_idx_val, 8) source_slice = tpu.memref_slice( ir.MemRefType.get((row_chunk_size,), i32, memory_space=memoryspace_hbm), idx_ref, base_idx=[base_idx_val], dynamic_sizes=[], ) tpu.enqueue_dma(source_slice, idx_tile_0, target_semaphore=sflag_0) tpu.wait_dma2(semaphore=sflag_0, src=source_slice, dst=idx_tile_0) if topk_weights is not None: # The linear index is global_row_idx_start + loop_row_chunk_idx lin_idx = arith.addi(global_row_idx_start, loop_row_chunk_idx) parity_0 = load_weights(lin_idx, weights_tile_0, sflag_weights_0) parity_1 = const_lut(0) else: parity_0 = const_lut(0) parity_1 = const_lut(0) offset_tile_0, idx_parity_0 = fill_load_offset_tile(offset_tile_0, idx_tile_0, col_chunk_ij) if is_bf16: mem_num = 128 * 2 new_input_shape = (op.shape[0] * op.shape[1] // mem_num, mem_num) new_input_layout = ir.Attribute.parse("#tpu.tiled<,[" + str(mem_num) + ",1]>") new_input_ref_ty = ir.MemRefType.get( new_input_shape, bf16, new_input_layout, memory_space=memoryspace_hbm, ) else: mem_num = 128 new_input_shape = (op.shape[0] * op.shape[1] // mem_num, mem_num) new_input_layout = ir.Attribute.parse("#tpu.tiled<,[" + str(mem_num) + ",1]>") new_input_ref_ty = ir.MemRefType.get( new_input_shape, f32, new_input_layout, memory_space=memoryspace_hbm, ) op_ref = memref.reinterpret_cast( new_input_ref_ty, op_ref, [], [], [], static_offsets=[0], static_sizes=new_input_shape, static_strides=[mem_num, 1], ) tpu.enqueue_indirect_dma( source=op_ref, target=scratch_0, offsets=offset_tile_0, semaphore=sflag_0, ) ############# # corpus ############# loop_over_row_chunks = scf.ForOp( lower_bound=const_lut(row_chunk_size * 1), upper_bound=const_lut(row_chunk_size * ((idx.shape[0] // num_sc // row_chunk_size) - 1)), step=const_lut(row_chunk_size * 2), iter_args=[ scratch_0, scratch_out_0, sflag_0, sflag_out_0, idx_tile_0, offset_tile_0, offset_tile_out_0, scratch_1, scratch_out_1, sflag_1, sflag_out_1, idx_tile_1, offset_tile_1, offset_tile_out_1, weights_tile_0 if topk_weights is not None else scratch_0, weights_tile_1 if topk_weights is not None else scratch_0, sflag_weights_0 if topk_weights is not None else sflag_0, sflag_weights_1 if topk_weights is not None else sflag_0, parity_0 if topk_weights is not None else const_lut(0), parity_1 if topk_weights is not None else const_lut(0), idx_parity_0, idx_parity_0, # idx_parity_1 is not initialized yet ], ) with ir.InsertionPoint(loop_over_row_chunks.body): loop_row_chunk_idx = loop_over_row_chunks.induction_variable ( scratch_0, scratch_out_0, sflag_0, sflag_out_0, idx_tile_0, offset_tile_0, offset_tile_out_0, scratch_1, scratch_out_1, sflag_1, sflag_out_1, idx_tile_1, offset_tile_1, offset_tile_out_1, weights_tile_0, weights_tile_1, sflag_weights_0, sflag_weights_1, parity_0, parity_1, idx_parity_0, idx_parity_1, ) = loop_over_row_chunks.inner_iter_args # setup stream for #1 base_idx_val = arith.index_cast( i32, arith.addi( global_row_idx_start, arith.addi(loop_row_chunk_idx, const_lut(0)), ), ) base_idx_val = tpu.assume_multiple(base_idx_val, 8) source_slice = tpu.memref_slice( ir.MemRefType.get((row_chunk_size,), i32, memory_space=memoryspace_hbm), idx_ref, base_idx=[base_idx_val], dynamic_sizes=[], ) tpu.enqueue_dma(source_slice, idx_tile_1, target_semaphore=sflag_1) tpu.wait_dma2(semaphore=sflag_1, src=source_slice, dst=idx_tile_1) if topk_weights is not None: # The linear index is global_row_idx_start + loop_row_chunk_idx lin_idx = arith.addi(global_row_idx_start, loop_row_chunk_idx) parity_1 = load_weights(lin_idx, weights_tile_1, sflag_weights_1) offset_tile_1, idx_parity_1 = fill_load_offset_tile(offset_tile_1, idx_tile_1, col_chunk_ij) tpu.enqueue_indirect_dma( source=op_ref, target=scratch_1, offsets=offset_tile_1, semaphore=sflag_1, ) # wait stream #0 tpu.wait_indirect_dma(semaphore=sflag_0, src=op_ref, dst=scratch_0) # process #0 scratch_out_0 = perform_add( scratch_0, scratch_out_0, idx_parity_0, weights_local=weights_tile_0 if topk_weights is not None else None, parity=parity_0 if topk_weights is not None else None, ) offset_tile_out_0 = fill_out_offset_tile( offset_tile_out_0, col_chunk_ij, arith.divui( arith.addi( global_row_idx_start, arith.subi(loop_row_chunk_idx, const_lut(row_chunk_size)), ), const_lut(reduce_group_size), ), ) tpu.enqueue_indirect_dma( source=scratch_out_0, target=out_ref, offsets=offset_tile_out_0, semaphore=sflag_out_0, ) if topk_weights is not None: # The linear index is: # global_row_idx_start + loop_row_chunk_idx + row_chunk_size lin_idx = arith.addi( global_row_idx_start, arith.addi(loop_row_chunk_idx, const_lut(row_chunk_size)), ) parity_0 = load_weights(lin_idx, weights_tile_0, sflag_weights_0) base_idx_val = arith.index_cast( i32, arith.addi( global_row_idx_start, arith.addi(loop_row_chunk_idx, const_lut(row_chunk_size)), ), ) base_idx_val = tpu.assume_multiple(base_idx_val, 8) source_slice = tpu.memref_slice( ir.MemRefType.get((row_chunk_size,), i32, memory_space=memoryspace_hbm), idx_ref, base_idx=[base_idx_val], dynamic_sizes=[], ) tpu.enqueue_dma(source_slice, idx_tile_0, target_semaphore=sflag_0) tpu.wait_dma2(semaphore=sflag_0, src=source_slice, dst=idx_tile_0) offset_tile_0, idx_parity_0 = fill_load_offset_tile(offset_tile_0, idx_tile_0, col_chunk_ij) tpu.enqueue_indirect_dma( source=op_ref, target=scratch_0, offsets=offset_tile_0, semaphore=sflag_0, ) # wait stream #1 tpu.wait_indirect_dma(semaphore=sflag_1, src=op_ref, dst=scratch_1) # # process #1 scratch_out_1 = perform_add( scratch_1, scratch_out_1, idx_parity_1, weights_local=weights_tile_1 if topk_weights is not None else None, parity=parity_1 if topk_weights is not None else None, ) offset_tile_out_1 = fill_out_offset_tile( offset_tile_out_1, col_chunk_ij, arith.divui( arith.addi(global_row_idx_start, loop_row_chunk_idx), const_lut(reduce_group_size), ), ) tpu.enqueue_indirect_dma( source=scratch_out_1, target=out_ref, offsets=offset_tile_out_1, semaphore=sflag_out_1, ) # wait stream #0 out tpu.wait_indirect_dma(semaphore=sflag_out_0, src=scratch_out_0, dst=out_ref) # wait stream #1 out tpu.wait_indirect_dma(semaphore=sflag_out_1, src=scratch_out_1, dst=out_ref) # return #0 and #1 scf.YieldOp( [ scratch_0, scratch_out_0, sflag_0, sflag_out_0, idx_tile_0, offset_tile_0, offset_tile_out_0, scratch_1, scratch_out_1, sflag_1, sflag_out_1, idx_tile_1, offset_tile_1, offset_tile_out_1, weights_tile_0 if topk_weights is not None else scratch_0, weights_tile_1 if topk_weights is not None else scratch_0, sflag_weights_0 if topk_weights is not None else sflag_0, sflag_weights_1 if topk_weights is not None else sflag_0, parity_0 if topk_weights is not None else const_lut(0), parity_1 if topk_weights is not None else const_lut(0), idx_parity_0, idx_parity_1, ] ) ############# # epilogue ############# ( scratch_0, scratch_out_0, sflag_0, sflag_out_0, _, _, offset_tile_out_0, scratch_1, scratch_out_1, sflag_1, sflag_out_1, idx_tile_1, offset_tile_1, offset_tile_out_1, weights_tile_0, weights_tile_1, _, sflag_weights_1, parity_0, parity_1, idx_parity_0, idx_parity_1, ) = loop_over_row_chunks.results_ epi_idx_loc = row_chunk_size * ((idx.shape[0] // num_sc // row_chunk_size) - 1) add_f = arith.divui(global_row_idx_start, const_lut(reduce_group_size)) # setup stream for #1 base_idx_val = arith.index_cast(i32, arith.addi(global_row_idx_start, const_lut(epi_idx_loc))) base_idx_val = tpu.assume_multiple(base_idx_val, 8) source_slice = tpu.memref_slice( ir.MemRefType.get((row_chunk_size,), i32, memory_space=memoryspace_hbm), idx_ref, base_idx=[base_idx_val], dynamic_sizes=[], ) tpu.enqueue_dma(source_slice, idx_tile_1, target_semaphore=sflag_1) tpu.wait_dma2(semaphore=sflag_1, src=source_slice, dst=idx_tile_1) if topk_weights is not None: lin_idx = arith.addi(global_row_idx_start, const_lut(epi_idx_loc)) parity_1 = load_weights(lin_idx, weights_tile_1, sflag_weights_1) offset_tile_1, idx_parity_1 = fill_load_offset_tile(offset_tile_1, idx_tile_1, col_chunk_ij) tpu.enqueue_indirect_dma( source=op_ref, target=scratch_1, offsets=offset_tile_1, semaphore=sflag_1, ) # wait stream #0 tpu.wait_indirect_dma(semaphore=sflag_0, src=op_ref, dst=scratch_0) # process #0 scratch_out_0 = perform_add( scratch_0, scratch_out_0, idx_parity_0, weights_local=weights_tile_0 if topk_weights is not None else None, parity=parity_0 if topk_weights is not None else None, ) offset_tile_out_0 = fill_out_offset_tile( offset_tile_out_0, col_chunk_ij, arith.addi( add_f, const_lut((idx.shape[0] // num_sc // reduce_group_size) - 4), ), ) tpu.enqueue_indirect_dma( source=scratch_out_0, target=out_ref, offsets=offset_tile_out_0, semaphore=sflag_out_0, ) # wait stream #1 tpu.wait_indirect_dma(semaphore=sflag_1, src=op_ref, dst=scratch_1) # process #1 scratch_out_1 = perform_add( scratch_1, scratch_out_1, idx_parity_1, weights_local=weights_tile_1 if topk_weights is not None else None, parity=parity_1 if topk_weights is not None else None, ) offset_tile_out_1 = fill_out_offset_tile( offset_tile_out_1, col_chunk_ij, arith.addi( add_f, const_lut((idx.shape[0] // num_sc // reduce_group_size) - 2), ), ) tpu.enqueue_indirect_dma( source=scratch_out_1, target=out_ref, offsets=offset_tile_out_1, semaphore=sflag_out_1, ) tpu.wait_indirect_dma(semaphore=sflag_out_0, src=scratch_out_0, dst=out_ref) tpu.wait_indirect_dma(semaphore=sflag_out_1, src=scratch_out_1, dst=out_ref) scf.YieldOp([]) input_types = [ i32, i32, ir.MemRefType.get( idx.shape, i32, memory_space=memoryspace_hbm, ), ir.MemRefType.get( op.shape, bf16 if is_bf16 else f32, memory_space=memoryspace_hbm, ), ir.MemRefType.get( (idx.shape[0] // reduce_group_size, op.shape[1]), bf16, memory_space=memoryspace_hbm, ), ] if topk_weights is not None: input_types.insert( 4, ir.MemRefType.get( topk_weights.shape, bf16, memory_space=memoryspace_hbm, ), ) # Configure the wrappers if topk_weights is not None: @func.FuncOp.from_py_func(*input_types, name="main") def kernel_main( current_sc_core, current_local_core, idx_ref, op_ref, weights_ref, out_ref, func_op, ): return _kernel_impl( current_sc_core, current_local_core, idx_ref, op_ref, weights_ref, out_ref, func_op, ) else: @func.FuncOp.from_py_func(*input_types, name="main") def kernel_main( current_sc_core, current_local_core, idx_ref, op_ref, out_ref, func_op, ): return _kernel_impl( current_sc_core, current_local_core, idx_ref, op_ref, None, out_ref, func_op, ) # Configure the Mosaic iteration space f = kernel_main.func_op f.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get([used_sc_cores, num_sc_per_core]) f.attributes["dimension_semantics"] = ir.ArrayAttr.get( [ ir.Attribute.parse("#tpu.dimension_semantics<core_parallel>"), ir.Attribute.parse("#tpu.dimension_semantics<parallel>"), ] ) window_params = [ ir.DictAttr.get( { "transform_indices": ir.Attribute.parse("affine_map<(n,m) -> (0)>"), } ), ir.DictAttr.get( { "transform_indices": ir.Attribute.parse("affine_map<(n,m) -> (0,0)>"), } ), ir.DictAttr.get( { "transform_indices": ir.Attribute.parse("affine_map<(n,m) -> (0,0)>"), } ), ] if topk_weights is not None: # Insert weights before output - we append here because it's the same # attribute as output. window_params.append( ir.DictAttr.get( { "transform_indices": ir.Attribute.parse("affine_map<(n,m) -> (0,0)>"), } ), ) f.attributes["window_params"] = ir.ArrayAttr.get(window_params) f.attributes["tpu.core_type"] = ir.Attribute.parse("#tpu.core_type<sc_vector_subcore>") assert f.verify(), f m = ir.Module.create() m.body.append(f) ir.SymbolTable(m.operation).insert(f) # If we are in a shard map and the input has a manual axis type, # preserve it. if jax.typeof(op).manual_axis_type: out_type = core.ShapedArray( (idx.shape[0] // reduce_group_size, op.shape[1]), jnp.bfloat16, sharding=jax.sharding.NamedSharding( jax.sharding.get_abstract_mesh(), jax.sharding.PartitionSpec(), ), manual_axis_type=jax.typeof(op).manual_axis_type, ) else: out_type = core.ShapedArray( (idx.shape[0] // reduce_group_size, op.shape[1]), jnp.bfloat16, ) return mosaic.as_tpu_kernel( m, out_type=out_type, )( *( [ idx, op, topk_weights, ] if topk_weights is not None else [ idx, op, ] ) )