Skip to content

[QST] How to optimize cutlass int8 fprop2d kernel with NCxHWx interleaved layout, which is ~12% slower than TensorRT on Jetson Orin #2996

@JJXiangJiaoJun

Description

@JJXiangJiaoJun

Summary

I'm observing a performance gap between CUTLASS INT8 NCxHWx interleaved Conv2d Fprop kernel and TensorRT's equivalent kernel on Jetson AGX Orin. For a specific convolution configuration, CUTLASS is approximately 12% slower than TensorRT.

Environment

  • Hardware: NVIDIA Jetson AGX Orin
  • GPU Architecture: SM87 (Ampere)
  • CUTLASS Version: 4.2 (tested with latest)
  • TensorRT Version: 8.6 (bundled with JetPack)

Problem Description

I benchmarked the INT8 Conv2d Fprop kernel with NCxHWx<32> interleaved layout using both CUTLASS templates and TensorRT, and found a noticeable performance difference.

Convolution Parameters

Parameter Value
N (batch) 10
C (input channels) 128
H (input height) 72
W (input width) 120
K (output channels) 128
R (filter height) 3
S (filter width) 3
pad_h, pad_w 1, 1
stride_h, stride_w 1, 1
dilation_h, dilation_w 1, 1

Performance Comparison

Implementation Epilogue Type Time (μs)
CUTLASS NCxHWx<32> Pertensor AlphaScaling ~420
CUTLASS NCxHWx<32> Perchannel AlphaScaling ~450
CUTLASS NHWC Pertensor AlphaScaling ~438
CUTLASS NHWC Perchannel AlphaScaling ~470
TensorRT Perchannel AlphaScaling ~400

TensorRT Kernel Name (from nsys profiling)

sm80_xmma_fprop_implicit_gemm_interleaved_i8i8_i8i32_f32_nchw_vect_c_32kcrs_vect_c_32_nchw_vect_c_32_tilesize128x128x64_stage4_warpsize2x2x1_g1_tensor16x8x32_t1r3s3_execute_kernel_trt

From the kernel name, TensorRT appears to use:

  • Tile size: 128x128x64
  • 4 pipeline stages
  • Warp configuration: 2x2x1
  • Tensor instruction: 16x8x32
  • Specialized for 3x3 filter (t1r3s3)

CUTLASS Kernel Configuration

I am using the following CUTLASS template configuration:

using Conv2dKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
  int8_t, cutlass::layout::TensorNCxHWx<32>,  // Input: NCxHWx<32>
  int8_t, cutlass::layout::TensorCxRSKx<32>,  // Filter: CxRSKx<32>
  int8_t, cutlass::layout::TensorNCxHWx<32>,  // Output: NCxHWx<32>
  int32_t,                                     // Accumulator
  cutlass::arch::OpClassTensorOp,
  cutlass::arch::Sm80,
  cutlass::gemm::GemmShape<128, 128, 64>,      // ThreadBlock shape
  cutlass::gemm::GemmShape<64, 64, 64>,        // Warp shape
  cutlass::gemm::GemmShape<16, 8, 32>,         // Instruction shape
  cutlass::epilogue::thread::LinearCombinationClamp<
    int8_t,
    64 / cutlass::sizeof_bits<int8_t>::value,  // 8 elements per access
    int32_t,
    float,
    cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
  >,
  cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
  4,                                           // Stages
  cutlass::arch::OpMultiplyAddSaturate,
  cutlass::conv::IteratorAlgorithm::kOptimized,
  cutlass::conv::StrideSupport::kStrided,
  16, 16                                       // Alignment
>::Kernel;

Questions

How to modify and optimize cutlass int8 fprop2d kernel to meet with tensorrt's performance

Additional Information

I can provide additional nsys profiling reports if needed.

Thank you for your guidance on how to achieve TensorRT-level performance with CUTLASS templates!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions