@@ -22,6 +22,7 @@ SDPA_vulkan::SDPA_vulkan()
2222
2323 pipeline_sdpa_fa = 0 ;
2424 use_flash_attention = false ;
25+ FA_UNROLL_SG_M = 1 ;
2526 FA_UNROLL_WG_M = 1 ;
2627
2728 use_cooperative_matrix = false ;
@@ -59,7 +60,7 @@ int SDPA_vulkan::create_pipeline(const Option& opt)
5960 use_bf16_cooperative_matrix = true ;
6061 }
6162
62- use_flash_attention = true && (opt.use_bf16_storage || opt.use_bf16_packed );
63+ use_flash_attention = use_cooperative_matrix && (opt.use_bf16_storage || opt.use_bf16_packed );
6364
6465 if (use_flash_attention)
6566 {
@@ -78,18 +79,29 @@ int SDPA_vulkan::create_pipeline(const Option& opt)
7879
7980 // assert coopmat_M != 0 && coopmat_N != 0 && coopmat_K != 0
8081
81- FA_UNROLL_WG_M = 2 ;
82-
83- // fa
82+ if (coopmat_N != coopmat_K)
83+ {
84+ // not implemented yet
85+ use_flash_attention = false ;
86+ }
87+ else
8488 {
85- std::vector<vk_specialization_type> specializations (1 + 5 );
89+ // fa
90+
91+ FA_UNROLL_SG_M = 2 ;
92+
93+ FA_UNROLL_WG_M = 2 ;
94+ // FA_UNROLL_WG_M = 1;
95+
96+ std::vector<vk_specialization_type> specializations (1 + 6 );
8697 specializations[0 ].i = attn_mask;
8798
8899 specializations[1 + 0 ].u32 = coopmat_M;
89100 specializations[1 + 1 ].u32 = coopmat_N;
90101 specializations[1 + 2 ].u32 = coopmat_K;
91102 specializations[1 + 3 ].u32 = coopmat_subgroup_size;
92- specializations[1 + 4 ].u32 = FA_UNROLL_WG_M;
103+ specializations[1 + 4 ].u32 = FA_UNROLL_SG_M;
104+ specializations[1 + 5 ].u32 = FA_UNROLL_WG_M;
93105
94106 pipeline_sdpa_fa = new Pipeline (vkdev);
95107 pipeline_sdpa_fa->set_subgroup_size (coopmat_subgroup_size);
@@ -285,6 +297,7 @@ int SDPA_vulkan::destroy_pipeline(const Option& opt)
285297 }
286298
287299 use_flash_attention = false ;
300+ FA_UNROLL_SG_M = 1 ;
288301 FA_UNROLL_WG_M = 1 ;
289302
290303 use_cooperative_matrix = false ;
@@ -436,7 +449,7 @@ int SDPA_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkM
436449 constants[12 ].i = attn_mask_blob.cstep ;
437450
438451 const int blocks_x = 1 ;
439- const int blocks_y = (src_seqlen + coopmat_M * FA_UNROLL_WG_M - 1 ) / (coopmat_M * FA_UNROLL_WG_M);
452+ const int blocks_y = (src_seqlen + coopmat_M * FA_UNROLL_SG_M * FA_UNROLL_WG_M - 1 ) / (coopmat_M * FA_UNROLL_SG_M * FA_UNROLL_WG_M);
440453
441454 VkMat dispatcher;
442455 dispatcher.w = (blocks_x * blocks_y) * (coopmat_subgroup_size * FA_UNROLL_WG_M);
0 commit comments