Skip to content

Feat: Trtllm-gen MxFP8 MoE integration#2505

Open
IwakuraRein wants to merge 26 commits intoflashinfer-ai:mainfrom
IwakuraRein:siyuanf/mxfp8-trtllm-integration
Open

Feat: Trtllm-gen MxFP8 MoE integration#2505
IwakuraRein wants to merge 26 commits intoflashinfer-ai:mainfrom
IwakuraRein:siyuanf/mxfp8-trtllm-integration

Conversation

@IwakuraRein
Copy link
Collaborator

@IwakuraRein IwakuraRein commented Feb 6, 2026

📌 Description

Author: @nekorobov

Add the trtllm-gen mxfp8 moe. It uses the existing trtllm_fp8_block_scale_moe api and can be selected by setting fp8_quantization_type

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added MxFP8 as an FP8 quantization option, exposed via CLI and public API.
  • Refactor

    • Replaced boolean FP8 flag with an explicit quantization-type enum propagated across MoE interfaces and launchers.
  • Bug Fixes

    • Added stricter config validation and bounds checks with clearer error messages.
  • Tests

    • Extended tests to cover DeepSeek and MxFP8 variants and new configuration constraints.
  • Chores

    • Updated runtime artifact checksums/paths.

nekorobov and others added 7 commits February 5, 2026 03:37
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an FP8 quantization enum and MxFP8 support across Python, C++ launchers, benchmarks, and tests; threads a new fp8_quantization_type through MoE entry points, config generation, kernel launcher wiring, and autotuner dispatch, plus artifact checksum updates and diagnostic logging.

Changes

Cohort / File(s) Summary
Benchmark & CLI
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Adds MxFP8xMxFP8 quant_mode, routes FP8 flows to distinct DeepSeek/MxFp8 paths, passes fp8_quantization_type to autotuner, prints scale-shape diagnostics.
Core Python API
flashinfer/fused_moe/core.py, flashinfer/fused_moe/__init__.py
Adds GatedActType and Fp8QuantizationType enums; replaces use_deepseek_fp8 boolean with fp8_quantization_type across MoE APIs; branches dtype/scale sourcing and public signatures for DeepSeek vs MxFp8.
CUDA Kernel Launcher & Public API
csrc/trtllm_fused_moe_kernel_launcher.cu, flashinfer/fused_moe/__init__.py
Introduces Fp8QuantizationType and string helper; threads quantization type through launcher constructors, getValidConfigs, tile selection, allocations, validation messages, and run paths to support DeepSeek/MxFp8/per-tensor variants.
CUDA Runners / GEMM Init
csrc/trtllm_batched_gemm_runner.cu, csrc/trtllm_fused_moe_runner.cu
Value-initializes BatchedGemmData, sets valid M/N/K fields, and adds bounds checks for configIndex in runner APIs.
Python package artifacts
flashinfer/artifacts.py
Updates TRTLLM_GEN_BMM artifact path and checksum constants.
Tests & Utilities
tests/moe/test_trtllm_gen_fused_moe.py, tests/moe/test_dpsk_fused_moe_fp8.py, tests/moe/utils.py
Adds MXFp8 test branches, MXFp8 quantize/dequantize helpers and MXFp8 reference runner; expands QuantMode with FP8_BLOCK_SCALE_MXFP8/FP8_PER_TENSOR; updates skip_checks to require shuffle/layout for MxFp8 and maps tests to new Fp8QuantizationType.

Sequence Diagram(s)

mermaid
sequenceDiagram
rect rgba(200,200,255,0.5)
participant CLI as CLI (bench/autotuner)
end
rect rgba(200,255,200,0.5)
participant Autotuner as Autotuner/Python
participant Core as fused_moe Core
end
rect rgba(255,200,200,0.5)
participant Launcher as C++ Launcher
participant GPU as GPU Kernel
end

CLI->>Autotuner: parse --quant-mode (e.g., MxFP8xMxFP8)
Autotuner->>Core: call autotune/run with fp8_quantization_type
Core->>Launcher: request valid configs / instantiate with quantization_type
Launcher->>GPU: launch kernels with quantization-aware buffers/scales
GPU-->>Launcher: return profiling/results
Launcher-->>Core: pass profiling/results
Core-->>Autotuner: autotuner records best config

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • djmmoss
  • cyx-6
  • bkryu
  • nvmbreughe
  • yzh119
  • aleozlx
  • joker-eph

