-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Describe the issue
I have simple einsum equation 'ijk->kj' with input shape [2,3,4] and output shape as [4,2]. However, onnx runtime can't get it computed with error ('onnxruntime::common::Status onnxruntime::EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy(const onnxruntime::Tensor&, onnxruntime::Tensor&, void*) output.SizeInBytes() == input.SizeInBytes() was false. Einsum op: The candidate output does not match the actual output's shape).
Looking through code, I would think:
EinsumTypedComputeProcessor::FinalizeOutput(const Tensor& candidate_output,
const gsl::span& ordered_subscript_indices_in_candidate)
would implicitly require candidate_output has the same rank as ordered_subscript_indices_in_candidate. But in this case, rank of candidate_output would still be 3 with shape {2,1,4} while ordered_subscript_indices_in_candidate would be 2 with shape{0, 2} , this will generate candidate_output_shape_without_reduced_dims as [2,1], which is incorrect and should be [2,4].
To reproduce
(py11) sjt@sjt:~/piano1$ cat test.py
from onnx import helper,checker
from onnx import TensorProto
import numpy as np
import onnxruntime as rt
import onnx
def generate_onnx(node_name, op_set):
Eqn = "ijk->ki"
test_node = helper.make_node(
node_name,
inputs=["X"],
outputs=["Y"],
equation = Eqn
)
inputs = [
helper.make_tensor_value_info("X", TensorProto.DOUBLE, [2, 3, 4]),
]
outputs = [helper.make_tensor_value_info("Y", TensorProto.DOUBLE, [4, 2])]
graph = helper.make_graph(
[test_node],
"test",
inputs,
outputs,
)
model = helper.make_model(graph)
model.opset_import[0].version = op_set
checker.check_model(model)
onnx.save(model, node_name + ".onnx")
print(f"model is generated: {node_name}.onnx")
generate_onnx('Einsum', op_set=12)
input = np.random.randn(2,3,4)
sess = rt.InferenceSession("./Einsum.onnx")
output = sess.run(None, {'X': input})
(py11) sjt@sjt:~/piano1$ python test.py
model is generated: Einsum.onnx
2023-11-30 16:28:27.019210785 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Einsum node. Name:'' Status Message: /home/conda/feedstock_root/build_artifacts/onnxruntime_1697223480578/work/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc:18 onnxruntime::common::Status onnxruntime::EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy(const onnxruntime::Tensor&, onnxruntime::Tensor&, void*) output.SizeInBytes() == input.SizeInBytes() was false. Einsum op: The candidate output does not match the actual output's shape
Traceback (most recent call last):
File "/home/sjt/piano1/test.py", line 41, in
output = sess.run(None, {'X': input})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sjt/anaconda3/envs/py11/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
return self._sess.run(output_names, input_feed, run_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Einsum node. Name:'' Status Message: /home/conda/feedstock_root/build_artifacts/onnxruntime_1697223480578/work/onnxruntime/core/providers/cpu/math/einsum_utils/einsum_auxiliary_ops.cc:18 onnxruntime::common::Status onnxruntime::EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy(const onnxruntime::Tensor&, onnxruntime::Tensor&, void*) output.SizeInBytes() == input.SizeInBytes() was false. Einsum op: The candidate output does not match the actual output's shape
Urgency
No response
Platform
Linux
OS Version
Ubuntu 20.04.6 LTS
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.16.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response