Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class PrimIDs(Enum):
# Linear algebra prims (Mostly experimental)
MATMUL = auto()
_GROUPED_MM = auto() # Used for grouped matmuls
SCALED_GROUPED_MM = auto() # Used for scaled grouped matmuls
# NN prims (Experimental!)
CONVOLUTION = auto()
EMBEDDING = auto()
Expand Down Expand Up @@ -3792,6 +3793,128 @@ def _grouped_mm_meta(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) -> Te
)


def scaled_grouped_mm_meta(
a: TensorProxy,
b: TensorProxy,
scale_a: TensorProxy,
scale_b: TensorProxy,
offsets: None | TensorProxy = None,
bias: None | TensorProxy = None,
scale_result: None | TensorProxy = None,
out_dtype: None | dtypes.dtype = None,
) -> TensorProxy:
"""Meta function for scaled_grouped_mm primitive.

Similar to _grouped_mm but with scale tensors for quantization/dequantization.
Accepts the following shape combinations:
1. (m, k) x (k, n) -> (groups, m, n)
2. (groups, m, k) x (k, n) -> (m, n)
3. (m, k) x (groups, k, n) -> (m, n)

Args:
a: Input tensor of shape (groups, m, k) or (m, k)
b: Input tensor of shape (groups, k, n) or (k, n)
scale_a: Scale tensor for a
scale_b: Scale tensor for b
offsets: Optional offset tensor of shape (groups,)
bias: Optional bias tensor
scale_result: Optional scale tensor for result
out_dtype: Optional output dtype

Returns:
TensorProxy with shape (groups, m, n) or (m, n)
"""
# Validate types
utils.check_type(a, TensorProxy)
utils.check_type(b, TensorProxy)
utils.check_type(scale_a, TensorProxy)
utils.check_type(scale_b, TensorProxy)

# Accept 2D or 3D tensors
utils.check(a.ndim in (2, 3), lambda: f"Expected a to have 2 or 3 dimensions, got {a.ndim}")
utils.check(b.ndim in (2, 3), lambda: f"Expected b to have 2 or 3 dimensions, got {b.ndim}")

# Compute output shape using same logic as _grouped_mm
if offsets is not None:
utils.check_type(offsets, TensorProxy)
utils.check(offsets.ndim == 1, lambda: f"`offsets` must be a vector, got shape {offsets.shape}")

