-
Notifications
You must be signed in to change notification settings - Fork 710
refactor: refactoring cuda code to cute-dsl (part 1) #2428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request marks the initial phase of refactoring the project's normalization kernels to leverage CuTe-DSL, aiming to enhance JIT compilation speed and overall kernel performance. It introduces a comprehensive set of CuTe-DSL-based normalization kernels and integrates them into the existing API with a conditional dispatch mechanism, paving the way for more efficient GPU computations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds CuTe-DSL normalization kernels (RMSNorm, QK RMSNorm, quantized and fused variants, LayerNorm), norm utilities, and runtime detection for CuTe-DSL; conditions imports/exports on CuTe-DSL availability and adds a CUDA-vs-CuTe dispatch path in the norm package. Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant NormAPI as flashinfer.norm
participant Dispatcher as Dispatcher
participant CUDAJIT as CUDA JIT (gen_norm_module)
participant CuTeDSL as CuTe-DSL Path
participant Compiled as Compiled Kernel (TVM/ptx)
participant GPU as GPU Device
App->>NormAPI: call rmsnorm(...)
NormAPI->>Dispatcher: check FLASHINFER_USE_CUDA_NORM / is_cute_dsl_available()
alt CUDA JIT selected
Dispatcher->>CUDAJIT: request/jit module
CUDAJIT->>Compiled: produce kernel
Compiled->>GPU: execute kernel
else CuTe-DSL selected
Dispatcher->>CuTeDSL: request compiled CuTe kernel
CuTeDSL->>Compiled: produce kernel (TVM-FFI)
Compiled->>GPU: execute kernel
end
GPU-->>Compiled: result
Compiled-->>NormAPI: output tensor
NormAPI-->>App: return
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant refactoring effort, moving normalization kernels from a custom CUDA JIT implementation to the CuTe-DSL. This is a commendable step towards improving performance and maintainability. The new flashinfer/cute_dsl/norm.py file is extensive and well-structured. My review has identified a few critical and high-severity issues that need to be addressed, including a bug in the FP8 quantization logic, incorrect API parameter naming, and inefficient shared memory usage. Once these issues are resolved, this will be a solid improvement.
flashinfer/cute_dsl/norm.py
Outdated
| .reg .b16 fp8_pair; | ||
| .reg .f32 zero; | ||
| mov.f32 zero, 0f00000000; | ||
| cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bug in the PTX inline assembly. The cvt.rn.satfinite.e4m3x2.f32 instruction converts the second source operand and stores it in the upper half of the destination register. The st.global.b8 instruction then stores the lower 8 bits of the register. As written, this will store the converted zero value, not the intended val ($0).
To fix this, you should swap the source operands in the cvt instruction to place the converted value in the lower half of the fp8_pair register.
| cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $0; | |
| cvt.rn.satfinite.e4m3x2.f32 fp8_pair, $0, zero; |
flashinfer/cute_dsl/norm.py
Outdated
| self.cols_per_tile_f32 * 4 * 2 | ||
| + self.cols_per_tile * elem_bytes * 2 | ||
| + 2 * self.num_warps * 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared memory calculation for LayerNormKernel includes space for gamma/beta in the input dtype, but these shared memory tiles (sGamma, sBeta) are allocated and partitioned but never actually used in the kernel. The kernel reads gamma and beta values directly from the float32 shared memory tiles (sGamma_f32, sBeta_f32).
This wastes a significant amount of shared memory, which can negatively impact performance by reducing occupancy.
You should remove the allocation of sGamma and sBeta (lines 1483-1492) and their partitioning (lines 1565-1566) in the kernel method, and update this shared memory size calculation.
| self.cols_per_tile_f32 * 4 * 2 | |
| + self.cols_per_tile * elem_bytes * 2 | |
| + 2 * self.num_warps * 4 | |
| self.cols_per_tile_f32 * 4 * 2 | |
| + 2 * self.num_warps * 4 |
flashinfer/cute_dsl/norm.py
Outdated
| def tensor_api( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| output: torch.Tensor, | ||
| B: int, | ||
| N: int, | ||
| eps: float, | ||
| num_blocks: int, | ||
| ) -> None: | ||
| compiled_kernel( | ||
| input, | ||
| weight, | ||
| output, | ||
| Int32(B), | ||
| Int32(N), | ||
| Float32(eps), | ||
| Int32(num_blocks), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The enable_pdl parameter is not being passed to the compiled kernel. The qk_rmsnorm_cute function accepts enable_pdl, but it's lost because the tensor_api wrapper doesn't accept it and pass it to the compiled_kernel call.
This is a bug that prevents Programmatic Dependent Launch from being used with this kernel. You should update tensor_api to accept enable_pdl and pass it through. You'll also need to update the call to kernel in qk_rmsnorm_cute (line 2087) to pass this new argument.
| def tensor_api( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| output: torch.Tensor, | |
| B: int, | |
| N: int, | |
| eps: float, | |
| num_blocks: int, | |
| ) -> None: | |
| compiled_kernel( | |
| input, | |
| weight, | |
| output, | |
| Int32(B), | |
| Int32(N), | |
| Float32(eps), | |
| Int32(num_blocks), | |
| ) | |
| def tensor_api( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| output: torch.Tensor, | |
| B: int, | |
| N: int, | |
| eps: float, | |
| enable_pdl: bool, | |
| num_blocks: int, | |
| ) -> None: | |
| compiled_kernel( | |
| input, | |
| weight, | |
| output, | |
| Int32(B), | |
| Int32(N), | |
| Float32(eps), | |
| enable_pdl, | |
| Int32(num_blocks), | |
| ) |
flashinfer/cute_dsl/norm.py
Outdated
| def predicate_k_3d(tXcX: cute.Tensor, limit: int) -> cute.Tensor: | ||
| """Create predicate tensor for bounds checking (3D tensors). | ||
|
|
||
| For 3D tensors after local_tile, the last coordinate [2] is the head_dim dimension. | ||
| """ | ||
| tXpX = cute.make_rmem_tensor( | ||
| cute.make_layout( | ||
| ( | ||
| cute.size(tXcX, mode=[0, 1]), | ||
| cute.size(tXcX, mode=[1]), | ||
| cute.size(tXcX, mode=[2]), | ||
| ), | ||
| stride=(cute.size(tXcX, mode=[2]), 0, 1), | ||
| ), | ||
| cutlass.Boolean, | ||
| ) | ||
| for rest_v in cutlass.range_constexpr(tXpX.shape[0]): | ||
| for rest_k in cutlass.range_constexpr(tXpX.shape[2]): | ||
| # For 3D tensor, coordinate[2] is the head_dim index | ||
| tXpX[rest_v, 0, rest_k] = cute.elem_less( | ||
| tXcX[(0, rest_v), 0, rest_k][2], limit | ||
| ) | ||
| return tXpX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flashinfer/cute_dsl/norm.py
Outdated
|
|
||
| idX = cute.make_identity_tensor(mX.shape) | ||
| gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) | ||
| cute.local_tile(mY, tiler_mn, (bidx, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/norm.py`:
- Around line 2044-2088: The qk_rmsnorm_cute function accepts enable_pdl but
never forwards it to the kernel compilation (kernel created via
_get_compiled_qk_rmsnorm_kernel uses a hardcoded value); update qk_rmsnorm_cute
to pass the enable_pdl flag into _get_compiled_qk_rmsnorm_kernel (or else remove
enable_pdl from qk_rmsnorm_cute's signature) so the compiled kernel respects PDL
support — locate the _get_compiled_qk_rmsnorm_kernel call in qk_rmsnorm_cute and
change its arguments to include enable_pdl (and ensure any downstream kernel
invocation/signature matches this added parameter).
🧹 Nitpick comments (8)
flashinfer/cute_dsl/norm.py (8)
858-862: Dead code:cute.local_tile(mY, ...)result is unused.The result of
cute.local_tile(mY, tiler_mn, (bidx, 0))at line 860 is not assigned to a variable. The FP8 output is stored using PTX scalar stores later (lines 920-922), which accessmYdirectly with computed offsets. This call appears to be unnecessary.♻️ Proposed fix
idX = cute.make_identity_tensor(mX.shape) gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) - cute.local_tile(mY, tiler_mn, (bidx, 0)) cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
1231-1236: Same issue:cute.local_tile(mY, ...)result is unused.Same dead code pattern as in
RMSNormQuantKernel.♻️ Proposed fix
idX = cute.make_identity_tensor(mX.shape) - cute.local_tile(mY, tiler_mn, (bidx, 0)) gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) gR = cute.local_tile(mR, tiler_mn, (bidx, 0)) cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
1564-1567: Dead code:partition_Dresults are unused.The results of
thr_copy_load.partition_D(sGamma)andthr_copy_load.partition_D(sBeta)are not assigned to variables. Gamma/beta are loaded directly fromsGamma_f32/sBeta_f32at lines 1634-1635.♻️ Proposed fix
- # Partitions for gamma/beta (input dtype) - thr_copy_load.partition_D(sGamma) - thr_copy_load.partition_D(sBeta) - # Register fragments - initialize to zero for proper handling of out-of-bounds threads
2016-2042: Missing@flashinfer_apidecorator on public API function.The
rmsnorm_cutefunction is exported in__all__and thus part of the public API, but it lacks the@flashinfer_apidecorator required by coding guidelines.Additionally, the
enable_pdlparameter is accepted but completely ignored. The kernel is compiled with a hardcodedFalsevalue at line 1764. This breaks the API contract with callers who expect PDL to be honored.♻️ Proposed fix for decorator
+from ..api_logging import flashinfer_api + +@flashinfer_api def rmsnorm_cute( input: torch.Tensor,As per coding guidelines: "Use
@flashinfer_apidecorator for debugging API calls."
2090-2113: Same issues: missing@flashinfer_apidecorator and unusedenable_pdl.
rmsnorm_quant_cutehas the same issues asrmsnorm_cute.
2116-2135: Same issues: missing@flashinfer_apidecorator and unusedenable_pdl.
fused_add_rmsnorm_cutehas the same issues.
2138-2170: Same issues: missing@flashinfer_apidecorator and unusedenable_pdl.
fused_add_rmsnorm_quant_cutehas the same issues.
2173-2192: Missing@flashinfer_apidecorator.
layernorm_cuteis missing the@flashinfer_apidecorator. Note that this function doesn't have anenable_pdlparameter, which is consistent since it doesn't expose PDL functionality.
flashinfer/cute_dsl/norm.py
Outdated
| def qk_rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| output: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Each warp processes one (batch, head) pair independently using warp-only reduction. | ||
|
|
||
| Args: | ||
| input: Input tensor of shape [batch_size, num_heads, head_dim]. | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| weight: Weight tensor of shape [head_dim]. | ||
| output: Output tensor (same shape as input). | ||
| eps: Small constant for numerical stability. | ||
| weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). | ||
| enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs. | ||
| """ | ||
| assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" | ||
|
|
||
| batch_size, num_heads, head_dim = input.shape | ||
| M = batch_size * num_heads | ||
|
|
||
| # Kernel configuration | ||
| num_warps = 4 | ||
|
|
||
| # Calculate grid size based on SM count and estimated occupancy | ||
| num_sms = get_num_sm(input.device) | ||
| blocks_per_sm = 16 # Theoretical max for 128-thread blocks | ||
| max_blocks = num_sms * blocks_per_sm | ||
| needed_blocks = (M + num_warps - 1) // num_warps | ||
| num_blocks = min(max_blocks, needed_blocks) | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| kernel = _get_compiled_qk_rmsnorm_kernel( | ||
| dtype_str, head_dim, weight_bias, num_warps | ||
| ) | ||
|
|
||
| # Pass 3D tensors directly - kernel handles arbitrary stride | ||
| kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable_pdl parameter is accepted but not effectively used.
The qk_rmsnorm_cute function accepts enable_pdl but the compiled kernel at line 1764 uses a hardcoded enable_pdl=False. The kernel supports PDL (lines 617-618, 747-748), but the parameter isn't being passed through during compilation.
🔧 Proposed fix to support PDL
To properly support PDL, the compilation would need to be done at runtime with the actual enable_pdl value, or the parameter should be removed from the API signature if PDL is intentionally disabled for CuTe-DSL kernels.
If PDL is intentionally disabled, consider removing the parameter:
def qk_rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
- enable_pdl: bool = False,
) -> None:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def qk_rmsnorm_cute( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| output: torch.Tensor, | |
| eps: float = 1e-6, | |
| weight_bias: float = 0.0, | |
| enable_pdl: bool = False, | |
| ) -> None: | |
| """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. | |
| Supports arbitrary stride - no need to call contiguous(). | |
| Each warp processes one (batch, head) pair independently using warp-only reduction. | |
| Args: | |
| input: Input tensor of shape [batch_size, num_heads, head_dim]. | |
| Last dimension must be contiguous (stride[-1] == 1). | |
| weight: Weight tensor of shape [head_dim]. | |
| output: Output tensor (same shape as input). | |
| eps: Small constant for numerical stability. | |
| weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). | |
| enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs. | |
| """ | |
| assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" | |
| batch_size, num_heads, head_dim = input.shape | |
| M = batch_size * num_heads | |
| # Kernel configuration | |
| num_warps = 4 | |
| # Calculate grid size based on SM count and estimated occupancy | |
| num_sms = get_num_sm(input.device) | |
| blocks_per_sm = 16 # Theoretical max for 128-thread blocks | |
| max_blocks = num_sms * blocks_per_sm | |
| needed_blocks = (M + num_warps - 1) // num_warps | |
| num_blocks = min(max_blocks, needed_blocks) | |
| dtype_str = _torch_dtype_to_str(input.dtype) | |
| kernel = _get_compiled_qk_rmsnorm_kernel( | |
| dtype_str, head_dim, weight_bias, num_warps | |
| ) | |
| # Pass 3D tensors directly - kernel handles arbitrary stride | |
| kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) | |
| def qk_rmsnorm_cute( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| output: torch.Tensor, | |
| eps: float = 1e-6, | |
| weight_bias: float = 0.0, | |
| ) -> None: | |
| """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. | |
| Supports arbitrary stride - no need to call contiguous(). | |
| Each warp processes one (batch, head) pair independently using warp-only reduction. | |
| Args: | |
| input: Input tensor of shape [batch_size, num_heads, head_dim]. | |
| Last dimension must be contiguous (stride[-1] == 1). | |
| weight: Weight tensor of shape [head_dim]. | |
| output: Output tensor (same shape as input). | |
| eps: Small constant for numerical stability. | |
| weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). | |
| """ | |
| assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" | |
| batch_size, num_heads, head_dim = input.shape | |
| M = batch_size * num_heads | |
| # Kernel configuration | |
| num_warps = 4 | |
| # Calculate grid size based on SM count and estimated occupancy | |
| num_sms = get_num_sm(input.device) | |
| blocks_per_sm = 16 # Theoretical max for 128-thread blocks | |
| max_blocks = num_sms * blocks_per_sm | |
| needed_blocks = (M + num_warps - 1) // num_warps | |
| num_blocks = min(max_blocks, needed_blocks) | |
| dtype_str = _torch_dtype_to_str(input.dtype) | |
| kernel = _get_compiled_qk_rmsnorm_kernel( | |
| dtype_str, head_dim, weight_bias, num_warps | |
| ) | |
| # Pass 3D tensors directly - kernel handles arbitrary stride | |
| kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) |
🧰 Tools
🪛 Ruff (0.14.14)
2050-2050: Unused function argument: enable_pdl
(ARG001)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 2044 - 2088, The qk_rmsnorm_cute
function accepts enable_pdl but never forwards it to the kernel compilation
(kernel created via _get_compiled_qk_rmsnorm_kernel uses a hardcoded value);
update qk_rmsnorm_cute to pass the enable_pdl flag into
_get_compiled_qk_rmsnorm_kernel (or else remove enable_pdl from
qk_rmsnorm_cute's signature) so the compiled kernel respects PDL support —
locate the _get_compiled_qk_rmsnorm_kernel call in qk_rmsnorm_cute and change
its arguments to include enable_pdl (and ensure any downstream kernel
invocation/signature matches this added parameter).
|
[FAILED] Pipeline #42732703: 1/20 passed |
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/norm.py`:
- Around line 1875-1882: The parameter enable_pdl in rmsnorm_cute is unused and
triggers ARG001; explicitly mark it as intentionally unused by adding a no-op
assignment (e.g., _ = enable_pdl) or a targeted noqa comment inside rmsnorm_cute
to show API-parity intent, and apply the same change to the other wrapper
functions mentioned in the review so each unused enable_pdl is acknowledged
rather than left unused.
- Around line 1875-2051: The public CuTe-DSL wrapper functions (rmsnorm_cute,
qk_rmsnorm_cute, rmsnorm_quant_cute, fused_add_rmsnorm_cute,
fused_add_rmsnorm_quant_cute, layernorm_cute) need the `@flashinfer_api` decorator
added and the decorator imported from the project’s standard utilities; add a
single import for flashinfer_api near other imports and prepend `@flashinfer_api`
above each of these function definitions so all public entry points are traced
for API-call logging (keep existing signatures and bodies unchanged).
- Around line 371-379: Rename the unused kernel parameter M to _M in the kernel
signatures to silence Ruff ARG002 (e.g., change the argument name in
cute_dsl.norm.LayerNormKernel.kernel and the other flagged kernels
RMSNormQuantKernel.kernel, FusedAddRMSNormKernel.kernel,
FusedAddRMSNormQuantKernel.kernel); update the parameter name only in the
function signature (or alternatively add a targeted "# noqa: ARG002" comment) so
the intent is clear and linters stop reporting the unused argument.
- Around line 1188-1200: The scalar FP8 store computes out_offset assuming
row-major contiguous layout (out_offset = bidx * H + idx), which fails for
non-contiguous mY; update the store in the block that calls
cvt_and_store_f32_to_e4m3/get_ptr_as_int64 to compute the correct linear offset
using the output tensor's stride (e.g., out_offset = bidx * mY.stride[0] + idx)
or mirror the non-quantized kernels by using CuTe's local_tile/partition_D logic
(as in FusedAddRMSNormKernel) to derive the physical address; ensure you
reference mY.stride and preserve idx calculation so cvt_and_store_f32_to_e4m3
receives the correct out_ptr for any layout.
- Around line 835-847: The FP8 store currently computes out_offset as bidx * H +
idx which assumes a contiguous row stride; update the offset calculation to use
the actual row stride (sym_row_stride_y) so stores respect arbitrary output
tensor strides—replace the use of H in out_offset with sym_row_stride_y (i.e.,
compute out_offset = bidx * sym_row_stride_y + idx) in the block that calls
get_ptr_as_int64(mY, Int32(out_offset)) and cvt_and_store_f32_to_e4m3; ensure
any alternative tiled layout approach mirrors how inputs are handled so the
store remains stride-aware.
flashinfer/cute_dsl/norm.py
Outdated
| def kernel( | ||
| self, | ||
| mX: cute.Tensor, | ||
| mW: cute.Tensor, | ||
| mY: cute.Tensor, | ||
| M: Int32, | ||
| eps: Float32, | ||
| tv_layout: cute.Layout, | ||
| tiler_mn: cute.Shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silence unused M kernel args to keep Ruff clean.
Ruff reports ARG002 for M in kernel signatures. Since M is not used inside kernels, rename it to _M (or add a targeted # noqa: ARG002) to document intent and satisfy lint. Apply the same pattern to the other kernel methods flagged by Ruff (RMSNormQuantKernel.kernel, FusedAddRMSNormKernel.kernel, FusedAddRMSNormQuantKernel.kernel, LayerNormKernel.kernel).
♻️ Example fix (apply similarly to other kernels)
- M: Int32,
+ _M: Int32,🧰 Tools
🪛 Ruff (0.14.14)
376-376: Unused method argument: M
(ARG002)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 371 - 379, Rename the unused kernel
parameter M to _M in the kernel signatures to silence Ruff ARG002 (e.g., change
the argument name in cute_dsl.norm.LayerNormKernel.kernel and the other flagged
kernels RMSNormQuantKernel.kernel, FusedAddRMSNormKernel.kernel,
FusedAddRMSNormQuantKernel.kernel); update the parameter name only in the
function signature (or alternatively add a targeted "# noqa: ARG002" comment) so
the intent is clear and linters stop reporting the unused argument.
flashinfer/cute_dsl/norm.py
Outdated
| col_offset = tidx * vec_size | ||
| for v in cutlass.range_constexpr(num_vec_blocks): | ||
| for e in cutlass.range_constexpr(vec_size): | ||
| idx = col_offset + v * threads_per_row * vec_size + e | ||
| if idx < H: | ||
| # Clamp and convert - use flat index for register tensor | ||
| flat_idx = v * vec_size + e | ||
| clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) | ||
| clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) | ||
| # Use PTX to convert and store FP8 byte | ||
| out_offset = bidx * H + idx | ||
| out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) | ||
| cvt_and_store_f32_to_e4m3(clamped, out_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cd /tmp/repo && find . -name "norm.py" -path "*/cute_dsl/*" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
cd /tmp/repo && wc -l flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
cd /tmp/repo && sed -n '800,870p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
cd /tmp/repo && sed -n '700,850p' flashinfer/cute_dsl/norm.py | grep -n "def \|sym_row_stride_y\|out_offset\|stride"Repository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
cd /tmp/repo && rg "sym_row_stride_y" flashinfer/cute_dsl/norm.py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
find . -name "norm.py" -path "*/cute_dsl/*"Repository: flashinfer-ai/flashinfer
Length of output: 97
🏁 Script executed:
wc -l flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 100
🏁 Script executed:
sed -n '800,870p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 3324
🏁 Script executed:
rg "sym_row_stride_y" flashinfer/cute_dsl/norm.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 2900
🏁 Script executed:
sed -n '600,750p' flashinfer/cute_dsl/norm.py | grep -n "class RMSNormQuantKernel\|def __init__\|sym_row_stride_y"Repository: flashinfer-ai/flashinfer
Length of output: 118
🏁 Script executed:
sed -n '1,100p' flashinfer/cute_dsl/norm.py | grep -n "class RMSNormQuantKernel"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "class RMSNormQuantKernel" flashinfer/cute_dsl/norm.py -A 50Repository: flashinfer-ai/flashinfer
Length of output: 1741
🏁 Script executed:
sed -n '93,300p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 8375
🏁 Script executed:
rg "def __call__" flashinfer/cute_dsl/norm.py -A 100 | head -120Repository: flashinfer-ai/flashinfer
Length of output: 4016
🏁 Script executed:
sed -n '93,500p' flashinfer/cute_dsl/norm.py | grep -A 80 "class RMSNormQuantKernel"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '135,250p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 4708
🏁 Script executed:
sed -n '750,950p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 8235
🏁 Script executed:
rg "stride\|\.strides" flashinfer/cute_dsl/norm.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1200,1400p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 7964
🏁 Script executed:
rg "get_ptr_as_int64\|out_ptr\|iterator" flashinfer/cute_dsl/norm.py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
grep -n "@dsl_user_op" flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 115
🏁 Script executed:
sed -n '1,100p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 3944
🏁 Script executed:
sed -n '99,110p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 659
🏁 Script executed:
rg "\.strides\|\.shape\|iterator" flashinfer/cute_dsl/ -lRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '735,760p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 955
🏁 Script executed:
rg "mY\.|mX\." flashinfer/cute_dsl/norm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1174
🏁 Script executed:
sed -n '730,850p' flashinfer/cute_dsl/norm.py | grep -n "mY\|stride\|out_offset"Repository: flashinfer-ai/flashinfer
Length of output: 445
🏁 Script executed:
sed -n '1600,1750p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 5255
🏁 Script executed:
rg "sym_row_stride_y\|sym_row_stride_x" flashinfer/cute_dsl/norm.py -B 2 -A 10 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '760,800p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1635
🏁 Script executed:
rg "def kernel" flashinfer/cute_dsl/norm.py -A 15 | grep -A 15 "RMSNormQuantKernel" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "\.strides\|\.stride\(" flashinfer/ -t py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "mY\[" flashinfer/cute_dsl/norm.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1600,1620p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 845
🏁 Script executed:
rg "mY\.shape\|mX\.shape" flashinfer/cute_dsl/norm.py -B 1 -A 1Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1550,1570p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 649
🏁 Script executed:
sed -n '53,104p' flashinfer/cute_dsl/norm.py | grep -A 50 "_get_compiled_rmsnorm_quant_kernel"Repository: flashinfer-ai/flashinfer
Length of output: 50
Use stride-aware offset calculation for FP8 output store.
Line 96: out_offset = bidx * H + idx assumes contiguous row stride equal to H, which breaks for arbitrary-stride outputs declared in the tensor layout (stride = sym_row_stride_y). Replace with out_offset = bidx * sym_row_stride_y + idx, or apply consistent tiled layout to the output tensor (as done for input) to automatically respect strides.
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 835 - 847, The FP8 store currently
computes out_offset as bidx * H + idx which assumes a contiguous row stride;
update the offset calculation to use the actual row stride (sym_row_stride_y) so
stores respect arbitrary output tensor strides—replace the use of H in
out_offset with sym_row_stride_y (i.e., compute out_offset = bidx *
sym_row_stride_y + idx) in the block that calls get_ptr_as_int64(mY,
Int32(out_offset)) and cvt_and_store_f32_to_e4m3; ensure any alternative tiled
layout approach mirrors how inputs are handled so the store remains
stride-aware.
flashinfer/cute_dsl/norm.py
Outdated
| col_offset = tidx * vec_size | ||
| for v in cutlass.range_constexpr(num_vec_blocks): | ||
| for e in cutlass.range_constexpr(vec_size): | ||
| idx = col_offset + v * threads_per_row * vec_size + e | ||
| if idx < H: | ||
| # Clamp and convert - use flat index for register tensor | ||
| flat_idx = v * vec_size + e | ||
| clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) | ||
| clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) | ||
| # Use PTX to convert and store FP8 byte | ||
| out_offset = bidx * H + idx | ||
| out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) | ||
| cvt_and_store_f32_to_e4m3(clamped, out_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
wc -l flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 100
🏁 Script executed:
sed -n '1170,1220p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2030
🏁 Script executed:
sed -n '1100,1170p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2690
🏁 Script executed:
sed -n '1050,1120p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2074
🏁 Script executed:
grep -n "def.*mY" flashinfer/cute_dsl/norm.py | head -5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1000,1050p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1671
🏁 Script executed:
sed -n '1120,1160p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1691
🏁 Script executed:
grep -n "gY\|mY" flashinfer/cute_dsl/norm.py | grep -A5 -B5 "1188"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1155,1210p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2403
🏁 Script executed:
grep -n "get_ptr_as_int64" flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 309
🏁 Script executed:
grep -rn "def get_ptr_as_int64" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 315
🏁 Script executed:
grep -n "out_offset\|output.*offset" flashinfer/cute_dsl/norm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 321
🏁 Script executed:
sed -n '100,120p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 985
🏁 Script executed:
sed -n '840,860p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 933
🏁 Script executed:
sed -n '1060,1080p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 639
🏁 Script executed:
grep -B20 "def __call__" flashinfer/cute_dsl/norm.py | grep -A20 "FusedAddRMSNormQuantKernel"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
grep -n "cute.local_tile\|partition_D" flashinfer/cute_dsl/norm.py | head -15Repository: flashinfer-ai/flashinfer
Length of output: 847
🏁 Script executed:
sed -n '1120,1135p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 644
🏁 Script executed:
sed -n '800,850p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2323
🏁 Script executed:
sed -n '400,430p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1343
🏁 Script executed:
sed -n '960,1000p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1683
🏁 Script executed:
grep -A15 "class FusedAddRMSNormKernel" flashinfer/cute_dsl/norm.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 465
🏁 Script executed:
grep -rn "FusedAddRMSNormQuantKernel\|RMSNormQuantKernel" flashinfer/ --include="*.py" | grep -v "class\|def\|#" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 720
🏁 Script executed:
grep -n "tensor.*stride\|stride.*tensor" flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 669
🏁 Script executed:
grep -rn "def get_ptr_as_int64" flashinfer/cute_dsl/ -A10Repository: flashinfer-ai/flashinfer
Length of output: 1828
The FP8 scalar store path assumes row-major contiguous output layout.
The hardcoded out_offset = bidx * H + idx breaks non-contiguous outputs. Use CuTe's local_tile and partition_D like the non-quantized kernels (e.g., FusedAddRMSNormKernel), or query the output tensor's stride and compute out_offset = bidx * mY.stride[0] + idx.
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 1188 - 1200, The scalar FP8 store
computes out_offset assuming row-major contiguous layout (out_offset = bidx * H
+ idx), which fails for non-contiguous mY; update the store in the block that
calls cvt_and_store_f32_to_e4m3/get_ptr_as_int64 to compute the correct linear
offset using the output tensor's stride (e.g., out_offset = bidx * mY.stride[0]
+ idx) or mirror the non-quantized kernels by using CuTe's
local_tile/partition_D logic (as in FusedAddRMSNormKernel) to derive the
physical address; ensure you reference mY.stride and preserve idx calculation so
cvt_and_store_f32_to_e4m3 receives the correct out_ptr for any layout.
flashinfer/cute_dsl/norm.py
Outdated
| def rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| out: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable_pdl is unused in most wrappers.
Ruff flags ARG001 for these functions. If the parameter is only for API parity, make the intent explicit (e.g., _ = enable_pdl or a targeted # noqa: ARG001). Otherwise, plumb it through once those kernels support PDL.
✅ Example (apply similarly to other wrappers)
def rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
out: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
+ _ = enable_pdl # reserved for future PDL support
"""CuTe DSL RMSNorm implementation.Also applies to: 1949-1957, 1975-1982, 1997-2006
🧰 Tools
🪛 Ruff (0.14.14)
1881-1881: Unused function argument: enable_pdl
(ARG001)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 1875 - 1882, The parameter
enable_pdl in rmsnorm_cute is unused and triggers ARG001; explicitly mark it as
intentionally unused by adding a no-op assignment (e.g., _ = enable_pdl) or a
targeted noqa comment inside rmsnorm_cute to show API-parity intent, and apply
the same change to the other wrapper functions mentioned in the review so each
unused enable_pdl is acknowledged rather than left unused.
flashinfer/cute_dsl/norm.py
Outdated
| def rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| out: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL RMSNorm implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
| H = input.shape[-1] | ||
| if input.dim() == 3: | ||
| M = input.shape[0] * input.shape[1] | ||
| input_2d = input.view(M, H) | ||
| out_2d = out.view(M, H) | ||
| else: | ||
| M = input.shape[0] | ||
| input_2d = input | ||
| out_2d = out | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| kernel = _get_compiled_rmsnorm_kernel(dtype_str, H, weight_bias) | ||
| kernel(input_2d, weight, out_2d, M, eps) | ||
|
|
||
|
|
||
| def qk_rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| output: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Each warp processes one (batch, head) pair independently using warp-only reduction. | ||
|
|
||
| Args: | ||
| input: Input tensor of shape [batch_size, num_heads, head_dim]. | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| weight: Weight tensor of shape [head_dim]. | ||
| output: Output tensor (same shape as input). | ||
| eps: Small constant for numerical stability. | ||
| weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). | ||
| enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs. | ||
| """ | ||
| assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" | ||
|
|
||
| batch_size, num_heads, head_dim = input.shape | ||
| M = batch_size * num_heads | ||
|
|
||
| # Kernel configuration | ||
| num_warps = 4 | ||
|
|
||
| # Calculate grid size based on SM count and estimated occupancy | ||
| num_sms = get_num_sm(input.device) | ||
| blocks_per_sm = 16 # Theoretical max for 128-thread blocks | ||
| max_blocks = num_sms * blocks_per_sm | ||
| needed_blocks = (M + num_warps - 1) // num_warps | ||
| num_blocks = min(max_blocks, needed_blocks) | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| kernel = _get_compiled_qk_rmsnorm_kernel( | ||
| dtype_str, head_dim, weight_bias, num_warps | ||
| ) | ||
|
|
||
| # Pass 3D tensors directly - kernel handles arbitrary stride | ||
| kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) | ||
|
|
||
|
|
||
| def rmsnorm_quant_cute( | ||
| out: torch.Tensor, | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| scale: float, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL RMSNorm + FP8 quantization implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| out_dtype_str = _torch_dtype_to_str(out.dtype) | ||
| kernel = _get_compiled_rmsnorm_quant_kernel( | ||
| dtype_str, out_dtype_str, H, weight_bias | ||
| ) | ||
| kernel(out, input, weight, M, scale, eps) | ||
|
|
||
|
|
||
| def fused_add_rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL Fused Add + RMSNorm implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| kernel = _get_compiled_fused_add_rmsnorm_kernel(dtype_str, H, weight_bias) | ||
| kernel(input, residual, weight, M, eps) | ||
|
|
||
|
|
||
| def fused_add_rmsnorm_quant_cute( | ||
| out: torch.Tensor, | ||
| input: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| scale: float, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL Fused Add + RMSNorm + FP8 quantization implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| out_dtype_str = _torch_dtype_to_str(out.dtype) | ||
| kernel = _get_compiled_fused_add_rmsnorm_quant_kernel( | ||
| dtype_str, out_dtype_str, H, weight_bias | ||
| ) | ||
| kernel( | ||
| out, | ||
| input, | ||
| residual, | ||
| weight, | ||
| M, | ||
| scale, | ||
| eps, | ||
| ) | ||
|
|
||
|
|
||
| def layernorm_cute( | ||
| out: torch.Tensor, | ||
| input: torch.Tensor, | ||
| gamma: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| ) -> None: | ||
| """CuTe DSL LayerNorm implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| gamma_dtype_str = _torch_dtype_to_str(gamma.dtype) | ||
| kernel = _get_compiled_layernorm_kernel(dtype_str, gamma_dtype_str, H) | ||
| kernel(out, input, gamma, beta, M, eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add @flashinfer_api on public CuTe-DSL API wrappers.
The public Python entry points here (e.g., rmsnorm_cute, qk_rmsnorm_cute, rmsnorm_quant_cute, fused_add_rmsnorm_cute, fused_add_rmsnorm_quant_cute, layernorm_cute) should be decorated for API-call tracing. Please apply @flashinfer_api (and import it from the standard location used elsewhere in the repo).
As per coding guidelines: Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats).
🧰 Tools
🪛 Ruff (0.14.14)
1881-1881: Unused function argument: enable_pdl
(ARG001)
1909-1909: Unused function argument: enable_pdl
(ARG001)
1956-1956: Unused function argument: enable_pdl
(ARG001)
1981-1981: Unused function argument: enable_pdl
(ARG001)
2005-2005: Unused function argument: enable_pdl
(ARG001)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 1875 - 2051, The public CuTe-DSL
wrapper functions (rmsnorm_cute, qk_rmsnorm_cute, rmsnorm_quant_cute,
fused_add_rmsnorm_cute, fused_add_rmsnorm_quant_cute, layernorm_cute) need the
`@flashinfer_api` decorator added and the decorator imported from the project’s
standard utilities; add a single import for flashinfer_api near other imports
and prepend `@flashinfer_api` above each of these function definitions so all
public entry points are traced for API-call logging (keep existing signatures
and bodies unchanged).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/norm.py`:
- Around line 468-471: The unconditional imports of rmsnorm_fp4quant and
add_rmsnorm_fp4quant from .cute_dsl can raise when cutlass/cute_dsl isn't
available; wrap these imports in a try/except (or check
cute_dsl.is_cute_dsl_available()) and fall back to setting rmsnorm_fp4quant =
None and add_rmsnorm_fp4quant = None so the module can still import and the
existing _USE_CUDA_NORM/CUDA fallback logic can decide at runtime; update any
usage sites to handle None accordingly.
flashinfer/norm.py
Outdated
| # CuTe-DSL fused RMSNorm + FP4 Quantization kernels | ||
| # These require CuTe-DSL to be available and SM100+ (Blackwell) GPUs | ||
| try: | ||
| from .cute_dsl import rmsnorm_fp4quant, add_rmsnorm_fp4quant | ||
| except ImportError: | ||
| # CuTe-DSL not available | ||
| rmsnorm_fp4quant = None # type: ignore[misc,assignment] | ||
| add_rmsnorm_fp4quant = None # type: ignore[misc,assignment] | ||
| # These require SM100+ (Blackwell) GPUs | ||
| from .cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant | ||
| from .cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the actual file structure around lines 468-471 in flashinfer/norm.py
head -n 480 flashinfer/norm.py | tail -n 30Repository: flashinfer-ai/flashinfer
Length of output: 914
🏁 Script executed:
# Also check if there's a _USE_CUDA_NORM conditional block and where these imports sit relative to it
rg -n "_USE_CUDA_NORM|SM100|Blackwell" flashinfer/norm.py -A 5 -B 5Repository: flashinfer-ai/flashinfer
Length of output: 4119
🏁 Script executed:
# Search for existing error handling patterns with cute_dsl imports elsewhere
rg -n "cute_dsl|try.*import|except ImportError" flashinfer/ --type py -lRepository: flashinfer-ai/flashinfer
Length of output: 439
🏁 Script executed:
# Check how the cute_dsl module is structured and whether it handles SM100+ checks
head -n 100 flashinfer/cute_dsl.py 2>/dev/null || find . -name "cute_dsl*" -type fRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check the cute_dsl __init__.py to see if there's SM100 version checking
cat -n flashinfer/cute_dsl/__init__.pyRepository: flashinfer-ai/flashinfer
Length of output: 3281
🏁 Script executed:
# Search for rmsnorm_fp4quant definition and any SM100 guards
rg -n "rmsnorm_fp4quant|SM100" flashinfer/cute_dsl/ --type py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 11099
🏁 Script executed:
# Check if cute_dsl module raises errors or handles gracefully on non-SM100 systems
rg -n "raise|RuntimeError|ImportError|version|compute_capability" flashinfer/cute_dsl/ --type pyRepository: flashinfer-ai/flashinfer
Length of output: 6658
🏁 Script executed:
# Check how the module is used after import - are these functions wrapped with SM100 checks?
rg -n "rmsnorm_fp4quant\|add_rmsnorm_fp4quant" flashinfer/ --type py -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Find the is_cute_dsl_available function
rg -n "def is_cute_dsl_available" flashinfer/ --type py -A 10Repository: flashinfer-ai/flashinfer
Length of output: 742
🏁 Script executed:
# Check the utils.py in cute_dsl to see what is_cute_dsl_available does
head -n 50 flashinfer/cute_dsl/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 1524
🏁 Script executed:
# Check how rmsnorm_fp4quant and add_rmsnorm_fp4quant are used after import
rg -n "rmsnorm_fp4quant|add_rmsnorm_fp4quant" flashinfer/norm.py -A 3 -B 3Repository: flashinfer-ai/flashinfer
Length of output: 306
🏁 Script executed:
# Also check if there's any __getattr__ or dynamic import fallback in norm.py or at module level
rg -n "__getattr__|try:|except|__all__" flashinfer/norm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 134
🏁 Script executed:
# Check flashinfer's main __init__.py to see if norm module import is conditional
head -n 100 flashinfer/__init__.pyRepository: flashinfer-ai/flashinfer
Length of output: 3782
Wrap FP4 imports in try/except to prevent import failures on systems without cutlass.
The rmsnorm_fp4quant and add_rmsnorm_fp4quant imports are unconditional, but cute_dsl/__init__.py only exports them when is_cute_dsl_available() returns True (i.e., when cutlass is installed). Without error handling, importing flashinfer.norm will fail on systems without cutlass, even though the module provides CUDA-based fallbacks via _USE_CUDA_NORM. This breaks backward compatibility.
Suggested pattern for graceful degradation
# CuTe-DSL fused RMSNorm + FP4 Quantization kernels
# These require SM100+ (Blackwell) GPUs
-from .cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant
-from .cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
+try:
+ from .cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant
+ from .cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
+except ImportError:
+ rmsnorm_fp4quant = None
+ add_rmsnorm_fp4quant = None🤖 Prompt for AI Agents
In `@flashinfer/norm.py` around lines 468 - 471, The unconditional imports of
rmsnorm_fp4quant and add_rmsnorm_fp4quant from .cute_dsl can raise when
cutlass/cute_dsl isn't available; wrap these imports in a try/except (or check
cute_dsl.is_cute_dsl_available()) and fall back to setting rmsnorm_fp4quant =
None and add_rmsnorm_fp4quant = None so the module can still import and the
existing _USE_CUDA_NORM/CUDA fallback logic can decide at runtime; update any
usage sites to handle None accordingly.
|
wanna ask you @yzh119 about the reason we put these things all in cute_dsl |
We should categorize kernels by functionalities, not sources. All kernels inside For this specific PR, let me make |
|
cc @kahyunnam |
|
[CANCELED] Pipeline #42752005: canceled |
|
/bot run |
|
@yzh119 , and |
|
[FAILED] Pipeline #42861129: 3/20 passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🤖 Fix all issues with AI agents
In `@flashinfer/__init__.py`:
- Around line 103-107: The package-level exports for the quantized norm variants
are missing: import rmsnorm_fp4quant and add_rmsnorm_fp4quant from
flashinfer.norm (already attempted in the try block) and then assign them to the
public names used elsewhere (e.g., expose rmsnorm_fp4quant as rmsnorm_quant and
add_rmsnorm_fp4quant as fused_add_rmsnorm_quant) so flashinfer.rmsnorm_quant and
flashinfer.fused_add_rmsnorm_quant resolve; update the try block in
flashinfer.__init__.py to perform these assignments (keep the existing
ImportError/AttributeError handling).
In `@flashinfer/cute_dsl/__init__.py`:
- Around line 21-28: The is_cute_dsl_available() function should be decorated
with `@flashinfer_api` and `@functools.cache` to enable API logging and cache the
module discovery; update the top imports to import functools (or
functools.cache) and import the flashinfer_api decorator (or from its module) so
you can apply `@flashinfer_api` and `@functools.cache` directly above def
is_cute_dsl_available to avoid repeated find_spec calls and ensure API logging.
In `@flashinfer/norm/kernels/fused_add_rmsnorm.py`:
- Around line 379-397: The output pointer arithmetic currently assumes
contiguous rows ("out_offset = bidx * H + idx") which breaks when the output
tensor has a non-contiguous row stride; replace the hardcoded linear computation
with one that uses the actual row stride symbol (sym_row_stride_y) so stores
honor arbitrary strides: compute out_offset = bidx * sym_row_stride_y + idx
(casting sym_row_stride_y to the same integer type used for offsets if
necessary), use that out_offset with get_ptr_as_int64(mY, Int32(out_offset)) (or
the appropriate Int64 cast) and then call cvt_and_store_f32_to_e4m3 as before;
update references around out_offset, mY, bidx, H, idx, sym_row_stride_y, and
tYrY_f32 accordingly.
In `@flashinfer/norm/kernels/layernorm.py`:
- Around line 90-100: The shared-memory calculation in _smem_size_in_bytes
currently includes space for sGamma/sBeta tiles in the input dtype that are
never read (the allocations/partitioning for sGamma and sBeta around the sMem
setup at the block-level, e.g. the code allocating sGamma/sBeta in the kernel
and the partitioning logic at lines ~188-198 and ~261-263), causing unnecessary
SMEM use; remove the unused sGamma/sBeta allocations/partitioning in the kernel
and update _smem_size_in_bytes to drop the "+ self.cols_per_tile * elem_bytes *
2" term so the function only accounts for the f32 gamma/beta tiles
(cols_per_tile_f32 * 4 * 2) and reduction buffers (2 * num_warps * 4).
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 591-603: The code assumes output is contiguous by computing
out_offset = bidx * H + idx which breaks arbitrary-stride support; change the
offset calculation to use the output tensor's row stride (the symbolic stride
declared as sym_row_stride_y) instead of H so writes go to the correct memory
for non-contiguous mY; locate the write site (out_offset, get_ptr_as_int64(mY,
...), cvt_and_store_f32_to_e4m3) and replace the bidx*H term with
bidx*row_stride (use the row stride obtained from mY/sym_row_stride_y) or
otherwise read the stride from mY before the loop, leaving the rest of the
flat/index math unchanged.
- Around line 812-819: The code currently uses input.view(M, H) and out.view(M,
H) which will raise for non-contiguous 3D inputs; replace those calls with
input.reshape(M, H) and out.reshape(M, H) (or input.contiguous().view(M, H) if
you want to force an in-place-like view) so arbitrary strides are supported as
the docstring claims, updating the input_2d/out_2d assignments in the
input.dim() == 3 branch accordingly; alternatively if you prefer not to copy,
update the docstring to state the 3D input must be contiguous.
- Around line 886-895: The rmsnorm_quant path currently assumes 2D inputs
(calculating M = input.shape[0]) which breaks for 3D tensors; modify the code
around H/M and the kernel call (the block using
_get_compiled_rmsnorm_quant_kernel and kernel(out, input, weight, M, scale,
eps)) to mirror rmsnorm_cute: if input.dim() == 3 set M = input.shape[0] *
input.shape[1] and pass flattened tensors (reshape input and out to (M, H) or
otherwise view/contiguously flatten the first two dims) before calling the
compiled kernel, otherwise keep the existing 2D behavior; ensure
weight/scale/eps usage remains compatible with the flattened shape.
🧹 Nitpick comments (2)
flashinfer/norm/kernels/rmsnorm.py (2)
804-804: Unusedenable_pdlparameter.The
enable_pdlparameter is accepted but never used. The compiled kernel has PDL hardcoded toFalseat Line 710. If PDL support is planned for the future, consider documenting this as a placeholder; otherwise, remove it to avoid confusion.♻️ Option: Add a note or remove unused param
def rmsnorm_cute( input: torch.Tensor, weight: torch.Tensor, out: torch.Tensor, eps: float = 1e-6, weight_bias: float = 0.0, - enable_pdl: bool = False, + enable_pdl: bool = False, # TODO: PDL support not yet implemented ) -> None:Or remove from all three API functions if not planned.
898-911:__all__is not sorted (style nit).Static analysis flagged unsorted
__all__. The current grouping by category (classes → getters → APIs) is logical, but if the project prefers alphabetical sorting, consider applying it.♻️ Alphabetically sorted version
__all__ = [ - # Kernel classes - "RMSNormKernel", "QKRMSNormKernel", "RMSNormQuantKernel", - # Compiled kernel getters + "RMSNormKernel", "_get_compiled_rmsnorm_kernel", "_get_compiled_qk_rmsnorm_kernel", "_get_compiled_rmsnorm_quant_kernel", - # CuTe DSL APIs - "rmsnorm_cute", "qk_rmsnorm_cute", + "rmsnorm_cute", "rmsnorm_quant_cute", ]
| try: | ||
| from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant | ||
| from .norm import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant | ||
| except (ImportError, AttributeError): | ||
| pass # nvidia-cutlass-dsl not installed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expose quantized norm APIs at the package level.
Line 97-101 exports rmsnorm and fused_add_rmsnorm, but the new quantized variants (rmsnorm_quant, fused_add_rmsnorm_quant) from flashinfer.norm are still missing at the top level. Consider exporting them here so flashinfer.rmsnorm_quant works consistently.
✅ Suggested export additions
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
+from .norm import fused_add_rmsnorm_quant as fused_add_rmsnorm_quant
from .norm import layernorm as layernorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm
+from .norm import rmsnorm_quant as rmsnorm_quantAs per coding guidelines: Export new operations in flashinfer/init.py to make them available at package level.
🤖 Prompt for AI Agents
In `@flashinfer/__init__.py` around lines 103 - 107, The package-level exports for
the quantized norm variants are missing: import rmsnorm_fp4quant and
add_rmsnorm_fp4quant from flashinfer.norm (already attempted in the try block)
and then assign them to the public names used elsewhere (e.g., expose
rmsnorm_fp4quant as rmsnorm_quant and add_rmsnorm_fp4quant as
fused_add_rmsnorm_quant) so flashinfer.rmsnorm_quant and
flashinfer.fused_add_rmsnorm_quant resolve; update the try block in
flashinfer.__init__.py to perform these assignments (keep the existing
ImportError/AttributeError handling).
| import importlib.util | ||
|
|
||
| # Conditionally import CuTe-DSL kernels | ||
|
|
||
| def is_cute_dsl_available() -> bool: | ||
| return ( | ||
| importlib.util.find_spec("cutlass") is not None | ||
| and importlib.util.find_spec("cutlass.cute") is not None | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decorate and cache is_cute_dsl_available().
Line 24-28 defines a public API (exported in __all__) but it isn’t logged or cached. Adding @flashinfer_api and @functools.cache avoids repeated module discovery and aligns with the API logging/caching policy.
✅ Suggested update
-import importlib.util
+import functools
+import importlib.util
+
+from ..api_logging import flashinfer_api
@@
-def is_cute_dsl_available() -> bool:
+@functools.cache
+@flashinfer_api
+def is_cute_dsl_available() -> bool:
return (
importlib.util.find_spec("cutlass") is not None
and importlib.util.find_spec("cutlass.cute") is not None
)As per coding guidelines: Python API functions should use @functools.cache decorator for module caching to avoid recompilation; Use @flashinfer_api decorator on Python functions for API logging with crash-safe input capture before execution.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import importlib.util | |
| # Conditionally import CuTe-DSL kernels | |
| def is_cute_dsl_available() -> bool: | |
| return ( | |
| importlib.util.find_spec("cutlass") is not None | |
| and importlib.util.find_spec("cutlass.cute") is not None | |
| ) | |
| import functools | |
| import importlib.util | |
| from ..api_logging import flashinfer_api | |
| `@functools.cache` | |
| `@flashinfer_api` | |
| def is_cute_dsl_available() -> bool: | |
| return ( | |
| importlib.util.find_spec("cutlass") is not None | |
| and importlib.util.find_spec("cutlass.cute") is not None | |
| ) |
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/__init__.py` around lines 21 - 28, The
is_cute_dsl_available() function should be decorated with `@flashinfer_api` and
`@functools.cache` to enable API logging and cache the module discovery; update
the top imports to import functools (or functools.cache) and import the
flashinfer_api decorator (or from its module) so you can apply `@flashinfer_api`
and `@functools.cache` directly above def is_cute_dsl_available to avoid repeated
find_spec calls and ensure API logging.
| # Phase 4: Clamp and store to FP8 output using PTX scalar stores | ||
| # (CuTe FP8 conversion requires vectorized ops, so we use PTX for scalar stores) | ||
| # Store y to register tensor for element-wise access | ||
| tYrY_f32 = cute.make_rmem_tensor(tXgX.shape, Float32) | ||
| tYrY_f32.store(y) | ||
|
|
||
| col_offset = tidx * vec_size | ||
| for v in cutlass.range_constexpr(num_vec_blocks): | ||
| for e in cutlass.range_constexpr(vec_size): | ||
| idx = col_offset + v * threads_per_row * vec_size + e | ||
| if idx < H: | ||
| # Clamp and convert - use flat index for register tensor | ||
| flat_idx = v * vec_size + e | ||
| clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) | ||
| clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) | ||
| # Use PTX to convert and store FP8 byte | ||
| out_offset = bidx * H + idx | ||
| out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) | ||
| cvt_and_store_f32_to_e4m3(clamped, out_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's find and examine the fused_add_rmsnorm.py file
fd -n "fused_add_rmsnorm.py" -t fRepository: flashinfer-ai/flashinfer
Length of output: 298
🏁 Script executed:
#!/bin/bash
# Read the kernel function around lines 553-566 (docstring)
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '545,575p'Repository: flashinfer-ai/flashinfer
Length of output: 1056
🏁 Script executed:
#!/bin/bash
# Read the output addressing code around lines 379-397
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '370,410p'Repository: flashinfer-ai/flashinfer
Length of output: 2223
🏁 Script executed:
#!/bin/bash
# Search for function definition to see full context
rg -n "def fused_add_rmsnorm_quant_cute" -A 30 flashinfer/norm/kernels/fused_add_rmsnorm.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1001
🏁 Script executed:
#!/bin/bash
# Search for call sites of this function
rg -n "fused_add_rmsnorm_quant" --type py -C 3Repository: flashinfer-ai/flashinfer
Length of output: 11153
🏁 Script executed:
#!/bin/bash
# Examine the test more closely to see how the contiguous parameter is used
cat -n tests/utils/test_norm.py | sed -n '225,270p'Repository: flashinfer-ai/flashinfer
Length of output: 2226
🏁 Script executed:
#!/bin/bash
# Check if mY is the output tensor and how it's used in the kernel
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '300,350p'Repository: flashinfer-ai/flashinfer
Length of output: 2328
🏁 Script executed:
#!/bin/bash
# Look at the complete FusedAddRMSNormQuantKernel to understand mY setup
rg -n "mY" flashinfer/norm/kernels/fused_add_rmsnorm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 283
🏁 Script executed:
#!/bin/bash
# See how mY is constructed - look at the kernel class definition
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '200,280p'Repository: flashinfer-ai/flashinfer
Length of output: 3182
🏁 Script executed:
#!/bin/bash
# Find the definition of get_ptr_as_int64
rg -n "def get_ptr_as_int64|get_ptr_as_int64\(" flashinfer/norm/kernels/fused_add_rmsnorm.py | head -5Repository: flashinfer-ai/flashinfer
Length of output: 141
🏁 Script executed:
#!/bin/bash
# Search for get_ptr_as_int64 definition across the codebase
rg -n "def get_ptr_as_int64" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 311
🏁 Script executed:
#!/bin/bash
# Look for imports and understand how get_ptr_as_int64 works
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '1,50p'Repository: flashinfer-ai/flashinfer
Length of output: 1800
🏁 Script executed:
#!/bin/bash
# Check CuTe DSL documentation/examples for tensor address computation
rg -n "get_ptr_as_int64\|gmem_ptr\|element_ptr" flashinfer/norm/kernels/ | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Read the get_ptr_as_int64 function definition
cat -n flashinfer/norm/utils.py | sed -n '98,130p'Repository: flashinfer-ai/flashinfer
Length of output: 1509
🏁 Script executed:
#!/bin/bash
# Also check the FP4 version for comparison
cat -n flashinfer/cute_dsl/fp4_common.py | sed -n '175,210p'Repository: flashinfer-ai/flashinfer
Length of output: 1598
🏁 Script executed:
#!/bin/bash
# Look for other uses of get_ptr_as_int64 to see if stride is handled elsewhere
rg -n "get_ptr_as_int64" -B2 -A2 flashinfer/cute_dsl/ | head -40Repository: flashinfer-ai/flashinfer
Length of output: 3079
🏁 Script executed:
#!/bin/bash
# Look at how mY is created and passed to the kernel - check the _get_compiled_fused_add_rmsnorm_quant_kernel
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '455,490p'Repository: flashinfer-ai/flashinfer
Length of output: 1639
🏁 Script executed:
#!/bin/bash
# Check if CuTe tensors have stride information by looking at how mX/mR/mY are wrapped
rg -n "cute.Tensor\|tensor_descriptor" flashinfer/norm/kernels/fused_add_rmsnorm.py | head -15Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# See how the output tensor is wrapped when passed to the kernel
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '540,575p'Repository: flashinfer-ai/flashinfer
Length of output: 1194
🏁 Script executed:
#!/bin/bash
# Check test to see if it's actually testing non-contiguous outputs correctly
cat -n tests/utils/test_norm.py | sed -n '240,260p'Repository: flashinfer-ai/flashinfer
Length of output: 1004
🏁 Script executed:
#!/bin/bash
# Check how the strides are actually handled in CuTe compilation and if sym_row_stride_y is used
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '463,500p'Repository: flashinfer-ai/flashinfer
Length of output: 1519
🏁 Script executed:
#!/bin/bash
# Check if the kernel actually uses the stride symbols or just ignores them
rg -n "sym_row_stride" flashinfer/norm/kernels/fused_add_rmsnorm.pyRepository: flashinfer-ai/flashinfer
Length of output: 796
🏁 Script executed:
#!/bin/bash
# Check what happens with the stride information in the compiled kernel
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '281,298p'Repository: flashinfer-ai/flashinfer
Length of output: 598
🏁 Script executed:
#!/bin/bash
# Run the test to see if it actually passes with non-contiguous output
python -m pytest tests/utils/test_norm.py::test_fused_add_rmsnorm_quant -v -k "contiguous-False" 2>&1 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 241
🏁 Script executed:
#!/bin/bash
# Check if there's any stride handling in the kernel - look for mY.iterator usage
rg -n "mY\|iterator" flashinfer/norm/kernels/fused_add_rmsnorm.py | grep -E "mY|iterator"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look at how fake tensors are constructed with strides and if the stride is actually used
rg -n "make_fake_tensor" flashinfer/ -A 2 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1925
🏁 Script executed:
#!/bin/bash
# Check documentation or examples of how CuTe handles strides in kernels
rg -n "tensor.iterator" flashinfer/ | head -10Repository: flashinfer-ai/flashinfer
Length of output: 458
🏁 Script executed:
#!/bin/bash
# Look at the complete output phase to understand if stride is used in offset calculation
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '284,400p'Repository: flashinfer-ai/flashinfer
Length of output: 5373
🏁 Script executed:
#!/bin/bash
# Check if there's any guard that forces contiguity in the wrapper function
cat -n flashinfer/norm/__init__.py | sed -n '248,300p'Repository: flashinfer-ai/flashinfer
Length of output: 1989
🏁 Script executed:
#!/bin/bash
# Check the wrapper implementation to see if there's a contiguity guard
cat -n flashinfer/norm/__init__.py | sed -n '280,305p'Repository: flashinfer-ai/flashinfer
Length of output: 1049
🏁 Script executed:
#!/bin/bash
# Look for any contiguous() calls in the wrapper or around the kernel call
rg -n "contiguous\(\)" flashinfer/norm/Repository: flashinfer-ai/flashinfer
Length of output: 695
🏁 Script executed:
#!/bin/bash
# Check what _USE_CUDA_NORM is and when CuTe kernel is used
rg -n "_USE_CUDA_NORM" flashinfer/norm/__init__.pyRepository: flashinfer-ai/flashinfer
Length of output: 412
🏁 Script executed:
#!/bin/bash
# Check if the CuTe kernel actually respects stride or hardcodes it
# Let's look at the non-quant version to compare how it handles output
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '100,220p'Repository: flashinfer-ai/flashinfer
Length of output: 4927
🏁 Script executed:
#!/bin/bash
# Look at the FusedAddRMSNormKernel (non-quant) to see how it handles output
rg -n "class FusedAddRMSNormKernel" -A 150 flashinfer/norm/kernels/fused_add_rmsnorm.py | head -160Repository: flashinfer-ai/flashinfer
Length of output: 5529
🏁 Script executed:
#!/bin/bash
# Look at how tensor.iterator is set up - check CuTe documentation or examples
# First, check if there's any stride handling in the CuTe tensor setup
rg -n "iterator\|make_fake_tensor" flashinfer/norm/kernels/rmsnorm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check what happens in the tensor_api wrapper - how are strides passed to the kernel?
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '494,530p'Repository: flashinfer-ai/flashinfer
Length of output: 1192
🏁 Script executed:
#!/bin/bash
# Look for any documentation or comments about stride handling
rg -n "stride\|iterator" flashinfer/norm/kernels/fused_add_rmsnorm.py | head -15Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check how tensor.iterator works in the CuTe framework and if it's stride-aware
rg -n "tensor.iterator" flashinfer/ -B 2 -A 2 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1813
🏁 Script executed:
#!/bin/bash
# Look for documentation or examples of how offset calculation should work with strides
cat -n flashinfer/cute_dsl/rmsnorm_fp4quant.py | sed -n '545,560p'Repository: flashinfer-ai/flashinfer
Length of output: 880
🏁 Script executed:
#!/bin/bash
# Compare with the FP4 quantization which also does scalar stores
cat -n flashinfer/cute_dsl/rmsnorm_fp4quant.py | sed -n '548,560p'Repository: flashinfer-ai/flashinfer
Length of output: 761
🏁 Script executed:
#!/bin/bash
# Check if there's any difference in how the offset is calculated in FP4 vs FP8
rg -n "out_offset\|row_offset.*H" flashinfer/cute_dsl/rmsnorm_fp4quant.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check if CuTe tensor.iterator actually encodes stride information
# Search for how tensors are created and what iterator contains
rg -n "make_fake_tensor\|tensor.iterator" flashinfer/cute_dsl/ -A 1 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Let's verify the exact claim - look at the docstring again and check if
# the kernel really supports arbitrary strides or if it's wrong
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '528,560p'Repository: flashinfer-ai/flashinfer
Length of output: 1207
🏁 Script executed:
#!/bin/bash
# Check the non-quantized version's docstring for comparison
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '521,535p'Repository: flashinfer-ai/flashinfer
Length of output: 578
🏁 Script executed:
#!/bin/bash
# The critical test: verify offset calculation assumes contiguous
# For a non-contiguous tensor with stride != H, what happens?
# Let's check if the offset calculation in the quantized kernel respects stride
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '385,398p'Repository: flashinfer-ai/flashinfer
Length of output: 932
🏁 Script executed:
#!/bin/bash
# Check if CuTe's make_fake_tensor with symbolic stride automatically handles stride-aware addressing
# Let's look for how compiled kernels handle variable strides
rg -n "sym_row_stride" -B 5 -A 5 flashinfer/norm/kernels/fused_add_rmsnorm.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1876
🏁 Script executed:
#!/bin/bash
# Critical verification: does the compiled kernel check strides at runtime?
# Look at the tensor_api wrapper to see if there's runtime validation
cat -n flashinfer/norm/kernels/fused_add_rmsnorm.py | sed -n '494,514p'Repository: flashinfer-ai/flashinfer
Length of output: 638
🏁 Script executed:
#!/bin/bash
# Final check: see if there's any automatic contiguity enforcement at the wrapper level
rg -n "def fused_add_rmsnorm_quant_cute" -A 20 flashinfer/norm/kernels/fused_add_rmsnorm.pyRepository: flashinfer-ai/flashinfer
Length of output: 747
🏁 Script executed:
#!/bin/bash
# Check if the test actually runs - does it have skip conditions or known failures?
cat -n tests/utils/test_norm.py | sed -n '228,260p'Repository: flashinfer-ai/flashinfer
Length of output: 1485
Output offset calculation ignores tensor stride, contradicting docstring claim.
Line 395 computes out_offset = bidx * H + idx, which assumes the output tensor has stride H in the first dimension. This contradicts the docstring (line 555) claiming "Supports arbitrary stride." For non-contiguous output tensors where stride[0] ≠ H, this produces incorrect memory addresses.
The kernel is compiled with symbolic stride sym_row_stride_y (line 464), but the scalar store phase uses a hardcoded linear offset that ignores it. Either enforce contiguous output (updating the docstring accordingly) or compute the pointer using the actual tensor stride.
Suggested fix
def fused_add_rmsnorm_quant_cute(
out: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: float,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
"""CuTe DSL Fused Add + RMSNorm + FP8 quantization implementation.
- Supports arbitrary stride - no need to call contiguous().
- Last dimension must be contiguous (stride[-1] == 1).
+ Output must be contiguous (row-major). Last dimension of inputs must be contiguous
+ (stride[-1] == 1).
"""
H = input.shape[-1]
M = input.shape[0]
+ if not out.is_contiguous():
+ raise ValueError("Output tensor must be contiguous for FP8 scalar store operations")🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/fused_add_rmsnorm.py` around lines 379 - 397, The
output pointer arithmetic currently assumes contiguous rows ("out_offset = bidx
* H + idx") which breaks when the output tensor has a non-contiguous row stride;
replace the hardcoded linear computation with one that uses the actual row
stride symbol (sym_row_stride_y) so stores honor arbitrary strides: compute
out_offset = bidx * sym_row_stride_y + idx (casting sym_row_stride_y to the same
integer type used for offsets if necessary), use that out_offset with
get_ptr_as_int64(mY, Int32(out_offset)) (or the appropriate Int64 cast) and then
call cvt_and_store_f32_to_e4m3 as before; update references around out_offset,
mY, bidx, H, idx, sym_row_stride_y, and tYrY_f32 accordingly.
| def _smem_size_in_bytes(self) -> int: | ||
| # Shared memory for: | ||
| # - gamma/beta f32 tiles: cols_per_tile_f32 * 4 * 2 | ||
| # - gamma/beta input dtype tiles: cols_per_tile * elem_bytes * 2 | ||
| # - reduction buffers: 2 * num_warps * 4 | ||
| elem_bytes = self.dtype.width // 8 | ||
| return ( | ||
| self.cols_per_tile_f32 * 4 * 2 | ||
| + self.cols_per_tile * elem_bytes * 2 | ||
| + 2 * self.num_warps * 4 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused shared-memory tiles to avoid SMEM blowups.
Lines 188-198 allocate sGamma/sBeta in the input dtype but they’re never read; _smem_size_in_bytes (Line 90-100) accounts for them, which can push SMEM usage over limits for large H and reduce occupancy. Suggest removing these allocations (and their partitioning) and updating the SMEM calculation.
🧹 Suggested cleanup
def _smem_size_in_bytes(self) -> int:
# Shared memory for:
# - gamma/beta f32 tiles: cols_per_tile_f32 * 4 * 2
- # - gamma/beta input dtype tiles: cols_per_tile * elem_bytes * 2
# - reduction buffers: 2 * num_warps * 4
- elem_bytes = self.dtype.width // 8
- return (
- self.cols_per_tile_f32 * 4 * 2
- + self.cols_per_tile * elem_bytes * 2
- + 2 * self.num_warps * 4
- )
+ return self.cols_per_tile_f32 * 4 * 2 + 2 * self.num_warps * 4
@@
- # Shared memory tiles for gamma, beta in input dtype (for matching shape with x)
- sGamma = smem.allocate_tensor(
- mX.element_type,
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
- byte_alignment=16,
- )
- sBeta = smem.allocate_tensor(
- mX.element_type,
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
- byte_alignment=16,
- )
@@
- thr_copy_load.partition_D(sGamma)
- thr_copy_load.partition_D(sBeta)Also applies to: 188-198, 261-263
🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/layernorm.py` around lines 90 - 100, The
shared-memory calculation in _smem_size_in_bytes currently includes space for
sGamma/sBeta tiles in the input dtype that are never read (the
allocations/partitioning for sGamma and sBeta around the sMem setup at the
block-level, e.g. the code allocating sGamma/sBeta in the kernel and the
partitioning logic at lines ~188-198 and ~261-263), causing unnecessary SMEM
use; remove the unused sGamma/sBeta allocations/partitioning in the kernel and
update _smem_size_in_bytes to drop the "+ self.cols_per_tile * elem_bytes * 2"
term so the function only accounts for the f32 gamma/beta tiles
(cols_per_tile_f32 * 4 * 2) and reduction buffers (2 * num_warps * 4).
| col_offset = tidx * vec_size | ||
| for v in cutlass.range_constexpr(num_vec_blocks): | ||
| for e in cutlass.range_constexpr(vec_size): | ||
| idx = col_offset + v * threads_per_row * vec_size + e | ||
| if idx < H: | ||
| # Clamp and convert - use flat index for register tensor | ||
| flat_idx = v * vec_size + e | ||
| clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) | ||
| clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) | ||
| # Use PTX to convert and store FP8 byte | ||
| out_offset = bidx * H + idx | ||
| out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) | ||
| cvt_and_store_f32_to_e4m3(clamped, out_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Output stride assumption contradicts arbitrary stride claim.
The offset calculation out_offset = bidx * H + idx assumes the output tensor is contiguous (row stride = H). However, the compiled kernel getter at Line 749 declares a symbolic row stride (sym_row_stride_y), and the API docstring at Line 883 claims arbitrary stride support.
For non-contiguous output tensors, this will write to incorrect memory locations.
🛠️ Suggested fix: Use tensor stride from mY
- out_offset = bidx * H + idx
- out_ptr = get_ptr_as_int64(mY, Int32(out_offset))
+ # Use 2D indexing to respect output stride
+ out_ptr = get_ptr_as_int64(
+ cute.local_tile(mY, tiler_mn, (bidx, 0)),
+ Int32(idx)
+ )Alternatively, if contiguous output is required, update the docstring and remove symbolic stride from the compiled kernel getter.
🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/rmsnorm.py` around lines 591 - 603, The code assumes
output is contiguous by computing out_offset = bidx * H + idx which breaks
arbitrary-stride support; change the offset calculation to use the output
tensor's row stride (the symbolic stride declared as sym_row_stride_y) instead
of H so writes go to the correct memory for non-contiguous mY; locate the write
site (out_offset, get_ptr_as_int64(mY, ...), cvt_and_store_f32_to_e4m3) and
replace the bidx*H term with bidx*row_stride (use the row stride obtained from
mY/sym_row_stride_y) or otherwise read the stride from mY before the loop,
leaving the rest of the flat/index math unchanged.
| if input.dim() == 3: | ||
| M = input.shape[0] * input.shape[1] | ||
| input_2d = input.view(M, H) | ||
| out_2d = out.view(M, H) | ||
| else: | ||
| M = input.shape[0] | ||
| input_2d = input | ||
| out_2d = out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
view() will fail for non-contiguous 3D inputs despite docstring claim.
The docstring at Line 808 states "Supports arbitrary stride," but input.view(M, H) requires the tensor to be contiguous. For a non-contiguous 3D input (e.g., from slicing or transpose), this will raise a RuntimeError.
🛠️ Suggested fix: Use reshape() or contiguous().view()
Option 1 – Use reshape() (may copy if needed):
if input.dim() == 3:
M = input.shape[0] * input.shape[1]
- input_2d = input.view(M, H)
- out_2d = out.view(M, H)
+ input_2d = input.reshape(M, H)
+ out_2d = out.reshape(M, H)Option 2 – Update docstring to clarify the contiguity requirement for 3D inputs:
- Supports arbitrary stride - no need to call contiguous().
+ Supports arbitrary row stride for 2D inputs.
+ For 3D inputs, the tensor must be contiguous (or use qk_rmsnorm_cute).📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if input.dim() == 3: | |
| M = input.shape[0] * input.shape[1] | |
| input_2d = input.view(M, H) | |
| out_2d = out.view(M, H) | |
| else: | |
| M = input.shape[0] | |
| input_2d = input | |
| out_2d = out | |
| if input.dim() == 3: | |
| M = input.shape[0] * input.shape[1] | |
| input_2d = input.reshape(M, H) | |
| out_2d = out.reshape(M, H) | |
| else: | |
| M = input.shape[0] | |
| input_2d = input | |
| out_2d = out |
🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/rmsnorm.py` around lines 812 - 819, The code
currently uses input.view(M, H) and out.view(M, H) which will raise for
non-contiguous 3D inputs; replace those calls with input.reshape(M, H) and
out.reshape(M, H) (or input.contiguous().view(M, H) if you want to force an
in-place-like view) so arbitrary strides are supported as the docstring claims,
updating the input_2d/out_2d assignments in the input.dim() == 3 branch
accordingly; alternatively if you prefer not to copy, update the docstring to
state the 3D input must be contiguous.
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| out_dtype_str = _torch_dtype_to_str(out.dtype) | ||
| kernel = _get_compiled_rmsnorm_quant_kernel( | ||
| dtype_str, out_dtype_str, H, weight_bias | ||
| ) | ||
| kernel(out, input, weight, M, scale, eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing 3D input handling unlike rmsnorm_cute.
rmsnorm_cute handles both 2D and 3D inputs (Lines 812-819), but rmsnorm_quant_cute only handles 2D. If a 3D tensor is passed, M = input.shape[0] will be incorrect (should be input.shape[0] * input.shape[1]), leading to silent incorrect normalization.
Consider adding consistent 3D handling or documenting that only 2D inputs are supported.
🛠️ Suggested fix: Add 3D handling
H = input.shape[-1]
- M = input.shape[0]
+ if input.dim() == 3:
+ M = input.shape[0] * input.shape[1]
+ input = input.reshape(M, H)
+ out = out.reshape(M, H)
+ else:
+ M = input.shape[0]🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/rmsnorm.py` around lines 886 - 895, The rmsnorm_quant
path currently assumes 2D inputs (calculating M = input.shape[0]) which breaks
for 3D tensors; modify the code around H/M and the kernel call (the block using
_get_compiled_rmsnorm_quant_kernel and kernel(out, input, weight, M, scale,
eps)) to mirror rmsnorm_cute: if input.dim() == 3 set M = input.shape[0] *
input.shape[1] and pass flattened tensors (reshape input and out to (M, H) or
otherwise view/contiguously flatten the first two dims) before calling the
compiled kernel, otherwise keep the existing 2D behavior; ensure
weight/scale/eps usage remains compatible with the flattened shape.
|
@flashinfer-bot run |
|
/bot run |
|
[FAILED] Pipeline #43133707: 8/20 passed |
bkryu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit tests are now coming back as passing. Left a comment about PDL support. Perf implications are unclear.
@yzh119, if you'd like, I can approve and draft a followup PR that adds PDL support and maybe some perf tuning if I find necessary. What do you think?
| # PDL: Signal dependent kernels (SM90+ only) | ||
| if enable_pdl: | ||
| cute.arch.griddepcontrol_launch_dependents() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PDL seems to be only added for QKRMSNormKernel, but not other kernel classes. Others simply seem to ignore the use PDL flag.
Do we want to add them in the current or a subsequent PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be fixed in fcd5c5d.
|
I found some cpu overhead which makes the host+device execution time slower than cuda's implementation (even though cute-dsl kernel is faster). @cyx-6 is investigating whether we are using tvm-ffi correctly. I suggest holding on until we fix the CPU overhead. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@flashinfer/norm/kernels/fused_add_rmsnorm.py`:
- Around line 549-605: The new CuTe kernels fused_add_rmsnorm_cute and
fused_add_rmsnorm_quant_cute currently rely on runtime JIT via
_get_compiled_fused_add_rmsnorm_kernel and
_get_compiled_fused_add_rmsnorm_quant_kernel (cached), so they won't be AOT
precompiled; update flashinfer/aot.py to register pre-generation calls that
invoke those _get_compiled_* functions for the common configurations (iterate
common H values and combinations of input/out dtypes, weight_bias values, and
enable_pdl flag) so the kernels are compiled at AOT time and packaged, e.g., add
a function that calls _get_compiled_fused_add_rmsnorm_kernel(...) and
_get_compiled_fused_add_rmsnorm_quant_kernel(...) for each desired tuple and
ensure that function is invoked by the AOT build registration.
In `@flashinfer/norm/kernels/layernorm.py`:
- Around line 242-289: The kernel currently hardcodes Float32 for gamma/beta
(e.g., copy_atom_load_f32, tiled_copy_load_f32,
tGgGamma/tGsGamma/tGgBeta/tGsBeta and the register fragments) while
LayerNormKernel.__init__ does not accept gamma_dtype, so gamma_dtype_str passed
from layernorm_cute()/_get_compiled_layernorm_kernel() never reaches the
implementation; either enforce float32 at the API boundary or (preferred) modify
LayerNormKernel.__init__ to accept a gamma_dtype parameter and thread it through
to the places that currently use Float32: replace usages of Float32 and
mX.element_type in the shared-memory copies and register fragment creation
(copy_atom_load_f32, tiled_copy_load_f32, tGgGamma/tGsGamma/tGgBeta/tGsBeta, and
tXrGamma/tXrBeta) with the provided gamma_dtype (and ensure predicate/partition
shapes use corresponding layouts), and update _get_compiled_layernorm_kernel()
and layernorm_cute() to pass the gamma_dtype into LayerNormKernel so the kernel
honors non-float32 gamma/beta types.
🧹 Nitpick comments (4)
flashinfer/norm/kernels/rmsnorm.py (1)
924-937: Consider sorting__all__alphabetically.Static analysis suggests sorting
__all__. The current logical grouping by category (kernel classes → getters → APIs) is also reasonable. This is a minor style preference.♻️ Optional: alphabetically sorted __all__
__all__ = [ - # Kernel classes - "RMSNormKernel", "QKRMSNormKernel", "RMSNormQuantKernel", - # Compiled kernel getters + "RMSNormKernel", "_get_compiled_rmsnorm_kernel", "_get_compiled_qk_rmsnorm_kernel", "_get_compiled_rmsnorm_quant_kernel", - # CuTe DSL APIs - "rmsnorm_cute", "qk_rmsnorm_cute", + "rmsnorm_cute", "rmsnorm_quant_cute", ]flashinfer/norm/kernels/layernorm.py (1)
153-161: Silence the unusedMargument in the kernel signature.♻️ Suggested tweak
def kernel( self, mY: cute.Tensor, mX: cute.Tensor, mGamma: cute.Tensor, mBeta: cute.Tensor, - M: Int32, + _M: Int32, eps: Float32, enable_pdl: cutlass.Constexpr[bool], tv_layout: cute.Layout, tiler_mn: cute.Shape,flashinfer/norm/kernels/fused_add_rmsnorm.py (2)
119-126: Silence the unusedMargument in the kernel signature.♻️ Suggested tweak
def kernel( self, mX: cute.Tensor, mR: cute.Tensor, mW: cute.Tensor, - M: Int32, + _M: Int32, eps: Float32, enable_pdl: cutlass.Constexpr[bool], tv_layout: cute.Layout,
300-309: Silence the unusedMargument in the quant kernel signature.♻️ Suggested tweak
def kernel( self, mY: cute.Tensor, mX: cute.Tensor, mR: cute.Tensor, mW: cute.Tensor, - M: Int32, + _M: Int32, scale: Float32, eps: Float32, enable_pdl: cutlass.Constexpr[bool],
| def fused_add_rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL Fused Add + RMSNorm implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| kernel = _get_compiled_fused_add_rmsnorm_kernel( | ||
| dtype_str, H, weight_bias, enable_pdl | ||
| ) | ||
| kernel(input, residual, weight, M, eps) | ||
|
|
||
|
|
||
| def fused_add_rmsnorm_quant_cute( | ||
| out: torch.Tensor, | ||
| input: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| scale: float, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL Fused Add + RMSNorm + FP8 quantization implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| out_dtype_str = _torch_dtype_to_str(out.dtype) | ||
| kernel = _get_compiled_fused_add_rmsnorm_quant_kernel( | ||
| dtype_str, out_dtype_str, H, weight_bias, enable_pdl | ||
| ) | ||
| kernel( | ||
| out, | ||
| input, | ||
| residual, | ||
| weight, | ||
| M, | ||
| scale, | ||
| eps, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
fd -n "aot.py" -t fRepository: flashinfer-ai/flashinfer
Length of output: 298
🏁 Script executed:
rg -n "fused_add_rmsnorm" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 3584
🏁 Script executed:
cat flashinfer/aot.pyRepository: flashinfer-ai/flashinfer
Length of output: 29093
🏁 Script executed:
fd "norm.py" -path "*/jit/*" -type fRepository: flashinfer-ai/flashinfer
Length of output: 236
🏁 Script executed:
head -100 flashinfer/jit/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1048
🏁 Script executed:
rg "_get_compiled_fused_add_rmsnorm_kernel" -A 15 flashinfer/norm/kernels/fused_add_rmsnorm.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1337
🏁 Script executed:
rg "lru_cache" flashinfer/norm/kernels/fused_add_rmsnorm.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
cat flashinfer/jit/norm.cu | head -50Repository: flashinfer-ai/flashinfer
Length of output: 122
🏁 Script executed:
find flashinfer -name "norm.cu" -o -name "*norm*.cu" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "FusedAddRMSNormKernel" -A 10 flashinfer/norm/kernels/fused_add_rmsnorm.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1170
🏁 Script executed:
rg "def _get_compiled_fused_add_rmsnorm_kernel" -B 5 -A 30 flashinfer/norm/kernels/fused_add_rmsnorm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1260
CuTe kernels need AOT pre-compilation registration.
The new fused_add_rmsnorm_cute and fused_add_rmsnorm_quant_cute kernels use runtime JIT compilation via @functools.cache and are not pre-registered in the AOT build. They will be compiled on first use rather than pre-built into packages. Add AOT pre-generation variants in flashinfer/aot.py to pre-compile kernel configurations for common H (hidden dimension) values and parameter combinations (dtype, weight_bias, enable_pdl).
🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/fused_add_rmsnorm.py` around lines 549 - 605, The new
CuTe kernels fused_add_rmsnorm_cute and fused_add_rmsnorm_quant_cute currently
rely on runtime JIT via _get_compiled_fused_add_rmsnorm_kernel and
_get_compiled_fused_add_rmsnorm_quant_kernel (cached), so they won't be AOT
precompiled; update flashinfer/aot.py to register pre-generation calls that
invoke those _get_compiled_* functions for the common configurations (iterate
common H values and combinations of input/out dtypes, weight_bias values, and
enable_pdl flag) so the kernels are compiled at AOT time and packaged, e.g., add
a function that calls _get_compiled_fused_add_rmsnorm_kernel(...) and
_get_compiled_fused_add_rmsnorm_quant_kernel(...) for each desired tuple and
ensure that function is invoked by the AOT build registration.
| # Copy atom for gamma/beta (float32) - load to shared memory | ||
| copy_atom_load_f32 = cute.make_copy_atom( | ||
| cute.nvgpu.CopyUniversalOp(), | ||
| Float32, | ||
| num_bits_per_copy=copy_bits_f32, | ||
| ) | ||
|
|
||
| tiled_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn) | ||
| tiled_copy_load_f32 = cute.make_tiled_copy( | ||
| copy_atom_load_f32, tv_layout_f32, tiler_mn_f32 | ||
| ) | ||
|
|
||
| thr_copy_load = tiled_copy_load.get_slice(tidx) | ||
| thr_copy_load_f32 = tiled_copy_load_f32.get_slice(tidx) | ||
|
|
||
| # Partitions for input | ||
| tXgX = thr_copy_load.partition_S(gX) | ||
| tXgY = thr_copy_load.partition_D(gY) | ||
| tXcX = thr_copy_load.partition_S(cX) | ||
|
|
||
| # Partitions for gamma/beta (float32) | ||
| tGgGamma = thr_copy_load_f32.partition_S(mGamma_2d) | ||
| tGsGamma = thr_copy_load_f32.partition_D(sGamma_f32) | ||
| tGgBeta = thr_copy_load_f32.partition_S(mBeta_2d) | ||
| tGsBeta = thr_copy_load_f32.partition_D(sBeta_f32) | ||
| tGcGamma = thr_copy_load_f32.partition_S(cGamma) | ||
|
|
||
| # Partitions for gamma/beta (input dtype) | ||
| thr_copy_load.partition_D(sGamma) | ||
| thr_copy_load.partition_D(sBeta) | ||
|
|
||
| # Register fragments - initialize to zero for proper handling of out-of-bounds threads | ||
| tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) | ||
| tXrGamma = cute.make_rmem_tensor(tXgX.shape, mX.element_type) | ||
| tXrBeta = cute.make_rmem_tensor(tXgX.shape, mX.element_type) | ||
| tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) | ||
| tXrGamma.store(cute.zeros_like(tXrGamma, dtype=mX.element_type)) | ||
| tXrBeta.store(cute.zeros_like(tXrBeta, dtype=mX.element_type)) | ||
|
|
||
| tXpX = predicate_k(tXcX, limit=H) | ||
| tGpGamma = predicate_k(tGcGamma, limit=H) | ||
|
|
||
| # Phase 1: Load input from global to register | ||
| cute.copy(copy_atom_load, tXgX, tXrX, pred=tXpX) | ||
|
|
||
| # Phase 1b: Load gamma/beta global -> shared (float32) | ||
| cute.copy(copy_atom_load_f32, tGgGamma, tGsGamma, pred=tGpGamma) | ||
| cute.copy(copy_atom_load_f32, tGgBeta, tGsBeta, pred=tGpGamma) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, check the file structure and size
wc -l flashinfer/norm/kernels/layernorm.pyRepository: flashinfer-ai/flashinfer
Length of output: 108
🏁 Script executed:
#!/bin/bash
# Read the kernel code around the specified lines 242-289
sed -n '235,300p' flashinfer/norm/kernels/layernorm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2826
🏁 Script executed:
#!/bin/bash
# Check the _get_compiled_layernorm_kernel function
rg -n "_get_compiled_layernorm_kernel" flashinfer/norm/kernels/layernorm.py -A 20Repository: flashinfer-ai/flashinfer
Length of output: 1308
🏁 Script executed:
#!/bin/bash
# Check the layernorm_cute function signature and dtype handling
rg -n "def layernorm_cute" flashinfer/norm/kernels/layernorm.py -A 30Repository: flashinfer-ai/flashinfer
Length of output: 992
🏁 Script executed:
#!/bin/bash
# Search for gamma_dtype usage in the file
rg -n "gamma_dtype" flashinfer/norm/kernels/layernorm.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 938
🏁 Script executed:
#!/bin/bash
# Find LayerNormKernel class definition
rg -n "class LayerNormKernel" flashinfer/norm/kernels/layernorm.py -A 100 | head -120Repository: flashinfer-ai/flashinfer
Length of output: 3679
🏁 Script executed:
#!/bin/bash
# Check where Float32 is defined and used
rg -n "Float32" flashinfer/norm/kernels/layernorm.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 3026
🏁 Script executed:
#!/bin/bash
# Look for the __call__ method of LayerNormKernel to see how gamma_dtype flows through
rg -n "def __call__" flashinfer/norm/kernels/layernorm.py -A 50Repository: flashinfer-ai/flashinfer
Length of output: 1644
Enforce float32 gamma/beta or update the kernel to honor their dtype.
The kernel hardcodes Float32 for all gamma/beta operations (shared memory allocation at lines 184-193, copy atom at lines 243-247, and register fragments at lines 328-331). However, layernorm_cute() accepts any gamma dtype and passes gamma_dtype_str to _get_compiled_layernorm_kernel(), which uses it to create fake tensors for compilation. The issue is that LayerNormKernel.__init__() doesn't accept or use gamma_dtype, so the dtype information never reaches the kernel implementation. If callers pass non-float32 gamma/beta, the kernel will still treat them as float32, causing type mismatches and incorrect behavior.
Either enforce float32 inputs at the API boundary or refactor LayerNormKernel to accept and use gamma_dtype consistently throughout.
Also applies to: 364-388, 430-450
🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/layernorm.py` around lines 242 - 289, The kernel
currently hardcodes Float32 for gamma/beta (e.g., copy_atom_load_f32,
tiled_copy_load_f32, tGgGamma/tGsGamma/tGgBeta/tGsBeta and the register
fragments) while LayerNormKernel.__init__ does not accept gamma_dtype, so
gamma_dtype_str passed from layernorm_cute()/_get_compiled_layernorm_kernel()
never reaches the implementation; either enforce float32 at the API boundary or
(preferred) modify LayerNormKernel.__init__ to accept a gamma_dtype parameter
and thread it through to the places that currently use Float32: replace usages
of Float32 and mX.element_type in the shared-memory copies and register fragment
creation (copy_atom_load_f32, tiled_copy_load_f32,
tGgGamma/tGsGamma/tGgBeta/tGsBeta, and tXrGamma/tXrBeta) with the provided
gamma_dtype (and ensure predicate/partition shapes use corresponding layouts),
and update _get_compiled_layernorm_kernel() and layernorm_cute() to pass the
gamma_dtype into LayerNormKernel so the kernel honors non-float32 gamma/beta
types.
| input: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| scale: float, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with reference to #2459, scale would need to be a tensor for it to be cuda graph compatible similar change would be required for rmsnorm_quant_cute also
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@flashinfer/norm/kernels/fused_add_rmsnorm.py`:
- Around line 529-530: The code in fused_add_rmsnorm_cute and
fused_add_rmsnorm_quant_cute computes M = input.shape[0], which breaks for 3D
tensors (B,S,H); change M to represent the number of rows over which to
normalize by computing the product of all dimensions except the last (e.g.,
mirror the rmsnorm_cute approach: derive H = input.shape[-1] and set M to
input.numel() // H or multiply input.shape[:-1]) so that for (B,S,H) you get M =
B*S; apply the same fix in both fused_add_rmsnorm_cute and
fused_add_rmsnorm_quant_cute to keep behavior consistent with rmsnorm_cute.
In `@flashinfer/norm/kernels/layernorm.py`:
- Around line 196-206: Remove the dead SMEM and register allocations: delete the
shared-memory tile allocations sGamma and sBeta (created via
smem.allocate_tensor) and the register fragments tXrGamma and tXrBeta (the
zero-initialized tensors) along with any no-op partitioning code that references
them; then update the kernel's _smem_size_in_bytes calculation to subtract the
removed SMEM footprint so reported shared-memory usage is correct. Locate
references by the symbols sGamma, sBeta, tXrGamma, tXrBeta and ensure
gamma_reg/beta_reg usage (the Float32 tensors loaded later) remains unchanged.
- Around line 152-166: The kernel signature declares an unused parameter M in
the function kernel, which is only required by the host launch (grid=[M,1,1]) in
__call__; remove M from the kernel parameters to avoid confusion (edit the
kernel definition to drop the M: Int32 argument and update any internal
references if present), or if the CuTe DSL requires a placeholder, rename it to
_M or _ to indicate it's intentionally unused; ensure the __call__ still
computes grid=[M,1,1] and that callers/launch site do not attempt to pass M as a
device/kernel argument.
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 862-870: The runtime call to the compiled kernel has the wrong
argument order; change the call in rmsnorm_quant_cute so it passes (input,
weight, out, M, scale, eps) to match the kernel signature produced by
_get_compiled_rmsnorm_quant_kernel and RMSNormQuantKernel.__call__(mX, mW, mY,
...); specifically, replace kernel(out, input, weight, M, scale, eps) with
kernel(input, weight, out, M, scale, eps) so types/shapes align (input: Float16
MxH, weight: Float16 H, out: Float8).
🧹 Nitpick comments (2)
flashinfer/norm/kernels/layernorm.py (2)
436-443: Minor:_get_compiled_layernorm_kernelis exported in__all__despite the underscore-prefix convention indicating a private API.If this is intentionally public (e.g., used by the norm package's dispatch layer), consider dropping the underscore. Otherwise, remove it from
__all__.
348-352: Naming nit:copy_atom_loadis reused for the store path.
copy_atom_load(defined at line 236 for loading input) is reused here to copy results from registers to global memory. WhileCopyUniversalOpworks bidirectionally, the name is misleading in this context. Consider renaming tocopy_atomor introducing a separatecopy_atom_storefor clarity.
| H = input.shape[-1] | ||
| M = input.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
M = input.shape[0] silently breaks for 3D inputs.
Both fused_add_rmsnorm_cute (line 530) and fused_add_rmsnorm_quant_cute (line 556) compute M = input.shape[0]. For a 3D tensor (B, S, H), M would be B instead of B * S, producing incorrect normalization. The sibling rmsnorm_cute (line 787–790) handles this case. Consider adding consistent 3D handling.
🛠️ Example fix for fused_add_rmsnorm_cute
H = input.shape[-1]
- M = input.shape[0]
+ if input.dim() == 3:
+ M = input.shape[0] * input.shape[1]
+ input = input.reshape(M, H)
+ residual = residual.reshape(M, H)
+ else:
+ M = input.shape[0]Apply the same pattern for fused_add_rmsnorm_quant_cute.
Also applies to: 555-556
🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/fused_add_rmsnorm.py` around lines 529 - 530, The
code in fused_add_rmsnorm_cute and fused_add_rmsnorm_quant_cute computes M =
input.shape[0], which breaks for 3D tensors (B,S,H); change M to represent the
number of rows over which to normalize by computing the product of all
dimensions except the last (e.g., mirror the rmsnorm_cute approach: derive H =
input.shape[-1] and set M to input.numel() // H or multiply input.shape[:-1]) so
that for (B,S,H) you get M = B*S; apply the same fix in both
fused_add_rmsnorm_cute and fused_add_rmsnorm_quant_cute to keep behavior
consistent with rmsnorm_cute.
| @cute.kernel | ||
| def kernel( | ||
| self, | ||
| mY: cute.Tensor, | ||
| mX: cute.Tensor, | ||
| mGamma: cute.Tensor, | ||
| mBeta: cute.Tensor, | ||
| M: Int32, | ||
| eps: Float32, | ||
| enable_pdl: cutlass.Constexpr[bool], | ||
| tv_layout: cute.Layout, | ||
| tiler_mn: cute.Shape, | ||
| tv_layout_f32: cute.Layout, | ||
| tiler_mn_f32: cute.Shape, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
M parameter is unused inside the kernel body.
M is declared at line 159 but never referenced within the kernel. It's only needed in __call__ for the grid launch dimension (grid=[M, 1, 1]). If the CuTe DSL framework doesn't require it in the kernel signature, consider removing it to avoid confusion.
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 159-159: Unused method argument: M
(ARG002)
🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/layernorm.py` around lines 152 - 166, The kernel
signature declares an unused parameter M in the function kernel, which is only
required by the host launch (grid=[M,1,1]) in __call__; remove M from the kernel
parameters to avoid confusion (edit the kernel definition to drop the M: Int32
argument and update any internal references if present), or if the CuTe DSL
requires a placeholder, rename it to _M or _ to indicate it's intentionally
unused; ensure the __call__ still computes grid=[M,1,1] and that callers/launch
site do not attempt to pass M as a device/kernel argument.
| # Shared memory tiles for gamma, beta in input dtype (for matching shape with x) | ||
| sGamma = smem.allocate_tensor( | ||
| mX.element_type, | ||
| cute.make_ordered_layout(tiler_mn, order=(1, 0)), | ||
| byte_alignment=16, | ||
| ) | ||
| sBeta = smem.allocate_tensor( | ||
| mX.element_type, | ||
| cute.make_ordered_layout(tiler_mn, order=(1, 0)), | ||
| byte_alignment=16, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove dead code: sGamma/sBeta (SMEM tiles) and tXrGamma/tXrBeta (register fragments) are never read.
In addition to the previously flagged unused sGamma/sBeta SMEM allocations (lines 197–206) and their no-op partitioning (lines 270–271), the register tensors tXrGamma and tXrBeta (lines 275–279) are also dead code. They are allocated and zero-initialized but never read — the actual gamma/beta values are loaded into the separate gamma_reg/beta_reg Float32 tensors at lines 328–343.
All of these are vestiges of a prior approach and waste both SMEM and register resources.
🧹 Suggested cleanup
- # Shared memory tiles for gamma, beta in input dtype (for matching shape with x)
- sGamma = smem.allocate_tensor(
- mX.element_type,
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
- byte_alignment=16,
- )
- sBeta = smem.allocate_tensor(
- mX.element_type,
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
- byte_alignment=16,
- )
@@
- # Partitions for gamma/beta (input dtype)
- thr_copy_load.partition_D(sGamma)
- thr_copy_load.partition_D(sBeta)
@@
# Register fragments - initialize to zero for proper handling of out-of-bounds threads
tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type)
- tXrGamma = cute.make_rmem_tensor(tXgX.shape, mX.element_type)
- tXrBeta = cute.make_rmem_tensor(tXgX.shape, mX.element_type)
tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type))
- tXrGamma.store(cute.zeros_like(tXrGamma, dtype=mX.element_type))
- tXrBeta.store(cute.zeros_like(tXrBeta, dtype=mX.element_type))And update _smem_size_in_bytes accordingly:
def _smem_size_in_bytes(self) -> int:
- elem_bytes = self.dtype.width // 8
- return (
- self.cols_per_tile_f32 * 4 * 2
- + self.cols_per_tile * elem_bytes * 2
- + 2 * self.num_warps * 4
- )
+ return self.cols_per_tile_f32 * 4 * 2 + 2 * self.num_warps * 4Also applies to: 270-271, 273-279
🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/layernorm.py` around lines 196 - 206, Remove the dead
SMEM and register allocations: delete the shared-memory tile allocations sGamma
and sBeta (created via smem.allocate_tensor) and the register fragments tXrGamma
and tXrBeta (the zero-initialized tensors) along with any no-op partitioning
code that references them; then update the kernel's _smem_size_in_bytes
calculation to subtract the removed SMEM footprint so reported shared-memory
usage is correct. Locate references by the symbols sGamma, sBeta, tXrGamma,
tXrBeta and ensure gamma_reg/beta_reg usage (the Float32 tensors loaded later)
remains unchanged.
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| out_dtype_str = _torch_dtype_to_str(out.dtype) | ||
| kernel = _get_compiled_rmsnorm_quant_kernel( | ||
| dtype_str, out_dtype_str, H, weight_bias, enable_pdl | ||
| ) | ||
| kernel(out, input, weight, M, scale, eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Argument order mismatch in rmsnorm_quant_cute – will produce wrong results or crash.
The compiled kernel expects arguments in the order (input, weight, output, M, scale, eps) because _get_compiled_rmsnorm_quant_kernel compiles with (x_fake, w_fake, y_fake, ...) matching RMSNormQuantKernel.__call__(mX, mW, mY, ...). But the runtime call passes (out, input, weight, ...):
- Position 0: expects Float16 input → receives Float8
out - Position 1: expects Float16 weight shape
(H,)→ receives Float16inputshape(M, H) - Position 2: expects Float8 output → receives Float16
weight
Compare with fused_add_rmsnorm_quant_cute (in fused_add_rmsnorm.py) which correctly places output first in both __call__ and compile.
🐛 Fix: reorder runtime arguments to match compiled kernel signature
kernel = _get_compiled_rmsnorm_quant_kernel(
dtype_str, out_dtype_str, H, weight_bias, enable_pdl
)
- kernel(out, input, weight, M, scale, eps)
+ kernel(input, weight, out, M, scale, eps)🤖 Prompt for AI Agents
In `@flashinfer/norm/kernels/rmsnorm.py` around lines 862 - 870, The runtime call
to the compiled kernel has the wrong argument order; change the call in
rmsnorm_quant_cute so it passes (input, weight, out, M, scale, eps) to match the
kernel signature produced by _get_compiled_rmsnorm_quant_kernel and
RMSNormQuantKernel.__call__(mX, mW, mY, ...); specifically, replace kernel(out,
input, weight, M, scale, eps) with kernel(input, weight, out, M, scale, eps) so
types/shapes align (input: Float16 MxH, weight: Float16 H, out: Float8).
📌 Description
We prioritize using dsl for kernel development over cuda for faster JIT compilation speed.
This PR is the first series that refactors the simple normalization kernels to cute-dsl.
CUDA code should be ready to remove after we finish end-to-end testing.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
API