Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes known issues in the MHA (Multi-Head Attention) C++ API by updating a deprecated function call and improving dimension padding logic.
Changes:
- Updated copyright year to 2026
- Replaced deprecated
composesfunction withmake_composes - Refactored dimension padding logic to remove assertions and handle mismatched dimensions gracefully
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| op_tests/cpp/mha/benchmark_mha_fwd.cpp | Updated copyright year and replaced deprecated composes function with make_composes |
| csrc/cpp_itfs/mha_bwd.cu | Removed assertions and added conditional logic to handle dimension padding cases |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| return std::make_tuple(hdim_q, hdim_v); | ||
| assert(hdim_q == hdim_v); | ||
| if(hdim_q <= 64) | ||
There was a problem hiding this comment.
This line contains only whitespace. Remove the trailing whitespace for cleaner code formatting.
| } | ||
|
|
||
| assert(false); | ||
| return std::make_tuple(hdim_q, hdim_v); |
There was a problem hiding this comment.
When hdim_q != hdim_v and the special case (hdim_q == 192 && hdim_v == 128) is not met, the function returns unpadded dimensions without any validation or error handling. This silently allows unsupported dimension combinations to proceed, which could lead to runtime errors. Consider adding validation or error handling for unsupported dimension pairs.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist