Skip to content

Commit 040ae1f

Browse files
hariharans29github-actions[bot]Copilot
authored
[Kernel] Fix bug in Einsum implementation when a lone operand had a reduction operation (#27225)
### Description When the lone operand had a reduction operation in Einsum, there was a mis-management in the shape which led to a mis-match in terms of the subscript indices in the output as deciphered from the Einsum equation and the rank of the candidate output being produced after the reduction operation which caused issues while finalizing the output of the Einsum op. This change fixes the bug and adds a test ### Motivation and Context Resolve #18654 --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b75bfe1 commit 040ae1f

File tree

4 files changed

+44
-28
lines changed

4 files changed

+44
-28
lines changed

onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ const std::vector<int64_t>& EinsumComputePreprocessor::GetMappedSubscriptIndices
5656
return subscript_indices_to_output_indices_;
5757
}
5858

59-
int64_t EinsumComputePreprocessor::GetNumSubscriptIndices() const {
59+
size_t EinsumComputePreprocessor::GetNumSubscriptIndices() const {
6060
return num_subscript_indices_;
6161
}
6262

@@ -73,7 +73,7 @@ Status EinsumComputePreprocessor::ProcessSubscripts() {
7373
"Number of subscripts in the input equation does not match number of input tensors");
7474
}
7575

76-
int64_t input_index = 0;
76+
size_t input_index = 0;
7777