if a.ndim == 2 and b.ndim == 2:
utils.check(a.shape[1] == b.shape[0], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
out_shape = (offsets.shape[0], a.shape[0], b.shape[1])
elif a.ndim == 3 and b.ndim == 2:
utils.check(a.shape[2] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
utils.check(a.shape[0] == offsets.shape[0], lambda: f"Group count mismatch: {a.shape} vs {offsets.shape}")
out_shape = (a.shape[1], b.shape[1])
elif a.ndim == 2 and b.ndim == 3:
utils.check(a.shape[1] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
utils.check(b.shape[0] == offsets.shape[0], lambda: f"Group count mismatch: {b.shape} vs {offsets.shape}")
out_shape = (a.shape[0], b.shape[2])
else:
utils.check(False, lambda: f"Unexpected shape combination: {a.shape} and {b.shape}")
else:
# Without offsets, fall back to standard matmul shape logic
if a.ndim == 2 and b.ndim == 2:
utils.check(a.shape[1] == b.shape[0], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
out_shape = (a.shape[0], b.shape[1])
elif a.ndim == 3 and b.ndim == 2:
utils.check(a.shape[2] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
out_shape = (a.shape[0], a.shape[1], b.shape[1])
elif a.ndim == 2 and b.ndim == 3:
utils.check(a.shape[1] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
out_shape = (b.shape[0], a.shape[0], b.shape[2])
else:
utils.check(False, lambda: f"Unexpected shape combination: {a.shape} and {b.shape}")

# Validate scale tensors
# Scale tensors are typically 1D with shape matching the number of groups
# or they can be scalars
utils.check(
scale_a.ndim <= 1,
lambda: f"Expected scale_a to be a scalar or 1D tensor, got shape {scale_a.shape}",
)
utils.check(
scale_b.ndim <= 1,
lambda: f"Expected scale_b to be a scalar or 1D tensor, got shape {scale_b.shape}",
)

# Validate bias if provided
if bias is not None:
utils.check_type(bias, TensorProxy)
utils.check_same_device(a, bias)
utils.check_same_dtype(a, bias)

# Validate scale_result if provided
if scale_result is not None:
utils.check_type(scale_result, TensorProxy)
utils.check(
scale_result.ndim <= 1,
lambda: f"Expected scale_result to be a scalar or 1D tensor, got shape {scale_result.shape}",
)

utils.check_same_dtype(a, b)
utils.check(a.dtype in dtypes.float_math_dtypes, lambda: f"`a` must be 16-bit float or higher, got {a.dtype}")
if offsets is not None:
utils.check(utils.is_integer_dtype(offsets.dtype), lambda: f"`offsets` must be integers, got {offsets.dtype}")

utils.check_same_device(a, b, scale_a, scale_b)
if offsets is not None:
utils.check_same_device(a, offsets)

# Determine output dtype
result_dtype = out_dtype if out_dtype is not None else a.dtype

return TensorProxy(like=a, shape=out_shape, dtype=result_dtype)


scaled_grouped_mm = make_prim(
PrimIDs.SCALED_GROUPED_MM,
"scaled_grouped_mm",
meta=scaled_grouped_mm_meta,
tags=(OpTags.MATMUL_OP,),
)


def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorProxy:
utils.check_type(a, TensorProxy)
utils.check_type(permutation, tuple)
Expand Down
142 changes: 142 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,119 @@ def _copy_with_setitem_impl(a, key, value):
"torch.ops.aten._adaptive_avg_pool2d_backward", like=ltorch.adaptive_avg_pool2d_backward
)
multi_dot = _register_torch_operation("torch.linalg.multi_dot", like=ltorch.multi_dot)
if hasattr(torch.nn.functional, "scaled_grouped_mm"):
# PyTorch 2.10+ introduced scaled_grouped_mm with ScalingType/SwizzleType enums
if hasattr(torch.nn.functional, "ScalingType"):
# PyTorch 2.10+: scaled_grouped_mm is a new API with specific requirements
def _scaled_grouped_mm_impl(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
offsets: None | torch.Tensor = None,
bias: None | torch.Tensor = None,
scale_result: None | torch.Tensor = None,
out_dtype: None | torch.dtype = None,
) -> torch.Tensor:
"""Wrapper for PyTorch 2.10+ scaled_grouped_mm API.

PyTorch 2.10 introduced scaled_grouped_mm with requirements for:
- mat_b to have transposed memory layout (create as [G, N, K] then .transpose(-2, -1))
- Specific scale formats based on scale tensor shapes (infer RowWise vs TensorWise)
- ScalingType and SwizzleType enums for quantization control
"""
num_groups = offsets.numel() if offsets is not None else 1

# Transpose b to match PyTorch's expected memory layout
# PyTorch expects b in transposed form: shape [G, K, N] but with strides of [G, N, K].transpose(-2,-1)
b_transposed = b.transpose(-2, -1) if b.ndim >= 2 else b

# Infer scaling type and format scales appropriately
# For 2D x 3D case: a is [M, K], b is [G, K, N] (or [G, N, K] transposed)
# PyTorch expects:
# - If scale is scalar (0D) or has 1 element: TensorWise
# - If scale is 1D with length matching rows: RowWise (needs reshaping)
# - If scale is 2D: RowWise

# Handle scale_a
if scale_a.numel() == 1:
# Scalar scale - TensorWise
scale_a_list = [scale_a.view(1)] * num_groups
scale_recipe_a = [torch.nn.functional.ScalingType.TensorWise] * num_groups
elif scale_a.dim() == 1:
# 1D scale - could be TensorWise (if 1 elem) or RowWise
# For RowWise in 2D x 3D case: scale_a should be (num_groups * M,) which gets reshaped to (num_groups, M, 1)
if a.dim() == 2:
# a is [M, K], scale_a should be expandable to rowwise format
scale_a_2d = scale_a.view(num_groups, -1, 1) # [G, M, 1]
scale_a_list = [scale_a_2d[i] for i in range(num_groups)]
scale_recipe_a = [torch.nn.functional.ScalingType.RowWise] * num_groups
else:
# Fallback to splitting
if scale_a.size(0) == num_groups:
scale_a_list = [scale_a[i : i + 1] for i in range(num_groups)]
scale_recipe_a = [torch.nn.functional.ScalingType.TensorWise] * num_groups
else:
# Try RowWise
scale_a_2d = scale_a.view(num_groups, -1, 1)
scale_a_list = [scale_a_2d[i] for i in range(num_groups)]
scale_recipe_a = [torch.nn.functional.ScalingType.RowWise] * num_groups
else:
# 2D scale - RowWise
scale_a_list = [
scale_a[i].unsqueeze(-1) if scale_a[i].dim() == 1 else scale_a[i] for i in range(scale_a.size(0))
]
scale_recipe_a = [torch.nn.functional.ScalingType.RowWise] * num_groups

# Handle scale_b
if scale_b.numel() == 1:
# Scalar scale - TensorWise
scale_b_list = [scale_b.view(1)] * num_groups
scale_recipe_b = [torch.nn.functional.ScalingType.TensorWise] * num_groups
elif scale_b.dim() == 1:
# 1D scale
if scale_b.size(0) == num_groups:
scale_b_list = [scale_b[i : i + 1] for i in range(num_groups)]
scale_recipe_b = [torch.nn.functional.ScalingType.TensorWise] * num_groups
else:
# RowWise: reshape to (num_groups, N, 1) for transposed b
scale_b_2d = scale_b.view(num_groups, -1, 1)
scale_b_list = [scale_b_2d[i] for i in range(num_groups)]
scale_recipe_b = [torch.nn.functional.ScalingType.RowWise] * num_groups
else:
# 2D scale - RowWise
# b is [G, K, N], after transpose becomes [G, N, K]
# For RowWise on transposed b, scale should be [G, N, 1]
scale_b_list = [
scale_b[i].unsqueeze(-1) if scale_b[i].dim() == 1 else scale_b[i] for i in range(scale_b.size(0))
]
scale_recipe_b = [torch.nn.functional.ScalingType.RowWise] * num_groups

# Create swizzle parameters (no swizzle)
swizzle_a = [torch.nn.functional.SwizzleType.NO_SWIZZLE] * num_groups
swizzle_b = [torch.nn.functional.SwizzleType.NO_SWIZZLE] * num_groups

return torch.nn.functional.scaled_grouped_mm(
a,
b_transposed,
scale_a_list,
scale_recipe_a,
scale_b_list,
scale_recipe_b,
swizzle_a=swizzle_a,
swizzle_b=swizzle_b,
bias=bias,
offs=offsets,
output_dtype=out_dtype,
)

scaled_grouped_mm = ex.register_operator(
"scaled_grouped_mm", like=ltorch.scaled_grouped_mm, fn=_scaled_grouped_mm_impl
)
else:
# PyTorch < 2.10: scaled_grouped_mm doesn't exist yet
# This branch shouldn't be reached if the function exists
scaled_grouped_mm = _register_torch_operation("scaled_grouped_mm", module=torch.nn.functional)


def _max_pool_with_indices_helper(
Expand Down Expand Up @@ -1823,6 +1936,32 @@ def _grouped_mm_checker(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) ->
return a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16 and offsets.dtype == dtypes.int32


def _scaled_grouped_mm_checker(
a: TensorProxy,
b: TensorProxy,
scale_a: TensorProxy,
scale_b: TensorProxy,
offsets: None | TensorProxy = None,
bias: None | TensorProxy = None,
scale_result: None | TensorProxy = None,
out_dtype: None | dtypes.dtype = None,
) -> bool:
if not hasattr(torch.nn.functional, "scaled_grouped_mm"):
return False

if not torch.cuda.is_available():
return False

capability = torch.cuda.get_device_capability()
if capability < (9, 0):
return False

if torch.float4_e2m1fn_x2 in (a.dtype, b.dtype):
return False

return True


_register_implementation(ltorch.baddbmm, baddbmm, checker=_always_executable)
_register_implementation(ltorch.bmm, bmm, checker=_always_executable)
if LooseVersion(torch.__version__) >= "2.8":
Expand All @@ -1846,6 +1985,9 @@ def _grouped_mm_checker(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) ->
ltorch.log_softmax_backward, checker=_always_executable, execution_transform=_log_softmax_backward_transform
)
_register_implementation(ltorch.max_pool1d, max_pool1d, checker=_always_executable)
if hasattr(torch.nn.functional, "scaled_grouped_mm"):
_register_implementation(prims.scaled_grouped_mm, scaled_grouped_mm, checker=_scaled_grouped_mm_checker)
_register_implementation(ltorch.scaled_grouped_mm, scaled_grouped_mm, checker=_scaled_grouped_mm_checker)


def max_pool2d_bwd_wrapper(
Expand Down
Loading
Loading