1#ifndef ATEN_KERNELS_LINALG_H_
2#define ATEN_KERNELS_LINALG_H_
11template <
typename T,
typename Device>
15 const int& num_element,
23template <
typename T,
typename Device>
26 const int& num_element,
32 const int& num_element,
39template <
typename T,
typename Device>
43 const int& num_element,
50template <
typename T,
typename Device>
54 const int& num_element,
63template <
typename T,
typename Device,
bool Conjugate = false>
66 const std::vector<int>& perm,
67 const std::vector<int64_t>& p_shape,
68 const std::vector<int64_t>& q_shape,
74template <
typename T,
typename Device>
77 const std::vector<int64_t>&
stride,
78 const std::vector<int64_t>& p_shape,
79 const std::vector<int64_t>& q_shape,
84template <
typename T,
typename Device>
87 const std::vector<int64_t>&
inflate,
88 const std::vector<int64_t>& p_shape,
89 const std::vector<int64_t>& q_shape,
95template <
typename T,
typename Device>
98 const int64_t& num_element,
99 const int64_t& inner_most_dim,
#define T
Definition exp.cpp:237
void operator()(const int &num_element, const T &alpha, const T *x, const T &beta, const T *y, T *z)
Definition linalg.cpp:32
void operator()(const int &num_element, const T &alpha, const T *x, const T *y, T *z)
Definition linalg.cpp:62
void operator()(const int &num_element, const T &alpha, const T *x, const T *y, const T &beta, const T *z, T *out)
Definition linalg.cpp:72
void operator()(const std::vector< int64_t > &inflate, const std::vector< int64_t > &p_shape, const std::vector< int64_t > &q_shape, const T *p, T *q)
Definition linalg.cpp:169
void operator()(const int &num_element, const T &alpha, const T *x, T *y)
Definition linalg.cpp:42
void operator()(const int64_t &num_element, const int64_t &inner_most_dim, const T *p, T *q)
Definition linalg.cpp:215
void operator()(const std::vector< int64_t > &stride, const std::vector< int64_t > &p_shape, const std::vector< int64_t > &q_shape, const T *p, T *q)
Definition linalg.cpp:128
void operator()(const std::vector< int > &perm, const std::vector< int64_t > &p_shape, const std::vector< int64_t > &q_shape, const T *p, T *q)
Definition linalg.cpp:82