-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
Summary
I'm observing a performance gap between CUTLASS INT8 NCxHWx interleaved Conv2d Fprop kernel and TensorRT's equivalent kernel on Jetson AGX Orin. For a specific convolution configuration, CUTLASS is approximately 12% slower than TensorRT.
Environment
- Hardware: NVIDIA Jetson AGX Orin
- GPU Architecture: SM87 (Ampere)
- CUTLASS Version: 4.2 (tested with latest)
- TensorRT Version: 8.6 (bundled with JetPack)
Problem Description
I benchmarked the INT8 Conv2d Fprop kernel with NCxHWx<32> interleaved layout using both CUTLASS templates and TensorRT, and found a noticeable performance difference.
Convolution Parameters
| Parameter | Value |
|---|---|
| N (batch) | 10 |
| C (input channels) | 128 |
| H (input height) | 72 |
| W (input width) | 120 |
| K (output channels) | 128 |
| R (filter height) | 3 |
| S (filter width) | 3 |
| pad_h, pad_w | 1, 1 |
| stride_h, stride_w | 1, 1 |
| dilation_h, dilation_w | 1, 1 |
Performance Comparison
| Implementation | Epilogue Type | Time (μs) |
|---|---|---|
| CUTLASS NCxHWx<32> | Pertensor AlphaScaling | ~420 |
| CUTLASS NCxHWx<32> | Perchannel AlphaScaling | ~450 |
| CUTLASS NHWC | Pertensor AlphaScaling | ~438 |
| CUTLASS NHWC | Perchannel AlphaScaling | ~470 |
| TensorRT | Perchannel AlphaScaling | ~400 |
TensorRT Kernel Name (from nsys profiling)
sm80_xmma_fprop_implicit_gemm_interleaved_i8i8_i8i32_f32_nchw_vect_c_32kcrs_vect_c_32_nchw_vect_c_32_tilesize128x128x64_stage4_warpsize2x2x1_g1_tensor16x8x32_t1r3s3_execute_kernel_trt
From the kernel name, TensorRT appears to use:
- Tile size: 128x128x64
- 4 pipeline stages
- Warp configuration: 2x2x1
- Tensor instruction: 16x8x32
- Specialized for 3x3 filter (
t1r3s3)
CUTLASS Kernel Configuration
I am using the following CUTLASS template configuration:
using Conv2dKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
int8_t, cutlass::layout::TensorNCxHWx<32>, // Input: NCxHWx<32>
int8_t, cutlass::layout::TensorCxRSKx<32>, // Filter: CxRSKx<32>
int8_t, cutlass::layout::TensorNCxHWx<32>, // Output: NCxHWx<32>
int32_t, // Accumulator
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>, // ThreadBlock shape
cutlass::gemm::GemmShape<64, 64, 64>, // Warp shape
cutlass::gemm::GemmShape<16, 8, 32>, // Instruction shape
cutlass::epilogue::thread::LinearCombinationClamp<
int8_t,
64 / cutlass::sizeof_bits<int8_t>::value, // 8 elements per access
int32_t,
float,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
4, // Stages
cutlass::arch::OpMultiplyAddSaturate,
cutlass::conv::IteratorAlgorithm::kOptimized,
cutlass::conv::StrideSupport::kStrided,
16, 16 // Alignment
>::Kernel;Questions
How to modify and optimize cutlass int8 fprop2d kernel to meet with tensorrt's performance
Additional Information
I can provide additional nsys profiling reports if needed.
Thank you for your guidance on how to achieve TensorRT-level performance with CUTLASS templates!