@@ -103,7 +103,7 @@ namespace Flux {
103103 auto norm = std::dynamic_pointer_cast<QKNorm>(blocks[" norm" ]);
104104
105105 auto qkv = qkv_proj->forward (ctx, x);
106- auto qkv_vec = split_qkv (ctx->ggml_ctx , qkv);
106+ auto qkv_vec = ggml_ext_chunk (ctx->ggml_ctx , qkv, 3 , 0 , true );
107107 int64_t head_dim = qkv_vec[0 ]->ne [0 ] / num_heads;
108108 auto q = ggml_reshape_4d (ctx->ggml_ctx , qkv_vec[0 ], head_dim, num_heads, qkv_vec[0 ]->ne [1 ], qkv_vec[0 ]->ne [2 ]);
109109 auto k = ggml_reshape_4d (ctx->ggml_ctx , qkv_vec[1 ], head_dim, num_heads, qkv_vec[1 ]->ne [1 ], qkv_vec[1 ]->ne [2 ]);
@@ -153,7 +153,7 @@ namespace Flux {
153153 if (use_mlp_silu_act) {
154154 x = ggml_ext_silu_act (ctx->ggml_ctx , x);
155155 } else {
156- x = ggml_gelu_inplace (ctx->ggml_ctx , x);
156+ x = ggml_ext_gelu (ctx->ggml_ctx , x, true );
157157 }
158158 x = mlp_2->forward (ctx, x);
159159 return x;
@@ -376,26 +376,23 @@ namespace Flux {
376376 auto k = ggml_concat (ctx->ggml_ctx , txt_k, img_k, 2 ); // [N, n_txt_token + n_img_token, n_head, d_head]
377377 auto v = ggml_concat (ctx->ggml_ctx , txt_v, img_v, 2 ); // [N, n_txt_token + n_img_token, n_head, d_head]
378378
379- auto attn = Rope::attention (ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
380- attn = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , attn, 0 , 2 , 1 , 3 )); // [n_txt_token + n_img_token, N, hidden_size]
379+ auto attn = Rope::attention (ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
381380 auto txt_attn_out = ggml_view_3d (ctx->ggml_ctx ,
382381 attn,
383382 attn->ne [0 ],
384- attn->ne [1 ],
385383 txt->ne [1 ],
384+ attn->ne [2 ],
386385 attn->nb [1 ],
387386 attn->nb [2 ],
388- 0 ); // [n_txt_token, N, hidden_size]
389- txt_attn_out = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , txt_attn_out, 0 , 2 , 1 , 3 )); // [N, n_txt_token, hidden_size]
387+ 0 ); // [N, n_txt_token, hidden_size]
390388 auto img_attn_out = ggml_view_3d (ctx->ggml_ctx ,
391389 attn,
392390 attn->ne [0 ],
393- attn->ne [1 ],
394391 img->ne [1 ],
392+ attn->ne [2 ],
395393 attn->nb [1 ],
396394 attn->nb [2 ],
397- attn->nb [2 ] * txt->ne [1 ]); // [n_img_token, N, hidden_size]
398- img_attn_out = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , img_attn_out, 0 , 2 , 1 , 3 )); // [N, n_img_token, hidden_size]
395+ txt->ne [1 ] * attn->nb [1 ]); // [N, n_img_token, hidden_size]
399396
400397 // calculate the img bloks
401398 img = ggml_add (ctx->ggml_ctx , img, ggml_mul (ctx->ggml_ctx , img_attn->post_attention (ctx, img_attn_out), img_mod1.gate ));
@@ -492,43 +489,29 @@ namespace Flux {
492489 }
493490
494491 auto x_mod = Flux::modulate (ctx->ggml_ctx , pre_norm->forward (ctx, x), mod.shift , mod.scale );
495- auto qkv_mlp = linear1->forward (ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
496- qkv_mlp = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , qkv_mlp, 2 , 0 , 1 , 3 )); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
497-
498- auto qkv = ggml_view_3d (ctx->ggml_ctx ,
499- qkv_mlp,
500- qkv_mlp->ne [0 ],
501- qkv_mlp->ne [1 ],
502- hidden_size * 3 ,
503- qkv_mlp->nb [1 ],
504- qkv_mlp->nb [2 ],
505- 0 ); // [hidden_size * 3 , N, n_token]
506- qkv = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , qkv, 1 , 2 , 0 , 3 )); // [N, n_token, hidden_size * 3]
507- auto mlp = ggml_view_3d (ctx->ggml_ctx ,
508- qkv_mlp,
509- qkv_mlp->ne [0 ],
510- qkv_mlp->ne [1 ],
511- mlp_hidden_dim * mlp_mult_factor,
512- qkv_mlp->nb [1 ],
513- qkv_mlp->nb [2 ],
514- qkv_mlp->nb [2 ] * hidden_size * 3 ); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
515- mlp = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , mlp, 1 , 2 , 0 , 3 )); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]
516-
517- auto qkv_vec = split_qkv (ctx->ggml_ctx , qkv); // q,k,v: [N, n_token, hidden_size]
492+ auto qkv_mlp = linear1->forward (ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor]
493+
494+ auto q = ggml_view_3d (ctx->ggml_ctx , qkv_mlp, hidden_size, qkv_mlp->ne [1 ], qkv_mlp->ne [2 ], qkv_mlp->nb [1 ], qkv_mlp->nb [2 ], 0 );
495+ auto k = ggml_view_3d (ctx->ggml_ctx , qkv_mlp, hidden_size, qkv_mlp->ne [1 ], qkv_mlp->ne [2 ], qkv_mlp->nb [1 ], qkv_mlp->nb [2 ], hidden_size * qkv_mlp->nb [0 ]);
496+ auto v = ggml_view_3d (ctx->ggml_ctx , qkv_mlp, hidden_size, qkv_mlp->ne [1 ], qkv_mlp->ne [2 ], qkv_mlp->nb [1 ], qkv_mlp->nb [2 ], hidden_size * 2 * qkv_mlp->nb [0 ]);
497+
518498 int64_t head_dim = hidden_size / num_heads;
519- auto q = ggml_reshape_4d (ctx->ggml_ctx , qkv_vec[0 ], head_dim, num_heads, qkv_vec[0 ]->ne [1 ], qkv_vec[0 ]->ne [2 ]); // [N, n_token, n_head, d_head]
520- auto k = ggml_reshape_4d (ctx->ggml_ctx , qkv_vec[1 ], head_dim, num_heads, qkv_vec[1 ]->ne [1 ], qkv_vec[1 ]->ne [2 ]); // [N, n_token, n_head, d_head]
521- auto v = ggml_reshape_4d (ctx->ggml_ctx , qkv_vec[2 ], head_dim, num_heads, qkv_vec[2 ]->ne [1 ], qkv_vec[2 ]->ne [2 ]); // [N, n_token, n_head, d_head]
522- q = norm->query_norm (ctx, q);
523- k = norm->key_norm (ctx, k);
524- auto attn = Rope::attention (ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
525499
500+ q = ggml_reshape_4d (ctx->ggml_ctx , ggml_cont (ctx->ggml_ctx , q), head_dim, num_heads, q->ne [1 ], q->ne [2 ]); // [N, n_token, n_head, d_head]
501+ k = ggml_reshape_4d (ctx->ggml_ctx , ggml_cont (ctx->ggml_ctx , k), head_dim, num_heads, k->ne [1 ], k->ne [2 ]); // [N, n_token, n_head, d_head]
502+ v = ggml_reshape_4d (ctx->ggml_ctx , ggml_cont (ctx->ggml_ctx , v), head_dim, num_heads, v->ne [1 ], v->ne [2 ]); // [N, n_token, n_head, d_head]
503+
504+ q = norm->query_norm (ctx, q);
505+ k = norm->key_norm (ctx, k);
506+ auto attn = Rope::attention (ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
507+
508+ auto mlp = ggml_view_3d (ctx->ggml_ctx , qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne [1 ], qkv_mlp->ne [2 ], qkv_mlp->nb [1 ], qkv_mlp->nb [2 ], hidden_size * 3 * qkv_mlp->nb [0 ]);
526509 if (use_yak_mlp) {
527510 mlp = ggml_ext_silu_act (ctx->ggml_ctx , mlp, false );
528511 } else if (use_mlp_silu_act) {
529512 mlp = ggml_ext_silu_act (ctx->ggml_ctx , mlp);
530513 } else {
531- mlp = ggml_gelu_inplace (ctx->ggml_ctx , mlp);
514+ mlp = ggml_ext_gelu (ctx->ggml_ctx , mlp, true );
532515 }
533516 auto attn_mlp = ggml_concat (ctx->ggml_ctx , attn, mlp, 0 ); // [N, n_token, hidden_size + mlp_hidden_dim]
534517 auto output = linear2->forward (ctx, attn_mlp); // [N, n_token, hidden_size]
@@ -580,13 +563,10 @@ namespace Flux {
580563 } else {
581564 auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks[" adaLN_modulation.1" ]);
582565
583- auto m = adaLN_modulation_1->forward (ctx, ggml_silu (ctx->ggml_ctx , c)); // [N, 2 * hidden_size]
584- m = ggml_reshape_3d (ctx->ggml_ctx , m, c->ne [0 ], 2 , c->ne [1 ]); // [N, 2, hidden_size]
585- m = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , m, 0 , 2 , 1 , 3 )); // [2, N, hidden_size]
586-
587- int64_t offset = m->nb [1 ] * m->ne [1 ];
588- shift = ggml_view_2d (ctx->ggml_ctx , m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 0 ); // [N, hidden_size]
589- scale = ggml_view_2d (ctx->ggml_ctx , m, m->ne [0 ], m->ne [1 ], m->nb [1 ], offset * 1 ); // [N, hidden_size]
566+ auto m = adaLN_modulation_1->forward (ctx, ggml_silu (ctx->ggml_ctx , c)); // [N, 2 * hidden_size]
567+ auto m_vec = ggml_ext_chunk (ctx->ggml_ctx , m, 2 , 0 );
568+ shift = m_vec[0 ]; // [N, hidden_size]
569+ scale = m_vec[1 ]; // [N, hidden_size]
590570 }
591571
592572 x = Flux::modulate (ctx->ggml_ctx , norm_final->forward (ctx, x), shift, scale);
@@ -1034,16 +1014,14 @@ namespace Flux {
10341014 txt_img = block->forward (ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
10351015 }
10361016
1037- txt_img = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , txt_img, 0 , 2 , 1 , 3 )); // [n_txt_token + n_img_token, N, hidden_size]
1038- img = ggml_view_3d (ctx->ggml_ctx ,
1039- txt_img,
1040- txt_img->ne [0 ],
1041- txt_img->ne [1 ],
1042- img->ne [1 ],
1043- txt_img->nb [1 ],
1044- txt_img->nb [2 ],
1045- txt_img->nb [2 ] * txt->ne [1 ]); // [n_img_token, N, hidden_size]
1046- img = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , img, 0 , 2 , 1 , 3 )); // [N, n_img_token, hidden_size]
1017+ img = ggml_view_3d (ctx->ggml_ctx ,
1018+ txt_img,
1019+ txt_img->ne [0 ],
1020+ img->ne [1 ],
1021+ txt_img->ne [2 ],
1022+ txt_img->nb [1 ],
1023+ txt_img->nb [2 ],
1024+ txt->ne [1 ] * txt_img->nb [1 ]); // [N, n_img_token, hidden_size]
10471025
10481026 if (final_layer) {
10491027 img = final_layer->forward (ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
@@ -1196,9 +1174,8 @@ namespace Flux {
11961174 auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
11971175
11981176 if (out->ne [1 ] > img_tokens) {
1199- out = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , out, 0 , 2 , 1 , 3 )); // [num_tokens, N, C * patch_size * patch_size]
1200- out = ggml_view_3d (ctx->ggml_ctx , out, out->ne [0 ], out->ne [1 ], img_tokens, out->nb [1 ], out->nb [2 ], 0 );
1201- out = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , out, 0 , 2 , 1 , 3 )); // [N, h*w, C * patch_size * patch_size]
1177+ out = ggml_view_3d (ctx->ggml_ctx , out, out->ne [0 ], img_tokens, out->ne [2 ], out->nb [1 ], out->nb [2 ], 0 );
1178+ out = ggml_cont (ctx->ggml_ctx , out);
12021179 }
12031180
12041181 // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
0 commit comments