Skip to content

Commit 7837232

Browse files
authored
perf: make dit faster (#1228)
1 parent 4ccce02 commit 7837232

File tree

16 files changed

+204
-187
lines changed

16 files changed

+204
-187
lines changed

clip.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,9 @@ struct CLIPMLP : public GGMLBlock {
479479

480480
x = fc1->forward(ctx, x);
481481
if (use_gelu) {
482-
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
482+
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
483483
} else {
484-
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x);
484+
x = ggml_ext_gelu_quick(ctx->ggml_ctx, x, true);
485485
}
486486
x = fc2->forward(ctx, x);
487487
return x;

common.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class GEGLU : public UnaryBlock {
200200

201201
gate = ggml_cont(ctx->ggml_ctx, gate);
202202

203-
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
203+
gate = ggml_ext_gelu(ctx->ggml_ctx, gate, true);
204204

205205
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
206206

@@ -220,7 +220,7 @@ class GELU : public UnaryBlock {
220220
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
221221

222222
x = proj->forward(ctx, x);
223-
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
223+
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
224224
return x;
225225
}
226226
};
@@ -536,8 +536,8 @@ class AlphaBlender : public GGMLBlock {
536536
// image_only_indicator is always tensor([0.])
537537
float alpha = get_alpha();
538538
auto x = ggml_add(ctx->ggml_ctx,
539-
ggml_scale(ctx->ggml_ctx, x_spatial, alpha),
540-
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
539+
ggml_ext_scale(ctx->ggml_ctx, x_spatial, alpha),
540+
ggml_ext_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
541541
return x;
542542
}
543543
};

esrgan.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class ResidualDenseBlock : public GGMLBlock {
5151
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
5252
auto x5 = conv5->forward(ctx, x_cat);
5353

54-
x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x);
54+
x5 = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, x5, 0.2f), x);
5555
return x5;
5656
}
5757
};
@@ -76,7 +76,7 @@ class RRDB : public GGMLBlock {
7676
out = rdb2->forward(ctx, out);
7777
out = rdb3->forward(ctx, out);
7878

79-
out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x);
79+
out = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, out, 0.2f), x);
8080
return out;
8181
}
8282
};

flux.hpp

Lines changed: 37 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ namespace Flux {
103103
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
104104

105105
auto qkv = qkv_proj->forward(ctx, x);
106-
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
106+
auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true);
107107
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
108108
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
109109
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
@@ -153,7 +153,7 @@ namespace Flux {
153153
if (use_mlp_silu_act) {
154154
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
155155
} else {
156-
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
156+
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
157157
}
158158
x = mlp_2->forward(ctx, x);
159159
return x;
@@ -376,26 +376,23 @@ namespace Flux {
376376
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
377377
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
378378

379-
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
380-
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
379+
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
381380
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
382381
attn,
383382
attn->ne[0],
384-
attn->ne[1],
385383
txt->ne[1],
384+
attn->ne[2],
386385
attn->nb[1],
387386
attn->nb[2],
388-
0); // [n_txt_token, N, hidden_size]
389-
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
387+
0); // [N, n_txt_token, hidden_size]
390388
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
391389
attn,
392390
attn->ne[0],
393-
attn->ne[1],
394391
img->ne[1],
392+
attn->ne[2],
395393
attn->nb[1],
396394
attn->nb[2],
397-
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
398-
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
395+
txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size]
399396

400397
// calculate the img bloks
401398
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
@@ -492,43 +489,29 @@ namespace Flux {
492489
}
493490

494491
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
495-
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
496-
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
497-
498-
auto qkv = ggml_view_3d(ctx->ggml_ctx,
499-
qkv_mlp,
500-
qkv_mlp->ne[0],
501-
qkv_mlp->ne[1],
502-
hidden_size * 3,
503-
qkv_mlp->nb[1],
504-
qkv_mlp->nb[2],
505-
0); // [hidden_size * 3 , N, n_token]
506-
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
507-
auto mlp = ggml_view_3d(ctx->ggml_ctx,
508-
qkv_mlp,
509-
qkv_mlp->ne[0],
510-
qkv_mlp->ne[1],
511-
mlp_hidden_dim * mlp_mult_factor,
512-
qkv_mlp->nb[1],
513-
qkv_mlp->nb[2],
514-
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
515-
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]
516-
517-
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
492+
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor]
493+
494+
auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0);
495+
auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]);
496+
auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]);
497+
518498
int64_t head_dim = hidden_size / num_heads;
519-
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
520-
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
521-
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
522-
q = norm->query_norm(ctx, q);
523-
k = norm->key_norm(ctx, k);
524-
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
525499

500+
q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head]
501+
k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head]
502+
v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head]
503+
504+
q = norm->query_norm(ctx, q);
505+
k = norm->key_norm(ctx, k);
506+
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
507+
508+
auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]);
526509
if (use_yak_mlp) {
527510
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
528511
} else if (use_mlp_silu_act) {
529512
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
530513
} else {
531-
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp);
514+
mlp = ggml_ext_gelu(ctx->ggml_ctx, mlp, true);
532515
}
533516
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
534517
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
@@ -580,13 +563,10 @@ namespace Flux {
580563
} else {
581564
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
582565

583-
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
584-
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
585-
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
586-
587-
int64_t offset = m->nb[1] * m->ne[1];
588-
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
589-
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
566+
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
567+
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
568+
shift = m_vec[0]; // [N, hidden_size]
569+
scale = m_vec[1]; // [N, hidden_size]
590570
}
591571

592572
x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
@@ -1034,16 +1014,14 @@ namespace Flux {
10341014
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
10351015
}
10361016

1037-
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
1038-
img = ggml_view_3d(ctx->ggml_ctx,
1039-
txt_img,
1040-
txt_img->ne[0],
1041-
txt_img->ne[1],
1042-
img->ne[1],
1043-
txt_img->nb[1],
1044-
txt_img->nb[2],
1045-
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
1046-
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
1017+
img = ggml_view_3d(ctx->ggml_ctx,
1018+
txt_img,
1019+
txt_img->ne[0],
1020+
img->ne[1],
1021+
txt_img->ne[2],
1022+
txt_img->nb[1],
1023+
txt_img->nb[2],
1024+
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
10471025

10481026
if (final_layer) {
10491027
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
@@ -1196,9 +1174,8 @@ namespace Flux {
11961174
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
11971175

11981176
if (out->ne[1] > img_tokens) {
1199-
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
1200-
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
1201-
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
1177+
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
1178+
out = ggml_cont(ctx->ggml_ctx, out);
12021179
}
12031180

12041181
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)

0 commit comments

Comments
 (0)