Skip to content

Conversation

@vincentzed
Copy link
Contributor

@vincentzed vincentzed commented Feb 8, 2026

Motivation

Support both MOE and Dense for Modelopt MXFP8

NVIDIA/Model-Optimizer#736 adds MXFP8 PTQ.

python3 examples/llm_ptq/hf_ptq.py --pyt_ckpt_path /root/.cache/huggingface/hub/models--Qwen--Qwen3-30B-A3B-Instruct-2507/snapshots/0d7cf23991f47feeb3a57ecb4c9cee8ea4a17bfe --qformat mxfp8 --export_path ./Qwen3-30B-A3B-Instruct-2507-MXFP8

After the checkpoint is exported, users should be able to use --quantization modelopt_mxfp8.

#18258

We use GEMM here: flashinfer-ai/flashinfer#2464
bmm_mxfp8 seems buggy and is incompatible with CUDA Graph. Follow-up needed.

We tried using flashinfer-ai/flashinfer#2505, but the results seem incorrect, and its tests are not passing locally. Future: add a trtllm_moe runner for MXFP8.

This PR depends on flashinfer-ai/flashinfer#2464 being merged.

TRTLLM MXFP8 flashinfer: flashinfer-ai/flashinfer#2505

We support moe-runner-backend with Cutlass and Triton.

Triton is very slow at the moment. Both are correct.

Future:

In order of difficulty:

Minor

Medium

  • Further optimize our usage of the Cutlass MoE runner cutlass_fused_experts_fp8, until the TRTLLM MoE runner is unblocked.
  • Further optimize _swizzle_mxfp8_scales, and figure out why non-swizzled produces wrong outputs.
  • Avoid upcast to BF16 then downcast + swizzle, to reduce peak memory on startup; process_weights_after_loading for MXFP8 ModelOpt.

High

  • CUTLASS MXFP8 requires m >= 32, which fails on models like Qwen-3-next coder / Qwen 3 next 80b. Fix the kernel
  • Need to determine flashinfer cutlass mxfp8 support. Whether flashinfer::cutlass_fused_moe supports mxfp8. And add it if necessary, so flashinfer_cutlass is valid. It may not be more performant than trtllm_moe, and it is similar to regular cutlass.
  • Make this PR work on SM90, since even though Blackwell has MXFP8 quantize HW, SM90 supports it too.
  • Support deepgemm with ue8m0 scale, it supports MXFP8.
  • use TRTLLM MoE runner default with MXFP8. It requires it to be correct.
  • Refactor ModelOptQuantConfig, which has grown extremely large and hard to reason about across configs, and has many brittle/duplicated code paths for special cases that cause severe bugs.

Benchmark

Expected performance (MXFP8 should be similar to Block FP8):

Runner backend MXFP8 expected (tok/s) Reference (Block FP8 / BF16) Notes
Cutlass MoE ~130–135 Block FP8 Cutlass: 135 Current MXFP8 is ~115; target is to close the ~15–20 tok/s gap.
TRTLLM MoE ~200 Block FP8 TRTLLM MoE: 202 MXFP8 TRTLLM path is currently blocked/unverified; once correct it should match Block FP8.
Triton MoE N/A (not a perf target) Currently very slow; kept mainly for correctness/debug.

MXFP8 (ModelOpt), using FlashInfer MXFP8 GEMM:

--moe-runner-backend cutlass (SGL Cutlass MoE runner, not flashinfer_cutlass). flashinfer_cutlass MXFP8 support needs to be verified; if/when it supports MxE4m3, it may be a better option.

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    4.419    |  512   |   1.000    |     115.87      |
+-------------+--------+------------+-----------------+

Expected >130 tok/s; current gap should be addressable with further Cutlass MoE integration/kernel optimizations.

Triton MoE runner:

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|   12.077    |  512   |   1.000    |      42.39      |
+-------------+--------+------------+-----------------+

BF16 (FlashInfer trtllm_moe):

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    2.344    |  512   |   1.000    |     218.40      |
+-------------+--------+------------+-----------------+

