ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
blas.h
Go to the documentation of this file.
1#ifndef BASE_THIRD_PARTY_BLAS_H_
2#define BASE_THIRD_PARTY_BLAS_H_
3
4#include <complex>
5
6#if defined(__CUDA)
8#elif defined(__ROCM)
10#endif
11
12extern "C"
13{
14// level 1: std::vector-std::vector operations, O(n) data and O(n) work.
15
16// Peize Lin add ?scal 2016-08-04, to compute x=a*x
17void sscal_(const int *N, const float *alpha, float *x, const int *incx);
18void dscal_(const int *N, const double *alpha, double *x, const int *incx);
19void cscal_(const int *N, const std::complex<float> *alpha, std::complex<float> *x, const int *incx);
20void zscal_(const int *N, const std::complex<double> *alpha, std::complex<double> *x, const int *incx);
21
22// Peize Lin add ?axpy 2016-08-04, to compute y=a*x+y
23void saxpy_(const int *N, const float *alpha, const float *x, const int *incx, float *y, const int *incy);
24void daxpy_(const int *N, const double *alpha, const double *x, const int *incx, double *y, const int *incy);
25void caxpy_(const int *N, const std::complex<float> *alpha, const std::complex<float> *x, const int *incx, std::complex<float> *y, const int *incy);
26void zaxpy_(const int *N, const std::complex<double> *alpha, const std::complex<double> *x, const int *incx, std::complex<double> *y, const int *incy);
27
28void scopy_(const int *n, const float *a, const int *incx, float *b, int const *incy);
29void dcopy_(const int *n, const double *a, const int *incx, double *b, int const *incy);
30void ccopy_(const int *n, const std::complex<float> *a, const int *incx, std::complex<float> *b, int const *incy);
31void zcopy_(const int *n, const std::complex<double> *a, const int *incx, std::complex<double> *b, int const *incy);
32
33
34//reason for passing results as argument instead of returning it:
35//see https://www.numbercrunch.de/blog/2014/07/lost-in-translation/
36void cdotc_(const int *n, const std::complex<float> *zx, const int *incx,
37 const std::complex<float> *zy, const int *incy, std::complex<float> *result);
38void zdotc_(const int *n, const std::complex<double> *zx, const int *incx,
39 const std::complex<double> *zy, const int *incy, std::complex<double> *result);
40// Peize Lin add ?dot 2017-10-27, to compute d=x*y
41float sdot_(const int *N, const float *x, const int *incx, const float *y, const int *incy);
42double ddot_(const int *N, const double *x, const int *incx, const double *y, const int *incy);
43
44// Peize Lin add ?nrm2 2018-06-12, to compute out = ||x||_2 = \sqrt{ \sum_i x_i**2 }
45float snrm2_( const int *n, const float *x, const int *incx );
46double dnrm2_( const int *n, const double *x, const int *incx );
47float scnrm2_( const int *n, const std::complex<float> *x, const int *incx );
48double dznrm2_( const int *n, const std::complex<double> *x, const int *incx );
49
50// level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work.
51void sgemv_(const char*const transa, const int*const m, const int*const n,
52 const float*const alpha, const float*const a, const int*const lda, const float*const x, const int*const incx,
53 const float*const eta, float*const y, const int*const incy);
54void dgemv_(const char*const transa, const int*const m, const int*const n,
55 const double*const alpha, const double*const a, const int*const lda, const double*const x, const int*const incx,
56 const double*const beta, double*const y, const int*const incy);
57
58void cgemv_(const char *trans, const int *m, const int *n, const std::complex<float> *alpha,
59 const std::complex<float> *a, const int *lda, const std::complex<float> *x, const int *incx,
60 const std::complex<float> *beta, std::complex<float> *y, const int *incy);
61
62void zgemv_(const char *trans, const int *m, const int *n, const std::complex<double> *alpha,
63 const std::complex<double> *a, const int *lda, const std::complex<double> *x, const int *incx,
64 const std::complex<double> *beta, std::complex<double> *y, const int *incy);
65
66void dsymv_(const char *uplo, const int *n,
67 const double *alpha, const double *a, const int *lda,
68 const double *x, const int *incx,
69 const double *beta, double *y, const int *incy);
70
71// A := alpha x * y.T + A
72void dger_(const int* m,
73 const int* n,
74 const double* alpha,
75 const double* x,
76 const int* incx,
77 const double* y,
78 const int* incy,
79 double* a,
80 const int* lda);
81void zgerc_(const int* m,
82 const int* n,
83 const std::complex<double>* alpha,
84 const std::complex<double>* x,
85 const int* incx,
86 const std::complex<double>* y,
87 const int* incy,
88 std::complex<double>* a,
89 const int* lda);
90
91// level 3: matrix-matrix operations, O(n^2) data and O(n^3) work.
92
93// Peize Lin add ?gemm 2017-10-27, to compute C = a * A.? * B.? + b * C
94// A is general
95void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
96 const float *alpha, const float *a, const int *lda, const float *b, const int *ldb,
97 const float *beta, float *c, const int *ldc);
98void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
99 const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
100 const double *beta, double *c, const int *ldc);
101void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
102 const std::complex<float> *alpha, const std::complex<float> *a, const int *lda, const std::complex<float> *b, const int *ldb,
103 const std::complex<float> *beta, std::complex<float> *c, const int *ldc);
104void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
105 const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
106 const std::complex<double> *beta, std::complex<double> *c, const int *ldc);
107
108
109//a is symmetric
110void dsymm_(const char *side, const char *uplo, const int *m, const int *n,
111 const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
112 const double *beta, double *c, const int *ldc);
113//a is hermitian
114void zhemm_(const char *side, const char *uplo,
115 const int *m, const int *n,
116 const std::complex<double> *alpha,
117 const std::complex<double> *a, const int *lda,
118 const std::complex<double> *b, const int *ldb,
119 const std::complex<double> *beta,
120 std::complex<double> *c, const int *ldc);
121
122//solving triangular matrix with multiple right hand sides
123void dtrsm_(const char *side, const char *uplo, const char *transa, const char *diag,
124 const int *m, const int *n,
125 const double *alpha,
126 const double *a, const int *lda,
127 double *b, const int *ldb);
128
129void ztrsm_(const char *side, const char *uplo, const char *transa, const char *diag,
130 const int *m, const int *n,
131 const std::complex<double> *alpha,
132 const std::complex<double> *a, const int *lda,
133 std::complex<double> *b, const int *ldb);
134
135}
136
137namespace container {
138
139// Class BlasConnector provide the connector to fortran lapack routine.
140// The entire function in this class are static and inline function.
141// Usage example: BlasConnector::functionname(parameter list).
142namespace BlasConnector {
143
144static inline
145void axpy( const int& n, const float& alpha, const float *x, const int& incx, float *y, const int& incy)
146{
147 saxpy_(&n, &alpha, x, &incx, y, &incy);
148}
149static inline
150void axpy( const int& n, const double& alpha, const double *x, const int& incx, double *y, const int& incy)
151{
152 daxpy_(&n, &alpha, x, &incx, y, &incy);
153}
154static inline
155void axpy( const int& n, const std::complex<float>& alpha, const std::complex<float> *x, const int& incx, std::complex<float> *y, const int& incy)
156{
157 caxpy_(&n, &alpha, x, &incx, y, &incy);
158}
159static inline
160void axpy( const int& n, const std::complex<double>& alpha, const std::complex<double> *x, const int& incx, std::complex<double> *y, const int& incy)
161{
162 zaxpy_(&n, &alpha, x, &incx, y, &incy);
163}
164
165// Peize Lin add 2016-08-04
166// x=a*x
167static inline
168void scal( const int& n, const float& alpha, float *x, const int& incx)
169{
170 sscal_(&n, &alpha, x, &incx);
171}
172static inline
173void scal( const int& n, const double& alpha, double *x, const int& incx)
174{
175 dscal_(&n, &alpha, x, &incx);
176}
177static inline
178void scal( const int& n, const std::complex<float>& alpha, std::complex<float> *x, const int& incx)
179{
180 cscal_(&n, &alpha, x, &incx);
181}
182static inline
183void scal( const int& n, const std::complex<double>& alpha, std::complex<double> *x, const int& incx)
184{
185 zscal_(&n, &alpha, x, &incx);
186}
187
188// Peize Lin add 2017-10-27
189// d=x*y
190static inline
191float dot( const int& n, const float *x, const int& incx, const float *y, const int& incy)
192{
193 return sdot_(&n, x, &incx, y, &incy);
194}
195static inline
196double dot( const int& n, const double *x, const int& incx, const double *y, const int& incy)
197{
198 return ddot_(&n, x, &incx, y, &incy);
199}
200// Denghui Lu add 2023-8-01
201static inline
202std::complex<float> dot(const int& n, const std::complex<float> *x, const int& incx, const std::complex<float> *y, const int& incy)
203{
204 std::complex<float> result = {0, 0};
205 // cdotc_(&n, x, &incx, y, &incy, &result);
206 for (int ii = 0; ii < n; ii++) {
207 result += std::conj(x[ii * incx]) * y[ii * incy];
208 }
209 return result;
210}
211static inline
212std::complex<double> dot(const int& n, const std::complex<double> *x, const int& incx, const std::complex<double> *y, const int& incy)
213{
214 std::complex<double> result = {0, 0};
215 // zdotc_(&n, x, &incx, y, &incy, &result);
216 for (int ii = 0; ii < n; ii++) {
217 result += std::conj(x[ii * incx]) * y[ii * incy];
218 }
219 return result;
220}
221
222// Peize Lin add 2017-10-27, fix bug trans 2019-01-17
223// C = a * A.? * B.? + b * C
224static inline
225void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k,
226 const float& alpha, const float* A, const int& lda, const float* B, const int& ldb,
227 const float& beta, float* C, const int& ldc)
228{
229 sgemm_(&transa, &transb, &m, &n, &k,
230 &alpha, A, &lda, B, &ldb,
231 &beta, C, &ldc);
232}
233static inline
234void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k,
235 const double& alpha, const double* A, const int& lda, const double* B, const int& ldb,
236 const double& beta, double* C, const int& ldc)
237{
238 dgemm_(&transa, &transb, &m, &n, &k,
239 &alpha, A, &lda, B, &ldb,
240 &beta, C, &ldc);
241}
242static inline
243void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k,
244 const std::complex<float>& alpha, const std::complex<float>* A, const int& lda, const std::complex<float>* B, const int& ldb,
245 const std::complex<float>& beta, std::complex<float>* C, const int& ldc)
246{
247 cgemm_(&transa, &transb, &m, &n, &k,
248 &alpha, A, &lda, B, &ldb,
249 &beta, C, &ldc);
250}
251static inline
252void gemm(const char& transa, const char& transb, const int& m, const int& n, const int& k,
253 const std::complex<double>& alpha, const std::complex<double>* A, const int& lda, const std::complex<double>* B, const int& ldb,
254 const std::complex<double>& beta, std::complex<double>* C, const int& ldc)
255{
256 zgemm_(&transa, &transb, &m, &n, &k,
257 &alpha, A, &lda, B, &ldb,
258 &beta, C, &ldc);
259}
260
261template <typename T>
262static inline
263void gemm_batched(const char& transa, const char& transb, const int& m, const int& n, const int& k,
264 const T& alpha, T** A, const int& lda, T** B, const int& ldb,
265 const T& beta, T** C, const int& ldc, const int& batch_size)
266{
267 for (int ii = 0; ii < batch_size; ++ii) {
268 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
269 BlasConnector::gemm(transa, transb, m, n, k, alpha, A[ii], lda, B[ii], ldb, beta, C[ii], ldc);
270 }
271}
272
273template <typename T>
274static inline
275void gemm_batched_strided(const char& transa, const char& transb, const int& m, const int& n, const int& k,
276 const T& alpha, const T* A, const int& lda, const int& stride_a, const T* B, const int& ldb, const int& stride_b,
277 const T& beta, T* C, const int& ldc, const int& stride_c, const int& batch_size)
278{
279 for (int ii = 0; ii < batch_size; ii++) {
280 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
281 BlasConnector::gemm(transa, transb, m, n, k, alpha, A + ii * stride_a, lda, B + ii * stride_b, ldb, beta, C + ii * stride_c, ldc);
282 }
283}
284
285static inline
286void gemv(const char& trans, const int& m, const int& n,
287 const float& alpha, const float *A, const int& lda, const float *x, const int& incx,
288 const float& beta, float *y, const int& incy)
289{
290 sgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy);
291}
292static inline
293void gemv(const char& trans, const int& m, const int& n,
294 const double& alpha, const double *A, const int& lda, const double *x, const int& incx,
295 const double& beta, double *y, const int& incy)
296{
297 dgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy);
298}
299static inline
300void gemv(const char& trans, const int& m, const int& n,
301 const std::complex<float>& alpha, const std::complex<float> *A, const int& lda, const std::complex<float> *x, const int& incx,
302 const std::complex<float>& beta, std::complex<float> *y, const int& incy)
303{
304 cgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy);
305}
306static inline
307void gemv(const char& trans, const int& m, const int& n,
308 const std::complex<double>& alpha, const std::complex<double> *A, const int& lda, const std::complex<double> *x, const int& incx,
309 const std::complex<double>& beta, std::complex<double> *y, const int& incy)
310{
311 zgemv_(&trans, &m, &n, &alpha, A, &lda, x, &incx, &beta, y, &incy);
312}
313
314template <typename T>
315static inline
316void gemv_batched(const char& trans, const int& m, const int& n,
317 const T& alpha, T** A, const int& lda, T** x, const int& incx,
318 const T& beta, T** y, const int& incy, const int& batch_size)
319{
320 for (int ii = 0; ii < batch_size; ++ii) {
321 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
322 BlasConnector::gemv(trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy);
323 }
324}
325
326template <typename T>
327static inline
328void gemv_batched_strided(const char& transa, const int& m, const int& n,
329 const T& alpha, const T* A, const int& lda, const int& stride_a, const T* x, const int& incx, const int& stride_x,
330 const T& beta, T* y, const int& incy, const int& stride_y, const int& batch_size)
331{
332 for (int ii = 0; ii < batch_size; ii++) {
333 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
334 BlasConnector::gemv(transa, m, n, alpha, A + ii * stride_a, lda, x + ii * stride_x, incx, beta, y + ii * stride_y, incy);
335 }
336}
337
338// Peize Lin add 2018-06-12
339// out = ||x||_2
340static inline
341float nrm2( const int n, const float *x, const int incx )
342{
343 return snrm2_( &n, x, &incx );
344}
345static inline
346double nrm2( const int n, const double *x, const int incx )
347{
348 return dnrm2_( &n, x, &incx );
349}
350static inline
351double nrm2( const int n, const std::complex<float> *x, const int incx )
352{
353 return scnrm2_( &n, x, &incx );
354}
355static inline
356double nrm2( const int n, const std::complex<double> *x, const int incx )
357{
358 return dznrm2_( &n, x, &incx );
359}
360
361// copies a into b
362static inline
363void copy(const int n, const float *a, const int incx, float *b, const int incy)
364{
365 scopy_(&n, a, &incx, b, &incy);
366}
367static inline
368void copy(const int n, const double *a, const int incx, double *b, const int incy)
369
370{
371 dcopy_(&n, a, &incx, b, &incy);
372}
373static inline
374void copy(const int n, const std::complex<float> *a, const int incx, std::complex<float> *b, const int incy)
375{
376 ccopy_(&n, a, &incx, b, &incy);
377}
378static inline
379void copy(const int n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy)
380{
381 zcopy_(&n, a, &incx, b, &incy);
382}
383
384} // namespace BlasConnector
385} // namespace container
386
387#endif // BASE_THIRD_PARTY_BLAS_H_
void cdotc_(const int *n, const std::complex< float > *zx, const int *incx, const std::complex< float > *zy, const int *incy, std::complex< float > *result)
void saxpy_(const int *N, const float *alpha, const float *x, const int *incx, float *y, const int *incy)
void sgemv_(const char *const transa, const int *const m, const int *const n, const float *const alpha, const float *const a, const int *const lda, const float *const x, const int *const incx, const float *const eta, float *const y, const int *const incy)
void ccopy_(const int *n, const std::complex< float > *a, const int *incx, std::complex< float > *b, int const *incy)
void zcopy_(const int *n, const std::complex< double > *a, const int *incx, std::complex< double > *b, int const *incy)
void dtrsm_(const char *side, const char *uplo, const char *transa, const char *diag, const int *m, const int *n, const double *alpha, const double *a, const int *lda, double *b, const int *ldb)
void zgerc_(const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *x, const int *incx, const std::complex< double > *y, const int *incy, std::complex< double > *a, const int *lda)
void ztrsm_(const char *side, const char *uplo, const char *transa, const char *diag, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, std::complex< double > *b, const int *ldb)
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc)
void daxpy_(const int *N, const double *alpha, const double *x, const int *incx, double *y, const int *incy)
void zhemm_(const char *side, const char *uplo, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void dgemv_(const char *const transa, const int *const m, const int *const n, const double *const alpha, const double *const a, const int *const lda, const double *const x, const int *const incx, const double *const beta, double *const y, const int *const incy)
void dsymm_(const char *side, const char *uplo, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
float sdot_(const int *N, const float *x, const int *incx, const float *y, const int *incy)
void zscal_(const int *N, const std::complex< double > *alpha, std::complex< double > *x, const int *incx)
void dger_(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *a, const int *lda)
void zdotc_(const int *n, const std::complex< double > *zx, const int *incx, const std::complex< double > *zy, const int *incy, std::complex< double > *result)
void caxpy_(const int *N, const std::complex< float > *alpha, const std::complex< float > *x, const int *incx, std::complex< float > *y, const int *incy)
void sscal_(const int *N, const float *alpha, float *x, const int *incx)
void scopy_(const int *n, const float *a, const int *incx, float *b, int const *incy)
void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void dcopy_(const int *n, const double *a, const int *incx, double *b, int const *incy)
double ddot_(const int *N, const double *x, const int *incx, const double *y, const int *incy)
void dscal_(const int *N, const double *alpha, double *x, const int *incx)
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void dsymv_(const char *uplo, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
double dnrm2_(const int *n, const double *x, const int *incx)
void cscal_(const int *N, const std::complex< float > *alpha, std::complex< float > *x, const int *incx)
void cgemv_(const char *trans, const int *m, const int *n, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *x, const int *incx, const std::complex< float > *beta, std::complex< float > *y, const int *incy)
double dznrm2_(const int *n, const std::complex< double > *x, const int *incx)
void zaxpy_(const int *N, const std::complex< double > *alpha, const std::complex< double > *x, const int *incx, std::complex< double > *y, const int *incy)
void zgemv_(const char *trans, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *x, const int *incx, const std::complex< double > *beta, std::complex< double > *y, const int *incy)
float scnrm2_(const int *n, const std::complex< float > *x, const int *incx)
float snrm2_(const int *n, const float *x, const int *incx)
Definition blas_connector.h:203
#define N
Definition exp.cpp:24
#define T
Definition exp.cpp:237
Definition tensor.cpp:8