Skip to content

Commit 4b2b639

Browse files
kimm240Hyun Gyu Kim
andauthored
[TIR][Schedule]Generalize fuseReductionEpilogue to support arbitrary epilogue expressions (#18636)
## Major Changes for Generalization ### 1. Pattern Matching Removal **Removed Items:** - `EpilogueType` enum (Bias, BiasReLU, Clipping) - `AnalyzeEpiloguePattern()` function - Pattern-specific branching logic **Current Approach:** - Directly process the entire epilogue expression without pattern matching ### 2. Store Entire Epilogue Expression - Store the entire epilogue expression in `epilogue_expression_` - Use the expression directly without pattern analysis ```cpp // Store the epilogue expression and reduction buffer load epilogue_expression_ = inlined_store_->value; reduction_buffer_load_ = loads[0]; ``` ### 3. Generalized Init Transformation - Replace reduction buffer load with identity element (0) - Apply to the entire expression to generate init value ```cpp InitSubstituter init_subst(inlined_buffer_, identity_elem); PrimExpr init_epilogue = init_subst(epilogue_expression_); // Simplify: 0 + C[vi, vj] -> C[vi, vj] ``` **Examples:** - `temp + C` → `0 + C` → `C` (simplify) - `max(temp + C, 0)` → `max(0 + C, 0)` → `max(C, 0)` - `min(max(temp, lower), upper)` → `min(max(0, lower), upper)` ### 4. Generalized Update Transformation - Replace reduction buffer load with reduction update - If parent is Add and the other operand is not a reduction buffer → treat as bias addend and remove - Otherwise → apply expression as-is ```cpp class GeneralizedEpilogueApplier : public ExprMutator { // Replace reduction buffer load with reduction update // Automatically detect and remove bias addend in Add nodes // Automatically support other activation functions }; ``` ## Results and Verification ### Existing Tests Pass All existing tests pass, maintaining backward compatibility: - `test_fuse_reduction_epilogue_basic` - `test_fuse_reduction_epilogue_fp32` - `test_fuse_reduction_epilogue_numerical_correctness` - `test_fuse_reduction_epilogue_multiple_epilogue` - `test_matmul_bias_relu` - `test_matmul_bias_relu_correctness_unified` - `test_matmul_clipping` - `test_matmul_clipping_correctness_unified` - Other commutative variants tests Total: All 15 tests pass --------- Signed-off-by: Hyun Gyu Kim <kimm240@telepix.net> Co-authored-by: Hyun Gyu Kim <kimm240@telepix.net>
1 parent 4747a92 commit 4b2b639

File tree

2 files changed

+334
-212
lines changed

2 files changed

+334
-212
lines changed

0 commit comments

Comments
 (0)