ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
linalg_op.h
Go to the documentation of this file.
1#ifndef ATEN_OPS_LINALG_H_
2#define ATEN_OPS_LINALG_H_
3
4#include <ATen/core/tensor.h>
5
7
8namespace container {
9namespace op {
10
17struct add_op {
29 void operator()(
30 const Tensor& x,
31 const Tensor& y,
32 Tensor& z);
33
34 template <typename T>
35 void operator()(
36 const T& alpha,
37 const Tensor& x,
38 const T& beta,
39 const Tensor& y,
40 Tensor& z);
41};
42
43struct mul_op {
44 // z = x * y
45 void operator()(
46 const Tensor& x,
47 const Tensor& y,
48 Tensor& z);
49
50 // y = alpha * x
51 template <typename T>
52 void operator()(
53 const T& alpha,
54 const Tensor& x,
55 Tensor& y);
56};
57
58struct div_op {
59 // z = x / y
60 void operator()(
61 const Tensor& x,
62 const Tensor& y,
63 Tensor& z);
64};
65
66template <bool Conjugate = false>
87 void operator()(
88 const Tensor& input,
89 const std::vector<int>& permutation,
90 Tensor& output);
91};
92
103struct stride_op {
117 void operator()(
118 const Tensor& input,
119 const std::vector<int64_t>& stride,
120 Tensor& output);
121};
122
138 void operator()(
139 const Tensor& input,
140 const std::vector<int64_t>& stride,
141 Tensor& output);
142};
143
144
145struct reduce_op {
146 void operator()(
147 const Tensor& input,
148 const int64_t& inner_most_dim,
149 Tensor& output);
150};
151
152} // namespace op
153} // namespace container
154
155ct::Tensor operator+(const ct::Tensor& self, const ct::Tensor& other);
156ct::Tensor operator-(const ct::Tensor& self, const ct::Tensor& other);
157ct::Tensor operator*(const ct::Tensor& self, const ct::Tensor& other);
158ct::Tensor operator/(const ct::Tensor& self, const ct::Tensor& other);
159ct::Tensor& operator+=(ct::Tensor& self, const ct::Tensor& other);
160ct::Tensor& operator-=(ct::Tensor& self, const ct::Tensor& other);
161ct::Tensor& operator*=(ct::Tensor& self, const ct::Tensor& other);
162ct::Tensor& operator/=(ct::Tensor& self, const ct::Tensor& other);
163
164#endif // ATEN_OPS_LINALG_H_
A multi-dimensional array of elements of a single data type.
Definition tensor.h:32
Definition output.h:13
#define T
Definition exp.cpp:237
ct::Tensor operator*(const ct::Tensor &self, const ct::Tensor &other)
Definition linalg_op.cpp:181
ct::Tensor operator+(const ct::Tensor &self, const ct::Tensor &other)
Definition linalg_op.cpp:152
ct::Tensor operator/(const ct::Tensor &self, const ct::Tensor &other)
Definition linalg_op.cpp:191
ct::Tensor & operator+=(ct::Tensor &self, const ct::Tensor &other)
Definition linalg_op.cpp:201
ct::Tensor & operator-=(ct::Tensor &self, const ct::Tensor &other)
Definition linalg_op.cpp:209
ct::Tensor & operator*=(ct::Tensor &self, const ct::Tensor &other)
Definition linalg_op.cpp:228
ct::Tensor operator-(const ct::Tensor &self, const ct::Tensor &other)
Definition linalg_op.cpp:162
ct::Tensor & operator/=(ct::Tensor &self, const ct::Tensor &other)
Definition linalg_op.cpp:236
Definition tensor.cpp:8
A functor to perform add operation on a Tensor.
Definition linalg_op.h:17
void operator()(const Tensor &x, const Tensor &y, Tensor &z)
Perform add operation on the input Tensors.
Definition linalg_op.cpp:11
Definition linalg_op.h:58
void operator()(const Tensor &x, const Tensor &y, Tensor &z)
Definition linalg_op.cpp:76
A functor for inflating a tensor.
Definition linalg_op.h:128
void operator()(const Tensor &input, const std::vector< int64_t > &stride, Tensor &output)
Inflate the input tensor.
Definition linalg_op.cpp:116
Definition linalg_op.h:43
void operator()(const Tensor &x, const Tensor &y, Tensor &z)
Definition linalg_op.cpp:44
Definition linalg_op.h:145
void operator()(const Tensor &input, const int64_t &inner_most_dim, Tensor &output)
Definition linalg_op.cpp:127
A functor to perform stride operation on a Tensor.
Definition linalg_op.h:103
void operator()(const Tensor &input, const std::vector< int64_t > &stride, Tensor &output)
Perform stride operation on the input Tensor.
Definition linalg_op.cpp:105
Definition linalg_op.h:67
void operator()(const Tensor &input, const std::vector< int > &permutation, Tensor &output)
Perform the transpose operation on the input tensor.
Definition linalg_op.cpp:94
const std::map< std::string, std::vector< double > > op
Definition vdwd3_autoset_xcparam.cpp:372