Poem

🐇 I hopped through enums and tiny byte scales,
Routed MxFP8 flows across kernels and trails,
Autotuners hum, launchers line up the crew,
Scales reshape, tests follow — a quantized view,
🥕 Rabbit cheers: kernels tuned and passing too!

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.93% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The description mentions the main feature but lacks details on the motivation, related issues, and checklist items are unchecked. It provides minimal context beyond the feature name. Expand the description with rationale for the feature, mention related issues if any, and either complete or remove unchecked checklist items to clarify the PR's readiness.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Feat: Trtllm-gen MxFP8 MoE integration' clearly summarizes the main change: adding MxFP8 quantization support to the TrtLLM-gen MoE implementation.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the TensorRT-LLM fused Mixture-of-Experts (MoE) implementation by integrating MxFP8 quantization. This integration provides a new, flexible FP8 quantization option alongside the existing DeepSeek FP8, allowing for fine-grained control over mixed-precision computations. The changes span core kernel logic, benchmarking, and testing, ensuring that the new quantization mode is robustly supported and validated across the system.

Highlights

  • MxFP8 Quantization Integration: Introduced support for MxFP8 (mixed FP8) quantization within the TensorRT-LLM fused Mixture-of-Experts (MoE) kernels, allowing for more flexible and potentially optimized FP8 operations.
  • Fp8QuantizationType Enum: A new Fp8QuantizationType enum was added to differentiate between various FP8 quantization schemes, including DeepSeek FP8 and the newly integrated MxFp8, enabling explicit control over the quantization method used.
  • Benchmarking and Testing Expansion: The benchmarking suite (bench_trtllm_gen_fused_moe_autotuner.py) and unit tests (test_trtllm_gen_fused_moe.py) were extended to cover the new MxFP8 quantization mode, ensuring its correctness and performance characteristics are validated.
  • Kernel Configuration Adjustments: Modifications were made to the C++ kernel launcher (trtllm_fused_moe_kernel_launcher.cu) and batched GEMM runner (trtllm_batched_gemm_runner.cu) to correctly handle the specific requirements and configurations of MxFP8, including dtype handling and skipping incompatible configurations.
  • Artifact Updates: The TRTLLM_GEN_BMM artifact path and checksum in flashinfer/artifacts.py were updated, indicating changes to the pre-compiled batched GEMM kernels.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
    • Imported partial from functools and Fp8QuantizationType from flashinfer.fused_moe.
    • Added a new mxint4_quantize function for mixed-integer 4-bit quantization.
    • Extended the quant_mode literal type to include "MxFP8xMxFP8" and "MxInt4xBf16".
    • Modified the quantization logic to handle "MxFP8xMxFP8" mode, utilizing mxfp8_quantize and reshaping scales accordingly.
    • Updated bench_gpu_time calls to include enable_cupti, use_cuda_graph, input_kwargs, and cold_l2_cache for more comprehensive benchmarking.
    • Introduced bench_trtllm_gen_fused_moe_autotuner_mxint4 for benchmarking MxInt4 quantization.
    • Refactored the main execution block to dynamically select the appropriate benchmark function based on args.quant_mode.
  • csrc/trtllm_batched_gemm_runner.cu
    • Added a new condition to skip specific configurations for MxE4m3 dtypes when mNumSlicesForSplitK > 1 to prevent unsupported operations.
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Defined a new Fp8QuantizationType enum to explicitly manage different FP8 quantization types.
    • Modified the Fp8BlockScaleLauncher constructor and init method to accept and utilize the new Fp8QuantizationType.
    • Adjusted the mDtypeAct and mDtypeWeights assignments within Fp8BlockScaleLauncher::init based on the quantization_type.
    • Updated the mUseDeepSeekFp8 flag logic to be conditional on quantization_type == Fp8QuantizationType::DeepSeekFp8.
    • Enhanced Fp8BlockScaleLauncher::check_moe_common to perform dtype and dimension checks for scale tensors specific to DeepSeekFp8 and MxFp8.
    • Modified the allocation of gemm1_output_scale and activation_output_scale to adapt to the chosen quantization_type.
    • Added Fp8QuantizationType as a parameter to Fp8BlockScaleLauncher::getValidConfigs.
    • Included Fp8QuantizationType as a parameter in the trtllm_fp8_block_scale_moe function signature.
    • Updated trtllm_get_valid_moe_configs to use Fp8QuantizationType instead of a boolean useDeepSeekFp8 for more granular control.
  • flashinfer/artifacts.py
    • Updated the TRTLLM_GEN_BMM artifact path and its corresponding checksum to reflect new pre-compiled binaries.
  • flashinfer/fused_moe/init.py
    • Imported the newly defined Fp8QuantizationType enum.
  • flashinfer/fused_moe/core.py
    • Introduced GatedActType and Fp8QuantizationType enums for better type management.
    • Refactored TrtllmGenFusedMoE class to use fp8_quantization_type instead of the boolean use_deepseek_fp8.
    • Modified the forward method to correctly handle MxFp8 quantization for hidden_states_scale based on the fp8_quantization_type.
    • Ensured fp8_quantization_type is passed to all relevant moe_op.trtllm_fp8_block_scale_moe calls.
    • Updated the signatures of trtllm_bf16_moe_op, trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp8_block_scale_moe_op, trtllm_fp4_block_scale_moe_op, trtllm_mxint4_block_scale_moe_op, trtllm_fp8_block_scale_moe, and trtllm_fp8_block_scale_routed_moe to include the fp8_quantization_type parameter.
  • tests/moe/test_trtllm_gen_fused_moe.py
    • Imported Fp8QuantizationType for use in tests.
    • Modified the FP8BlockScaleMoe class to accept fp8_quantization_type in its constructor and use it for the quant_mode property.
    • Updated quantize_weights in FP8BlockScaleMoe to use mxfp8_quantize_batches for MxFp8 mode.
    • Adjusted quantize_inputs in FP8BlockScaleMoe to use mxfp8_quantize for MxFp8 mode, including handling of swizzling.
    • Modified prepare_static_weights_for_kernel to dynamically set epilogue_tile_m and reorder weights for MxFp8 based on the quantization type.
    • Ensured quantization_mode is passed to the trtllm_fp8_block_scale_moe function call.
    • Updated compute_reference in FP8BlockScaleMoe to call run_moe_reference_mxfp8 for MxFp8 mode.
    • Added new helper functions mxfp8_quantize_batches, mxfp8_dequantize_batches, and run_moe_reference_mxfp8 for MxFp8 testing.
    • Modified run_moe_dequant to correctly handle FP8_BLOCK_SCALE_MXFP8.
    • Updated pytest.param definitions to explicitly differentiate between FP8_Block_DeepSeek and FP8_Block_MxFp8 test cases.
  • tests/moe/utils.py
    • Imported WeightLayout from flashinfer.fused_moe.
    • Renamed QuantMode.FP8_BLOCK_SCALE to QuantMode.FP8_BLOCK_SCALE_DEEPSEEK and added QuantMode.FP8_BLOCK_SCALE_MXFP8.
    • Added specific skip checks for MxFp8 quantization, enforcing that use_shuffled_weight must be true and weight_layout must be MajorK.
