-
Notifications
You must be signed in to change notification settings - Fork 649
[TORCH][MLIR] Added _sdpa_flash_attention op #4417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TORCH][MLIR] Added _sdpa_flash_attention op #4417
Conversation
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
2. Fixed op signature 3. Added DecomposeComplexOps template for (flash_attn, flash_attn_for_cpu) -> sdpa rewrite. 4. Lit test to check correct decomposition. Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for the _scaled_dot_product_flash_attention and _scaled_dot_product_flash_attention_for_cpu operations in the Torch-MLIR dialect. These operations are decomposed into the existing scaled_dot_product_attention operation.
Key Changes:
- Added decomposition patterns to convert flash attention ops to standard scaled dot product attention
- Registered the new operations in the ODS generator
- Added comprehensive test coverage for both new operations
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | Implements decomposition pattern template for both flash attention operations |
| include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td | Defines the two new flash attention operation signatures and parsing/printing logic |
| projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py | Registers the new operations with their type signatures |
| test/Dialect/Torch/decompose-complex-ops.mlir | Adds test cases verifying decomposition behavior for both operations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Following the discussion at iree-org/iree-turbine#1224