Add cutlass python dsl executor for quack-kernels#2719
Add cutlass python dsl executor for quack-kernels#2719
quack-kernels#2719Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR adds support for the CUTLASS DSL executor (cutlass_dsl_ex) to Thunder, integrating the quack library for optimized operations like softmax, cross_entropy, layer_norm, and RMS norm on NVIDIA SM9.0/10.0 GPUs.
- Introduces a new
cutlass_dsl_exexecutor with quack operation implementations - Adds comprehensive test coverage for quack operations
- Adds benchmark suites for performance comparison against nvfuser and torch_compile
- Registers the new executor in Thunder's executor registry
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| thunder/executors/cutlass_dsl_ex.py | New file implementing the cutlass_dsl executor with quack operations for softmax, cross_entropy, layer_norm, and RMS norm |
| thunder/extend/init.py | Registers cutlass_dsl_ex in the get_all_executors function |
| thunder/tests/test_extend.py | Updates test to include cutlass_dsl executor in the expected executors list |
| thunder/tests/test_cutlass_dsl_ex.py | New test file with comprehensive tests for quack operations |
| thunder/benchmarks/targets.py | Adds benchmark classes and test functions for quack operations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if requires_reshpae := a.ndim > 2: | ||
| a = a.view(-1, original_shape[-1]) | ||
| ret = softmax_fwd(a) | ||
| if requires_reshpae: |
There was a problem hiding this comment.
Corrected spelling of 'requires_reshpae' to 'requires_reshape'.
| if requires_reshpae := a.ndim > 2: | |
| a = a.view(-1, original_shape[-1]) | |
| ret = softmax_fwd(a) | |
| if requires_reshpae: | |
| if requires_reshape := a.ndim > 2: | |
| a = a.view(-1, original_shape[-1]) | |
| ret = softmax_fwd(a) | |
| if requires_reshape: |
| if requires_reshpae := a.ndim > 2: | ||
| a = a.view(-1, original_shape[-1]) | ||
| ret = softmax_fwd(a) | ||
| if requires_reshpae: |
There was a problem hiding this comment.
Corrected spelling of 'requires_reshpae' to 'requires_reshape'.
| if requires_reshpae := a.ndim > 2: | |
| a = a.view(-1, original_shape[-1]) | |
| ret = softmax_fwd(a) | |
| if requires_reshpae: | |
| if requires_reshape := a.ndim > 2: | |
| a = a.view(-1, original_shape[-1]) | |
| ret = softmax_fwd(a) | |
| if requires_reshape: |
| a.ndim != 2 | ||
| or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} | ||
| and target.ndim == 1 | ||
| and target.dytpe in {dtypes.int32, dtypes.int64} |
There was a problem hiding this comment.
Corrected spelling of 'dytpe' to 'dtype'.
| and target.dytpe in {dtypes.int32, dtypes.int64} | |
| and target.dtype in {dtypes.int32, dtypes.int64} |
| def quack_softmax_backward_meta(g: TensorProxy, a: TensorProxy) -> TensorProxy: | ||
| return TensorProxy(like=g) | ||
|
|
||
| quack_softmax_backward = cutlass_dsl_ex.register_operator( |
There was a problem hiding this comment.
The global variable 'quack_softmax_backward' is not used.
| quack_softmax_backward = cutlass_dsl_ex.register_operator( | |
| cutlass_dsl_ex.register_operator( |
| return thunder.jit(fn, executors=[nvfuserex]) | ||
|
|
||
|
|
||
| class BaseBenchmarkForQuack(Benchmark, metaclass=UserFacingBenchmarkMeta): |
There was a problem hiding this comment.
This class does not call Benchmark.init during initialization. (BaseBenchmarkForQuack.init may be missing a call to a base class init)
| weight: TensorProxy | None = None, | ||
| bias: TensorProxy | None = None, | ||
| eps: Number = 1e-5, | ||
| ) -> bool: | ||
| if ( | ||
| a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} | ||
| or weight.ndim != 1 | ||
| or a.shape[-1] != weight.shape[0] | ||
| or weight.dtype not in {dtypes.float32} |
There was a problem hiding this comment.
Can weight be None? In that case this would need to check before trying to access .ndim
There was a problem hiding this comment.
good catch. will check it
thunder/executors/cutlass_dsl_ex.py
Outdated
|
|
||
| quack_version: LooseVersion | ||
| try: | ||
| import quack |
There was a problem hiding this comment.
do we need to add this into requirements to install it?
There was a problem hiding this comment.
I'd not think we should do so. Because pip install quack-kernels seems to install cuda python packages such as nvidia-cutlass-dsl and I don't know how to having requirements.txt install cuda python packages that respect users local environments
|
|
||
| expected = F.cross_entropy(ref_x, targets, reduction="none") | ||
| actual = jitted(x, targets, reduction="none") | ||
| torch.testing.assert_close(expected, actual) |
There was a problem hiding this comment.
It seems the backward is not tested
There was a problem hiding this comment.
I've not managed to have backward work
Starting with quack's softmax Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
it seems that quack's cross-entropy function upcasts inputs to fp32, thus updating test and meta function Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
What does this PR do?
As per title, this adds cutlass python dsl executor.
In this PR, the kernels defined in https://github.com/Dao-AILab/quack, except matmul, are registered. Also, backward is not integrated.