Skip to content

Commit 470c5ca

Browse files
authored
[MLIR][XeGPU] Fix insert_strided_slice op in subgroup distribution (#180604)
The PR modifies the subgroup distribution pass to only sink insert_strided_slice operation if it becomes the last op before yield. It avoids sinking insert_strided_slice multiple times and cause potential issue in worst case.
1 parent 4136d3f commit 470c5ca

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,8 +1757,11 @@ struct VectorInsertStridedSliceDistribution
17571757
using gpu::WarpDistributionPattern::WarpDistributionPattern;
17581758
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
17591759
PatternRewriter &rewriter) const override {
1760-
OpOperand *operand =
1761-
getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1760+
OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
1761+
// Check if the InsertStridedSliceOp is the last op before yield op
1762+
return llvm::IsaPred<vector::InsertStridedSliceOp>(op) &&
1763+
warpOp.getTerminator()->getPrevNode() == op;
1764+
});
17621765
if (!operand)
17631766
return failure();
17641767
unsigned int operandNumber = operand->getOperandNumber();

0 commit comments

Comments
 (0)