Skip to content

Commit 306c8be

Browse files
committed
wip
1 parent 101fb03 commit 306c8be

File tree

3 files changed

+176
-109
lines changed

3 files changed

+176
-109
lines changed

src/layer/vulkan/sdpa_vulkan.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ SDPA_vulkan::SDPA_vulkan()
2222

2323
pipeline_sdpa_fa = 0;
2424
use_flash_attention = false;
25+
FA_UNROLL_WG_M = 1;
2526

2627
use_cooperative_matrix = false;
2728
coopmat_M = 0;
@@ -77,19 +78,22 @@ int SDPA_vulkan::create_pipeline(const Option& opt)
7778

7879
// assert coopmat_M != 0 && coopmat_N != 0 && coopmat_K != 0
7980

81+
FA_UNROLL_WG_M = 2;
82+
8083
// fa
8184
{
82-
std::vector<vk_specialization_type> specializations(1 + 4);
85+
std::vector<vk_specialization_type> specializations(1 + 5);
8386
specializations[0].i = attn_mask;
8487

8588
specializations[1 + 0].u32 = coopmat_M;
8689
specializations[1 + 1].u32 = coopmat_N;
8790
specializations[1 + 2].u32 = coopmat_K;
8891
specializations[1 + 3].u32 = coopmat_subgroup_size;
92+
specializations[1 + 4].u32 = FA_UNROLL_WG_M;
8993

9094
pipeline_sdpa_fa = new Pipeline(vkdev);
9195
pipeline_sdpa_fa->set_subgroup_size(coopmat_subgroup_size);
92-
pipeline_sdpa_fa->set_local_size_xyz(coopmat_subgroup_size, 1, 1);
96+
pipeline_sdpa_fa->set_local_size_xyz(coopmat_subgroup_size * FA_UNROLL_WG_M, 1, 1);
9397
pipeline_sdpa_fa->create(LayerShaderType::sdpa_fa_cm, opt, specializations);
9498
}
9599
}
@@ -281,6 +285,7 @@ int SDPA_vulkan::destroy_pipeline(const Option& opt)
281285
}
282286

283287
use_flash_attention = false;
288+
FA_UNROLL_WG_M = 1;
284289

285290
use_cooperative_matrix = false;
286291
coopmat_M = 0;
@@ -430,11 +435,11 @@ int SDPA_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkM
430435
constants[11].i = top_blob.cstep;
431436
constants[12].i = attn_mask_blob.cstep;
432437

433-
const int blocks_x = (src_seqlen + coopmat_M - 1) / (coopmat_M);
434-
const int blocks_y = 1; //(out_embed_dim + coopmat_N - 1) / (coopmat_N);
438+
const int blocks_x = 1;
439+
const int blocks_y = (src_seqlen + coopmat_M * FA_UNROLL_WG_M - 1) / (coopmat_M * FA_UNROLL_WG_M);
435440

436441
VkMat dispatcher;
437-
dispatcher.w = (blocks_x * blocks_y) * (coopmat_subgroup_size);
442+
dispatcher.w = (blocks_x * blocks_y) * (coopmat_subgroup_size * FA_UNROLL_WG_M);
438443
dispatcher.h = 1;
439444
dispatcher.c = num_heads;
440445

src/layer/vulkan/sdpa_vulkan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class SDPA_vulkan : public SDPA
3131
Pipeline* pipeline_sdpa_fa;
3232

3333
bool use_flash_attention;
34+
int FA_UNROLL_WG_M;
3435

3536
// cooperative matrix
3637
bool use_cooperative_matrix;

0 commit comments

Comments
 (0)