ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
blas.h
Go to the documentation of this file.
1#ifndef ATEN_KERNELS_BLAS_H_
2#define ATEN_KERNELS_BLAS_H_
3
4#include <ATen/core/tensor.h>
6
8
9namespace container {
10namespace kernels {
11
12template <typename T, typename Device>
13struct blas_dot {
15 const int& n,
16 const T* x,
17 const int& incx,
18 const T* y,
19 const int& incy,
20 T* result);
21};
22
23
24template <typename T, typename Device>
25struct blas_scal {
27 const int& n,
28 const T* alpha,
29 T* x,
30 const int& incx);
31};
32
33
34template <typename T, typename Device>
35struct blas_axpy {
37 const int& n,
38 const T* alpha,
39 const T* x,
40 const int& incx,
41 T* y,
42 const int& incy);
43};
44
45
46template <typename T, typename Device>
47struct blas_gemv {
49 const char& trans,
50 const int& m,
51 const int& n,
52 const T* alpha,
53 const T* A,
54 const int& lda,
55 const T* x,
56 const int& incx,
57 const T* beta,
58 T* y,
59 const int& incy);
60};
61
62
63template <typename T, typename Device>
66 const char& trans,
67 const int& m,
68 const int& n,
69 const T* alpha,
70 T** A,
71 const int& lda,
72 T** x,
73 const int& incx,
74 const T* beta,
75 T** y,
76 const int& incy,
77 const int& batch_size);
78};
79
80
81template <typename T, typename Device>
84 const char& trans,
85 const int& m,
86 const int& n,
87 const T* alpha,
88 const T* A,
89 const int& lda,
90 const int64_t& stride_a,
91 const T* x,
92 const int& incx,
93 const int64_t& stride_x,
94 const T* beta,
95 T* y,
96 const int& incy,
97 const int64_t& stride_y,
98 const int& batch_size);
99};
100
101
102template <typename T, typename Device>
103struct blas_gemm {
105 const char& transa,
106 const char& transb,
107 const int& m,
108 const int& n,
109 const int& k,
110 const T* alpha,
111 const T* A,
112 const int& lda,
113 const T* B,
114 const int& ldb,
115 const T* beta,
116 T* C,
117 const int& ldc);
118};
119
120
121template <typename T, typename Device>
124 const char& transa,
125 const char& transb,
126 const int& m,
127 const int& n,
128 const int& k,
129 const T* alpha,
130 T** A,
131 const int& lda,
132 T** B,
133 const int& ldb,
134 const T* beta,
135 T** C,
136 const int& ldc,
137 const int& batch_size);
138};
139
140
141template <typename T, typename Device>
144 const char& transa,
145 const char& transb,
146 const int& m,
147 const int& n,
148 const int& k,
149 const T* alpha,
150 const T* A,
151 const int& lda,
152 const int& stride_a,
153 const T* B,
154 const int& ldb,
155 const int& stride_b,
156 const T* beta,
157 T* C,
158 const int& ldc,
159 const int& stride_c,
160 const int& batch_size);
161};
162
163#if __CUDA || __ROCM
164void createGpuBlasHandle(); // create blas handle
165void destroyGpuBlasHandle(); // destory blas handle
166#endif // __CUDA || __UT_USE_CUDA
167
168} // namespace kernels
169} // namespace container
170
171#endif // ATEN_KERNELS_BLAS_H_
#define T
Definition exp.cpp:237
Definition tensor.cpp:8
Definition blas.h:35
void operator()(const int &n, const T *alpha, const T *x, const int &incx, T *y, const int &incy)
Definition blas.h:13
void operator()(const int &n, const T *x, const int &incx, const T *y, const int &incy, T *result)
void operator()(const char &transa, const char &transb, const int &m, const int &n, const int &k, const T *alpha, const T *A, const int &lda, const int &stride_a, const T *B, const int &ldb, const int &stride_b, const T *beta, T *C, const int &ldc, const int &stride_c, const int &batch_size)
void operator()(const char &transa, const char &transb, const int &m, const int &n, const int &k, const T *alpha, T **A, const int &lda, T **B, const int &ldb, const T *beta, T **C, const int &ldc, const int &batch_size)
Definition blas.h:103
void operator()(const char &transa, const char &transb, const int &m, const int &n, const int &k, const T *alpha, const T *A, const int &lda, const T *B, const int &ldb, const T *beta, T *C, const int &ldc)
void operator()(const char &trans, const int &m, const int &n, const T *alpha, const T *A, const int &lda, const int64_t &stride_a, const T *x, const int &incx, const int64_t &stride_x, const T *beta, T *y, const int &incy, const int64_t &stride_y, const int &batch_size)
void operator()(const char &trans, const int &m, const int &n, const T *alpha, T **A, const int &lda, T **x, const int &incx, const T *beta, T **y, const int &incy, const int &batch_size)
Definition blas.h:47
void operator()(const char &trans, const int &m, const int &n, const T *alpha, const T *A, const int &lda, const T *x, const int &incx, const T *beta, T *y, const int &incy)
Definition blas.h:25
void operator()(const int &n, const T *alpha, T *x, const int &incx)
This file contains the definition of the DataType enum class.