moefinalize_allreduce_fusion_kernel_oneshot_lamport - potential 20-30% lift?#11
moefinalize_allreduce_fusion_kernel_oneshot_lamport - potential 20-30% lift?#11
Conversation
Add LSE suport in all attention kernels
GLM has negative bias.
|
|
||
| // * Clear previous buffer | ||
| for (int idx = access_id; idx < clear_access; idx += access_stride) { | ||
| clear_vec.store(reinterpret_cast<T*>(comm.clear_buf) + idx * VEC_SIZE); |
There was a problem hiding this comment.
Oh - wait I'm missing clear_vec - so my numbers are actually off. Since the wait loop won't wait for the PUT to arrive - since previous PUT is still there.
Probably why my numerics are a bit off as well. Let me fix.
There was a problem hiding this comment.
Ah yep - that was why I was getting faster - my bad. Getting similar results as before:
batch size fused-trtllm-ar (us) baseline-do-finalize (us) speedup (baseline/fused)
------------ ---------------------- --------------------------- --------------------------
8 242 264 1.091x
32 335 354 1.057x
128 367 380 1.035x
Hmm - have we hit hardware limits then
|
I think overlapping doesn't really make sense actually. Because we can assign one element to one thread - there's no serial workload to really overlap. If we completely optimized away moe_finalize and fused_op, we'd get at most a 5% speedup (I tried commenting them out - 1.5% moe_finalize, 3% fused_op). So the only meaningful gain we'd get here is to reduce latency of the PUT essentially. |
aleozlx
left a comment
There was a problem hiding this comment.
reviewed. this looks good
we likely have no problem integrating it as well if you want to send a PR to upstream FI
cross linking
flashinfer-ai/flashinfer#2269
Just a POC PR - RE: flashinfer-ai/flashinfer#2269 (comment)
Getting:
As compared to current implementation:
But it's a bit inconsistent - sometimes I'm getting:
But maybe there's a way to fix it?
(Numerics are also still a bit wonky - but wanted to check with you if you think this is a good direction to take before going further)
Some other ideas I'm thinking about is:
Basically the current implementation is mainly bound waiting for the PUT to arrive the finalize and rms norm parts are fairly small - seems possible to overlap - but probs will be at most a 5-10% win