Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
std::unique_ptr<EinsumTypedComputeProcessor<float>> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<float>>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
std::unique_ptr<EinsumTypedComputeProcessor<double>> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<double>>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
std::unique_ptr<EinsumTypedComputeProcessor<MLFloat16>> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) override { return std::make_unique<EinsumTypedComputeProcessor<MLFloat16>>(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); }
void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) override { return p->SetDeviceHelpers(device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, device_zeroing_func); }
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<float>* p) override { return p->Run(); }
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<double>* p) override { return p->Run(); }
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<MLFloat16>* p) override { return p->Run(); }
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ struct ProviderHostCPU {
virtual std::unique_ptr<EinsumTypedComputeProcessor<float>> EinsumTypedComputeProcessor_float__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0;
virtual std::unique_ptr<EinsumTypedComputeProcessor<double>> EinsumTypedComputeProcessor_double__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0;
virtual std::unique_ptr<EinsumTypedComputeProcessor<MLFloat16>> EinsumTypedComputeProcessor_MLFloat16__Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<float>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<float>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<float>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<double>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<double>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<double>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0;
virtual void EinsumTypedComputeProcessor__SetDeviceHelpers(EinsumTypedComputeProcessor<MLFloat16>* p, const EinsumOp::DeviceHelpers::Transpose& device_transpose_func, const EinsumOp::DeviceHelpers::MatMul<MLFloat16>& device_matmul_func, const EinsumOp::DeviceHelpers::ReduceSum<MLFloat16>& device_reduce_sum_func, const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func, const EinsumOp::DeviceHelpers::Zeroing& device_zeroing_func) = 0;
virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<float>* p) = 0;
virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<double>* p) = 0;
virtual Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<MLFloat16>* p) = 0;
Expand Down Expand Up @@ -296,8 +296,9 @@ struct EinsumTypedComputeProcessor {
void SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func,
const EinsumOp::DeviceHelpers::MatMul<T>& device_matmul_func,
const EinsumOp::DeviceHelpers::ReduceSum<T>& device_reduce_sum_func,
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) {
g_host_cpu.EinsumTypedComputeProcessor__SetDeviceHelpers(this, device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func);
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func,
const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func) {
g_host_cpu.EinsumTypedComputeProcessor__SetDeviceHelpers(this, device_transpose_func, device_matmul_func, device_reduce_sum_func, device_data_copy_func, zero_input_buffer_func);
}

Status Run() { return g_host_cpu.EinsumTypedComputeProcessor__Run(this); }
Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/core/providers/cpu/math/einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor.SetDeviceHelpers(EinsumOp::DeviceHelpers::CpuDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::MatMul<float>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum<float>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing);
return einsum_compute_processor.Run();
} else if (inputs[0]->IsDataType<int32_t>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<int32_t>(context,
Expand All @@ -78,7 +79,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor.SetDeviceHelpers(EinsumOp::DeviceHelpers::CpuDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::MatMul<int32_t>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum<int32_t>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing);

return einsum_compute_processor.Run();
} else if (inputs[0]->IsDataType<double>()) {
Expand All @@ -92,7 +94,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor.SetDeviceHelpers(EinsumOp::DeviceHelpers::CpuDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::MatMul<double>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum<double>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing);
return einsum_compute_processor.Run();
} else if (inputs[0]->IsDataType<int64_t>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<int64_t>(context,
Expand All @@ -104,7 +107,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor.SetDeviceHelpers(EinsumOp::DeviceHelpers::CpuDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::MatMul<int64_t>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::ReduceSum<int64_t>,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CpuDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CpuDeviceHelpers::Zeroing);

return einsum_compute_processor.Run();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ Status DataCopy(const Tensor& input, Tensor& output, void* /*einsum_cuda_assets*
return Status::OK();
}

// CPU specific Zeroing helper
Status Zeroing(Tensor& input, void* /*einsum_cuda_assets*/) {
memset(input.MutableDataRaw(), 0, input.SizeInBytes());
return Status::OK();
}

// CPU specific Transpose helper
Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input,
Tensor& output, const TensorShape* input_shape_override, void* /*einsum_cuda_assets*/) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ namespace DeviceHelpers {
// Data copy op - Copies raw data from the source tensor's buffer to the destination tensor's buffer
using DataCopy = std::function<Status(const Tensor& input, Tensor& output, void* einsum_cuda_assets)>;

// Zeroing op - Sets all bytes in the tensor's buffer to zero
using Zeroing = std::function<Status(Tensor& input, void* einsum_cuda_assets)>;

// Transpose op - Transposes given input based on data in `permutation`
using Transpose = std::function<Status(const gsl::span<const size_t>& permutation, const Tensor& input,
Tensor& output, const TensorShape* input_shape_override,
Expand Down Expand Up @@ -63,6 +66,8 @@ namespace CpuDeviceHelpers {

Status DataCopy(const Tensor& input, Tensor& output, void* einsum_cuda_assets);

Status Zeroing(Tensor& input, void* einsum_cuda_assets);

Status Transpose(const gsl::span<const size_t>& permutation, const Tensor& input,
Tensor& output, const TensorShape* input_shape_override, void* einsum_cuda_assets);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,13 @@ template <typename T>
void EinsumTypedComputeProcessor<T>::SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func,
const EinsumOp::DeviceHelpers::MatMul<T>& device_matmul_func,
const EinsumOp::DeviceHelpers::ReduceSum<T>& device_reduce_sum_func,
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func) {
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func,
const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func) {
device_transpose_func_ = device_transpose_func;
device_matmul_func_ = device_matmul_func;
device_reduce_sum_func_ = device_reduce_sum_func;
device_data_copy_func_ = device_data_copy_func;
zero_input_buffer_func_ = zero_input_buffer_func;
}

template <typename T>
Expand All @@ -357,6 +359,20 @@ Status EinsumTypedComputeProcessor<T>::Run() {

auto num_inputs = context_->InputCount();

{
bool has_empty_input = std::any_of(raw_inputs.begin(), raw_inputs.end(), [](const auto& input) {
return input->Shape().Size() == 0;
});

// Skip all the work, fill with zeros if needed
if (has_empty_input) {
const auto output_dims = einsum_compute_preprocessor_.GetOutputDims();
Tensor& output = *context_->Output(0, output_dims);

return zero_input_buffer_func_(output, einsum_ep_assets_);
}
}

// Pre-process the first input so as to reduce any dims that only it has
std::unique_ptr<const Tensor> result;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class EinsumTypedComputeProcessor {
void SetDeviceHelpers(const EinsumOp::DeviceHelpers::Transpose& device_transpose_func,
const EinsumOp::DeviceHelpers::MatMul<T>& device_matmul_func,
const EinsumOp::DeviceHelpers::ReduceSum<T>& device_reduce_sum_func,
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func);
const EinsumOp::DeviceHelpers::DataCopy& device_data_copy_func,
const EinsumOp::DeviceHelpers::Zeroing& zero_input_buffer_func);

Status Run();

Expand Down Expand Up @@ -64,6 +65,7 @@ class EinsumTypedComputeProcessor {
EinsumOp::DeviceHelpers::MatMul<T> device_matmul_func_;
EinsumOp::DeviceHelpers::ReduceSum<T> device_reduce_sum_func_;
EinsumOp::DeviceHelpers::DataCopy device_data_copy_func_;
EinsumOp::DeviceHelpers::Zeroing zero_input_buffer_func_;

// Holds EP-specific assets required for (auxiliary) ops that need to be executed on non-CPU EPs
void* einsum_ep_assets_;
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/core/providers/cuda/math/einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul<float>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum<float>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing);
return einsum_compute_processor->Run();
} else if (inputs[0]->IsDataType<double>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<double>::Create(context, allocator, tp,
Expand All @@ -63,7 +64,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul<double>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum<double>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing);
return einsum_compute_processor->Run();
} else if (inputs[0]->IsDataType<MLFloat16>()) {
auto einsum_compute_processor = EinsumTypedComputeProcessor<MLFloat16>::Create(context, allocator, tp,
Expand All @@ -73,7 +75,8 @@ Status Einsum::DeviceCompute(OpKernelContext* context, const std::vector<const T
einsum_compute_processor->SetDeviceHelpers(EinsumOp::DeviceHelpers::CudaDeviceHelpers::Transpose,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::MatMul<MLFloat16>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::ReduceSum<MLFloat16>,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy);
EinsumOp::DeviceHelpers::CudaDeviceHelpers::DataCopy,
EinsumOp::DeviceHelpers::CudaDeviceHelpers::Zeroing);
return einsum_compute_processor->Run();
}

Expand Down
Loading
Loading