-
Notifications
You must be signed in to change notification settings - Fork 710
fix: RMSNorm/FusedRMSNorm + Quant kernels cuda graph fixes #2459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix: RMSNorm/FusedRMSNorm + Quant kernels cuda graph fixes #2459
Conversation
📝 WalkthroughWalkthroughThe PR changes the Changes
Sequence Diagram(s)(Skipped — changes are API/signature updates and do not introduce a new multi-step runtime control flow requiring visualization.) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @BLaZeKiLL, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical issue preventing CUDA graph capture for quantized RMSNorm and FusedRMSNorm kernels. By refactoring the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a CUDA graph capture issue by changing the scale parameter in RMSNorm and FusedRMSNorm quantization kernels from a scalar to a tensor. The changes are consistent across the C++ backend, Python bindings, and tests. My review includes suggestions to improve robustness by adding validation for the new scale tensor to ensure it's a scalar. I've also pointed out several inaccuracies in the Python docstrings for the updated functions, including incorrect formulas and parameter types, and provided corrections to improve clarity and prevent misuse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/norm.py (2)
117-118:⚠️ Potential issue | 🟡 MinorDocstring inconsistent with type hint for
scaleparameter.The type hint on line 101 is
scale: torch.Tensor, but the docstring on line 117-118 still saysscale: float. Update the docstring to reflect the tensor type.📝 Proposed fix
scale: float - Scale factor for quantization. + scale: torch.Tensor + Scale factor tensor for quantization (scalar tensor on CUDA device).
225-225:⚠️ Potential issue | 🟡 MinorSame docstring inconsistency in
fused_add_rmsnorm_quant.Line 225 shows
scale: floatin the docstring, but the actual parameter type on line 203 istorch.Tensor.📝 Proposed fix
- scale: float - Scale factor for quantization. + scale: torch.Tensor + Scale factor tensor for quantization (scalar tensor on CUDA device).
🤖 Fix all issues with AI agents
In `@csrc/norm.cu`:
- Around line 80-81: Add validation for the scale TensorView before launching
the kernel in rmsnorm_quant: verify device compatibility using
CHECK_DEVICE(input, scale), ensure scale->dtype is float32, and ensure scale has
at least one element (non-empty). Apply the same checks to the other
rmsnorm_quant overload referenced in the diff (the alternate signature around
the second occurrence) so both entry points validate scale consistently before
use.
- Around line 147-148: The fused_add_rmsnorm_quant implementation is missing
validation for the scale TensorView; add the same checks used in rmsnorm_quant
to verify that scale is non-null/non-empty, on the expected device (same device
as input/output), has the expected dtype/shape (compatible with
weight/broadcasting), and contains valid data before using it. Update
fused_add_rmsnorm_quant (and the similar block around lines 173-174) to perform
these validations and return or throw a clear error if validation fails,
mirroring the validation logic and error messages from rmsnorm_quant so the two
functions behave consistently.
In `@flashinfer/norm.py`:
- Line 210: The docstring for Step 1 is incorrect: it claims `residual[i] +=
input[i].to(fp8)` but the code actually adds the raw input to the residual and
only converts the normalized output to fp8 later; update the Step 1 description
in the norm.py docstring to state that the raw input is added to `residual`
(e.g., `residual[i] += input[i]`) and that fp8 conversion is applied to the
normalized output, not to the value being added to `residual`, referencing the
symbols `residual`, `input`, and `fp8` so the doc matches the implementation.
🧹 Nitpick comments (1)
include/flashinfer/norm.cuh (1)
161-161: Potential division by zero ifscale[0]is zero.The computation
const float scale_inv = 1.0f / scale[0];will produce infinity if the scale tensor contains zero. While this may be an invalid input, a defensive check or documented precondition would be helpful.
004478e to
245e315
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/norm.py (1)
97-133:⚠️ Potential issue | 🟠 MajorReturn the output tensor from
rmsnorm_quant.The function is annotated to return
torch.Tensorand the docstring promises an output, but it currently returnsNone. This breaks API expectations and callers using the return value.💡 Suggested fix
if enable_pdl is None: enable_pdl = device_support_pdl(input.device) get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl) + return out
follow up on flashinfer-ai#2243 quant_scale being a float causes cuda graph capture to fail even with workaround, by making it a tensor it fixes cuda graph capture for fusion passes in sglang. also added docs for the fused kernels. Signed-off-by: Devashish Lal <laldevashish@gmail.com>
245e315 to
86420fa
Compare
|
@yzh119 I saw you were refactoring these kernels to cutedsl, I could just wait for those and hopefully you incorporate the cuda graph fixes there too |
|
Yes it will be preferrable if you can work on this feature on cutedsl, after pr #2428 . |
📌 Description
follow up on #2243
quant_scale being a float causes cuda graph capture to fail, by making it a tensor it fixes cuda graph capture for fusion passes in sglang.
also added docs for the fused kernels.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Sglang's fusion pass MR: sgl-project/sglang#10549
Some performance numbers from sglang on an RTX 5090 running llama 3.1 8b fp8 on a 16 prompt benchmark
Summary by CodeRabbit
Refactor
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.