Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ const std::vector<int64_t>& EinsumComputePreprocessor::GetMappedSubscriptIndices
return subscript_indices_to_output_indices_;
}

int64_t EinsumComputePreprocessor::GetNumSubscriptIndices() const {
size_t EinsumComputePreprocessor::GetNumSubscriptIndices() const {
return num_subscript_indices_;
}

Expand All @@ -73,7 +73,7 @@ Status EinsumComputePreprocessor::ProcessSubscripts() {
"Number of subscripts in the input equation does not match number of input tensors");
}

int64_t input_index = 0;
size_t input_index = 0;

// Holds mapping between input indices to its corresponding subscript labels for each input
input_subscript_indices_.reserve(inputs_.size());
Expand All @@ -84,7 +84,7 @@ Status EinsumComputePreprocessor::ProcessSubscripts() {
subscript_indices_to_dim_value_.reserve(10);

for (const auto& subscript : left_equation_split) {
const auto& shape = inputs_[onnxruntime::narrow<size_t>(input_index)]->Shape();
const auto& shape = inputs_[input_index]->Shape();
const auto& dims = shape.GetDims();
size_t rank = dims.size();
size_t dim_counter = 0;
Expand Down Expand Up @@ -237,13 +237,13 @@ Status EinsumComputePreprocessor::PostProcessBroadcastedDims() {
}
}

std::vector<int64_t> temp_index_to_last_input(onnxruntime::narrow<size_t>(num_subscript_indices_), -1);
std::vector<int64_t> temp_index_to_last_input(num_subscript_indices_, -1);
for (size_t i = 0; i < subscript_indices_to_last_input_.size(); ++i) {
temp_index_to_last_input[i + num_of_ellipsis_dims_] = subscript_indices_to_last_input_[i];
}
subscript_indices_to_last_input_ = std::move(temp_index_to_last_input);

std::vector<int64_t> temp_index_to_dim_value(onnxruntime::narrow<size_t>(num_subscript_indices_), -1);
std::vector<int64_t> temp_index_to_dim_value(num_subscript_indices_, -1);
for (size_t i = 0; i < subscript_indices_to_dim_value_.size(); ++i) {
temp_index_to_dim_value[i + num_of_ellipsis_dims_] = subscript_indices_to_dim_value_[i];
}
Expand Down Expand Up @@ -338,7 +338,7 @@ Status EinsumComputePreprocessor::CalculateOutputShape() {
bool is_in_middle_of_ellipsis = false;
int64_t ellipsis_char_count = 0;

subscript_indices_to_output_indices_.resize(onnxruntime::narrow<size_t>(num_subscript_indices_), -1);
subscript_indices_to_output_indices_.resize(num_subscript_indices_, -1);

std::array<int64_t, EinsumOp::num_of_letters> output_letter_to_count;
output_letter_to_count.fill(0);
Expand Down Expand Up @@ -407,24 +407,24 @@ Status EinsumComputePreprocessor::PreprocessInputs() {
// As part of input preprocessing we "homogenize" them by
// 1) Making them all of the same rank
// 2) The axes order in all the inputs are to be made the same
int64_t input_iter = 0;
size_t input_iter = 0;
for (const auto* input : inputs_) {
// Eventually will hold the "preprocessed" version of the original input
std::unique_ptr<Tensor> preprocessed;

const auto& input_dims = input->Shape().GetDims();
const auto& current_subscript_indices = input_subscript_indices_[onnxruntime::narrow<size_t>(input_iter)];
const auto& current_subscript_indices = input_subscript_indices_[input_iter];

// If all has gone well, we will have a subscript index (subscript label) for each dim of the input
if (input_dims.size() != current_subscript_indices.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Rank of the input must match number of subscript labels corresponding to the input");
}

std::vector<int64_t> subscript_indices_to_input_index(onnxruntime::narrow<size_t>(num_subscript_indices_), -1);
std::vector<int64_t> subscript_indices_to_input_index(num_subscript_indices_, -1);

// This is the input dims after re-ordering so that all inputs have same axes order
TensorShapeVector homogenized_input_dims(onnxruntime::narrow<size_t>(num_subscript_indices_), 1);
TensorShapeVector homogenized_input_dims(num_subscript_indices_, 1);

// Preprocessed dim rank may not be the same as original input rank if we need to parse diagonals along the way
// (which reduces rank in the preprocessed input by 1 for each diagonal we parse)
Expand All @@ -437,7 +437,7 @@ Status EinsumComputePreprocessor::PreprocessInputs() {
subscript_indices_to_input_index[onnxruntime::narrow<size_t>(subscript_index)] = dim_index_in_preprocessed_input++;
homogenized_input_dims[onnxruntime::narrow<size_t>(subscript_index)] = input_dims[onnxruntime::narrow<size_t>(dim_index_in_original_input)];
} else { // Diagonal needs to be parsed along the repeated axes
preprocessed = device_diagonal_func_(preprocessed ? *preprocessed : *inputs_[onnxruntime::narrow<size_t>(input_iter)],
preprocessed = device_diagonal_func_(preprocessed ? *preprocessed : *inputs_[input_iter],
subscript_indices_to_input_index[onnxruntime::narrow<size_t>(subscript_index)],
dim_index_in_preprocessed_input,
allocator_, einsum_ep_assets_);
Expand All @@ -454,10 +454,10 @@ Status EinsumComputePreprocessor::PreprocessInputs() {
}

// (Identify no-op transpose and prevent triggering the transpose)
if (EinsumOp::IsTransposeRequired(preprocessed ? preprocessed->Shape().GetDims().size() : inputs_[onnxruntime::narrow<size_t>(input_iter)]->Shape().GetDims().size(),
if (EinsumOp::IsTransposeRequired(preprocessed ? preprocessed->Shape().GetDims().size() : inputs_[input_iter]->Shape().GetDims().size(),
permutation)) {
preprocessed = EinsumOp::Transpose(preprocessed ? *preprocessed : *inputs_[onnxruntime::narrow<size_t>(input_iter)],
preprocessed ? preprocessed->Shape().GetDims() : inputs_[onnxruntime::narrow<size_t>(input_iter)]->Shape().GetDims(),
preprocessed = EinsumOp::Transpose(preprocessed ? *preprocessed : *inputs_[input_iter],
preprocessed ? preprocessed->Shape().GetDims() : inputs_[input_iter]->Shape().GetDims(),
permutation, allocator_, einsum_ep_assets_, device_transpose_func_);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class EinsumComputePreprocessor final {
const std::vector<int64_t>& GetMappedSubscriptIndicesToOutputindices() const;

// Get the number of subscript indices (subscript labels) in the einsum equation
int64_t GetNumSubscriptIndices() const;
size_t GetNumSubscriptIndices() const;

// Pass-in device specific functions
// (Pass-in CPU implementation or CUDA implementation function depending on the kernel using this class)
Expand Down Expand Up @@ -185,7 +185,7 @@ class EinsumComputePreprocessor final {
// num_subscript_indices_ = 3 (i, j, k)
// E.g. 2 : With equation -> '...ij', 'jk' -> '...ik'
// num_subscript_indices_ = 3 (i, j, k) + number of dims specified by an ellipsis (across all inputs)
int64_t num_subscript_indices_ = 0;
size_t num_subscript_indices_ = 0;

// Hold the count corresponding to the letter seen
// `0` means the corresponding letter wasn't seen at all
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ namespace onnxruntime {
template <typename T>
void EinsumTypedComputeProcessor<T>::FinalizeOutput(const Tensor& candidate_output,
const gsl::span<const int64_t>& ordered_subscript_indices_in_candidate) {
ORT_ENFORCE(candidate_output.Shape().NumDimensions() == ordered_subscript_indices_in_candidate.size(),
"Einsum op: The candidate output's rank has to be the same number of elements as "
"the ordered subscript indices in the candidate output. Hitting this error points to an "
"internal bug in the Einsum op's implementation. "
"Please open a bug report with appropriate repro steps");

const std::vector<int64_t>& subscript_indices_to_output_indices =
einsum_compute_preprocessor_.GetMappedSubscriptIndicesToOutputindices();
const auto output_dims = einsum_compute_preprocessor_.GetOutputDims();
Expand Down Expand Up @@ -75,7 +81,7 @@ void EinsumTypedComputeProcessor<T>::FinalizeOutput(const Tensor& candidate_outp
static bool IsTransposeReshapeForEinsum(const gsl::span<const size_t>& perm,
gsl::span<const int64_t> input_dims,
TensorShapeVector& new_shape) {
// As long as the dims with values > 1 stay in the same order, it's a reshape.
// As long as the dims with values > 1 stay in the same relative order, it's a reshape.
// Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
size_t last_permuted_axis = 0;
for (size_t i = 0; i < perm.size(); ++i) {
Expand Down Expand Up @@ -361,17 +367,19 @@ Status EinsumTypedComputeProcessor<T>::Run() {
std::unique_ptr<const Tensor> result;

{
TensorShapeVector reduced_dims;
TensorShapeVector preserved_dims; // dims which were not reduced
reduced_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving.
preserved_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving.
TensorShapeVector reduced_dims; // All dims of the input that are reduced using the `ReduceSum` op
reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving

TensorShapeVector all_dims; // All dimension indices from 0 to num_subscript_labels - 1
all_dims.reserve(num_subscript_labels); // num_subscript_labels is the number of elements

for (size_t i = 0; i < onnxruntime::narrow<size_t>(num_subscript_labels); ++i) {
for (size_t i = 0; i < num_subscript_labels; ++i) {
if (mapped_indices_to_last_input_index[i] == 0) {
reduced_dims.push_back(i);
} else {
preserved_dims.push_back(i);
}

// ReduceSum operation preserves even the reduced dims with reduced dim shape value being 1
all_dims.push_back(i);
}

// Reduce the dims that are last seen in the first input alone
Expand All @@ -391,7 +399,7 @@ Status EinsumTypedComputeProcessor<T>::Run() {
if (num_inputs == 1) {
// Finalize the output by applying any transpose required to get
// it to the required output ordering and move it to the op's output
FinalizeOutput(result ? *result : *raw_inputs[0], preserved_dims);
FinalizeOutput(result ? *result : *raw_inputs[0], all_dims);

return Status::OK();
}
Expand All @@ -403,9 +411,9 @@ Status EinsumTypedComputeProcessor<T>::Run() {
// Keep processing each input pair-wise
for (int input = 1; input < num_inputs; ++input) {
TensorShapeVector reduced_dims;
reduced_dims.reserve(onnxruntime::narrow<size_t>(num_subscript_labels)); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin.
for (int64_t dim = 0; dim < num_subscript_labels; ++dim) {
if (mapped_indices_to_last_input_index[onnxruntime::narrow<size_t>(dim)] == input) {
reduced_dims.reserve(num_subscript_labels); // num_subscript_labels is the upper bound. No harm in over-reserving by a small margin.
for (size_t dim = 0; dim < num_subscript_labels; ++dim) {
if (mapped_indices_to_last_input_index[dim] == input) {
// This is the last input we are seeing this dimension (and it doesn't occur in the output), so reduce along the dimension
reduced_dims.push_back(dim);
}
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/test/providers/cpu/math/einsum_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ TEST(Einsum, ExplicitEinsumAsBatchedReduceOp_3D_input_1) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
}

TEST(Einsum, ExplicitEinsumAsReduceWithTransposeOp_3D_input_0) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ijk->ki");
test.AddInput<float>("x", {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f});
test.AddOutput<float>("y", {4, 2}, {15.f, 15.f, 18.f, 18.f, 21.f, 21.f, 24.f, 24.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100());
}

// Implicit
// Cannot do implicit reduction

Expand Down
Loading