ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
blas.h
Go to the documentation of this file.
1#ifndef BASE_THIRD_PARTY_BLAS_H_
2#define BASE_THIRD_PARTY_BLAS_H_
3
4#include <complex>
5
6#if defined(__CUDA)
8#elif defined(__ROCM)
10#endif
11
12extern "C"
13{
14// level 1: std::vector-std::vector operations, O(n) data and O(n) work.
15
16// Peize Lin add ?scal 2016-08-04, to compute x=a*x
17void sscal_(const int *N, const float *alpha, float *x, const int *incx);
18void dscal_(const int *N, const double *alpha, double *x, const int *incx);
19void cscal_(const int *N, const std::complex<float> *alpha, std::complex<float> *x, const int *incx);
20void zscal_(const int *N, const std::complex<double> *alpha, std::complex<double> *x, const int *incx);
21
22// Peize Lin add ?axpy 2016-08-04, to compute y=a*x+y
23void saxpy_(const int *N, const float *alpha, const float *x, const int *incx, float *y, const int *incy);
24void daxpy_(const int *N, const double *alpha, const double *x, const int *incx, double *y, const int *incy);
25void caxpy_(const int *N, const std::complex<float> *alpha, const std::complex<float> *x, const int *incx, std::complex<float> *y, const int *incy);
26void zaxpy_(const int *N, const std::complex<double> *alpha, const std::complex<double> *x, const int *incx, std::complex<double> *y, const int *incy);
27
28void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy);
29void zcopy_(long const *n, const std::complex<double> *a, int const *incx, std::complex<double> *b, int const *incy);
30
31//reason for passing results as argument instead of returning it:
32//see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/
33void cdotc_(const int *n, const std::complex<float> *zx, const int *incx,
34 const std::complex<float> *zy, const int *incy, std::complex<float> *result);
35void zdotc_(const int *n, const std::complex<double> *zx, const int *incx,
36 const std::complex<double> *zy, const int *incy, std::complex<double> *result);
37// Peize Lin add ?dot 2017-10-27, to compute d=x*y
38float sdot_(const int *N, const float *x, const int *incx, const float *y, const int *incy);
39double ddot_(const int *N, const double *x, const int *incx, const double *y, const int *incy);
40
41// Peize Lin add ?nrm2 2018-06-12, to compute out = ||x||_2 = \sqrt{ \sum_i x_i**2 }
42float snrm2_( const int *n, const float *x, const int *incx );
43double dnrm2_( const int *n, const double *x, const int *incx );
44double dznrm2_( const int *n, const std::complex<double> *x, const int *incx );
45
46// level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work.
47void sgemv_(const char*const transa, const int*const m, const int*const n,
48 const float*const alpha, const float*const a, const int*const lda, const float*const x, const int*const incx,
49 const float*const eta, float*const y, const int*const incy);
50void dgemv_(const char*const transa, const int*const m, const int*const n,
51 const double*const alpha, const double*const a, const int*const lda, const double*const x, const int*const incx,
52 const double*const beta, double*const y, const int*const incy);
53
54void cgemv_(const char *trans, const int *m, const int *n, const std::complex<float> *alpha,
55 const std::complex<float> *a, const int *lda, const std::complex<float> *x, const int *incx,
56 const std::complex<float> *beta, std::complex<float> *y, const int *incy);
57
58void zgemv_(const char *trans, const int *m, const int *n, const std::complex<double> *alpha,
59 const std::complex<double> *a, const int *lda, const std::complex<double> *x, const int *incx,
60 const std::complex<double> *beta, std::complex<double> *y, const int *incy);
61
62void dsymv_(const char *uplo, const int *n,
63 const double *alpha, const double *a, const int *lda,
64 const double *x, const int *incx,
65 const double *beta, double *y, const int *incy);
66
67// A := alpha x * y.T + A
68void dger_(const int* m,
69 const int* n,
70 const double* alpha,
71 const double* x,
72 const int* incx,
73 const double* y,
74 const int* incy,
75 double* a,
76 const int* lda);
77void zgerc_(const int* m,
78 const int* n,
79 const std::complex<double>* alpha,
80 const std::complex<double>* x,
81 const int* incx,
82 const std::complex<double>* y,
83 const int* incy,
84 std::complex<double>* a,
85 const int* lda);
86
87// level 3: matrix-matrix operations, O(n^2) data and O(n^3) work.
88
89// Peize Lin add ?gemm 2017-10-27, to compute C = a * A.? * B.? + b * C
90// A is general
91void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
92 const float *alpha, const float *a, const int *lda, const float *b, const int *ldb,
93 const float *beta, float *c, const int *ldc);
94void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
95 const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
96 const double *beta, double *c, const int *ldc);
97void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
98 const std::complex<float> *alpha, const std::complex<float> *a, const int *lda, const std::complex<float> *b, const int *ldb,
99 const std::complex<float> *beta, std::complex<float> *c, const int *ldc);
100void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
101 const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
102 const std::complex<double> *beta, std::complex<double> *c, const int *ldc);
103
104
105//a is symmetric
106void dsymm_(const char *side, const char *uplo, const int *m, const int *n,
107 const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
108 const double *beta, double *c, const int *ldc);
109//a is hermitian
110void zhemm_(char *side, char *uplo, int *m, int *n,std::complex<double> *alpha,
111 std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, std::complex<double> *beta, std::complex<double> *c, int *ldc);
112
113//solving triangular matrix with multiple right hand sides
114void dtrsm_(char *side, char* uplo, char *transa, char *diag, int *m, int *n,
115 double* alpha, double* a, int *lda, double*b, int *ldb);
116void ztrsm_(char *side, char* uplo, char *transa, char *diag, int *m, int *n,
117 std::complex<double>* alpha, std::complex<double>* a, int *lda, std::complex<double>*b, int *ldb);
118
119}
120
121namespace container {
122
123// Class BlasConnector provide the connector to fortran lapack routine.
124// The entire function in this class are static and inline function.
125// Usage example: BlasConnector::functionname(parameter list).
126namespace BlasConnector {
127
128static inline
129void axpy( const int& n, const float& alpha, const float *x, const int& incx, float *y, const int& incy)
130{
131 saxpy_(&n, &alpha, x, &incx, y, &incy);
132}
133static inline
134void axpy( const int& n, const double& alpha, const double *x, const int& incx, double *y, const int& incy)
135{
136 daxpy_(&n, &alpha, x, &incx, y, &incy);
137}
138static inline
139void axpy( const int& n, const std::complex<float>& alpha, const std::complex<float> *x, const int& incx, std::complex<float> *y, const int& incy)
140{
141 caxpy_(&n, &alpha, x, &incx, y, &incy);
142}
143static inline
144void axpy( const int& n, const std::complex<double>& alpha, const std::complex<double> *x, const int& incx, std::complex<double> *y, const int& incy)
145{
146 zaxpy_(&n, &alpha, x, &incx, y, &incy);
147}
148
149// Peize Lin add 2016-08-04
150// x=a*x
151static inline
152void scal( const int& n, const float& alpha, float *x, const int& incx)
153{
154 sscal_(&n, &alpha, x, &incx);
155}
156static inline
157void scal( const int& n, const double& alpha, double *x, const int& incx)
158{
159 dscal_(&n, &alpha, x, &incx);
160}
161static inline
162void scal( const int& n, const std::complex<float>& alpha, std::complex<float> *x, const int& incx)
163{
164 cscal_(&n, &alpha, x, &incx);
165}
166static inline
167void scal( const int& n, const std::complex<double>& alpha, std::complex<double> *x, const int& incx)
168{
169 zscal_(&n, &alpha, x, &incx);
170}
171
172// Peize Lin add 2017-10-27
173// d=x*y
174static inline
175float dot( const int& n, const float *x, const int& incx, const float *y, const int& incy)
176{
177 return sdot_(&n, x, &incx, y, &incy);
178}
179static inline
180double dot( const int& n, const double *x, const int& incx, const double *y, const int& incy)
181{
182 return ddot_(&n, x, &incx, y, &incy);
183}
184// Denghui Lu add 2023-8-01
185static inline
186std::complex<float> dot(const int& n, const std::complex<float> *x, const int& incx, const std::complex<float> *y, const int& incy)
187{
188 std::complex<float> result = {0, 0};
189 // cdotc_(&n, x, &incx, y, &incy, &result);
190 for (int ii = 0; ii < n; ii++) {
191 result += std::conj(x[ii * incx]) * y[ii * incy];
192 }
193 return result;
194}
195static inline
196std::complex<double> dot(const int& n, const std::complex<double> *x, const int& incx, const std::complex<double> *y, const int& incy)
197{
198 std::complex<double> result = {0, 0};
199 // zdotc_(&n, x, &incx, y, &incy, &result);
200 for (int ii = 0; ii < n; ii++) {
201 result += std::conj(x[ii * incx]) * y[ii * incy];
202 }
203 return result;
204}
205
206// Peize Lin add 2017-10-27, fix bug trans 2019-01-17
207// C = a * A.? * B.? + b * C
208static inline
209void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k,
210 const float& alpha, const float* A, const int& lda, const float* B, const int& ldb,
211 const float& beta, float* C, const int& ldc)
212{
213 sgemm_(&transa, &transb, &m, &n, &k,
214 &alpha, A, &lda, B, &ldb,
215 &beta, C, &ldc);
216}
217static inline
218void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k,
219 const double& alpha, const double* A, const int& lda, const double* B, const int& ldb,
220 const double& beta, double* C, const int& ldc)
221{
222 dgemm_(&transa, &transb, &m, &n, &k,
223 &alpha, A, &lda, B, &ldb,
224 &beta, C, &ldc);
225}
226static inline
227void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k,
228 const std::complex<float>& alpha, const std::complex<float>* A, const int& lda, const std::complex<float>* B, const int& ldb,
229 const std::complex<float>& beta, std::complex<float>* C, const int& ldc)
230{
231 cgemm_(&transa, &transb, &m, &n, &k,
232 &alpha, A, &lda, B, &ldb,
233 &beta, C, &ldc);
234}
235static inline
236void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k,
237 const std::complex<double>& alpha, const std::complex<double>* A, const int& lda, const std::complex<double>* B, const int& ldb,
238 const std::complex<double>& beta, std::complex<double>* C, const int& ldc)
239{
240 zgemm_(&transa, &transb, &m, &n, &k,
241 &alpha, A, &lda, B, &ldb,
242 &beta, C, &ldc);
243}
244
245template <typename T>
246static inline
247void gemm_batched(const char& transa, const char& transb, const int& m, const int& n, const int& k,
248 const T& alpha, T** A, const int& lda, T** B, const int& ldb,
249 const T& beta, T** C, const int& ldc, const int& batch_size)
250{
251 for (int ii = 0; ii < batch_size; ++ii) {
252 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
253 BlasConnector::gemm(transa, transb, m, n, k, alpha, A[ii], lda, B[ii], ldb, beta, C[ii], ldc);
254 }
255}
256
257template <typename T>
258static inline
259void gemm_batched_strided(const char& transa, const char& transb, const int& m, const int& n, const int& k,
260 const T& alpha, const T* A, const int& lda, const int& stride_a, const T* B, const int& ldb, const int& stride_b,
261 const T& beta, T* C, const int& ldc, const int& stride_c, const int& batch_size)
262{
263 for (int ii = 0; ii < batch_size; ii++) {
264 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
265 BlasConnector::gemm(transa, transb, m, n, k, alpha, A + ii * stride_a, lda, B + ii * stride_b, ldb, beta, C + ii * stride_c, ldc);
266 }
267}
268
269static inline
270void gemv(const char& trans, const int& m, const int& n,
271 const float& alpha, const float *A, const int& lda, const float *x, const int& incx,
272 const float& beta, float *y, const int& incy)
273{
274 sgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy);
275}
276static inline
277void gemv(const char& trans, const int& m, const int& n,
278 const double& alpha, const double *A, const int& lda, const double *x, const int& incx,
279 const double& beta, double *y, const int& incy)
280{
281 dgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy);
282}
283static inline
284void gemv(const char& trans, const int& m, const int& n,
285 const std::complex<float>& alpha, const std::complex<float> *A, const int& lda, const std::complex<float> *x, const int& incx,
286 const std::complex<float>& beta, std::complex<float> *y, const int& incy)
287{
288 cgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy);
289}
290static inline
291void gemv(const char& trans, const int& m, const int& n,
292 const std::complex<double>& alpha, const std::complex<double> *A, const int& lda, const std::complex<double> *x, const int& incx,
293 const std::complex<double>& beta, std::complex<double> *y, const int& incy)
294{
295 zgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy);
296}
297
298template <typename T>
299static inline
300void gemv_batched(const char& trans, const int& m, const int& n,
301 const T& alpha, T** A, const int& lda, T** x, const int& incx,
302 const T& beta, T** y, const int& incy, const int& batch_size)
303{
304 for (int ii = 0; ii < batch_size; ++ii) {
305 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
306 BlasConnector::gemv(trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy);
307 }
308}
309
310template <typename T>
311static inline
312void gemv_batched_strided(const char& transa, const int& m, const int& n,
313 const T& alpha, const T* A, const int& lda, const int& stride_a, const T* x, const int& incx, const int& stride_x,
314 const T& beta, T* y, const int& incy, const int& stride_y, const int& batch_size)
315{
316 for (int ii = 0; ii < batch_size; ii++) {
317 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
318 BlasConnector::gemv(transa, m, n, alpha, A + ii * stride_a, lda, x + ii * stride_x, incx, beta, y + ii * stride_y, incy);
319 }
320}
321
322// Peize Lin add 2018-06-12
323// out = ||x||_2
324static inline
325float nrm2( const int n, const float *x, const int incx )
326{
327 return snrm2_( &n, x, &incx );
328}
329static inline
330double nrm2( const int n, const double *x, const int incx )
331{
332 return dnrm2_( &n, x, &incx );
333}
334static inline
335double nrm2( const int n, const std::complex<double> *x, const int incx )
336{
337 return dznrm2_( &n, x, &incx );
338}
339
340// copies a into b
341static inline
342void copy(const long n, const double *a, const int incx, double *b, const int incy)
343{
344 dcopy_(&n, a, &incx, b, &incy);
345}
346static inline
347void copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy)
348{
349 zcopy_(&n, a, &incx, b, &incy);
350}
351
352} // namespace BlasConnector
353} // namespace container
354
355#endif // BASE_THIRD_PARTY_BLAS_H_
void cdotc_(const int *n, const std::complex< float > *zx, const int *incx, const std::complex< float > *zy, const int *incy, std::complex< float > *result)
void saxpy_(const int *N, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
void sgemv_(const char *const transa, const int *const m, const int *const n, const float *const alpha, const float *const a, const int *const lda, const float *const x, const int *const incx, const float *const eta, float *const y, const int *const incy)
void zgerc_(const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *x, const int *incx, const std::complex< double > *y, const int *incy, std::complex< double > *a, const int *lda)
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc)
void daxpy_(const int *N, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
void dgemv_(const char *const transa, const int *const m, const int *const n, const double *const alpha, const double *const a, const int *const lda, const double *const x, const int *const incx, const double *const beta, double *const y, const int *const incy)
void dsymm_(const char *side, const char *uplo, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
float sdot_(const int *N, const float *x, const int *incx, const float *y, const int *incy)
void zscal_(const int *N, const std::complex< double > *alpha, std::complex< double > *x, const int *incx)
void dger_(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *a, const int *lda)
void zdotc_(const int *n, const std::complex< double > *zx, const int *incx, const std::complex< double > *zy, const int *incy, std::complex< double > *result)
void dtrsm_(char *side, char *uplo, char *transa, char *diag, int *m, int *n, double *alpha, double *a, int *lda, double *b, int *ldb)
void caxpy_(const int *N, const std::complex< float > *alpha, const std::complex< float > *x, const int *incx, std::complex< float > *y, const int *incy)
void zhemm_(char *side, char *uplo, int *m, int *n, std::complex< double > *alpha, std::complex< double > *a, int *lda, std::complex< double > *b, int *ldb, std::complex< double > *beta, std::complex< double > *c, int *ldc)
void sscal_(const int *N, const float *alpha, float *x, const int *incx)
void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void ztrsm_(char *side, char *uplo, char *transa, char *diag, int *m, int *n, std::complex< double > *alpha, std::complex< double > *a, int *lda, std::complex< double > *b, int *ldb)
void zcopy_(long const *n, const std::complex< double > *a, int const *incx, std::complex< double > *b, int const *incy)
void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy)
double ddot_(const int *N, const double *x, const int *incx, const double *y, const int *incy)
void dscal_(const int *N, const double *alpha, double *x, const int *incx)
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void dsymv_(const char *uplo, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
double dnrm2_(const int *n, const double *x, const int *incx)
void cscal_(const int *N, const std::complex< float > *alpha, std::complex< float > *x, const int *incx)
void cgemv_(const char *trans, const int *m, const int *n, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *x, const int *incx, const std::complex< float > *beta, std::complex< float > *y, const int *incy)
double dznrm2_(const int *n, const std::complex< double > *x, const int *incx)
void zaxpy_(const int *N, const std::complex< double > *alpha, const std::complex< double > *x, const int *incx, std::complex< double > *y, const int *incy)
void zgemv_(const char *trans, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *x, const int *incx, const std::complex< double > *beta, std::complex< double > *y, const int *incy)
float snrm2_(const int *n, const float *x, const int *incx)
Definition blas_connector.h:147
#define N
Definition exp.cpp:24
#define T
Definition exp.cpp:237
Definition tensor.cpp:8