[TIR][Schedule] Add FuseReductionEpilogue primitive to fuse epilogue into reduction init - 1. Problem Analysis and Limitations of Existing Solutions
Published:
Most AI accelerators support MAC (Multiply-Accumulate) instructions that process Output = Input * Weight + Bias in a single cycle. However, if these two operations are separated in the intermediate code (TIR) generated by the compiler, the hardware performance cannot be fully utilized.
1. Overview
Currently, TVM cannot inline Bias Addition into MatMul blocks. Even powerful existing scheduling tools are powerless against this simple pattern. We analyze the problem and summarize how to improve it.
2. Problem Situation
The code we want to optimize has Reduction (multiplication accumulation) and Epilogue (addition) separated as follows:
temp = T.alloc_buffer((16, 16), "int32")
# Block 1: MatMul (Reduction)
for i, j, k in T.grid(16, 16, 16):
with T.block("multiply"):
with T.init():
temp[vi, vj] = 0
temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
# Block 2: Bias Add (Epilogue)
for i, j in T.grid(16, 16):
with T.block("add"):
D[vi, vj] = temp[vi, vj] + C[vi, vj]
We need to eliminate the temp buffer and merge into a single block by loading Bias(C) instead of 0 in the T.init stage. However, TVM currently does not support this scheduling technique.
3. Problems with Existing Solutions
Attempt 1: compute_inline (Producer → Consumer)
This attempts to push the MatMul block (multiply) into the Bias Add block (add).
def compute_inline(self, block: Union[BlockRV, str]) -> None:
"""Inline a block into its consumer(s). It requires:
1) The block is a complete non-root block...
3) The body of the block must be a BufferStore statement in the form of,
``A[i, j, k, ...] = ...`` where the indices of the LHS are all distinct atomic variables...
"""
According to constraint 3 in the docstring, the body of the block to be inlined must be a simple BufferStore form.
However, the MatMul block is a Reduction Block. It internally contains a T.init() statement and has an accumulation structure that reads and writes to itself (temp = temp + …). This does not satisfy the condition of “simple assignment statement (BufferStore statement)”.
Attempt 2: reverse_compute_inline (Consumer → Producer)
Conversely, this attempts to bring the Bias Add block (add) into the MatMul block (multiply).
def reverse_compute_inline(self, block: Union[BlockRV, str]) -> None:
"""Inline a block into its only producer. It requires:
3) The only producer of the block is a read-after-write producer and a
complete non-root block
4) The body of the block must be a BufferStore statement...
"""
According to the docstring, the Consumer block (add) itself satisfies condition 4 (simple BufferStore). However, the qualification requirement for the Producer (multiply block) in condition 3 is problematic.
The Producer MatMul block has a Reduction axis (k), so it has an incomplete state where the output is not complete until the loop fully ends. reverse_compute_inline expects the Producer to have a simple ‘Read-After-Write’ relationship, but Reduction Blocks have much more complex dependencies.
Attempt 3: decompose_reduction then inline
What if we split the Reduction into init and update, then combine them?
def decompose_reduction(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV:
"""Decompose a reduction block into two separate blocks.
a) The init block... inserted right before the given loop.
b) The update block... original block without init statement.
"""
When decompose_reduction is performed, the initialization block comes out of the loop.
- init: outside the loop (initialize to 0)
- update: inside the loop (multiplication accumulation)
- add: outside the loop (add Bias)
The add block must execute after all update loops finish. If we forcibly inline the add block into the update loop, we get a (mathematically incorrect) result where Bias is added every iteration.
4. Conclusion
Through the above analysis, we confirmed that existing primitives were not designed to inject external operations into the initialization (Init) stage of Reduction Loops. Constraints such as complete non-root block or BufferStore statement written in each function’s docstring excluded these complex Reduction patterns.
To achieve the transformation we want, we need a new primitive that satisfies the following conditions:
def fuse_reduction_epilogue(self, reduction_block, epilogue_block):
"""
1) The reduction block is a complete reduction block
2) The epilogue block only reads from the reduction block's output
3) The epilogue performs a simple addition: output = reduction_result + bias
"""
This is the background for why I designed fuse_reduction_epilogue.
Series Posts
Language: 한국어 (Korean)