Block FP8 (FlashInfer trtllm_moe):

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    2.527    |  512   |   1.000    |     202.59      |
+-------------+--------+------------+-----------------+

Block FP8 (Cutlass MoE runner):

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    3.789    |  512   |   1.000    |     135.14      |
+-------------+--------+------------+-----------------+

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added the quant LLM Quantization label Feb 8, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @vincentzed, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the ModelOpt quantization capabilities by introducing full support for MXFP8 (Mixed Precision FP8) across the system. This enables the use of MXFP8 quantized models, particularly beneficial for performance on newer hardware like Blackwell GPUs. The changes involve adding new configuration and processing methods for MXFP8 weights and integrating specialized, high-performance kernels for both dense and Mixture-of-Experts (MoE) layers, ensuring efficient execution while maintaining model integrity.

Highlights

  • MXFP8 Quantization Support: Introduced comprehensive support for ModelOpt MXFP8 quantization, enabling its use for both standard dense layers and Mixture-of-Experts (MoE) models.
  • New Configuration and Methods: Added ModelOptMxfp8Config along with ModelOptMxfp8LinearMethod and ModelOptMxfp8MoEMethod to manage MXFP8-specific weight loading, processing, and inference logic.
  • Backend Integration: Integrated flashinfer.mm_mxfp8 for efficient MXFP8 linear operations and cutlass_fused_experts_fp8 for optimized MXFP8 MoE computations when using the CUTLASS backend.
  • Scale Handling Logic: Implemented detailed logic for handling MXFP8 scales, including dequantization from UE8M0 to BF16, requantization, and swizzling of scales to match the requirements of cuBLAS and CUTLASS kernels.
  • Quantization Method Auto-detection: Updated the auto-detection mechanism for ModelOpt quantization methods to correctly identify and prioritize modelopt_mxfp8 over modelopt_fp8.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/configs/model_config.py
    • Added modelopt_mxfp8 to the list of compatible quantization methods.
  • python/sglang/srt/layers/moe/fused_moe_triton/layer.py
    • Minor code cleanup by removing a redundant comment.
  • python/sglang/srt/layers/quantization/init.py
    • Registered the new ModelOptMxfp8Config for system-wide recognition.
  • python/sglang/srt/layers/quantization/base_config.py
    • Enhanced the quantization method auto-detection logic to correctly identify and prioritize modelopt_mxfp8.
  • python/sglang/srt/layers/quantization/fp8_utils.py
    • Introduced new utility functions for cuBLAS MXFP8 operations, including scale interleaving and a block-scaled linear function.
  • python/sglang/srt/layers/quantization/modelopt_quant.py
    • Implemented the core logic for MXFP8 quantization, including ModelOptMxfp8Config, ModelOptMxfp8LinearMethod (using flashinfer.mm_mxfp8), and ModelOptMxfp8MoEMethod (supporting CUTLASS and Triton backends with MXFP8 scale handling).
  • python/sglang/srt/server_args.py
    • Updated the allowed quantization methods to include modelopt_mxfp8.
Activity
  • The author notes that the PR depends on external flashinfer pull requests (#2464, #2505).
  • There's an identified bug with bmm_mxfp8 and CUDA graph incompatibility that requires future follow-up.
  • Performance benchmarks are provided, showing that the Cutlass MoE implementation is significantly faster (115.87 tokens/s) compared to the Triton MoE implementation (42.39 tokens/s) for MXFP8.
  • The PR description also outlines several future optimization and refactoring tasks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for ModelOpt MXFP8 quantization for both dense and MoE layers. The implementation leverages flashinfer for linear layers and provides both a high-performance CUTLASS path and a Triton fallback for MoE layers. The changes are well-structured and cover the necessary modifications in configuration, quantization logic, and server arguments. My review focuses on potential performance optimizations and a possible issue in the scale swizzling logic.

Comment on lines +2181 to +2184
scale_layout, scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function _swizzle_mxfp8_scales is using layout.make_default_matmul_mxfp4_w_scale_layout to swizzle MXFP8 scales. This is potentially incorrect or at least confusing. Please verify if this is the correct layout function for MXFP8 scales. If the layout is indeed identical for both formats, it would be helpful to add a comment explaining this. Otherwise, this could lead to incorrect results and should be updated to use an MXFP8-specific layout function if available.

Comment on lines 2015 to 2018
weight_bf16 = self._dequantize_mxfp8(layer.weight.data, layer.weight_scale.data)
weight_q, weight_scale = mxfp8_quantize(
weight_bf16, is_sf_swizzled_layout=True, backend="cuda"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation dequantizes the weights to bf16 and then re-quantizes them to get the swizzled scale layout required by flashinfer.mm_mxfp8. This process can be memory-intensive and slow down model loading, as noted in the PR description. For future optimization, consider if it's possible to swizzle the scales directly without the dequantization/re-quantization step. This would significantly reduce peak memory usage and improve model loading performance.

Comment on lines 722 to 743
if m % 128 != 0:
m_padded = ceil_div(m, 128) * 128
pad_rows = m_padded - m
q_input = torch.cat(
[
q_input,
torch.zeros((pad_rows, k), device=q_input.device, dtype=q_input.dtype),
],
dim=0,
)
x_scale_u8 = torch.cat(
[
x_scale_u8,
torch.full(
(pad_rows, k // 32),
127,
device=x_scale_u8.device,
dtype=x_scale_u8.dtype,
),
],
dim=0,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The padding logic for q_input and x_scale_u8 creates new tensors and concatenates on every forward pass if m is not a multiple of 128. This can introduce overhead on the hot path, especially for small batches common in decoding. Consider using a pre-allocated buffer and torch.nn.functional.pad for a more efficient implementation that avoids repeated memory allocations.

@zianglih
Copy link
Contributor

zianglih commented Feb 9, 2026

Further optimize our usage of the Cutlass MoE runner cutlass_fused_experts_fp8

#14640

Avoid upcast to BF16 then downcast + swizzle, to reduce peak memory on startup

#17945

Further optimize _swizzle_mxfp8_scales, and figure out why non-swizzled produces wrong outputs.

Scaling factor swizzling is always required: https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html#scale-factor-layouts

Make this PR work on SM90, since even though Blackwell has MXFP8 quantize HW, SM90 supports it too.

SM90 does not have mxf4/mxf8 mma.

Support deepgemm with ue8m0 scale, it supports MXFP8.

DeepGEMM only supports mxfp in a very recent release. Current sgl-kernel does not yet have the update. Some of my existing work: #17294

@vincentzed
Copy link
Contributor Author

Thank you @zianglih !

@IwakuraRein
Copy link

We tried using flashinfer-ai/flashinfer#2505, but the results seem incorrect, and its tests are not passing locally. Future: add a trtllm_moe runner for MXFP8.

Currently MxFP8 x MxFP8 is less accurate than DeepSeek FP8 in that unit test file so the unit tests are failing. However, if you change the percent in the get_tolerances to 0.7, the unit test can pass. I am investigating whether there is a accuracy issue or a problem with the unit test implementation.

@IwakuraRein
Copy link

@vincentzed The accuracy issue in the unit test is fixed. Can you try the latest commit? Thanks.

vincentzed and others added 14 commits February 10, 2026 21:06
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
The flashinfer_trtllm MXFP8 MoE path was untested and non-functional.
Remove all related code: align_mxfp8_moe_weights_for_flashinfer_trtllm,
use_mxfp8 field on FlashInferTrtllmFp8MoeQuantInfo, MXFP8 branches in
fused_experts_none_to_flashinfer_trtllm_fp8, and auto-selection of
flashinfer_trtllm for modelopt_mxfp8. Working backends are cutlass and
triton.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
…constant

- Remove unused cublas_mxfp8_blockscaled_linear, prepare_mxfp8_weight_for_cublas,
  and _interleave_mxfp8_scales_for_cublas from fp8_utils.py
- Extract _dequantize_mxfp8 to shared dequantize_mxfp8() in fp8_utils.py
- Add warning for unrecognized MXFP8 weight names in MoE weight loader
- Extract magic 90000 workspace size to CUTLASS_MOE_WORKSPACE_BYTES in cutlass_moe.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants