ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
linalg.h
Go to the documentation of this file.
1#ifndef ATEN_KERNELS_LINALG_H_
2#define ATEN_KERNELS_LINALG_H_
3
4#include <ATen/core/tensor.h>
5
7
8namespace container {
9namespace kernels {
10
11template <typename T, typename Device>
12struct add {
13 // z = alpha * x + beta * y
14 void operator()(
15 const int& num_element,
16 const T& alpha,
17 const T* x,
18 const T& beta,
19 const T* y,
20 T* z);
21};
22
23template <typename T, typename Device>
24struct mul {
25 void operator()(
26 const int& num_element,
27 const T& alpha,
28 const T* x,
29 T* y);
30 // z = alpha * x * y
31 void operator()(
32 const int& num_element,
33 const T& alpha,
34 const T* x,
35 const T* y,
36 T* z);
37};
38
39template <typename T, typename Device>
40struct div {
41 // z = alpha * x / y
42 void operator()(
43 const int& num_element,
44 const T& alpha,
45 const T* x,
46 const T* y,
47 T* z);
48};
49
50template <typename T, typename Device>
51struct fma {
52 // out = alpha * x * y + beta * z
53 void operator()(
54 const int& num_element,
55 const T& alpha,
56 const T* x,
57 const T* y,
58 const T& beta,
59 const T* z,
60 T* out);
61};
62
63template <typename T, typename Device, bool Conjugate = false>
64struct transpose {
65 void operator()(
66 const std::vector<int>& perm,
67 const std::vector<int64_t>& p_shape,
68 const std::vector<int64_t>& q_shape,
69 const T* p,
70 T* q);
71};
72
73
74template <typename T, typename Device>
75struct stride {
76 void operator()(
77 const std::vector<int64_t>& stride,
78 const std::vector<int64_t>& p_shape,
79 const std::vector<int64_t>& q_shape,
80 const T* p,
81 T* q);
82};
83
84template <typename T, typename Device>
85struct inflate {
86 void operator()(
87 const std::vector<int64_t>& inflate,
88 const std::vector<int64_t>& p_shape,
89 const std::vector<int64_t>& q_shape,
90 const T* p,
91 T* q);
92};
93
94
95template <typename T, typename Device>
96struct reduce {
97 void operator()(
98 const int64_t& num_element,
99 const int64_t& inner_most_dim,
100 const T* p,
101 T* q);
102};
103
104
105} // namespace op
106} // namespace container
107
108#endif // ATEN_KERNELS_LINALG_H_
#define T
Definition exp.cpp:237
Definition tensor.cpp:8
Definition linalg.h:12
void operator()(const int &num_element, const T &alpha, const T *x, const T &beta, const T *y, T *z)
Definition linalg.cpp:32
Definition linalg.h:40
void operator()(const int &num_element, const T &alpha, const T *x, const T *y, T *z)
Definition linalg.cpp:62
Definition linalg.h:51
void operator()(const int &num_element, const T &alpha, const T *x, const T *y, const T &beta, const T *z, T *out)
Definition linalg.cpp:72
Definition linalg.h:85
void operator()(const std::vector< int64_t > &inflate, const std::vector< int64_t > &p_shape, const std::vector< int64_t > &q_shape, const T *p, T *q)
Definition linalg.cpp:169
Definition linalg.h:24
void operator()(const int &num_element, const T &alpha, const T *x, T *y)
Definition linalg.cpp:42
Definition linalg.h:96
void operator()(const int64_t &num_element, const int64_t &inner_most_dim, const T *p, T *q)
Definition linalg.cpp:215
Definition linalg.h:75
void operator()(const std::vector< int64_t > &stride, const std::vector< int64_t > &p_shape, const std::vector< int64_t > &q_shape, const T *p, T *q)
Definition linalg.cpp:128
Definition linalg.h:64
void operator()(const std::vector< int > &perm, const std::vector< int64_t > &p_shape, const std::vector< int64_t > &q_shape, const T *p, T *q)
Definition linalg.cpp:82