-
Notifications
You must be signed in to change notification settings - Fork 4.4k
[ModelOpt MXFP8] #18449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ModelOpt MXFP8] #18449
Conversation
Summary of ChangesHello @vincentzed, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the ModelOpt quantization capabilities by introducing full support for MXFP8 (Mixed Precision FP8) across the system. This enables the use of MXFP8 quantized models, particularly beneficial for performance on newer hardware like Blackwell GPUs. The changes involve adding new configuration and processing methods for MXFP8 weights and integrating specialized, high-performance kernels for both dense and Mixture-of-Experts (MoE) layers, ensuring efficient execution while maintaining model integrity. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for ModelOpt MXFP8 quantization for both dense and MoE layers. The implementation leverages flashinfer for linear layers and provides both a high-performance CUTLASS path and a Triton fallback for MoE layers. The changes are well-structured and cover the necessary modifications in configuration, quantization logic, and server arguments. My review focuses on potential performance optimizations and a possible issue in the scale swizzling logic.
| scale_layout, scale_layout_opts = ( | ||
| layout.make_default_matmul_mxfp4_w_scale_layout( | ||
| mx_axis=1, num_warps=num_warps | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function _swizzle_mxfp8_scales is using layout.make_default_matmul_mxfp4_w_scale_layout to swizzle MXFP8 scales. This is potentially incorrect or at least confusing. Please verify if this is the correct layout function for MXFP8 scales. If the layout is indeed identical for both formats, it would be helpful to add a comment explaining this. Otherwise, this could lead to incorrect results and should be updated to use an MXFP8-specific layout function if available.
| weight_bf16 = self._dequantize_mxfp8(layer.weight.data, layer.weight_scale.data) | ||
| weight_q, weight_scale = mxfp8_quantize( | ||
| weight_bf16, is_sf_swizzled_layout=True, backend="cuda" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation dequantizes the weights to bf16 and then re-quantizes them to get the swizzled scale layout required by flashinfer.mm_mxfp8. This process can be memory-intensive and slow down model loading, as noted in the PR description. For future optimization, consider if it's possible to swizzle the scales directly without the dequantization/re-quantization step. This would significantly reduce peak memory usage and improve model loading performance.
| if m % 128 != 0: | ||
| m_padded = ceil_div(m, 128) * 128 | ||
| pad_rows = m_padded - m | ||
| q_input = torch.cat( | ||
| [ | ||
| q_input, | ||
| torch.zeros((pad_rows, k), device=q_input.device, dtype=q_input.dtype), | ||
| ], | ||
| dim=0, | ||
| ) | ||
| x_scale_u8 = torch.cat( | ||
| [ | ||
| x_scale_u8, | ||
| torch.full( | ||
| (pad_rows, k // 32), | ||
| 127, | ||
| device=x_scale_u8.device, | ||
| dtype=x_scale_u8.dtype, | ||
| ), | ||
| ], | ||
| dim=0, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The padding logic for q_input and x_scale_u8 creates new tensors and concatenates on every forward pass if m is not a multiple of 128. This can introduce overhead on the hot path, especially for small batches common in decoding. Consider using a pre-allocated buffer and torch.nn.functional.pad for a more efficient implementation that avoids repeated memory allocations.
Scaling factor swizzling is always required: https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html#scale-factor-layouts
SM90 does not have mxf4/mxf8 mma.
DeepGEMM only supports mxfp in a very recent release. Current |
|
Thank you @zianglih ! |
Currently MxFP8 x MxFP8 is less accurate than DeepSeek FP8 in that unit test file so the unit tests are failing. However, if you change the percent in the |
|
@vincentzed The accuracy issue in the unit test is fixed. Can you try the latest commit? Thanks. |
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
The flashinfer_trtllm MXFP8 MoE path was untested and non-functional. Remove all related code: align_mxfp8_moe_weights_for_flashinfer_trtllm, use_mxfp8 field on FlashInferTrtllmFp8MoeQuantInfo, MXFP8 branches in fused_experts_none_to_flashinfer_trtllm_fp8, and auto-selection of flashinfer_trtllm for modelopt_mxfp8. Working backends are cutlass and triton. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…constant - Remove unused cublas_mxfp8_blockscaled_linear, prepare_mxfp8_weight_for_cublas, and _interleave_mxfp8_scales_for_cublas from fp8_utils.py - Extract _dequantize_mxfp8 to shared dequantize_mxfp8() in fp8_utils.py - Add warning for unrecognized MXFP8 weight names in MoE weight loader - Extract magic 90000 workspace size to CUTLASS_MOE_WORKSPACE_BYTES in cutlass_moe.py
1dc4248 to
ec9ff1c
Compare
Motivation
Support both MOE and Dense for Modelopt MXFP8
NVIDIA/Model-Optimizer#736 adds MXFP8 PTQ.
After the checkpoint is exported, users should be able to use
--quantization modelopt_mxfp8.#18258
We use GEMM here: flashinfer-ai/flashinfer#2464
bmm_mxfp8 seems buggy and is incompatible with CUDA Graph. Follow-up needed.
We tried using flashinfer-ai/flashinfer#2505, but the results seem incorrect, and its tests are not passing locally. Future: add a
trtllm_moerunner for MXFP8.This PR depends on flashinfer-ai/flashinfer#2464 being merged.
TRTLLM MXFP8 flashinfer: flashinfer-ai/flashinfer#2505
We support
moe-runner-backendwith Cutlass and Triton.Triton is very slow at the moment. Both are correct.
Future:
In order of difficulty:
Minor
cuda(default).mm_mxfp8should always be better (but it is slower for small sizes).Medium
cutlass_fused_experts_fp8, until the TRTLLM MoE runner is unblocked._swizzle_mxfp8_scales, and figure out why non-swizzled produces wrong outputs.process_weights_after_loadingfor MXFP8 ModelOpt.High
CUTLASS MXFP8requiresm >= 32, which fails on models like Qwen-3-next coder / Qwen 3 next 80b. Fix the kernelflashinfer::cutlass_fused_moesupports mxfp8. And add it if necessary, so flashinfer_cutlass is valid. It may not be more performant than trtllm_moe, and it is similar to regular cutlass.ue8m0scale, it supports MXFP8.ModelOptQuantConfig, which has grown extremely large and hard to reason about across configs, and has many brittle/duplicated code paths for special cases that cause severe bugs.Benchmark
Expected performance (MXFP8 should be similar to Block FP8):
MXFP8 (ModelOpt), using FlashInfer MXFP8 GEMM:
--moe-runner-backend cutlass(SGL Cutlass MoE runner, notflashinfer_cutlass).flashinfer_cutlassMXFP8 support needs to be verified; if/when it supportsMxE4m3, it may be a better option.Expected >130 tok/s; current gap should be addressable with further Cutlass MoE integration/kernel optimizations.
Triton MoE runner:
BF16 (FlashInfer
trtllm_moe):Block FP8 (FlashInfer
trtllm_moe):Block FP8 (Cutlass MoE runner):
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci