1#ifndef BLAS_CONNECTOR_H
2#define BLAS_CONNECTOR_H
16void sscal_(
const int *
N,
const float *alpha,
float *X,
const int *incX);
17void dscal_(
const int *
N,
const double *alpha,
double *X,
const int *incX);
18void cscal_(
const int *
N,
const std::complex<float> *alpha, std::complex<float> *X,
const int *incX);
19void zscal_(
const int *
N,
const std::complex<double> *alpha, std::complex<double> *X,
const int *incX);
21void saxpy_(
const int *
N,
const float *alpha,
const float *X,
const int *incX,
float *Y,
const int *incY);
22void daxpy_(
const int *
N,
const double *alpha,
const double *X,
const int *incX,
double *Y,
const int *incY);
23void caxpy_(
const int *
N,
const std::complex<float> *alpha,
const std::complex<float> *X,
const int *incX, std::complex<float> *Y,
const int *incY);
24void zaxpy_(
const int *
N,
const std::complex<double> *alpha,
const std::complex<double> *X,
const int *incX, std::complex<double> *Y,
const int *incY);
26void scopy_(
const int *n,
const float *a,
const int *incx,
float *b,
const int *incy);
27void dcopy_(
const int *n,
const double *a,
const int *incx,
double *b,
const int *incy);
28void ccopy_(
const int *n,
const std::complex<float> *a,
const int *incx, std::complex<float> *b,
const int *incy);
29void zcopy_(
const int *n,
const std::complex<double> *a,
const int *incx, std::complex<double> *b,
const int *incy);
32float sdot_(
const int *
N,
const float *X,
const int *incX,
const float *Y,
const int *incY);
33double ddot_(
const int *
N,
const double *X,
const int *incX,
const double *Y,
const int *incY);
35float snrm2_(
const int *n,
const float *X,
const int *incX);
36double dnrm2_(
const int *n,
const double *X,
const int *incX);
37float scnrm2_(
const int *n,
const std::complex<float> *X,
const int *incX);
38double dznrm2_(
const int *n,
const std::complex<double> *X,
const int *incX);
42void sgemv_(
const char *transa,
const int *m,
const int *n,
43 const float *alpha,
const float *a,
const int *lda,
44 const float *x,
const int *incx,
45 const float *beta,
float *y,
const int *incy);
47void dgemv_(
const char *transa,
const int *m,
const int *n,
48 const double *alpha,
const double *a,
const int *lda,
49 const double *x,
const int *incx,
50 const double *beta,
double *y,
const int *incy);
52void cgemv_(
const char *trans,
const int *m,
const int *n,
53 const std::complex<float> *alpha,
54 const std::complex<float> *a,
const int *lda,
55 const std::complex<float> *x,
const int *incx,
56 const std::complex<float> *beta,
57 std::complex<float> *y,
const int *incy);
59void zgemv_(
const char *trans,
const int *m,
const int *n,
60 const std::complex<double> *alpha,
61 const std::complex<double> *a,
const int *lda,
62 const std::complex<double> *x,
const int *incx,
63 const std::complex<double> *beta,
64 std::complex<double> *y,
const int *incy);
66void dsymv_(
const char *uplo,
const int *n,
67 const double *alpha,
const double *a,
const int *lda,
68 const double *x,
const int *incx,
69 const double *beta,
double *y,
const int *incy);
71void dger_(
const int *m,
const int *n,
73 const double *x,
const int *incx,
74 const double *y,
const int *incy,
75 double *a,
const int *lda);
77void zgerc_(
const int *m,
const int *n,
78 const std::complex<double> *alpha,
79 const std::complex<double> *x,
const int *incx,
80 const std::complex<double> *y,
const int *incy,
81 std::complex<double> *a,
const int *lda);
85void sgemm_(
const char *transa,
const char *transb,
86 const int *m,
const int *n,
const int *k,
88 const float *a,
const int *lda,
89 const float *b,
const int *ldb,
91 float *c,
const int *ldc);
93void dgemm_(
const char *transa,
const char *transb,
94 const int *m,
const int *n,
const int *k,
96 const double *a,
const int *lda,
97 const double *b,
const int *ldb,
99 double *c,
const int *ldc);
101void cgemm_(
const char *transa,
const char *transb,
102 const int *m,
const int *n,
const int *k,
103 const std::complex<float> *alpha,
104 const std::complex<float> *a,
const int *lda,
105 const std::complex<float> *b,
const int *ldb,
106 const std::complex<float> *beta,
107 std::complex<float> *c,
const int *ldc);
109void zgemm_(
const char *transa,
const char *transb,
110 const int *m,
const int *n,
const int *k,
111 const std::complex<double> *alpha,
112 const std::complex<double> *a,
const int *lda,
113 const std::complex<double> *b,
const int *ldb,
114 const std::complex<double> *beta,
115 std::complex<double> *c,
const int *ldc);
117void ssymm_(
const char *side,
const char *uplo,
118 const int *m,
const int *n,
120 const float *a,
const int *lda,
121 const float *b,
const int *ldb,
123 float *c,
const int *ldc);
125void dsymm_(
const char *side,
const char *uplo,
126 const int *m,
const int *n,
128 const double *a,
const int *lda,
129 const double *b,
const int *ldb,
131 double *c,
const int *ldc);
133void csymm_(
const char *side,
const char *uplo,
134 const int *m,
const int *n,
135 const std::complex<float> *alpha,
136 const std::complex<float> *a,
const int *lda,
137 const std::complex<float> *b,
const int *ldb,
138 const std::complex<float> *beta,
139 std::complex<float> *c,
const int *ldc);
141void zsymm_(
const char *side,
const char *uplo,
142 const int *m,
const int *n,
143 const std::complex<double> *alpha,
144 const std::complex<double> *a,
const int *lda,
145 const std::complex<double> *b,
const int *ldb,
146 const std::complex<double> *beta,
147 std::complex<double> *c,
const int *ldc);
149void chemm_(
const char *side,
const char *uplo,
150 const int *m,
const int *n,
151 const std::complex<float> *alpha,
152 const std::complex<float> *a,
const int *lda,
153 const std::complex<float> *b,
const int *ldb,
154 const std::complex<float> *beta,
155 std::complex<float> *c,
const int *ldc);
157void zhemm_(
const char *side,
const char *uplo,
158 const int *m,
const int *n,
159 const std::complex<double> *alpha,
160 const std::complex<double> *a,
const int *lda,
161 const std::complex<double> *b,
const int *ldb,
162 const std::complex<double> *beta,
163 std::complex<double> *c,
const int *ldc);
165void dtrsm_(
const char *side,
const char *uplo,
const char *transa,
const char *diag,
166 const int *m,
const int *n,
168 const double *a,
const int *lda,
169 double *b,
const int *ldb);
171void ztrsm_(
const char *side,
const char *uplo,
const char *transa,
const char *diag,
172 const int *m,
const int *n,
173 const std::complex<double> *alpha,
174 const std::complex<double> *a,
const int *lda,
175 std::complex<double> *b,
const int *ldb);
178void cherk_(
const char* uplo,
const char* trans,
const int* n,
const int* k,
180 const std::complex<float>* a,
const int* lda,
182 std::complex<float>* c,
const int* ldc);
184void zherk_(
const char* uplo,
const char* trans,
const int* n,
const int* k,
186 const std::complex<double>* a,
const int* lda,
188 std::complex<double>* c,
const int* ldc);
191void dsyrk_(
const char* uplo,
const char* trans,
const int* n,
const int* k,
193 const double* a,
const int* lda,
274 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
275 const float alpha,
const float *a,
const int lda,
const float *b,
const int ldb,
279 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
280 const double alpha,
const double *a,
const int lda,
const double *b,
const int ldb,
284 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
285 const std::complex<float> alpha,
const std::complex<float> *a,
const int lda,
const std::complex<float> *b,
const int ldb,
289 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
290 const std::complex<double> alpha,
const std::complex<double> *a,
const int lda,
const std::complex<double> *b,
const int ldb,
296 void gemm_cm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
297 const float alpha,
const float *a,
const int lda,
const float *b,
const int ldb,
301 void gemm_cm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
302 const double alpha,
const double *a,
const int lda,
const double *b,
const int ldb,
306 void gemm_cm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
307 const std::complex<float> alpha,
const std::complex<float> *a,
const int lda,
const std::complex<float> *b,
const int ldb,
311 void gemm_cm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
312 const std::complex<double> alpha,
const std::complex<double> *a,
const int lda,
const std::complex<double> *b,
const int ldb,
320 void symm_cm(
const char side,
const char uplo,
const int m,
const int n,
321 const float alpha,
const float *a,
const int lda,
const float *b,
const int ldb,
325 void symm_cm(
const char side,
const char uplo,
const int m,
const int n,
326 const double alpha,
const double *a,
const int lda,
const double *b,
const int ldb,
330 void symm_cm(
const char side,
const char uplo,
const int m,
const int n,
331 const std::complex<float> alpha,
const std::complex<float> *a,
const int lda,
const std::complex<float> *b,
const int ldb,
335 void symm_cm(
const char side,
const char uplo,
const int m,
const int n,
336 const std::complex<double> alpha,
const std::complex<double> *a,
const int lda,
const std::complex<double> *b,
const int ldb,
343 void hemm_cm(
const char side,
const char uplo,
const int m,
const int n,
344 const float alpha,
const float *a,
const int lda,
const float *b,
const int ldb,
348 void hemm_cm(
const char side,
const char uplo,
const int m,
const int n,
349 const double alpha,
const double *a,
const int lda,
const double *b,
const int ldb,
353 void hemm_cm(
char side,
char uplo,
int m,
int n,
354 std::complex<float> alpha, std::complex<float> *a,
int lda, std::complex<float> *b,
int ldb,
358 void hemm_cm(
char side,
char uplo,
int m,
int n,
359 std::complex<double> alpha, std::complex<double> *a,
int lda, std::complex<double> *b,
int ldb,
364 void gemv(
const char trans,
const int m,
const int n,
365 const float alpha,
const float* A,
const int lda,
const float* X,
const int incx,
369 void gemv(
const char trans,
const int m,
const int n,
370 const double alpha,
const double* A,
const int lda,
const double* X,
const int incx,
374 void gemv(
const char trans,
const int m,
const int n,
375 const std::complex<float> alpha,
const std::complex<float> *A,
const int lda,
const std::complex<float> *X,
const int incx,
379 void gemv(
const char trans,
const int m,
const int n,
380 const std::complex<double> alpha,
const std::complex<double> *A,
const int lda,
const std::complex<double> *X,
const int incx,
409 template <
typename T>
413 template <
typename T>
433#include <cuda_runtime.h>
434#include "cublas_v2.h"
441 static cublasHandle_t cublas_handle =
nullptr;
443 void createGpuBlasHandle();
445 void destoryBLAShandle();
447 cublasOperation_t judge_trans(
bool is_complex,
const char& trans,
const char* name);
449 cublasSideMode_t judge_side(
const char& trans);
451 cublasFillMode_t judge_fill(
const char& trans);
462#define zgemm_ zgemm_i
463void zgemm_i(
const char *transa,
468 const std::complex<double> *alpha,
469 const std::complex<double> *a,
471 const std::complex<double> *b,
473 const std::complex<double> *beta,
474 std::complex<double> *c,
477#define zaxpy_ zaxpy_i
479 const std::complex<double> *alpha,
480 const std::complex<double> *X,
482 std::complex<double> *Y,
double ddot_(const int *N, const double *X, const int *incX, const double *Y, const int *incY)
void ccopy_(const int *n, const std::complex< float > *a, const int *incx, std::complex< float > *b, const int *incy)
void dscal_(const int *N, const double *alpha, double *X, const int *incX)
void dtrsm_(const char *side, const char *uplo, const char *transa, const char *diag, const int *m, const int *n, const double *alpha, const double *a, const int *lda, double *b, const int *ldb)
void cscal_(const int *N, const std::complex< float > *alpha, std::complex< float > *X, const int *incX)
void zgerc_(const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *x, const int *incx, const std::complex< double > *y, const int *incy, std::complex< double > *a, const int *lda)
void ztrsm_(const char *side, const char *uplo, const char *transa, const char *diag, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, std::complex< double > *b, const int *ldb)
float scnrm2_(const int *n, const std::complex< float > *X, const int *incX)
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc)
void zhemm_(const char *side, const char *uplo, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void scopy_(const int *n, const float *a, const int *incx, float *b, const int *incy)
void sscal_(const int *N, const float *alpha, float *X, const int *incX)
void daxpy_(const int *N, const double *alpha, const double *X, const int *incX, double *Y, const int *incY)
void cherk_(const char *uplo, const char *trans, const int *n, const int *k, const float *alpha, const std::complex< float > *a, const int *lda, const float *beta, std::complex< float > *c, const int *ldc)
void dsymm_(const char *side, const char *uplo, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void csymm_(const char *side, const char *uplo, const int *m, const int *n, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void dger_(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *a, const int *lda)
void zscal_(const int *N, const std::complex< double > *alpha, std::complex< double > *X, const int *incX)
void zherk_(const char *uplo, const char *trans, const int *n, const int *k, const double *alpha, const std::complex< double > *a, const int *lda, const double *beta, std::complex< double > *c, const int *ldc)
void chemm_(const char *side, const char *uplo, const int *m, const int *n, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void zsymm_(const char *side, const char *uplo, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void dsyrk_(const char *uplo, const char *trans, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *beta, double *c, const int *ldc)
void dgemv_(const char *transa, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
void zaxpy_(const int *N, const std::complex< double > *alpha, const std::complex< double > *X, const int *incX, std::complex< double > *Y, const int *incY)
void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void sgemv_(const char *transa, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy)
float sdot_(const int *N, const float *X, const int *incX, const float *Y, const int *incY)
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void dsymv_(const char *uplo, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
double dznrm2_(const int *n, const std::complex< double > *X, const int *incX)
void caxpy_(const int *N, const std::complex< float > *alpha, const std::complex< float > *X, const int *incX, std::complex< float > *Y, const int *incY)
void ssymm_(const char *side, const char *uplo, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc)
float snrm2_(const int *n, const float *X, const int *incX)
void cgemv_(const char *trans, const int *m, const int *n, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *x, const int *incx, const std::complex< float > *beta, std::complex< float > *y, const int *incy)
void saxpy_(const int *N, const float *alpha, const float *X, const int *incX, float *Y, const int *incY)
double dnrm2_(const int *n, const double *X, const int *incX)
void zcopy_(const int *n, const std::complex< double > *a, const int *incx, std::complex< double > *b, const int *incy)
void dcopy_(const int *n, const double *a, const int *incx, double *b, const int *incy)
void zgemv_(const char *trans, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *x, const int *incx, const std::complex< double > *beta, std::complex< double > *y, const int *incy)
Definition blas_connector.h:203
static void vector_add_vector(const int &dim, std::complex< double > *result, const std::complex< double > *vector1, const double constant1, const std::complex< double > *vector2, const double constant2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void vector_div_vector(const int &dim, T *result, const T *vector1, const T *vector2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void copy(const int n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:335
static void axpy(const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:18
static void gemm_cm(const char transa, const char transb, const int m, const int n, const int k, const float alpha, const float *a, const int lda, const float *b, const int ldb, const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:190
static void gemv(const char trans, const int m, const int n, const float alpha, const float *A, const int lda, const float *X, const int incx, const float beta, float *Y, const int incy, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:501
static void vector_add_vector(const int &dim, double *result, const double *vector1, const double constant1, const double *vector2, const double constant2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void vector_add_vector(const int &dim, std::complex< float > *result, const std::complex< float > *vector1, const float constant1, const std::complex< float > *vector2, const float constant2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static float nrm2(const int n, const float *X, const int, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:271
static float dotc(const int n, const float *const X, const int incX, const float *const Y, const int incY, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:224
static void gemm(const char transa, const char transb, const int m, const int n, const int k, const float alpha, const float *a, const int lda, const float *b, const int ldb, const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:20
static void symm_cm(const char side, const char uplo, const int m, const int n, const float alpha, const float *a, const int lda, const float *b, const int ldb, const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:361
static void hemm_cm(const char side, const char uplo, const int m, const int n, const float alpha, const float *a, const int lda, const float *b, const int ldb, const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:445
static void vector_add_vector(const int &dim, float *result, const float *vector1, const float constant1, const float *vector2, const float constant2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void vector_mul_vector(const int &dim, T *result, const T *vector1, const T *vector2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void scal(const int n, const float alpha, float *X, const int incX, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:80
static float dotu(const int n, const float *const X, const int incX, const float *const Y, const int incY, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:177
static float dot(const int n, const float *const X, const int incX, const float *const Y, const int incY, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:142
#define N
Definition exp.cpp:24
#define T
Definition exp.cpp:237
void zgemm_i(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
Definition gather_math_lib_info.cpp:16
void zaxpy_i(const int *N, const std::complex< double > *alpha, const std::complex< double > *X, const int *incX, std::complex< double > *Y, const int *incY)
Definition gather_math_lib_info.cpp:36
AbacusDevice_t
Definition types.h:12
@ CpuDevice
Definition types.h:14