[TRANSFORMATIONS] Align fused SDPA output#34097
[TRANSFORMATIONS] Align fused SDPA output#34097mryzhov wants to merge 8 commits intoopenvinotoolkit:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Improve SDPA Fusion output shape alignment for 2D inputs by inserting a post-fusion squeeze when needed, and extend tests to cover 2D static/dynamic shapes.
Changes:
- Added
try_align_outputshelper to insertSqueezewhen fused SDPA output shape differs from the original. - Updated SDPA fusion matcher to use
try_align_outputsbefore replacing the matched subgraph. - Extended SDPA test utilities and added parametrized 2D-input SDPA fusion tests (static and dynamic).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp | Adds output-shape alignment logic (optional squeeze) and integrates it into the fusion replacement. |
| src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp | Adds squeeze/unsqueeze helpers and new parametrized tests for 2D input fusion scenarios. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp
Show resolved
Hide resolved
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Outdated
Show resolved
Hide resolved
…ion_test.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
| }); | ||
| }; | ||
|
|
||
| std::shared_ptr<ov::Node> try_align_outputs(const std::shared_ptr<ov::Node>& src_output, |
There was a problem hiding this comment.
can you please add a comment where we have a case of 2D tensors (Q, K, V)? Which model and why it happens?
There was a problem hiding this comment.
The issue was identified in the GeekbenchAI model. It is not related to the SDPA inputs themselves; rather, it occurs during fusion of the SDPA subgraph. When the subgraph is replaced with the SDPA op, the resulting output shape can change in some cases, making it incompatible with the remainder of the model. The simplest way to reproduce the issue is with 2D inputs.
Another solution would be fixing the SDPA shape validation logic, @mitruska fyi
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| } | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P(SDPAFusion, | ||
| SDPAUnsquieezeOutput, |
There was a problem hiding this comment.
Typo in parameter reference: "Unsquieeze" should be "Unsqueeze".
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Show resolved
Hide resolved
| comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); | ||
| } | ||
|
|
||
| class SDPASquieezeOutput : public TransformationTestsF, public ::testing::WithParamInterface<PartialShape> {}; |
There was a problem hiding this comment.
Typo in class name: "Squieeze" should be "Squeeze".
|
|
||
| class SDPASquieezeOutput : public TransformationTestsF, public ::testing::WithParamInterface<PartialShape> {}; | ||
|
|
||
| TEST_P(SDPASquieezeOutput, SDPAFusionTest_SquieezeOutput) { |
There was a problem hiding this comment.
Typo in test name: "Squieeze" should be "Squeeze".
| TEST_P(SDPASquieezeOutput, SDPAFusionTest_SquieezeOutput) { | |
| TEST_P(SDPASquieezeOutput, SDPAFusionTest_SqueezeOutput) { |
| } | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P(SDPAFusion, | ||
| SDPASquieezeOutput, |
There was a problem hiding this comment.
Typo in parameter reference: "Squieeze" should be "Squeeze".
| class SDPAUnsquieezeOutput : public TransformationTestsF, public ::testing::WithParamInterface<PartialShape> {}; | ||
|
|
||
| TEST_P(SDPAUnsquieezeOutput, SDPAFusionTest_UnsquieezeOutput) { |
There was a problem hiding this comment.
Typo in class name: "Unsquieeze" should be "Unsqueeze".
| class SDPAUnsquieezeOutput : public TransformationTestsF, public ::testing::WithParamInterface<PartialShape> {}; | |
| TEST_P(SDPAUnsquieezeOutput, SDPAFusionTest_UnsquieezeOutput) { | |
| class SDPAUnsqueezeOutput : public TransformationTestsF, public ::testing::WithParamInterface<PartialShape> {}; | |
| TEST_P(SDPAUnsqueezeOutput, SDPAFusionTest_UnsqueezeOutput) { |
|
|
||
| class SDPAUnsquieezeOutput : public TransformationTestsF, public ::testing::WithParamInterface<PartialShape> {}; | ||
|
|
||
| TEST_P(SDPAUnsquieezeOutput, SDPAFusionTest_UnsquieezeOutput) { |
There was a problem hiding this comment.
Typo in test name: "Unsquieeze" should be "Unsqueeze".
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Details:
This pull request enhances the SDPA Fusion transformation by improving shape alignment between the original and fused nodes, particularly for handling cases when the fused SDPA operation output changes the original output shape (2D->3D, 4D-3D)
SDPA Fusion Logic Improvements
try_align_outputshelper function to insert aSqueezenode when the fused SDPA node's output shape does not match the original node's output, ensuring correct shape alignment after fusion.try_align_outputsafter creating the fused node, so that 2D inputs are properly handled and output shapes remain consistent.Testing Enhancements
SDPAtest utility class withsqueezeandunsqueezemethods for more flexible test graph construction.SDPAclass for unsqueezing Q, K, V inputs and squeezing SDPA outputs, supporting more detailed test scenarios.Tickets: