ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
einsum_op.h
Go to the documentation of this file.
1#ifndef ATEN_KERNELS_EINSUM_OP_H_
2#define ATEN_KERNELS_EINSUM_OP_H_
3
4#include <ATen/core/tensor.h>
6
7namespace container {
8
10 bool conj_x = false;
11 bool conj_y = false;
12 float alpha = 1.0;
13 float beta = 0.0;
14 Tensor* out = nullptr;
15
16 EinsumOption(bool conj_x_ = false, bool conj_y_ = false, float alpha_ = 1.0, float beta_ = 0.0, Tensor* out_ = nullptr)
17 : conj_x(conj_x_), conj_y(conj_y_), alpha(alpha_), beta(beta_), out(out_) {}
18};
19
20namespace einsum_utils {
21struct BCast;
22
23// Dummy axis label used to denote an ellipsis in an input or output subscript.
24constexpr int kEllipsisLabel = -1;
25
26// Each dimension is categorized into exactly one of five types based on
27// whether its corresponding label is present in the input and/or the output
28// subscripts.
30 // Batch dimensions are those present in two inputs as well as the output.
31 // They are part of the batch dimensions during Tensor contraction. Such
32 // dimensions may be broadcasting dimensions (those mapping to ellipsis)
33 // or explicit batch dimensions corresponding to named axis labels.
35 kBatch = 1,
36 // Free dimensions are present in exactly one of the inputs, and also the
37 // output. These are non-contracted axes in the Tensor contraction.
38 kFree = 2,
39 // Contract dimensions are present in two inputs, but not the output. These
40 // dimensions are contracted in Tensor contraction.
42 // Reduce dimensions are present in exactly one input; and not in the output
43 // and are summed over prior to Tensor contraction.
45};
46
47// Parses and validates an einsum equation in explicit form.
49 const std::string& equation,
50 std::vector<std::string>& input_subscripts,
51 std::string& output_subscript);
52
53// Parses and validates the equation and the input shapes. Single character
54// labels are integerized, and we populate input and output label subscripts
55// and corresponding counts. Also create the mapping from (named) labels to
56// their EinsumDimensionType.
58 const std::string& equation,
59 std::vector<EinsumDimensionType>& label_types,
60 std::vector<std::vector<int>>& input_labels,
61 std::vector<int>& output_labels,
62 std::vector<std::vector<int>>& input_label_counts,
63 std::vector<int>& output_label_counts,
64 std::vector<bool>& input_has_ellipsis,
65 bool& output_has_ellipsis);
66
68 const std::vector<const Tensor*>& inputs,
69 std::vector<EinsumDimensionType>& label_types,
70 std::vector<std::vector<int>>& input_labels,
71 std::vector<int>& output_labels,
72 std::vector<std::vector<int>>& input_label_counts,
73 std::vector<int>& output_label_counts,
74 const std::vector<bool>& input_has_ellipsis,
75 const bool output_has_ellipsis,
76 std::unordered_map<int, int64_t>& label_to_dim_sizes);
77
78// This function records the mapping of a label to its corresponding dimension for a specific axis in the input tensor.
79// It also validates that the label and dimension mapping is consistent with previous recordings, ensuring that the
80// same label is not mapped to different dimensions along different axes.
82 const int label,
83 const int axis,
84 const Tensor& input,
85 std::unordered_map<int, int64_t>& label_to_dim_sizes);
86
87bool ReduceOperand(
88 const Tensor& input,
89 const std::vector<EinsumDimensionType>& label_types,
90 std::vector<int>& labels,
91 const std::vector<int>& label_counts,
92 std::vector<int>& free_labels,
93 int& swap_free_and_contract,
94 Tensor& output);
95
106 std::vector<Tensor>& inputs,
107 const std::vector<int>& swap_free_and_contract,
108 const EinsumOption& option,
109 Tensor& output);
110
111void ProcessOutput(
112 const Tensor& input,
113 const std::vector<einsum_utils::EinsumDimensionType>& label_types,
114 const std::vector<std::vector<int>>& free_labels,
115 std::unordered_map<int, int64_t>& label_to_dim_sizes,
116 const std::vector<int>& output_labels,
117 const std::vector<int>& output_label_counts,
118 Tensor& output);
119
120} // namespace einsum_utils
121
122namespace op {
123
124// TODO: implement this method this week!
125// piapia pat face
126
141template <typename... Tensors>
142typename std::enable_if<std::is_same<
143 typename std::common_type<Tensors...>::type, Tensor>::value, Tensor>::type
144 einsum_impl(const std::string& equation, const EinsumOption& option, const Tensors&... tensors)
145{
146 // Check the input dimension
147 constexpr int num_inputs = sizeof...(Tensors);
148 if (num_inputs > 2) {
149 throw std::invalid_argument("Einstein notation only support two or less tensors!");
150 }
151 const std::vector<const Tensor*> inputs{reinterpret_cast<const Tensor*>(&tensors)...};
152 // Init the input and output labels
153 std::vector<std::vector<int>> input_labels = {};
154 std::vector<int> output_labels = {};
155 std::vector<einsum_utils::EinsumDimensionType> label_types = {};
156 std::vector<std::vector<int>> input_label_counts = {};
157 std::vector<int> output_label_counts = {};
158 std::vector<bool> input_has_ellipsis = {};
159 bool output_has_ellipsis = {};
160
162 equation, label_types,
163 input_labels, output_labels,
164 input_label_counts, output_label_counts,
165 input_has_ellipsis, output_has_ellipsis);
166
167 if (input_labels.size() != num_inputs) {
168 throw std::runtime_error("The number of input tensors does not match the number of input labels!");
169 }
170
171 std::unordered_map<int, int64_t> label_to_dim_sizes = {};
172
174 inputs, label_types,
175 input_labels, output_labels,
176 input_label_counts, output_label_counts,
177 input_has_ellipsis, output_has_ellipsis,
178 label_to_dim_sizes);
179
180 std::vector<std::vector<int>> free_labels(num_inputs);
181 std::vector<int> swap_free_and_contract(num_inputs);
182 std::vector<Tensor> inputs_reduced(num_inputs, Tensor(DataType::DT_FLOAT, {}));
183
184 for (int ii = 0; ii < num_inputs; ++ii) {
186 *inputs[ii], label_types,
187 input_labels[ii], input_label_counts[ii],
188 free_labels[ii], swap_free_and_contract[ii],
189 inputs_reduced[ii]);
190 }
191
192 // After reduction, the inputs should be reshaped to Tensors suitable for
193 // contraction. If num_inputs is 1, the reduced input is simply forwarded to
194 // the output.
195 Tensor contraction_output_reshaped;
197 inputs_reduced, swap_free_and_contract,
198 option, contraction_output_reshaped);
199
201 // Copy the batch labels from the contraction output. Recover the batch
202 // shape, which may have been broadcasted.
204 contraction_output_reshaped, label_types,
205 free_labels, label_to_dim_sizes,
206 output_labels, output_label_counts,
207 output);
208
209 return std::move(output);
210}
211
212// Make the conj params only works for the matmul equations.
213inline static Tensor einsum(const std::string& equation, const Tensor& A) {
214 const EinsumOption& option = {};
215 return std::move(op::einsum_impl(equation, option, A));
216}
217
218inline static Tensor einsum(const std::string& equation, const Tensor& A, const Tensor& B, const EinsumOption& option = {}) {
219 return std::move(op::einsum_impl(equation, option, A, B));
220}
221
222} // namespace op
223} // namespace container
224
225#endif // ATEN_KERNELS_EINSUM_OP_H_
A multi-dimensional array of elements of a single data type.
Definition tensor.h:32
Definition output.h:13
bool ValidateEinsumEquation(const std::string &equation, std::vector< std::string > &input_subscripts, std::string &output_subscript)
Check the validation of the input equations.
Definition einsum_op.cpp:442
bool ReduceOperand(const Tensor &input, const std::vector< EinsumDimensionType > &label_types, std::vector< int > &labels, const std::vector< int > &label_counts, std::vector< int > &free_labels, int &swap_free_and_contract, Tensor &output)
Definition einsum_op.cpp:689
constexpr int kEllipsisLabel
Definition einsum_op.h:24
bool ParseEinsumEquation(const std::string &equation, std::vector< EinsumDimensionType > &label_types, std::vector< std::vector< int > > &input_labels, std::vector< int > &output_labels, std::vector< std::vector< int > > &input_label_counts, std::vector< int > &output_label_counts, std::vector< bool > &input_has_ellipsis, bool &output_has_ellipsis)
Definition einsum_op.cpp:507
bool ContractOperands(std::vector< Tensor > &inputs, const std::vector< int > &swap_free_and_contract, const EinsumOption &option, Tensor &output)
A function to perform contraction operation on multiple Tensors.
Definition einsum_op.cpp:918
EinsumDimensionType
Definition einsum_op.h:29
@ kFree
Definition einsum_op.h:38
@ kBatch
Definition einsum_op.h:35
@ kBroadcasting
Definition einsum_op.h:34
@ kReduce
Definition einsum_op.h:44
@ kContract
Definition einsum_op.h:41
bool ProcessDimensions(const std::vector< const Tensor * > &inputs, std::vector< EinsumDimensionType > &label_types, std::vector< std::vector< int > > &input_labels, std::vector< int > &output_labels, std::vector< std::vector< int > > &input_label_counts, std::vector< int > &output_label_counts, const std::vector< bool > &input_has_ellipsis, const bool output_has_ellipsis, std::unordered_map< int, int64_t > &label_to_dim_sizes)
Definition einsum_op.cpp:588
bool RecordLabelToDimension(const int label, const int axis, const Tensor &input, std::unordered_map< int, int64_t > &label_to_dim_sizes)
Definition einsum_op.cpp:572
void ProcessOutput(const Tensor &input, const std::vector< einsum_utils::EinsumDimensionType > &label_types, const std::vector< std::vector< int > > &free_labels, std::unordered_map< int, int64_t > &label_to_dim_sizes, const std::vector< int > &output_labels, const std::vector< int > &output_label_counts, Tensor &output)
Definition einsum_op.cpp:968
std::enable_if< std::is_same< typenamestd::common_type< Tensors... >::type, Tensor >::value, Tensor >::type einsum_impl(const std::string &equation, const EinsumOption &option, const Tensors &... tensors)
Computes the Einstein summation convention on the given tensors based on the expression passed as a s...
Definition einsum_op.h:144
Definition tensor.cpp:8
@ DT_FLOAT
Single-precision floating point *‍/.
Definition einsum_op.h:9
Tensor * out
Definition einsum_op.h:14
bool conj_y
Definition einsum_op.h:11
float alpha
Definition einsum_op.h:12
EinsumOption(bool conj_x_=false, bool conj_y_=false, float alpha_=1.0, float beta_=0.0, Tensor *out_=nullptr)
Definition einsum_op.h:16
float beta
Definition einsum_op.h:13
bool conj_x
Definition einsum_op.h:10
const std::map< std::string, std::vector< double > > op
Definition vdwd3_autoset_xcparam.cpp:372