Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR adds support for torch.nn.functional.scaled_grouped_mm, a PyTorch function for scaled grouped matrix multiplication.
Key changes:
- Implements the
scaled_grouped_mmtorchsymbol with input validation and shape inference - Adds three comprehensive test cases covering 2D×2D and 2D×3D tensor combinations with different scaling types
- Registers the operation in the torch executor with appropriate availability checking
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| thunder/torch/init.py | Adds scaled_grouped_mm function with shape validation, dtype checking, and output shape computation |
| thunder/tests/test_ops.py | Adds test cases for tensorwise and blockwise scaling scenarios with FP8 and MXFP8 dtypes |
| thunder/executors/torchex.py | Registers scaled_grouped_mm operation and implements checker function for executor |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
thunder/tests/test_ops.py:1
- Both test_scaled_grouped_mm_3d2d_rowwise and test_scaled_grouped_mm_2d3d_rowwise test the same 2D @ 3D case (mat_a is 2D, mat_b after transpose is 3D). There is no test coverage for the 3D @ 2D case where mat_a would be 3D with shape (groups, m, k) and mat_b would be 2D with shape (k, n). Consider adding a test for this case or modifying one of the existing tests to cover it.
from collections.abc import Callable
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
74c59d3 to
f428a90
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
626c7e6 to
269d056
Compare
What does this PR do?
As per title, adds https://docs.pytorch.org/docs/main/generated/torch.nn.functional.scaled_grouped_mm.html