Activity
  • The pull request introduces mxfp8 (mixed FP8) quantization support for TensorRT-LLM fused MoE kernels.
  • It involves significant changes across C++ kernel implementations, Python bindings, benchmarking scripts, and unit tests.
  • The author, IwakuraRein, has provided a template for the PR description, but the specific details for this PR are left empty, indicating that the changes are primarily code-driven and self-explanatory through the diffs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates mxfp8 support into the trtllm fused MoE kernels. The changes are extensive, touching benchmark scripts, C++ kernel launchers, and Python bindings. The introduction of Fp8QuantizationType is a good refactoring that makes the code more extensible. The tests have also been updated to cover the new quantization modes.

My review focuses on improving code maintainability by reducing duplication in the benchmark scripts and C++ kernel launcher. I've also pointed out some leftover debugging code and minor issues that should be addressed before merging.

print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms")


def bench_trtllm_gen_fused_moe_autotuner_mxint4(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function bench_trtllm_gen_fused_moe_autotuner_mxint4 is very similar to bench_trtllm_gen_fused_moe_autotuner_fp8 and bench_trtllm_gen_fused_moe_autotuner_fp4. To improve maintainability and reduce code duplication, consider refactoring these into a more generic benchmark function or a base class. This could accept quantization functions and the specific MoE kernel as parameters, centralizing the common benchmarking logic.

Comment on lines 855 to 912
FusedMoeLauncher::check_moe_common();

TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8.";
TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32)
<< "hidden_states_scale must be float.";
TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D.";
TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128)
<< "hidden_states_scale dim0 must match hidden_states dim1 / 128.";
TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens)
<< "hidden_states_scale dim1 must match num_tokens.";
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32)
<< "hidden_states_scale must be float.";
TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D.";
TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128)
<< "hidden_states_scale dim0 must match hidden_states dim1 / 128.";
TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens)
<< "hidden_states_scale dim1 must match num_tokens.";
} else if (quantization_type == Fp8QuantizationType::MxFp8) {
TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_uint8);
}

TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8.";
TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8.";

TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32)
<< "gemm1_weights_scale must be float.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts)
<< "gemm1_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0)
<< "intermediate_size must be a multiple of 128.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128)
<< "gemm1_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128)
<< "gemm1_weights_scale has incorrect shape.";
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32)
<< "gemm1_weights_scale must be float.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts)
<< "gemm1_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0)
<< "intermediate_size must be a multiple of 128.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128)
<< "gemm1_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128)
<< "gemm1_weights_scale has incorrect shape.";
} else if (quantization_type == Fp8QuantizationType::MxFp8) {
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_uint8)
<< "gemm1_weights_scale must be uint8.";
}

TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32)
<< "gemm2_weights_scale must be float.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts)
<< "gemm2_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128)
<< "gemm2_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128)
<< "gemm2_weights_scale has incorrect shape.";
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32)
<< "gemm2_weights_scale must be float.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts)
<< "gemm2_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128)
<< "gemm2_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128)
<< "gemm2_weights_scale has incorrect shape.";
} else if (quantization_type == Fp8QuantizationType::MxFp8) {
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_uint8)
<< "gemm2_weights_scale must be uint8.";
}

check_weights_shape("gemm1");
check_weights_shape("gemm2");
TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0)
<< "intermediate_size must be a multiple of 128.";

if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0)
<< "intermediate_size must be a multiple of 128.";
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The validation logic within this function, particularly for gemm1_weights_scale and gemm2_weights_scale under the Fp8QuantizationType::DeepSeekFp8 condition, is quite repetitive. To improve maintainability and reduce code duplication, consider extracting the common checks into a helper function.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@vincentzed
Copy link
Contributor

vincentzed commented Feb 8, 2026

Hi @IwakuraRein . Currently we use this in sgl.

However it seems like we are missing cubin for some dim.

I build from src from this branch on this commit 1dc688d

[DEBUG] TRTLLM-Gen launch info: numCtasX = 1, numCtasY = 4, numCtasZ = 4096, clusterDimX = 1
[2026-02-08 03:07:34] trtllm_fp8_block_scale_moe call:
  fp8_quantization_type=2
  routing_logits: shape=torch.Size([4096, 128]) dtype=torch.bfloat16
  routing_bias: None
  hidden_states (a_q): shape=torch.Size([4096, 2048]) dtype=torch.float8_e4m3fn
  hidden_states_scale (a_sf): shape=torch.Size([4096, 64]) dtype=torch.uint8
  gemm1_weights (w13): shape=torch.Size([128, 1536, 2048]) dtype=torch.float8_e4m3fn
  gemm1_weights_scale (w13_scale_inv): shape=torch.Size([128, 98304]) dtype=torch.uint8
  gemm2_weights (w2): shape=torch.Size([128, 2048, 768]) dtype=torch.float8_e4m3fn
  gemm2_weights_scale (w2_scale_inv): shape=torch.Size([128, 49152]) dtype=torch.uint8
  num_experts=128 top_k=8 n_group=0 topk_group=0
  intermediate_size=768 local_expert_offset=0 local_num_experts=128
  routed_scaling_factor=1.0 routing_method_type=1
  use_shuffled_weight=True tune_max_num_tokens=4096
 File "/sgl-workspace/flashinfer/flashinfer/fused_moe/core.py", line 2373, in trtllm_fp8_block_scale_moe
    return get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/flashinfer/flashinfer/fused_moe/core.py", line 1694, in trtllm_fp8_block_scale_moe_op
    result = moe_op.trtllm_fp8_block_scale_moe(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
RuntimeError: Error in function 'TrtllmGenBatchedGemmRunner' at /sgl-workspace/flashinfer/csrc/trtllm_batched_gemm_runner.cu:138: No kernel found for the given options: mDtypeA: MxE4m3, mDtypeB: MxE4m3, mDtypeC: Bfloat16, mUseDeepSeekFp8: 0, mActType: 0, mEltwiseActType: 0, mTransposeMmaOutput: 1, mRouteAct: 1, mFusedAct: 1, mIsStaticBatch: 0, mTileSize: 64

Context: we are building the sglang MXFP8 trtllm_moe runner along with mm_mxfp8 flashinfer modelopt linear, so this would be quite useful.

If it turns out that my usages is wrong... user error. but even after inspect cubin, it seem like this shape should be available. Do you have any ideas?

should there be tileSize=64 cubin?

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Collaborator Author

@vincentzed Hi. There are tile size 64 cubins for mxfp8. I tried your problem shape and cannot reproduce the error. Could you try pull the latest commit? 1dc688d won't compile due to a typo so maybe flashinfer is using the old jit cache.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein IwakuraRein force-pushed the siyuanf/mxfp8-trtllm-integration branch from 0adc056 to aae1719 Compare February 9, 2026 21:13
@IwakuraRein IwakuraRein changed the title mxfp8 trtllm integration Feat: Trtllm-gen MxFP8 MoE integration Feb 12, 2026
@IwakuraRein IwakuraRein marked this pull request as ready for review February 12, 2026 17:45
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tests/moe/test_trtllm_gen_fused_moe.py (1)

2151-2164: ⚠️ Potential issue | 🟡 Minor

MxFP8 reference quantization likely uses the wrong swizzling flag.
quantize_inputs/run_moe_reference_mxfp8 use is_swizzling=False, but the FP8_BLOCK_SCALE_MXFP8 branch here forces True, which can skew the reference output. Consider splitting the cases.

🛠️ Suggested fix
-    elif (
-        quant_mode == QuantMode.FP4_MXFP4_MXFP8
-        or quant_mode == QuantMode.FP8_BLOCK_SCALE_MXFP8
-    ):
-        activation_output, scale_bytes = mxfp8_quantize(
-            activation_output.to(torch.bfloat16), True
-        )
+    elif quant_mode == QuantMode.FP4_MXFP4_MXFP8:
+        activation_output, scale_bytes = mxfp8_quantize(
+            activation_output.to(torch.bfloat16), True
+        )
+    elif quant_mode == QuantMode.FP8_BLOCK_SCALE_MXFP8:
+        activation_output, scale_bytes = mxfp8_quantize(
+            activation_output.to(torch.bfloat16), False
+        )
flashinfer/fused_moe/core.py (2)

1608-1660: ⚠️ Potential issue | 🟡 Minor

Validate fp8_quantization_type for the block-scale op.
Passing NoneFp8 currently falls into the MxE4m3 dtype path and then the per‑tensor branch, which lacks the required kwargs. A short guard makes failures explicit.

🛡️ Suggested guard
-        dtype_act = (
+        if fp8_quantization_type not in (
+            Fp8QuantizationType.DeepSeekFp8,
+            Fp8QuantizationType.MxFp8,
+        ):
+            raise ValueError(
+                "fp8_quantization_type must be DeepSeekFp8 or MxFp8 for block-scale MoE."
+            )
+
+        dtype_act = (
             DtypeTrtllmGen.E4m3
             if fp8_quantization_type == Fp8QuantizationType.DeepSeekFp8
             else DtypeTrtllmGen.MxE4m3
         )  # FP8 activation

2356-2364: ⚠️ Potential issue | 🟡 Minor

Docstring scale shapes for MxFP8 appear transposed.
Runtime checks expect hidden_states_scale.shape[0] == num_tokens for MxFP8, but the docs say [hidden_size//…, seq_len]. Updating the docs will prevent API misuse.

📝 Suggested doc tweak
-        hidden_states_scale: [hidden_size//128, seq_len] tensor of hidden states block scales
+        hidden_states_scale:
+            - DeepSeekFp8: ignored (kernel generates [hidden_size//128, seq_len])
+            - MxFp8: [seq_len, hidden_size//32] tensor of block scales
@@
-        hidden_states_scale: [hidden_size//(32 if mxfp8 else 128), seq_len] tensor of hidden states block scales
+        hidden_states_scale:
+            - DeepSeekFp8: ignored (kernel generates [hidden_size//128, seq_len])
+            - MxFp8: [seq_len, hidden_size//32] tensor of block scales

Also applies to: 2454-2459

csrc/trtllm_fused_moe_kernel_launcher.cu (1)

890-950: ⚠️ Potential issue | 🟡 Minor

MxFp8 path lacks shape validation for scale tensors.

For DeepSeekFp8, scale tensors get thorough ndim and dimension checks (e.g., lines 895-901, 914-922, 931-937). For MxFp8, only dtype is verified — no ndim or size checks. Mis-shaped MxFp8 scale tensors would silently pass validation and could cause out-of-bounds reads in the kernel.

Consider adding at least basic shape assertions (e.g., ndim, total element count) for the MxFp8 scale tensors.

Also, line 903 uses TVM_FFI_CHECK while the rest of the file consistently uses TVM_FFI_ICHECK — minor inconsistency.

🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1751-1777: The checks around quantization_type (the if/else-if
chains that validate hidden_states_scale, gemm1_weights/gemm2_weights and
gemm1_weights_scale/gemm2_weights_scale, and the MxFp8-specific checks)
currently only handle DeepSeekFp8 and MxFp8 and silently skip validation for
other enum values; add an explicit else branch (or a guard before these blocks)
that rejects unsupported Fp8 variants (e.g., NoneFp8, PerTensorFp8) by
asserting/failing with a clear message referencing Fp8QuantizationType and
quantization_type so callers get an early, descriptive error instead of a
downstream kernel failure (apply this to the hidden_states_scale/type checks,
the gemm*_weights_scale checks, and the use_shuffled_weight/weight_layout
checks).

In `@flashinfer/fused_moe/__init__.py`:
- Around line 17-21: The module imports Fp8QuantizationType but does not include
it in the public export list; update the __all__ declaration(s) in
flashinfer.fused_moe.__init__ to include "Fp8QuantizationType" so that from
flashinfer.fused_moe import * exposes it (also apply the same addition to the
other __all__ block referenced around lines 54-74); keep existing names
(ActivationType, RoutingMethodType, WeightLayout, etc.) intact and add the
string "Fp8QuantizationType" to each __all__ tuple/list.
🧹 Nitpick comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

989-1007: MxFp8 scale pointers are uint8 data cast to float* — fragile type punning.

For the MxFp8 path, gemm1_output_scale, hidden_states_scale, gemm1_weights_scale, and gemm2_weights_scale are all uint8 tensors, but their data_ptr() is static_cast<float*> into struct fields typed as float*. This works today because the downstream runner passes them through as void* to kernels, but any future code that dereferences these as float will produce UB.

Consider using a reinterpret_cast<float*> to signal intentional type-punning, or (preferably) widening the struct fields to void* if multiple dtypes are expected.

Also, workspace.activation_output and workspace.activation_output_scale are never set on the MxFp8 path. They happen to go unused, but explicitly setting them to nullptr would be safer.

Suggested nullptr initialization for MxFp8
     if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
       workspace.activation_output = activation_output.data_ptr();
       workspace.activation_output_scale = static_cast<float*>(activation_output_scale.data_ptr());
+    } else {
+      workspace.activation_output = nullptr;
+      workspace.activation_output_scale = nullptr;
     }
csrc/trtllm_batched_gemm_runner.cu (1)

177-257: Value-init and mValid assignments look correct.

Line 182: consider removing the commented-out printf before merging — it's a debug artifact.

@aleozlx
Copy link
Collaborator

aleozlx commented Feb 12, 2026

i will help review this


# The type of gated activation function
# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h
class GatedActType(IntEnum):
Copy link
Collaborator

@aleozlx aleozlx Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we changed the activation type to include non-gated ones and have a function like isGatedActivation to check them

https://github.com/flashinfer-ai/flashinfer/pull/2462/changes#diff-cc2263a5b65f54e0e1cf9f6eb3b7c9e36eaabe779f25af918cee1ada78e73116L178

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. However, the only API change in this PR is adding a fp8_quantization_type to the trtllm_fp8_block_scale_routed_moe. It won't conflict with other PRs introducing non-gated activation I believe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aleozlx pushed e18d73c to add intermediate_size_factor

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good overall.

posted a comment about GatedActType

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1073-1099: ⚠️ Potential issue | 🟠 Major

getValidConfigs uses a different Runner constructor than prepare_moe_common for MxFp8, causing config mismatch.

For MxFp8, getValidConfigs creates the runner using the 5-param weights-only constructor (line 1085–1088):

Runner(dtype_weights, /*useDeepSeekFp8=*/false, tile_N, use_shuffled_weight, weight_layout)

But at runtime, prepare_moe_common (lines 329–331) uses the 7-param act+weights constructor because the condition at line 323 checks for E4m3 (not MxE4m3), which is false for MxFp8:

Runner(mDtypeAct, mDtypeWeights, /*useDeepSeekFp8=*/false, tile_N, activation_type, ...)

These constructors have different signatures and parameters (the 5-param variant lacks activationType), so they may enumerate different kernel configs. This causes valid configs from autotuning to be rejected at runtime, potentially explaining the "No kernel found" errors for MxFp8 shapes.

🧹 Nitpick comments (2)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

89-89: Nit: scale_vec_size is unused in the MxFP8 path.

When quant_mode == "MxFP8xMxFP8", scale_vec_size is assigned 32 on this line but never referenced (it's only consumed inside the Fp8-Block branch). Consider moving the assignment into the if quant_mode == "Fp8-Block" block.

♻️ Suggested diff
-        scale_vec_size = 128 if quant_mode == "Fp8-Block" else 32
         if quant_mode == "Fp8-Block":
+            scale_vec_size = 128
             # block scale quantization is too slow, so we use per-tensor quantization for now
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

44-63: C++ enum has PerTensorFp8 not present in the Python Fp8QuantizationType.

The Python enum in flashinfer/fused_moe/core.py defines NoneFp8=0, DeepSeekFp8=1, MxFp8=2, but the C++ side adds PerTensorFp8=3. If this variant isn't meant to be used from Python, consider adding a comment. Also, the default label in fp8QuantizationTypeToString falls through to NoneFp8, which silently masks unexpected values rather than flagging them.

Suggested: make the default case explicit
   switch (quantization_type) {
-    default:
-    case Fp8QuantizationType::NoneFp8:
+    case Fp8QuantizationType::NoneFp8:
       return "NoneFp8";
     case Fp8QuantizationType::DeepSeekFp8:
       return "DeepSeekFp8";
     case Fp8QuantizationType::MxFp8:
       return "MxFp8";
     case Fp8QuantizationType::PerTensorFp8:
       return "PerTensorFp8";
+    default:
+      return "Unknown(" + std::to_string(static_cast<int>(quantization_type)) + ")";
   }

@IwakuraRein
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !316 has been created, and the CI pipeline #43998281 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)

1090-1101: ⚠️ Potential issue | 🔴 Critical

getValidConfigs uses weights-only Runner constructor, but MxFp8 runtime uses the two-dtype constructor — config index mismatch.

For MxFp8 (where dtype_act == MxE4m3 and dtype_weights == MxE4m3), getValidConfigs at line 1091 creates the Runner with 5 parameters: Runner(dtype_weights, useDeepSeekFp8, tile_N, use_shuffled_weight, weight_layout). However, in prepare_moe_common (lines 333–335), the same MxFp8 scenario matches the else branch (the condition at line 327 checks for E4m3, not MxE4m3), causing it to call a different 7-parameter constructor: Runner(dtype_act, dtype_weights, useDeepSeekFp8, tile_tokens_dim, activation_type, use_shuffled_weight, weight_layout). Different constructors produce different valid config indices, so the autotuner may select a config that the runtime runner rejects, causing "No kernel found" errors.


1020-1022: ⚠️ Potential issue | 🟡 Minor

Remove unnecessary static_cast<float*> on lines 1020–1022.

The args->hidden_states_scale, args->gemm1_weights_scale, and args->gemm2_weights_scale fields in MoERunnerArgs are typed as void*, not float*. In the MxFp8 case, these hold dl_uint8 tensor pointers, so casting to float* is both unnecessary and misleading. Other code paths (e.g., lines 1180, 1189, 1419, 1430) assign these same fields without casting. Remove the casts and assign data_ptr() directly.

🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 179-180: In check_routing_logits_shape(), remove the unused local
declaration "int64_t intermediate_size_factor =
isGatedActivation(activation_type) ? 2 : 1;" that shadows the class member
intermediate_size_factor (defined on the class) or replace its usage to
reference the member instead; ensure the function uses the class member
intermediate_size_factor (or a properly named local if truly needed) so the
dead/shadowing local is eliminated.
- Around line 987-991: The MxFp8 branch under-allocates gemm1_output_scale by
using args->intermediate_size/32 instead of accounting for
intermediate_size_factor (causing under-allocation for gated activations);
update the computeSwizzledLayoutSFSize call in the Fp8QuantizationType::MxFp8
branch to use (intermediate_size_factor * args->intermediate_size) / 32 (i.e.
pass the full swizzled width consistent with the gemm1_output allocation) so
gemm1_output_scale and alloc_tensor({sf_size}, ...) match the actual
gemm1_output width; references: gemm1_output_scale, computeSwizzledLayoutSFSize,
max_num_padded_tokens_gemm1, args->intermediate_size, intermediate_size_factor,
Fp8QuantizationType::MxFp8.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43998281: canceled

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1079-1105: ⚠️ Potential issue | 🔴 Critical

getValidConfigs uses wrong Runner constructor for MxFp8, causing config mismatch with runtime.

For MxFp8, prepare_moe_common (lines 326–335) constructs the Runner with the two-dtype constructor (passing mDtypeAct, mDtypeWeights, activation_type) when the condition E4m3 && E4m3 && mUseDeepSeekFp8 is false. However, getValidConfigs always uses the weights-only constructor (line 1091–1094), regardless of quantization_type. This means config enumeration and the actual kernel runner see different valid config sets — the root cause of "No kernel found" errors at runtime.

Proposed fix: branch getValidConfigs to match prepare_moe_common logic
     for (int32_t tile_N : selected_tile_nums) {
-      auto moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
-          dtype_weights,                                          // dtype_weights for DeepSeek FP8
-          quantization_type == Fp8QuantizationType::DeepSeekFp8,  // useDeepSeekFp8
-          tile_N, use_shuffled_weight, static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout));
+      std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner;
+      if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
+        moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
+            dtype_weights, true /* useDeepSeekFp8 */, tile_N, use_shuffled_weight,
+            static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout));
+      } else {
+        // MxFp8: match two-dtype constructor from prepare_moe_common
+        moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
+            dtype_weights, dtype_weights, false /* useDeepSeekFp8 */, tile_N,
+            ActivationType::Swiglu, use_shuffled_weight,
+            static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout));
+      }
 
       auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size,
                                                     num_local_experts, num_tokens);
🧹 Nitpick comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)

1004-1012: MxFp8 path does not explicitly set workspace.activation_output / workspace.activation_output_scale.

Only the DeepSeekFp8 branch (lines 1007–1010) assigns these workspace pointers. The MxFp8 path relies on implicit zero-initialization. Consider explicitly setting them to nullptr to be safe against future refactors where prepare_moe might be re-entered or workspace partially reused.

Proposed fix
     if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
       workspace.activation_output = activation_output.data_ptr();
       workspace.activation_output_scale = static_cast<float*>(activation_output_scale.data_ptr());
+    } else {
+      workspace.activation_output = nullptr;
+      workspace.activation_output_scale = nullptr;
     }

1006-1006: static_cast<float*> on a dl_uint8 tensor for MxFp8 — type mismatch in workspace pointer.

For MxFp8, gemm1_output_scale is allocated as dl_uint8 (line 990), but line 1006 unconditionally casts it to float*. The kernel likely consumes the raw address, but this cast is misleading and could mask bugs if the workspace struct gains type-safety. Consider a void* intermediate or a comment noting the intentional reinterpretation.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !316 has been updated with latest changes, and the CI pipeline #44003338 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein IwakuraRein force-pushed the siyuanf/mxfp8-trtllm-integration branch from 3e0dbdd to 03cac02 Compare February 14, 2026 00:36
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !316 has been updated with latest changes, and the CI pipeline #44028049 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44028049: 14/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants