Skip to content

Conversation

@BLaZeKiLL
Copy link
Contributor

@BLaZeKiLL BLaZeKiLL commented Feb 1, 2026

📌 Description

follow up on #2243

quant_scale being a float causes cuda graph capture to fail, by making it a tensor it fixes cuda graph capture for fusion passes in sglang.

also added docs for the fused kernels.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Sglang's fusion pass MR: sgl-project/sglang#10549

Some performance numbers from sglang on an RTX 5090 running llama 3.1 8b fp8 on a 16 prompt benchmark

server config throughput (tok/sec
cuda graph 503.25
cuda graph + torch compile 600.37
cuda graph + torch compile + fusion (vllm kernels) 612.58
cuda graph + torch compile + fusion (flashinfer kernels) 619.11

Summary by CodeRabbit

  • Refactor

    • Normalization routines updated to accept the scale parameter as a tensor (single-element) instead of a scalar, changing how quantized normalization is invoked.
  • Documentation

    • API docs and docstrings updated to describe FP8 quantization behavior and the tensor-shaped scale parameter.
  • Tests

    • Tests updated to pass scale as a CUDA tensor to match the new signatures.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 1, 2026

📝 Walkthrough

Walkthrough

The PR changes the scale parameter from a scalar to a tensor across Python, C++ bindings, CUDA kernels, and tests, updating signatures, validations, and call sites to pass a single-element tensor (device-backed float pointer / float*) instead of a scalar float/double.

Changes

Cohort / File(s) Summary
C++ Binding Layer
csrc/flashinfer_norm_binding.cu
Exported signatures updated: rmsnorm_quant and fused_add_rmsnorm_quant now accept TensorView scale instead of scalar double.
CUDA Host / Dispatch
csrc/norm.cu
Signatures accept TensorView scale; added device/shape checks for single-element scale; call sites pass static_cast<float*>(scale.data_ptr()) to downstream host/kernel functions.
CUDA Kernel Interface
include/flashinfer/norm.cuh
Kernel and host wrappers updated to accept float* scale (index as scale[0]) instead of float; scale_inv computed from scale[0].
Python API
flashinfer/norm.py
Public function signatures changed to accept scale: torch.Tensor (shape (1,)) for rmsnorm_quant, _rmsnorm_quant_fake, fused_add_rmsnorm_quant, and _fused_add_rmsnorm_quant_fake; return/docstring adjustments to reflect FP8 quantization semantics.
Tests
tests/utils/test_norm.py
Updated tests to pass quant_scale as a CUDA tensor (torch.tensor(..., device="cuda")) instead of a scalar.
Docs
docs/api/norm.rst
Added rmsnorm_quant and fused_add_rmsnorm_quant to API autosummary.

Sequence Diagram(s)

(Skipped — changes are API/signature updates and do not introduce a new multi-step runtime control flow requiring visualization.)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

model: dsr1

Suggested reviewers

  • djmmoss
  • yzh119
  • cyx-6
  • wenscarl
  • bkryu
  • nvmbreughe
  • kahyunnam

Poem

🐰 I nudged a tiny tensor in place,

scale tucked single, snug in its space,
From Python, through C++ and CUDA's glow,
scale[0] whispers where the numbers flow,
Hop, quantize, and onward we race. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: converting quant_scale from float to tensor to fix CUDA graph capture for RMSNorm/FusedRMSNorm quantization kernels.
Description check ✅ Passed The description covers the main issue (quant_scale causing CUDA graph failure), the fix (making it a tensor), and includes related documentation updates and performance metrics.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @BLaZeKiLL, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue preventing CUDA graph capture for quantized RMSNorm and FusedRMSNorm kernels. By refactoring the quant_scale parameter from a primitive float type to a tensor, the underlying CUDA operations can now be correctly integrated into CUDA graphs. This change is essential for enabling advanced optimizations like fusion passes, which significantly boost performance, as demonstrated by the provided Sglang benchmarks.

Highlights

  • CUDA Graph Fix: The quant_scale parameter in RMSNorm and FusedRMSNorm quantized kernels was changed from a float to a Tensor to resolve issues with CUDA graph capture, enabling proper optimization for fusion passes.
  • API and Implementation Updates: The C++ and Python interfaces for rmsnorm_quant and fused_add_rmsnorm_quant have been updated to reflect the change in the scale parameter type, now accepting a TensorView in C++ and torch.Tensor in Python.
  • Documentation and Tests: New documentation entries have been added for the rmsnorm_quant and fused_add_rmsnorm_quant kernels, and existing test cases have been updated to align with the modified function signatures.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a CUDA graph capture issue by changing the scale parameter in RMSNorm and FusedRMSNorm quantization kernels from a scalar to a tensor. The changes are consistent across the C++ backend, Python bindings, and tests. My review includes suggestions to improve robustness by adding validation for the new scale tensor to ensure it's a scalar. I've also pointed out several inaccuracies in the Python docstrings for the updated functions, including incorrect formulas and parameter types, and provided corrections to improve clarity and prevent misuse.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/norm.py (2)

117-118: ⚠️ Potential issue | 🟡 Minor

Docstring inconsistent with type hint for scale parameter.

The type hint on line 101 is scale: torch.Tensor, but the docstring on line 117-118 still says scale: float. Update the docstring to reflect the tensor type.

📝 Proposed fix
     scale: float
-        Scale factor for quantization.
+    scale: torch.Tensor
+        Scale factor tensor for quantization (scalar tensor on CUDA device).

225-225: ⚠️ Potential issue | 🟡 Minor

Same docstring inconsistency in fused_add_rmsnorm_quant.

Line 225 shows scale: float in the docstring, but the actual parameter type on line 203 is torch.Tensor.

📝 Proposed fix
-    scale: float
-        Scale factor for quantization.
+    scale: torch.Tensor
+        Scale factor tensor for quantization (scalar tensor on CUDA device).
🤖 Fix all issues with AI agents
In `@csrc/norm.cu`:
- Around line 80-81: Add validation for the scale TensorView before launching
the kernel in rmsnorm_quant: verify device compatibility using
CHECK_DEVICE(input, scale), ensure scale->dtype is float32, and ensure scale has
at least one element (non-empty). Apply the same checks to the other
rmsnorm_quant overload referenced in the diff (the alternate signature around
the second occurrence) so both entry points validate scale consistently before
use.
- Around line 147-148: The fused_add_rmsnorm_quant implementation is missing
validation for the scale TensorView; add the same checks used in rmsnorm_quant
to verify that scale is non-null/non-empty, on the expected device (same device
as input/output), has the expected dtype/shape (compatible with
weight/broadcasting), and contains valid data before using it. Update
fused_add_rmsnorm_quant (and the similar block around lines 173-174) to perform
these validations and return or throw a clear error if validation fails,
mirroring the validation logic and error messages from rmsnorm_quant so the two
functions behave consistently.

In `@flashinfer/norm.py`:
- Line 210: The docstring for Step 1 is incorrect: it claims `residual[i] +=
input[i].to(fp8)` but the code actually adds the raw input to the residual and
only converts the normalized output to fp8 later; update the Step 1 description
in the norm.py docstring to state that the raw input is added to `residual`
(e.g., `residual[i] += input[i]`) and that fp8 conversion is applied to the
normalized output, not to the value being added to `residual`, referencing the
symbols `residual`, `input`, and `fp8` so the doc matches the implementation.
🧹 Nitpick comments (1)
include/flashinfer/norm.cuh (1)

161-161: Potential division by zero if scale[0] is zero.

The computation const float scale_inv = 1.0f / scale[0]; will produce infinity if the scale tensor contains zero. While this may be an invalid input, a defensive check or documented precondition would be helpful.

@BLaZeKiLL BLaZeKiLL force-pushed the dev/dlal/norm_quant_fixes_and_docs branch from 004478e to 245e315 Compare February 1, 2026 07:50
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/norm.py (1)

97-133: ⚠️ Potential issue | 🟠 Major

Return the output tensor from rmsnorm_quant.

The function is annotated to return torch.Tensor and the docstring promises an output, but it currently returns None. This breaks API expectations and callers using the return value.

💡 Suggested fix
     if enable_pdl is None:
         enable_pdl = device_support_pdl(input.device)
     get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl)
+    return out

follow up on flashinfer-ai#2243

quant_scale being a float causes cuda graph capture to fail
even with workaround, by making it a tensor it fixes cuda
graph capture for fusion passes in sglang.

also added docs for the fused kernels.

Signed-off-by: Devashish Lal <laldevashish@gmail.com>
@BLaZeKiLL BLaZeKiLL force-pushed the dev/dlal/norm_quant_fixes_and_docs branch from 245e315 to 86420fa Compare February 2, 2026 01:59
@BLaZeKiLL
Copy link
Contributor Author

@yzh119 I saw you were refactoring these kernels to cutedsl, I could just wait for those and hopefully you incorporate the cuda graph fixes there too

@yzh119
Copy link
Collaborator

yzh119 commented Feb 3, 2026

Yes it will be preferrable if you can work on this feature on cutedsl, after pr #2428 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants