@@ -22,6 +22,7 @@ SDPA_vulkan::SDPA_vulkan()
2222
2323 pipeline_sdpa_fa = 0 ;
2424 use_flash_attention = false ;
25+ FA_UNROLL_WG_M = 1 ;
2526
2627 use_cooperative_matrix = false ;
2728 coopmat_M = 0 ;
@@ -77,19 +78,22 @@ int SDPA_vulkan::create_pipeline(const Option& opt)
7778
7879 // assert coopmat_M != 0 && coopmat_N != 0 && coopmat_K != 0
7980
81+ FA_UNROLL_WG_M = 2 ;
82+
8083 // fa
8184 {
82- std::vector<vk_specialization_type> specializations (1 + 4 );
85+ std::vector<vk_specialization_type> specializations (1 + 5 );
8386 specializations[0 ].i = attn_mask;
8487
8588 specializations[1 + 0 ].u32 = coopmat_M;
8689 specializations[1 + 1 ].u32 = coopmat_N;
8790 specializations[1 + 2 ].u32 = coopmat_K;
8891 specializations[1 + 3 ].u32 = coopmat_subgroup_size;
92+ specializations[1 + 4 ].u32 = FA_UNROLL_WG_M;
8993
9094 pipeline_sdpa_fa = new Pipeline (vkdev);
9195 pipeline_sdpa_fa->set_subgroup_size (coopmat_subgroup_size);
92- pipeline_sdpa_fa->set_local_size_xyz (coopmat_subgroup_size, 1 , 1 );
96+ pipeline_sdpa_fa->set_local_size_xyz (coopmat_subgroup_size * FA_UNROLL_WG_M , 1 , 1 );
9397 pipeline_sdpa_fa->create (LayerShaderType::sdpa_fa_cm, opt, specializations);
9498 }
9599 }
@@ -281,6 +285,7 @@ int SDPA_vulkan::destroy_pipeline(const Option& opt)
281285 }
282286
283287 use_flash_attention = false ;
288+ FA_UNROLL_WG_M = 1 ;
284289
285290 use_cooperative_matrix = false ;
286291 coopmat_M = 0 ;
@@ -430,11 +435,11 @@ int SDPA_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkM
430435 constants[11 ].i = top_blob.cstep ;
431436 constants[12 ].i = attn_mask_blob.cstep ;
432437
433- const int blocks_x = (src_seqlen + coopmat_M - 1 ) / (coopmat_M) ;
434- const int blocks_y = 1 ; // (out_embed_dim + coopmat_N - 1) / (coopmat_N );
438+ const int blocks_x = 1 ;
439+ const int blocks_y = (src_seqlen + coopmat_M * FA_UNROLL_WG_M - 1 ) / (coopmat_M * FA_UNROLL_WG_M );
435440
436441 VkMat dispatcher;
437- dispatcher.w = (blocks_x * blocks_y) * (coopmat_subgroup_size);
442+ dispatcher.w = (blocks_x * blocks_y) * (coopmat_subgroup_size * FA_UNROLL_WG_M );
438443 dispatcher.h = 1 ;
439444 dispatcher.c = num_heads;
440445
0 commit comments