Converting Split->conv->concat to Grouped conv#3124
Converting Split->conv->concat to Grouped conv#3124kumarappan-cmyk wants to merge 8 commits intoonnx:mainfrom
Conversation
Signed-off-by: Kumarappan <kumarappan.thiyagarajan@multicorewareinc.com>
|
Can one of the admins verify this patch? |
Signed-off-by: Kumarappan <kumarappan.thiyagarajan@multicorewareinc.com>
|
Can one of the admins verify this patch? |
|
Can one of the admins verify this patch? |
Signed-off-by: Kumarappan <kumarappan.thiyagarajan@multicorewareinc.com>
|
Can one of the admins verify this patch? |
|
@tungld Thanks for the review and verification, could you please approve and merge this patch! |
|
Can one of the admins verify this patch? |
|
@tungld thanks for the verification, could you please trigger the tests and merge the patch |
tungld
left a comment
There was a problem hiding this comment.
@kumarappan-cmyk thank you for the PR!
I put my 1st round of review. Please refactor the code to make it concise and please do use DialectBuilder.
For lit tests, could you put some tests that do not satisfy conditions for fusion, which makes sure that in such cases the recomposing pattern is not applied?
| llvm::SmallVector<ONNXConvOp, 2> convOps; | ||
| ONNXConcatOp concatOp; | ||
|
|
||
| // Ensure the pattern exists: Split → Conv → Concat |
There was a problem hiding this comment.
Please put . to the end of all comments.
| if (!weightType) | ||
| return failure(); | ||
| int64_t rank = weightType.getRank(); | ||
| if (1 >= rank) |
| return failure(); | ||
| int64_t rank = weightType.getRank(); | ||
| if (1 >= rank) | ||
| return failure(); // Ensure axis is within valid range |
There was a problem hiding this comment.
Not understand the comment. Is this check the same to the following check for axis?
|
|
||
| // **Concatenating Conv Weights Correctly** | ||
| SmallVector<Value, 2> weightTensors; | ||
| int64_t total_C_out = 0; |
There was a problem hiding this comment.
Please use camelCase instead of snake_case for naming.
| } | ||
|
|
||
| // Create correct IntegerAttrs | ||
| IntegerAttr axis0 = rewriter.getI64IntegerAttr(0); |
There was a problem hiding this comment.
Is this value rewriter.getI64IntegerAttr(0); unused? since you replace it by axis0 = IntegerAttr::get(si64Type, 0); later.
| biasTensors.push_back(conv.getB()); | ||
|
|
||
| Type newBiasType = | ||
| RankedTensorType::get({total_C_out}, weightType.getElementType()); |
There was a problem hiding this comment.
Define Type elementType = weightType.getElementType(); at the beginning of this function and reuse it to avoid boilerplate code.
|
|
||
| Type newBiasType = | ||
| RankedTensorType::get({total_C_out}, weightType.getElementType()); | ||
| axis0 = IntegerAttr::get(si64Type, 0); |
There was a problem hiding this comment.
Define this at the beginning of the function also. Thanks!
| // **Create new Grouped ConvOp** | ||
| auto newConv = rewriter.create<ONNXConvOp>(loc, resultType, input, | ||
| concatenatedWeight, hasBias ? concatenatedBias : Value(), autoPadAttr, | ||
| dilationsAttr, groupAttrVal, kernelShapeAttr, padsAttr, stridesAttr); |
There was a problem hiding this comment.
Please use DialectBuilder for conv, c.f. https://github.com/onnx/onnx-mlir/blob/main/src/Dialect/ONNX/DialectBuilder.hpp#L76
| concatenatedBias = | ||
| rewriter.create<ONNXConcatOp>(biasLoc, newBiasType, biasTensors, | ||
| axis0); // Bias should be concatenated along axis=0 | ||
| } |
There was a problem hiding this comment.
Please use DialectBuilder for concat.
| // RUN: onnx-mlir-opt --recompose-onnx --remove-dead-values --constprop-onnx %s -split-input-file | FileCheck %s | ||
|
|
||
| func.func @simple_split_conv_concat(%arg0: tensor<1x6x512x512xf64> {onnx.name = "input"}) -> (tensor<1x6x512x512xf64> {onnx.name = "output"}) { | ||
| %0 = onnx.Constant dense<[[[[-0.0017646604683250189, 0.12644097208976746, -0.19399359822273254], [-0.17346249520778656, -0.090781755745410919, 0.0632052943110466], [-0.0046700113452970982, 0.18688584864139557, -0.020917171612381935]], [[0.062369778752326965, -0.071232303977012634, -0.046330906450748444], [-0.22517779469490051, -0.15610139071941376, -0.097161918878555298], [0.008731253445148468, 0.093181401491165161, 0.14142672717571259]]], [[[-0.15979224443435669, -0.1026395708322525, 0.085611097514629364], [0.19572432339191437, -0.048507567495107651, 0.1763787716627121], [-0.037991281598806381, 0.024940622970461845, 0.21342279016971588]], [[-0.21865400671958923, -0.14838351309299469, -0.059671621769666672], [-0.09187673032283783, 0.2036469429731369, -0.15277740359306335], [-0.10850150138139725, -0.16467113792896271, -0.22074954211711884]]]]> : tensor<2x2x3x3xf64> |
There was a problem hiding this comment.
Please use splat constants for better reading since the values here are not the enssential part of the test.
This optimization fuses the pattern Split → Conv → Concat into a single Conv operation using grouped convolution when:
By converting multiple independent convolutions into a grouped convolution, we reduce memory usage, improve cache locality, and lower the number of kernel launches in backend runtimes.
Before optimization – Split+conv+concat

After Optimization – Conv with adjusted group
