Skip to content

Comments

moefinalize_allreduce_fusion_kernel_oneshot_lamport - potential 20-30% lift?#11

Open
aidando73 wants to merge 8 commits intofwaifrom
aidand-kernel-fuse-3
Open

moefinalize_allreduce_fusion_kernel_oneshot_lamport - potential 20-30% lift?#11
aidando73 wants to merge 8 commits intofwaifrom
aidand-kernel-fuse-3

Conversation

@aidando73
Copy link

@aidando73 aidando73 commented Dec 28, 2025

Just a POC PR - RE: flashinfer-ai/flashinfer#2269 (comment)

Getting:

  batch size    fused-trtllm-ar (us)    baseline-do-finalize (us)  speedup (baseline/fused)
------------  ----------------------  ---------------------------  --------------------------
           8                     214                          255  1.192x
          32                     271                          338  1.247x
         128                     337                          446  1.323x

As compared to current implementation:

  batch size    fused-trtllm-ar (us)    baseline-do-finalize (us)  speedup (baseline/fused)
------------  ----------------------  ---------------------------  --------------------------
           8                     230                          256  1.113x
          32                     318                          339  1.066x
         128                     447                          464  1.038x

But it's a bit inconsistent - sometimes I'm getting:

  batch size    fused-trtllm-ar (us)    baseline-do-finalize (us)  speedup (baseline/fused)
------------  ----------------------  ---------------------------  --------------------------
           8                     255                          256  1.003x
          32                     335                          337  1.005x
         128                     456                          462  1.013x
  batch size    fused-trtllm-ar (us)    baseline-do-finalize (us)  speedup (baseline/fused)
------------  ----------------------  ---------------------------  --------------------------
           8                     200                          256  1.280x
          32                     244                          337  1.381x
         128                     270                          462  1.711x

But maybe there's a way to fix it?

(Numerics are also still a bit wonky - but wanted to check with you if you think this is a good direction to take before going further)

    fused      max_abs_diff=1.8872e+00, cos=0.9582
    baseline  max_abs_diff=9.1162e-01, cos=0.9788
(nvfp4 moe so some diff is expected)

Some other ideas I'm thinking about is:

  • Some ll128 / NvlinkTwoSided implementation?
  • Multimem?

Basically the current implementation is mainly bound waiting for the PUT to arrive the finalize and rms norm parts are fairly small - seems possible to overlap - but probs will be at most a 5-10% win

# Batch size = 32 - baseline
moefinalize_allreduce_oneshot timings (clocks): VEC_SIZE=8, moe_finalize=3638 (5.57%), ar_store=309 (0.47%), clear=1640 (2.51%), ar_load_total=52377 (80.17%), fused_op_total=7369 (11.28%), total=65333

@aidando73 aidando73 changed the base branch from main to fwai December 28, 2025 07:34
@aidando73 aidando73 changed the title moe_finalize_all_reduce - potential 20-30% lift? moefinalize_allreduce_fusion_kernel_oneshot_lamport - potential 20-30% lift? Dec 28, 2025

// * Clear previous buffer
for (int idx = access_id; idx < clear_access; idx += access_stride) {
clear_vec.store(reinterpret_cast<T*>(comm.clear_buf) + idx * VEC_SIZE);
Copy link
Author

@aidando73 aidando73 Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh - wait I'm missing clear_vec - so my numbers are actually off. Since the wait loop won't wait for the PUT to arrive - since previous PUT is still there.

Probably why my numerics are a bit off as well. Let me fix.

Copy link
Author

@aidando73 aidando73 Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yep - that was why I was getting faster - my bad. Getting similar results as before:

  batch size    fused-trtllm-ar (us)    baseline-do-finalize (us)  speedup (baseline/fused)
------------  ----------------------  ---------------------------  --------------------------
           8                     242                          264  1.091x
          32                     335                          354  1.057x
         128                     367                          380  1.035x

Hmm - have we hit hardware limits then

@aidando73
Copy link
Author

aidando73 commented Jan 2, 2026

I think overlapping doesn't really make sense actually.

+--------+----------------------------------+--------+------------------------------------------+
| B200   |                                  |        | 32 token case                             |
+--------+----------------------------------+--------+------------------------------------------+
| 160    | SMs                              | 32     | tokens                                   |
| 32     | blocks per SM - limit            | 7168   | elements per token                       |
| 64     | concurrent warps / SM - limit    | 2      | bytes per element (bfloat16 case)        |
| 32     | threads per warp                 | 229376 | Total elements                           |
| 327680 | max concurrent threads           |        |                                          |
+--------+----------------------------------+--------+------------------------------------------+

Because we can assign one element to one thread - there's no serial workload to really overlap.

If we completely optimized away moe_finalize and fused_op, we'd get at most a 5% speedup (I tried commenting them out - 1.5% moe_finalize, 3% fused_op). So the only meaningful gain we'd get here is to reduce latency of the PUT essentially.

Copy link
Contributor

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reviewed. this looks good

we likely have no problem integrating it as well if you want to send a PR to upstream FI

cross linking
flashinfer-ai/flashinfer#2269

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants