Skip to content

Commit c8bbea2

Browse files
committed
qwq
1 parent 306c8be commit c8bbea2

File tree

3 files changed

+543
-96
lines changed

3 files changed

+543
-96
lines changed

src/layer/vulkan/sdpa_vulkan.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ SDPA_vulkan::SDPA_vulkan()
2222

2323
pipeline_sdpa_fa = 0;
2424
use_flash_attention = false;
25+
FA_UNROLL_SG_M = 1;
2526
FA_UNROLL_WG_M = 1;
2627

2728
use_cooperative_matrix = false;
@@ -59,7 +60,7 @@ int SDPA_vulkan::create_pipeline(const Option& opt)
5960
use_bf16_cooperative_matrix = true;
6061
}
6162

62-
use_flash_attention = true && (opt.use_bf16_storage || opt.use_bf16_packed);
63+
use_flash_attention = use_cooperative_matrix && (opt.use_bf16_storage || opt.use_bf16_packed);
6364

6465
if (use_flash_attention)
6566
{
@@ -78,18 +79,29 @@ int SDPA_vulkan::create_pipeline(const Option& opt)
7879

7980
// assert coopmat_M != 0 && coopmat_N != 0 && coopmat_K != 0
8081

81-
FA_UNROLL_WG_M = 2;
82-
83-
// fa
82+
if (coopmat_N != coopmat_K)
83+
{
84+
// not implemented yet
85+
use_flash_attention = false;
86+
}
87+
else
8488
{
85-
std::vector<vk_specialization_type> specializations(1 + 5);
89+
// fa
90+
91+
FA_UNROLL_SG_M = 2;
92+
93+
FA_UNROLL_WG_M = 2;
94+
// FA_UNROLL_WG_M = 1;
95+
96+
std::vector<vk_specialization_type> specializations(1 + 6);
8697
specializations[0].i = attn_mask;
8798

8899
specializations[1 + 0].u32 = coopmat_M;
89100
specializations[1 + 1].u32 = coopmat_N;
90101
specializations[1 + 2].u32 = coopmat_K;
91102
specializations[1 + 3].u32 = coopmat_subgroup_size;
92-
specializations[1 + 4].u32 = FA_UNROLL_WG_M;
103+
specializations[1 + 4].u32 = FA_UNROLL_SG_M;
104+
specializations[1 + 5].u32 = FA_UNROLL_WG_M;
93105

94106
pipeline_sdpa_fa = new Pipeline(vkdev);
95107
pipeline_sdpa_fa->set_subgroup_size(coopmat_subgroup_size);
@@ -285,6 +297,7 @@ int SDPA_vulkan::destroy_pipeline(const Option& opt)
285297
}
286298

287299
use_flash_attention = false;
300+
FA_UNROLL_SG_M = 1;
288301
FA_UNROLL_WG_M = 1;
289302

290303
use_cooperative_matrix = false;
@@ -436,7 +449,7 @@ int SDPA_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkM
436449
constants[12].i = attn_mask_blob.cstep;
437450

438451
const int blocks_x = 1;
439-
const int blocks_y = (src_seqlen + coopmat_M * FA_UNROLL_WG_M - 1) / (coopmat_M * FA_UNROLL_WG_M);
452+
const int blocks_y = (src_seqlen + coopmat_M * FA_UNROLL_SG_M * FA_UNROLL_WG_M - 1) / (coopmat_M * FA_UNROLL_SG_M * FA_UNROLL_WG_M);
440453

441454
VkMat dispatcher;
442455
dispatcher.w = (blocks_x * blocks_y) * (coopmat_subgroup_size * FA_UNROLL_WG_M);

src/layer/vulkan/sdpa_vulkan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class SDPA_vulkan : public SDPA
3131
Pipeline* pipeline_sdpa_fa;
3232

3333
bool use_flash_attention;
34+
int FA_UNROLL_SG_M;
3435
int FA_UNROLL_WG_M;
3536

3637
// cooperative matrix

0 commit comments

Comments
 (0)