Skip to content

Commit 99a3928

Browse files
committed
Address nit
1 parent 23303da commit 99a3928

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

onnxruntime/core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class EinsumComputePreprocessor final {
140140
const std::vector<int64_t>& GetMappedSubscriptIndicesToOutputindices() const;
141141

142142
// Get the number of subscript indices (subscript labels) in the einsum equation
143-
int64_t GetNumSubscriptIndices() const;
143+
size_t GetNumSubscriptIndices() const;
144144

145145
// Pass-in device specific functions
146146
// (Pass-in CPU implementation or CUDA implementation function depending on the kernel using this class)

onnxruntime/core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,12 @@ Status EinsumTypedComputeProcessor<T>::Run() {
368368

369369
{
370370
TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op
371-
reduced_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving
371+
reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving
372372

373373
TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1
374-
all_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the number of elements
374+
all_dims.reserve(num_subscript_labels); // num_subscript_labels is the number of elements
375375

376-
for (size_t i = 0; i < onnxruntime::narrow<size_t>(num_subscript_labels); ++i) {
376+
for (size_t i = 0; i < num_subscript_labels; ++i) {
377377
if (mapped_indices_to_last_input_index[i] == 0) {
378378
reduced_dims.push_back(i);
379379
}
@@ -411,9 +411,9 @@ Status EinsumTypedComputeProcessor<T>::Run() {
411411
// Keep processing each input pair-wise
412412
for (int input = 1; input < num_inputs; ++input) {
413413
TensorShapeVector reduced_dims;
414-
reduced_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin.
415-
for (int64_t dim = 0; dim < num_subscript_labels; ++dim) {
416-
if (mapped_indices_to_last_input_index[onnxruntime::narrow<size_t>(dim)] == input) {
414+
reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin.
415+
for (size_t dim = 0; dim < num_subscript_labels; ++dim) {
416+
if (mapped_indices_to_last_input_index[dim] == input) {
417417
// This is the last input we are seeing this dimension (and it doesn't occur in the output), so reduce along the dimension
418418
reduced_dims.push_back(dim);
419419
}

0 commit comments

Comments
 (0)