7878
// Holds mapping between input indices to its corresponding subscript labels for each input
7979
input_subscript_indices_.reserve(inputs_.size());
@@ -84,7 +84,7 @@ Status EinsumComputePreprocessor::ProcessSubscripts() {
8484
subscript_indices_to_dim_value_.reserve(10);
8585

8686
for (const auto& subscript : left_equation_split) {
87-
const auto& shape = inputs_[onnxruntime::narrow<size_t>(input_index)]->Shape();
87+
const auto& shape = inputs_[input_index]->Shape();
8888
const auto& dims = shape.GetDims();
8989
size_t rank = dims.size();
9090
size_t dim_counter = 0;
@@ -237,13 +237,13 @@ Status EinsumComputePreprocessor::PostProcessBroadcastedDims() {
237237
}
238238
}
239239

240-
std::vector<int64_t> temp_index_to_last_input(onnxruntime::narrow<size_t>(num_subscript_indices_), -1);
240+
std::vector<int64_t> temp_index_to_last_input(num_subscript_indices_, -1);
241241
for (size_t i = 0; i < subscript_indices_to_last_input_.size(); ++i) {
242242
temp_index_to_last_input[i + num_of_ellipsis_dims_] = subscript_indices_to_last_input_[i];
243243
}
244244
subscript_indices_to_last_input_ = std::move(temp_index_to_last_input);
245245

246-
std::vector<int64_t> temp_index_to_dim_value(onnxruntime::narrow<size_t>(num_subscript_indices_), -1);
246+
std::vector<int64_t> temp_index_to_dim_value(num_subscript_indices_, -1);
247247
for (size_t i = 0; i < subscript_indices_to_dim_value_.size(); ++i) {
248248
temp_index_to_dim_value[i + num_of_ellipsis_dims_] = subscript_indices_to_dim_value_[i];
249249
}
@@ -338,7 +338,7 @@ Status EinsumComputePreprocessor::CalculateOutputShape() {
338338
bool is_in_middle_of_ellipsis = false;
339339
int64_t ellipsis_char_count = 0;
340340

341-
subscript_indices_to_output_indices_.resize(onnxruntime::narrow<size_t>(num_subscript_indices_), -1);
341+
subscript_indices_to_output_indices_.resize(num_subscript_indices_, -1);
342342

343343
std::array<int64_t, EinsumOp::num_of_letters> output_letter_to_count;
344344
output_letter_to_count.fill(0);
@@ -407,24 +407,24 @@ Status EinsumComputePreprocessor::PreprocessInputs() {
407407
// As part of input preprocessing we "homogenize" them by
408408
// 1) Making them all of the same rank
409409
// 2) The axes order in all the inputs are to be made the same
410-
int64_t input_iter = 0;
410+
size_t input_iter = 0;
411411
for (const auto* input : inputs_) {
412412
// Eventually will hold the "preprocessed" version of the original input
413413
std::unique_ptr<Tensor> preprocessed;
414414

415415
const auto& input_dims = input->Shape().GetDims();
416-
const auto& current_subscript_indices = input_subscript_indices_[onnxruntime::narrow<size_t>(input_iter)];
416+
const auto& current_subscript_indices = input_subscript_indices_[input_iter];
417417

418418
// If all has gone well, we will have a subscript index (subscript label) for each dim of the input
419419
if (input_dims.size() != current_subscript_indices.size()) {
420420
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
421421
"Rank of the input must match number of subscript labels corresponding to the input");
422422
}
423423

424-
std::vector<int64_t> subscript_indices_to_input_index(onnxruntime::narrow<size_t>(num_subscript_indices_), -1);
424+
std::vector<int64_t> subscript_indices_to_input_index(num_subscript_indices_, -1);
425425

426426
// This is the input dims after re-ordering so that all inputs have same axes order
427-
TensorShapeVector homogenized_input_dims(onnxruntime::narrow<size_t>(num_subscript_indices_), 1);
427+
TensorShapeVector homogenized_input_dims(num_subscript_indices_, 1);
428428

429429
// Preprocessed dim rank may not be the same as original input rank if we need to parse diagonals along the way
430430
// (which reduces rank in the preprocessed input by 1 for each diagonal we parse)
@@ -437,7 +437,7 @@ Status EinsumComputePreprocessor::PreprocessInputs() {
437437
subscript_indices_to_input_index[onnxruntime::narrow<size_t>(subscript_index)] = dim_index_in_preprocessed_input++;
438438
homogenized_input_dims[onnxruntime::narrow<size_t>(subscript_index)] = input_dims[onnxruntime::narrow<size_t>(dim_index_in_original_input)];
439439
} else { // Diagonal needs to be parsed along the repeated axes
440-
preprocessed = device_diagonal_func_(preprocessed ? *preprocessed : *inputs_[onnxruntime::narrow<size_t>(input_iter)],
440+
preprocessed = device_diagonal_func_(preprocessed ? *preprocessed : *inputs_[input_iter],
441441
subscript_indices_to_input_index[onnxruntime::narrow<size_t>(subscript_index)],
442442
dim_index_in_preprocessed_input,
443443
allocator_, einsum_ep_assets_);
@@ -454,10 +454,10 @@ Status EinsumComputePreprocessor::PreprocessInputs() {
454454
}
455455

456456
// (Identify no-op transpose and prevent triggering the transpose)
457-
if (EinsumOp::IsTransposeRequired(preprocessed ? preprocessed->Shape().GetDims().size() : inputs_[onnxruntime::narrow<size_t>(input_iter)]->Shape().GetDims().size(),
457+
if (EinsumOp::IsTransposeRequired(preprocessed ? preprocessed->Shape().GetDims().size() : inputs_[input_iter]->Shape().GetDims().size(),
458458
permutation)) {
459-
preprocessed = EinsumOp::Transpose(preprocessed ? *preprocessed : *inputs_[onnxruntime::narrow<size_t>(input_iter)],
460-
preprocessed ? preprocessed->Shape().GetDims() : inputs_[onnxruntime::narrow<size_t>(input_iter)]->Shape().GetDims(),
459+
preprocessed = EinsumOp::Transpose(preprocessed ? *preprocessed : *inputs_[input_iter],
460+
preprocessed ? preprocessed->Shape().GetDims() : inputs_[input_iter]->Shape().GetDims(),
461461
permutation, allocator_, einsum_ep_assets_, device_transpose_func_);
462462
}
463463

onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class EinsumComputePreprocessor final {
140140
const std::vector<int64_t>& GetMappedSubscriptIndicesToOutputindices() const;
141141

142142
// Get the number of subscript indices (subscript labels) in the einsum equation
143-
int64_t GetNumSubscriptIndices() const;
143+
size_t GetNumSubscriptIndices() const;
144144

145145
// Pass-in device specific functions
146146
// (Pass-in CPU implementation or CUDA implementation function depending on the kernel using this class)
@@ -185,7 +185,7 @@ class EinsumComputePreprocessor final {
185185
// num_subscript_indices_ = 3 (i, j, k)
186186
// E.g. 2 : With equation -> '...ij', 'jk' -> '...ik'
187187
// num_subscript_indices_ = 3 (i, j, k) + number of dims specified by an ellipsis (across all inputs)
188-
int64_t num_subscript_indices_ = 0;
188+
size_t num_subscript_indices_ = 0;
189189

190190
// Hold the count corresponding to the letter seen
191191
// `0` means the corresponding letter wasn't seen at all

onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ namespace onnxruntime {
1010
template <typename T>
1111
void EinsumTypedComputeProcessor<T>::FinalizeOutput(const Tensor& candidate_output,
1212
const gsl::span<const int64_t>& ordered_subscript_indices_in_candidate) {
13+
ORT_ENFORCE(candidate_output.Shape().NumDimensions() == ordered_subscript_indices_in_candidate.size(),
14+
"Einsum op: The candidate output's rank has to be the same number of elements as "
15+
"the ordered subscript indices in the candidate output. Hitting this error points to an "
16+
"internal bug in the Einsum op's implementation. "
17+
"Please open a bug report with appropriate repro steps");
18+
1319
const std::vector<int64_t>& subscript_indices_to_output_indices =
1420
einsum_compute_preprocessor_.GetMappedSubscriptIndicesToOutputindices();
1521
const auto output_dims = einsum_compute_preprocessor_.GetOutputDims();
@@ -75,7 +81,7 @@ void EinsumTypedComputeProcessor<T>::FinalizeOutput(const Tensor& candidate_outp
7581
static bool IsTransposeReshapeForEinsum(const gsl::span<const size_t>& perm,
7682
gsl::span<const int64_t> input_dims,
7783
TensorShapeVector& new_shape) {
78-
// As long as the dims with values > 1 stay in the same order, it's a reshape.
84+
// As long as the dims with values > 1 stay in the same relative order, it's a reshape.
7985
// Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
8086
size_t last_permuted_axis = 0;
8187
for (size_t i = 0; i < perm.size(); ++i) {
@@ -377,17 +383,19 @@ Status EinsumTypedComputeProcessor<T>::Run() {
377383
std::unique_ptr<const Tensor> result;
378384

379385
{
380-
TensorShapeVector reduced_dims;
381-
TensorShapeVector preserved_dims; // dims which were not reduced
382-
reduced_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving.
383-
preserved_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving.
386+
TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op
387+
reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving
388+
389+
TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1
390+
all_dims.reserve(num_subscript_labels); // num_subscript_labels is the number of elements
384391

385-
for (size_t i = 0; i < onnxruntime::narrow<size_t>(num_subscript_labels); ++i) {
392+
for (size_t i = 0; i < num_subscript_labels; ++i) {
386393
if (mapped_indices_to_last_input_index[i] == 0) {
387394
reduced_dims.push_back(i);
388-
} else {
389-
preserved_dims.push_back(i);
390395
}
396+
397+
// ReduceSum operation preserves even the reduced dims with reduced dim shape value being 1
398+
all_dims.push_back(i);
391399
}
392400

393401
// Reduce the dims that are last seen in the first input alone
@@ -407,7 +415,7 @@ Status EinsumTypedComputeProcessor<T>::Run() {
407415
if (num_inputs == 1) {
408416
// Finalize the output by applying any transpose required to get
409417
// it to the required output ordering and move it to the op's output
410-
FinalizeOutput(result ? *result : *raw_inputs[0], preserved_dims);
418+
FinalizeOutput(result ? *result : *raw_inputs[0], all_dims);
411419

412420
return Status::OK();
413421
}
@@ -419,9 +427,9 @@ Status EinsumTypedComputeProcessor<T>::Run() {
419427
// Keep processing each input pair-wise
420428
for (int input = 1; input < num_inputs; ++input) {
421429
TensorShapeVector reduced_dims;
422-
reduced_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin.
423-
for (int64_t dim = 0; dim < num_subscript_labels; ++dim) {
424-
if (mapped_indices_to_last_input_index[onnxruntime::narrow<size_t>(dim)] == input) {
430+
reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin.
431+
for (size_t dim = 0; dim < num_subscript_labels; ++dim) {
432+
if (mapped_indices_to_last_input_index[dim] == input) {
425433
// This is the last input we are seeing this dimension (and it doesn't occur in the output), so reduce along the dimension
426434
reduced_dims.push_back(dim);
427435
}

onnxruntime/test/providers/cpu/math/einsum_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ TEST(Einsum, ExplicitEinsumAsBatchedReduceOp_3D_input_1) {
114114
test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
115115
}
116116

117+
TEST(Einsum, ExplicitEinsumAsReduceWithTransposeOp_3D_input_0) {
118+
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
119+
test.AddAttribute<std::string>("equation", "ijk->ki");
120+
test.AddInput<float>("x", {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f});
121+
test.AddOutput<float>("y", {4, 2}, {15.f, 15.f, 18.f, 18.f, 21.f, 21.f, 24.f, 24.f});
122+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
123+
}
124+
117125
// Implicit
118126
// Cannot do implicit reduction
119127

0 commit comments

Comments
 (0)