-
Notifications
You must be signed in to change notification settings - Fork 100
Open
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thisgood first issueGood for newcomersGood for newcomersmodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development
Description
import torch
class Model(torch.nn.Module):
def forward(self, a, b):
return torch.nn.functional.grouped_mm(
a.to(torch.bfloat16), b.transpose(-2, -1).to(torch.bfloat16)
).to(torch.float32)
G, M, N, K = 4, 16, 32, 64
a = torch.randn(G, M, K, device="cpu", dtype=torch.float32)
b = torch.randn(G, N, K, device="cpu", dtype=torch.float32)
model = Model()
expected = model(a, b)
ep = torch.export.export(model, (a, b))
names = [str(n.target) for n in ep.graph.nodes]
assert "aten._grouped_mm.default" in names
epo = torch.onnx.export(
model, (a, b), dynamic_shapes=({0: "G", 1: "M", 2: "K"}, {0: "G", 1: "N", 2: "K"})
)<class 'torch.onnx._internal.exporter._errors.DispatchError'>: No ONNX function found for <OpOverload(op='aten._grouped_mm', overload='default')>. Failure message: No decompositions registered for the real-valued input
Issue met with microsoft/OptiMind-SFT.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thisgood first issueGood for newcomersGood for newcomersmodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development