Feat: Trtllm-gen MxFP8 MoE integration#2505
Feat: Trtllm-gen MxFP8 MoE integration#2505IwakuraRein wants to merge 26 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an FP8 quantization enum and MxFP8 support across Python, C++ launchers, benchmarks, and tests; threads a new Changes
Sequence Diagram(s)mermaid CLI->>Autotuner: parse --quant-mode (e.g., MxFP8xMxFP8) Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the TensorRT-LLM fused Mixture-of-Experts (MoE) implementation by integrating MxFP8 quantization. This integration provides a new, flexible FP8 quantization option alongside the existing DeepSeek FP8, allowing for fine-grained control over mixed-precision computations. The changes span core kernel logic, benchmarking, and testing, ensuring that the new quantization mode is robustly supported and validated across the system. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request integrates mxfp8 support into the trtllm fused MoE kernels. The changes are extensive, touching benchmark scripts, C++ kernel launchers, and Python bindings. The introduction of Fp8QuantizationType is a good refactoring that makes the code more extensible. The tests have also been updated to cover the new quantization modes.
My review focuses on improving code maintainability by reducing duplication in the benchmark scripts and C++ kernel launcher. I've also pointed out some leftover debugging code and minor issues that should be addressed before merging.
| print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") | ||
|
|
||
|
|
||
| def bench_trtllm_gen_fused_moe_autotuner_mxint4( |
There was a problem hiding this comment.
This function bench_trtllm_gen_fused_moe_autotuner_mxint4 is very similar to bench_trtllm_gen_fused_moe_autotuner_fp8 and bench_trtllm_gen_fused_moe_autotuner_fp4. To improve maintainability and reduce code duplication, consider refactoring these into a more generic benchmark function or a base class. This could accept quantization functions and the specific MoE kernel as parameters, centralizing the common benchmarking logic.
| FusedMoeLauncher::check_moe_common(); | ||
|
|
||
| TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) | ||
| << "hidden_states_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) | ||
| << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens) | ||
| << "hidden_states_scale dim1 must match num_tokens."; | ||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) | ||
| << "hidden_states_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) | ||
| << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens) | ||
| << "hidden_states_scale dim1 must match num_tokens."; | ||
| } else if (quantization_type == Fp8QuantizationType::MxFp8) { | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_uint8); | ||
| } | ||
|
|
||
| TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; | ||
|
|
||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) | ||
| << "gemm1_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) | ||
| << "gemm1_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| } else if (quantization_type == Fp8QuantizationType::MxFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_uint8) | ||
| << "gemm1_weights_scale must be uint8."; | ||
| } | ||
|
|
||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) | ||
| << "gemm2_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) | ||
| << "gemm2_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| } else if (quantization_type == Fp8QuantizationType::MxFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_uint8) | ||
| << "gemm2_weights_scale must be uint8."; | ||
| } | ||
|
|
||
| check_weights_shape("gemm1"); | ||
| check_weights_shape("gemm2"); | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
|
|
||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
| } | ||
| } |
There was a problem hiding this comment.
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
Hi @IwakuraRein . Currently we use this in sgl. However it seems like we are missing cubin for some dim. I build from src from this branch on this commit 1dc688d Context: we are building the sglang MXFP8 trtllm_moe runner along with mm_mxfp8 flashinfer modelopt linear, so this would be quite useful. If it turns out that my usages is wrong... user error. but even after inspect cubin, it seem like this shape should be available. Do you have any ideas? should there be tileSize=64 cubin? |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
@vincentzed Hi. There are tile size 64 cubins for mxfp8. I tried your problem shape and cannot reproduce the error. Could you try pull the latest commit? 1dc688d won't compile due to a typo so maybe flashinfer is using the old jit cache. |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
0adc056 to
aae1719
Compare
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tests/moe/test_trtllm_gen_fused_moe.py (1)
2151-2164:⚠️ Potential issue | 🟡 MinorMxFP8 reference quantization likely uses the wrong swizzling flag.
quantize_inputs/run_moe_reference_mxfp8useis_swizzling=False, but the FP8_BLOCK_SCALE_MXFP8 branch here forcesTrue, which can skew the reference output. Consider splitting the cases.🛠️ Suggested fix
- elif ( - quant_mode == QuantMode.FP4_MXFP4_MXFP8 - or quant_mode == QuantMode.FP8_BLOCK_SCALE_MXFP8 - ): - activation_output, scale_bytes = mxfp8_quantize( - activation_output.to(torch.bfloat16), True - ) + elif quant_mode == QuantMode.FP4_MXFP4_MXFP8: + activation_output, scale_bytes = mxfp8_quantize( + activation_output.to(torch.bfloat16), True + ) + elif quant_mode == QuantMode.FP8_BLOCK_SCALE_MXFP8: + activation_output, scale_bytes = mxfp8_quantize( + activation_output.to(torch.bfloat16), False + )flashinfer/fused_moe/core.py (2)
1608-1660:⚠️ Potential issue | 🟡 MinorValidate fp8_quantization_type for the block-scale op.
PassingNoneFp8currently falls into the MxE4m3 dtype path and then the per‑tensor branch, which lacks the required kwargs. A short guard makes failures explicit.🛡️ Suggested guard
- dtype_act = ( + if fp8_quantization_type not in ( + Fp8QuantizationType.DeepSeekFp8, + Fp8QuantizationType.MxFp8, + ): + raise ValueError( + "fp8_quantization_type must be DeepSeekFp8 or MxFp8 for block-scale MoE." + ) + + dtype_act = ( DtypeTrtllmGen.E4m3 if fp8_quantization_type == Fp8QuantizationType.DeepSeekFp8 else DtypeTrtllmGen.MxE4m3 ) # FP8 activation
2356-2364:⚠️ Potential issue | 🟡 MinorDocstring scale shapes for MxFP8 appear transposed.
Runtime checks expecthidden_states_scale.shape[0] == num_tokensfor MxFP8, but the docs say[hidden_size//…, seq_len]. Updating the docs will prevent API misuse.📝 Suggested doc tweak
- hidden_states_scale: [hidden_size//128, seq_len] tensor of hidden states block scales + hidden_states_scale: + - DeepSeekFp8: ignored (kernel generates [hidden_size//128, seq_len]) + - MxFp8: [seq_len, hidden_size//32] tensor of block scales @@ - hidden_states_scale: [hidden_size//(32 if mxfp8 else 128), seq_len] tensor of hidden states block scales + hidden_states_scale: + - DeepSeekFp8: ignored (kernel generates [hidden_size//128, seq_len]) + - MxFp8: [seq_len, hidden_size//32] tensor of block scalesAlso applies to: 2454-2459
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
890-950:⚠️ Potential issue | 🟡 MinorMxFp8 path lacks shape validation for scale tensors.
For
DeepSeekFp8, scale tensors get thorough ndim and dimension checks (e.g., lines 895-901, 914-922, 931-937). ForMxFp8, onlydtypeis verified — nondimor size checks. Mis-shaped MxFp8 scale tensors would silently pass validation and could cause out-of-bounds reads in the kernel.Consider adding at least basic shape assertions (e.g.,
ndim, total element count) for the MxFp8 scale tensors.Also, line 903 uses
TVM_FFI_CHECKwhile the rest of the file consistently usesTVM_FFI_ICHECK— minor inconsistency.
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1751-1777: The checks around quantization_type (the if/else-if
chains that validate hidden_states_scale, gemm1_weights/gemm2_weights and
gemm1_weights_scale/gemm2_weights_scale, and the MxFp8-specific checks)
currently only handle DeepSeekFp8 and MxFp8 and silently skip validation for
other enum values; add an explicit else branch (or a guard before these blocks)
that rejects unsupported Fp8 variants (e.g., NoneFp8, PerTensorFp8) by
asserting/failing with a clear message referencing Fp8QuantizationType and
quantization_type so callers get an early, descriptive error instead of a
downstream kernel failure (apply this to the hidden_states_scale/type checks,
the gemm*_weights_scale checks, and the use_shuffled_weight/weight_layout
checks).
In `@flashinfer/fused_moe/__init__.py`:
- Around line 17-21: The module imports Fp8QuantizationType but does not include
it in the public export list; update the __all__ declaration(s) in
flashinfer.fused_moe.__init__ to include "Fp8QuantizationType" so that from
flashinfer.fused_moe import * exposes it (also apply the same addition to the
other __all__ block referenced around lines 54-74); keep existing names
(ActivationType, RoutingMethodType, WeightLayout, etc.) intact and add the
string "Fp8QuantizationType" to each __all__ tuple/list.
🧹 Nitpick comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
989-1007: MxFp8 scale pointers areuint8data cast tofloat*— fragile type punning.For the MxFp8 path,
gemm1_output_scale,hidden_states_scale,gemm1_weights_scale, andgemm2_weights_scaleare alluint8tensors, but theirdata_ptr()isstatic_cast<float*>into struct fields typed asfloat*. This works today because the downstream runner passes them through asvoid*to kernels, but any future code that dereferences these asfloatwill produce UB.Consider using a
reinterpret_cast<float*>to signal intentional type-punning, or (preferably) widening the struct fields tovoid*if multiple dtypes are expected.Also,
workspace.activation_outputandworkspace.activation_output_scaleare never set on the MxFp8 path. They happen to go unused, but explicitly setting them tonullptrwould be safer.Suggested nullptr initialization for MxFp8
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { workspace.activation_output = activation_output.data_ptr(); workspace.activation_output_scale = static_cast<float*>(activation_output_scale.data_ptr()); + } else { + workspace.activation_output = nullptr; + workspace.activation_output_scale = nullptr; }csrc/trtllm_batched_gemm_runner.cu (1)
177-257: Value-init and mValid assignments look correct.Line 182: consider removing the commented-out
printfbefore merging — it's a debug artifact.
|
i will help review this |
|
|
||
| # The type of gated activation function | ||
| # Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h | ||
| class GatedActType(IntEnum): |
There was a problem hiding this comment.
we changed the activation type to include non-gated ones and have a function like isGatedActivation to check them
There was a problem hiding this comment.
Thanks. However, the only API change in this PR is adding a fp8_quantization_type to the trtllm_fp8_block_scale_routed_moe. It won't conflict with other PRs introducing non-gated activation I believe.
aleozlx
left a comment
There was a problem hiding this comment.
looks good overall.
posted a comment about GatedActType
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1073-1099:⚠️ Potential issue | 🟠 Major
getValidConfigsuses a differentRunnerconstructor thanprepare_moe_commonfor MxFp8, causing config mismatch.For MxFp8,
getValidConfigscreates the runner using the 5-param weights-only constructor (line 1085–1088):Runner(dtype_weights, /*useDeepSeekFp8=*/false, tile_N, use_shuffled_weight, weight_layout)But at runtime,
prepare_moe_common(lines 329–331) uses the 7-param act+weights constructor because the condition at line 323 checks forE4m3(notMxE4m3), which is false for MxFp8:Runner(mDtypeAct, mDtypeWeights, /*useDeepSeekFp8=*/false, tile_N, activation_type, ...)These constructors have different signatures and parameters (the 5-param variant lacks
activationType), so they may enumerate different kernel configs. This causes valid configs from autotuning to be rejected at runtime, potentially explaining the "No kernel found" errors for MxFp8 shapes.
🧹 Nitpick comments (2)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
89-89: Nit:scale_vec_sizeis unused in the MxFP8 path.When
quant_mode == "MxFP8xMxFP8",scale_vec_sizeis assigned32on this line but never referenced (it's only consumed inside theFp8-Blockbranch). Consider moving the assignment into theif quant_mode == "Fp8-Block"block.♻️ Suggested diff
- scale_vec_size = 128 if quant_mode == "Fp8-Block" else 32 if quant_mode == "Fp8-Block": + scale_vec_size = 128 # block scale quantization is too slow, so we use per-tensor quantization for nowcsrc/trtllm_fused_moe_kernel_launcher.cu (1)
44-63: C++ enum hasPerTensorFp8not present in the PythonFp8QuantizationType.The Python enum in
flashinfer/fused_moe/core.pydefinesNoneFp8=0,DeepSeekFp8=1,MxFp8=2, but the C++ side addsPerTensorFp8=3. If this variant isn't meant to be used from Python, consider adding a comment. Also, thedefaultlabel infp8QuantizationTypeToStringfalls through toNoneFp8, which silently masks unexpected values rather than flagging them.Suggested: make the default case explicit
switch (quantization_type) { - default: - case Fp8QuantizationType::NoneFp8: + case Fp8QuantizationType::NoneFp8: return "NoneFp8"; case Fp8QuantizationType::DeepSeekFp8: return "DeepSeekFp8"; case Fp8QuantizationType::MxFp8: return "MxFp8"; case Fp8QuantizationType::PerTensorFp8: return "PerTensorFp8"; + default: + return "Unknown(" + std::to_string(static_cast<int>(quantization_type)) + ")"; }
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
1090-1101:⚠️ Potential issue | 🔴 Critical
getValidConfigsuses weights-onlyRunnerconstructor, but MxFp8 runtime uses the two-dtype constructor — config index mismatch.For MxFp8 (where
dtype_act == MxE4m3anddtype_weights == MxE4m3),getValidConfigsat line 1091 creates theRunnerwith 5 parameters:Runner(dtype_weights, useDeepSeekFp8, tile_N, use_shuffled_weight, weight_layout). However, inprepare_moe_common(lines 333–335), the same MxFp8 scenario matches the else branch (the condition at line 327 checks for E4m3, not MxE4m3), causing it to call a different 7-parameter constructor:Runner(dtype_act, dtype_weights, useDeepSeekFp8, tile_tokens_dim, activation_type, use_shuffled_weight, weight_layout). Different constructors produce different valid config indices, so the autotuner may select a config that the runtime runner rejects, causing "No kernel found" errors.
1020-1022:⚠️ Potential issue | 🟡 MinorRemove unnecessary
static_cast<float*>on lines 1020–1022.The
args->hidden_states_scale,args->gemm1_weights_scale, andargs->gemm2_weights_scalefields inMoERunnerArgsare typed asvoid*, notfloat*. In the MxFp8 case, these holddl_uint8tensor pointers, so casting tofloat*is both unnecessary and misleading. Other code paths (e.g., lines 1180, 1189, 1419, 1430) assign these same fields without casting. Remove the casts and assigndata_ptr()directly.
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 179-180: In check_routing_logits_shape(), remove the unused local
declaration "int64_t intermediate_size_factor =
isGatedActivation(activation_type) ? 2 : 1;" that shadows the class member
intermediate_size_factor (defined on the class) or replace its usage to
reference the member instead; ensure the function uses the class member
intermediate_size_factor (or a properly named local if truly needed) so the
dead/shadowing local is eliminated.
- Around line 987-991: The MxFp8 branch under-allocates gemm1_output_scale by
using args->intermediate_size/32 instead of accounting for
intermediate_size_factor (causing under-allocation for gated activations);
update the computeSwizzledLayoutSFSize call in the Fp8QuantizationType::MxFp8
branch to use (intermediate_size_factor * args->intermediate_size) / 32 (i.e.
pass the full swizzled width consistent with the gemm1_output allocation) so
gemm1_output_scale and alloc_tensor({sf_size}, ...) match the actual
gemm1_output width; references: gemm1_output_scale, computeSwizzledLayoutSFSize,
max_num_padded_tokens_gemm1, args->intermediate_size, intermediate_size_factor,
Fp8QuantizationType::MxFp8.
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
|
[CANCELING] Pipeline #43998281: canceled |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1079-1105:⚠️ Potential issue | 🔴 Critical
getValidConfigsuses wrong Runner constructor for MxFp8, causing config mismatch with runtime.For MxFp8,
prepare_moe_common(lines 326–335) constructs the Runner with the two-dtype constructor (passingmDtypeAct,mDtypeWeights,activation_type) when the conditionE4m3 && E4m3 && mUseDeepSeekFp8is false. However,getValidConfigsalways uses the weights-only constructor (line 1091–1094), regardless ofquantization_type. This means config enumeration and the actual kernel runner see different valid config sets — the root cause of "No kernel found" errors at runtime.Proposed fix: branch getValidConfigs to match prepare_moe_common logic
for (int32_t tile_N : selected_tile_nums) { - auto moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( - dtype_weights, // dtype_weights for DeepSeek FP8 - quantization_type == Fp8QuantizationType::DeepSeekFp8, // useDeepSeekFp8 - tile_N, use_shuffled_weight, static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)); + std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner; + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( + dtype_weights, true /* useDeepSeekFp8 */, tile_N, use_shuffled_weight, + static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)); + } else { + // MxFp8: match two-dtype constructor from prepare_moe_common + moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( + dtype_weights, dtype_weights, false /* useDeepSeekFp8 */, tile_N, + ActivationType::Swiglu, use_shuffled_weight, + static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)); + } auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens);
🧹 Nitpick comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
1004-1012: MxFp8 path does not explicitly setworkspace.activation_output/workspace.activation_output_scale.Only the DeepSeekFp8 branch (lines 1007–1010) assigns these workspace pointers. The MxFp8 path relies on implicit zero-initialization. Consider explicitly setting them to
nullptrto be safe against future refactors whereprepare_moemight be re-entered or workspace partially reused.Proposed fix
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { workspace.activation_output = activation_output.data_ptr(); workspace.activation_output_scale = static_cast<float*>(activation_output_scale.data_ptr()); + } else { + workspace.activation_output = nullptr; + workspace.activation_output_scale = nullptr; }
1006-1006:static_cast<float*>on adl_uint8tensor for MxFp8 — type mismatch in workspace pointer.For MxFp8,
gemm1_output_scaleis allocated asdl_uint8(line 990), but line 1006 unconditionally casts it tofloat*. The kernel likely consumes the raw address, but this cast is misleading and could mask bugs if the workspace struct gains type-safety. Consider avoid*intermediate or a comment noting the intentional reinterpretation.
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
3e0dbdd to
03cac02
Compare
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
|
[FAILED] Pipeline #44028049: 14/20 passed |
📌 Description
Author: @nekorobov
Add the trtllm-gen mxfp8 moe. It uses the existing
trtllm_fp8_block_scale_moeapi and can be selected by settingfp8_quantization_type🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Bug Fixes
Tests
Chores