1#ifndef ATEN_KERNELS_EINSUM_OP_H_
2#define ATEN_KERNELS_EINSUM_OP_H_
16 EinsumOption(
bool conj_x_ =
false,
bool conj_y_ =
false,
float alpha_ = 1.0,
float beta_ = 0.0,
Tensor* out_ =
nullptr)
20namespace einsum_utils {
49 const std::string& equation,
50 std::vector<std::string>& input_subscripts,
51 std::string& output_subscript);
58 const std::string& equation,
59 std::vector<EinsumDimensionType>& label_types,
60 std::vector<std::vector<int>>& input_labels,
61 std::vector<int>& output_labels,
62 std::vector<std::vector<int>>& input_label_counts,
63 std::vector<int>& output_label_counts,
64 std::vector<bool>& input_has_ellipsis,
65 bool& output_has_ellipsis);
68 const std::vector<const Tensor*>& inputs,
69 std::vector<EinsumDimensionType>& label_types,
70 std::vector<std::vector<int>>& input_labels,
71 std::vector<int>& output_labels,
72 std::vector<std::vector<int>>& input_label_counts,
73 std::vector<int>& output_label_counts,
74 const std::vector<bool>& input_has_ellipsis,
75 const bool output_has_ellipsis,
76 std::unordered_map<int, int64_t>& label_to_dim_sizes);
85 std::unordered_map<int, int64_t>& label_to_dim_sizes);
89 const std::vector<EinsumDimensionType>& label_types,
90 std::vector<int>& labels,
91 const std::vector<int>& label_counts,
92 std::vector<int>& free_labels,
93 int& swap_free_and_contract,
106 std::vector<Tensor>& inputs,
107 const std::vector<int>& swap_free_and_contract,
113 const std::vector<einsum_utils::EinsumDimensionType>& label_types,
114 const std::vector<std::vector<int>>& free_labels,
115 std::unordered_map<int, int64_t>& label_to_dim_sizes,
116 const std::vector<int>& output_labels,
117 const std::vector<int>& output_label_counts,
141template <
typename... Tensors>
142typename std::enable_if<std::is_same<
143 typename std::common_type<Tensors...>::type,
Tensor>::value,
Tensor>::type
147 constexpr int num_inputs =
sizeof...(Tensors);
148 if (num_inputs > 2) {
149 throw std::invalid_argument(
"Einstein notation only support two or less tensors!");
151 const std::vector<const Tensor*> inputs{
reinterpret_cast<const Tensor*
>(&tensors)...};
153 std::vector<std::vector<int>> input_labels = {};
154 std::vector<int> output_labels = {};
155 std::vector<einsum_utils::EinsumDimensionType> label_types = {};
156 std::vector<std::vector<int>> input_label_counts = {};
157 std::vector<int> output_label_counts = {};
158 std::vector<bool> input_has_ellipsis = {};
159 bool output_has_ellipsis = {};
162 equation, label_types,
163 input_labels, output_labels,
164 input_label_counts, output_label_counts,
165 input_has_ellipsis, output_has_ellipsis);
167 if (input_labels.size() != num_inputs) {
168 throw std::runtime_error(
"The number of input tensors does not match the number of input labels!");
171 std::unordered_map<int, int64_t> label_to_dim_sizes = {};
175 input_labels, output_labels,
176 input_label_counts, output_label_counts,
177 input_has_ellipsis, output_has_ellipsis,
180 std::vector<std::vector<int>> free_labels(num_inputs);
181 std::vector<int> swap_free_and_contract(num_inputs);
184 for (
int ii = 0; ii < num_inputs; ++ii) {
186 *inputs[ii], label_types,
187 input_labels[ii], input_label_counts[ii],
188 free_labels[ii], swap_free_and_contract[ii],
195 Tensor contraction_output_reshaped;
197 inputs_reduced, swap_free_and_contract,
198 option, contraction_output_reshaped);
204 contraction_output_reshaped, label_types,
205 free_labels, label_to_dim_sizes,
206 output_labels, output_label_counts,
213inline static Tensor einsum(
const std::string& equation,
const Tensor& A) {
218inline static Tensor einsum(
const std::string& equation,
const Tensor& A,
const Tensor& B,
const EinsumOption& option = {}) {
A multi-dimensional array of elements of a single data type.
Definition tensor.h:32
bool ValidateEinsumEquation(const std::string &equation, std::vector< std::string > &input_subscripts, std::string &output_subscript)
Check the validation of the input equations.
Definition einsum_op.cpp:442
bool ReduceOperand(const Tensor &input, const std::vector< EinsumDimensionType > &label_types, std::vector< int > &labels, const std::vector< int > &label_counts, std::vector< int > &free_labels, int &swap_free_and_contract, Tensor &output)
Definition einsum_op.cpp:689
constexpr int kEllipsisLabel
Definition einsum_op.h:24
bool ParseEinsumEquation(const std::string &equation, std::vector< EinsumDimensionType > &label_types, std::vector< std::vector< int > > &input_labels, std::vector< int > &output_labels, std::vector< std::vector< int > > &input_label_counts, std::vector< int > &output_label_counts, std::vector< bool > &input_has_ellipsis, bool &output_has_ellipsis)
Definition einsum_op.cpp:507
bool ContractOperands(std::vector< Tensor > &inputs, const std::vector< int > &swap_free_and_contract, const EinsumOption &option, Tensor &output)
A function to perform contraction operation on multiple Tensors.
Definition einsum_op.cpp:918
EinsumDimensionType
Definition einsum_op.h:29
@ kFree
Definition einsum_op.h:38
@ kBatch
Definition einsum_op.h:35
@ kBroadcasting
Definition einsum_op.h:34
@ kReduce
Definition einsum_op.h:44
@ kContract
Definition einsum_op.h:41
bool ProcessDimensions(const std::vector< const Tensor * > &inputs, std::vector< EinsumDimensionType > &label_types, std::vector< std::vector< int > > &input_labels, std::vector< int > &output_labels, std::vector< std::vector< int > > &input_label_counts, std::vector< int > &output_label_counts, const std::vector< bool > &input_has_ellipsis, const bool output_has_ellipsis, std::unordered_map< int, int64_t > &label_to_dim_sizes)
Definition einsum_op.cpp:588
bool RecordLabelToDimension(const int label, const int axis, const Tensor &input, std::unordered_map< int, int64_t > &label_to_dim_sizes)
Definition einsum_op.cpp:572
void ProcessOutput(const Tensor &input, const std::vector< einsum_utils::EinsumDimensionType > &label_types, const std::vector< std::vector< int > > &free_labels, std::unordered_map< int, int64_t > &label_to_dim_sizes, const std::vector< int > &output_labels, const std::vector< int > &output_label_counts, Tensor &output)
Definition einsum_op.cpp:968
std::enable_if< std::is_same< typenamestd::common_type< Tensors... >::type, Tensor >::value, Tensor >::type einsum_impl(const std::string &equation, const EinsumOption &option, const Tensors &... tensors)
Computes the Einstein summation convention on the given tensors based on the expression passed as a s...
Definition einsum_op.h:144
@ DT_FLOAT
Single-precision floating point */.
Tensor * out
Definition einsum_op.h:14
bool conj_y
Definition einsum_op.h:11
float alpha
Definition einsum_op.h:12
EinsumOption(bool conj_x_=false, bool conj_y_=false, float alpha_=1.0, float beta_=0.0, Tensor *out_=nullptr)
Definition einsum_op.h:16
float beta
Definition einsum_op.h:13
bool conj_x
Definition einsum_op.h:10
const std::map< std::string, std::vector< double > > op
Definition vdwd3_autoset_xcparam.cpp:372