Commit 4b2b639
[TIR][Schedule]Generalize fuseReductionEpilogue to support arbitrary epilogue expressions (#18636)
## Major Changes for Generalization
### 1. Pattern Matching Removal
**Removed Items:**
- `EpilogueType` enum (Bias, BiasReLU, Clipping)
- `AnalyzeEpiloguePattern()` function
- Pattern-specific branching logic
**Current Approach:**
- Directly process the entire epilogue expression without pattern
matching
### 2. Store Entire Epilogue Expression
- Store the entire epilogue expression in `epilogue_expression_`
- Use the expression directly without pattern analysis
```cpp
// Store the epilogue expression and reduction buffer load
epilogue_expression_ = inlined_store_->value;
reduction_buffer_load_ = loads[0];
```
### 3. Generalized Init Transformation
- Replace reduction buffer load with identity element (0)
- Apply to the entire expression to generate init value
```cpp
InitSubstituter init_subst(inlined_buffer_, identity_elem);
PrimExpr init_epilogue = init_subst(epilogue_expression_);
// Simplify: 0 + C[vi, vj] -> C[vi, vj]
```
**Examples:**
- `temp + C` → `0 + C` → `C` (simplify)
- `max(temp + C, 0)` → `max(0 + C, 0)` → `max(C, 0)`
- `min(max(temp, lower), upper)` → `min(max(0, lower), upper)`
### 4. Generalized Update Transformation
- Replace reduction buffer load with reduction update
- If parent is Add and the other operand is not a reduction buffer →
treat as bias addend and remove
- Otherwise → apply expression as-is
```cpp
class GeneralizedEpilogueApplier : public ExprMutator {
// Replace reduction buffer load with reduction update
// Automatically detect and remove bias addend in Add nodes
// Automatically support other activation functions
};
```
## Results and Verification
### Existing Tests Pass
All existing tests pass, maintaining backward compatibility:
- `test_fuse_reduction_epilogue_basic`
- `test_fuse_reduction_epilogue_fp32`
- `test_fuse_reduction_epilogue_numerical_correctness`
- `test_fuse_reduction_epilogue_multiple_epilogue`
- `test_matmul_bias_relu`
- `test_matmul_bias_relu_correctness_unified`
- `test_matmul_clipping`
- `test_matmul_clipping_correctness_unified`
- Other commutative variants tests
Total: All 15 tests pass
---------
Signed-off-by: Hyun Gyu Kim <kimm240@telepix.net>
Co-authored-by: Hyun Gyu Kim <kimm240@telepix.net>1 parent 4747a92 commit 4b2b639
File tree
2 files changed
+334
-212
lines changed- src/tir/schedule/primitive
- tests/python/tir-schedule
2 files changed
+334
-212
lines changed
0 commit comments