1#ifndef BLAS_CONNECTOR_H
2#define BLAS_CONNECTOR_H
16 void sscal_(
const int *
N,
const float *alpha,
float *X,
const int *incX);
17 void dscal_(
const int *
N,
const double *alpha,
double *X,
const int *incX);
18 void cscal_(
const int *
N,
const std::complex<float> *alpha, std::complex<float> *X,
const int *incX);
19 void zscal_(
const int *
N,
const std::complex<double> *alpha, std::complex<double> *X,
const int *incX);
22 void saxpy_(
const int *
N,
const float *alpha,
const float *X,
const int *incX,
float *Y,
const int *incY);
23 void daxpy_(
const int *
N,
const double *alpha,
const double *X,
const int *incX,
double *Y,
const int *incY);
24 void caxpy_(
const int *
N,
const std::complex<float> *alpha,
const std::complex<float> *X,
const int *incX, std::complex<float> *Y,
const int *incY);
25 void zaxpy_(
const int *
N,
const std::complex<double> *alpha,
const std::complex<double> *X,
const int *incX, std::complex<double> *Y,
const int *incY);
27 void dcopy_(
long const *n,
const double *a,
int const *incx,
double *b,
int const *incy);
28 void zcopy_(
long const *n,
const std::complex<double> *a,
int const *incx, std::complex<double> *b,
int const *incy);
35 float sdot_(
const int *
N,
const float *X,
const int *incX,
const float *Y,
const int *incY);
36 double ddot_(
const int *
N,
const double *X,
const int *incX,
const double *Y,
const int *incY);
39 float snrm2_(
const int *n,
const float *X,
const int *incX );
40 double dnrm2_(
const int *n,
const double *X,
const int *incX );
41 double dznrm2_(
const int *n,
const std::complex<double> *X,
const int *incX );
58 void sgemv_(
const char*
const transa,
const int*
const m,
const int*
const n,
59 const float*
const alpha,
const float*
const a,
const int*
const lda,
const float*
const x,
const int*
const incx,
60 const float*
const beta,
float*
const y,
const int*
const incy);
61 void dgemv_(
const char*
const transa,
const int*
const m,
const int*
const n,
62 const double*
const alpha,
const double*
const a,
const int*
const lda,
const double*
const x,
const int*
const incx,
63 const double*
const beta,
double*
const y,
const int*
const incy);
65 void cgemv_(
const char *trans,
const int *m,
const int *n,
const std::complex<float> *alpha,
66 const std::complex<float> *a,
const int *lda,
const std::complex<float> *x,
const int *incx,
67 const std::complex<float> *beta, std::complex<float> *y,
const int *incy);
69 void zgemv_(
const char *trans,
const int *m,
const int *n,
const std::complex<double> *alpha,
70 const std::complex<double> *a,
const int *lda,
const std::complex<double> *x,
const int *incx,
71 const std::complex<double> *beta, std::complex<double> *y,
const int *incy);
73 void dsymv_(
const char *uplo,
const int *n,
74 const double *alpha,
const double *a,
const int *lda,
75 const double *x,
const int *incx,
76 const double *beta,
double *y,
const int *incy);
90 const std::complex<double>* alpha,
91 const std::complex<double>* x,
93 const std::complex<double>* y,
95 std::complex<double>* a,
102 void sgemm_(
const char *transa,
const char *transb,
const int *m,
const int *n,
const int *k,
103 const float *alpha,
const float *a,
const int *lda,
const float *b,
const int *ldb,
104 const float *beta,
float *c,
const int *ldc);
105 void dgemm_(
const char *transa,
const char *transb,
const int *m,
const int *n,
const int *k,
106 const double *alpha,
const double *a,
const int *lda,
const double *b,
const int *ldb,
107 const double *beta,
double *c,
const int *ldc);
108 void cgemm_(
const char *transa,
const char *transb,
const int *m,
const int *n,
const int *k,
109 const std::complex<float> *alpha,
const std::complex<float> *a,
const int *lda,
const std::complex<float> *b,
const int *ldb,
110 const std::complex<float> *beta, std::complex<float> *c,
const int *ldc);
111 void zgemm_(
const char *transa,
const char *transb,
const int *m,
const int *n,
const int *k,
112 const std::complex<double> *alpha,
const std::complex<double> *a,
const int *lda,
const std::complex<double> *b,
const int *ldb,
113 const std::complex<double> *beta, std::complex<double> *c,
const int *ldc);
116 void ssymm_(
const char *side,
const char *uplo,
const int *m,
const int *n,
117 const float *alpha,
const float *a,
const int *lda,
const float *b,
const int *ldb,
118 const float *beta,
float *c,
const int *ldc);
119 void dsymm_(
const char *side,
const char *uplo,
const int *m,
const int *n,
120 const double *alpha,
const double *a,
const int *lda,
const double *b,
const int *ldb,
121 const double *beta,
double *c,
const int *ldc);
122 void csymm_(
const char *side,
const char *uplo,
const int *m,
const int *n,
123 const std::complex<float> *alpha,
const std::complex<float> *a,
const int *lda,
const std::complex<float> *b,
const int *ldb,
124 const std::complex<float> *beta, std::complex<float> *c,
const int *ldc);
125 void zsymm_(
const char *side,
const char *uplo,
const int *m,
const int *n,
126 const std::complex<double> *alpha,
const std::complex<double> *a,
const int *lda,
const std::complex<double> *b,
const int *ldb,
127 const std::complex<double> *beta, std::complex<double> *c,
const int *ldc);
130 void chemm_(
char *side,
char *uplo,
int *m,
int *n,std::complex<float> *alpha,
131 std::complex<float> *a,
int *lda, std::complex<float> *b,
int *ldb, std::complex<float> *beta, std::complex<float> *c,
int *ldc);
132 void zhemm_(
char *side,
char *uplo,
int *m,
int *n,std::complex<double> *alpha,
133 std::complex<double> *a,
int *lda, std::complex<double> *b,
int *ldb, std::complex<double> *beta, std::complex<double> *c,
int *ldc);
136 void dtrsm_(
char *side,
char* uplo,
char *transa,
char *diag,
int *m,
int *n,
137 double* alpha,
double* a,
int *lda,
double*b,
int *ldb);
138 void ztrsm_(
char *side,
char* uplo,
char *transa,
char *diag,
int *m,
int *n,
139 std::complex<double>* alpha, std::complex<double>* a,
int *lda, std::complex<double>*b,
int *ldb);
218 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
219 const float alpha,
const float *a,
const int lda,
const float *b,
const int ldb,
223 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
224 const double alpha,
const double *a,
const int lda,
const double *b,
const int ldb,
228 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
229 const std::complex<float> alpha,
const std::complex<float> *a,
const int lda,
const std::complex<float> *b,
const int ldb,
233 void gemm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
234 const std::complex<double> alpha,
const std::complex<double> *a,
const int lda,
const std::complex<double> *b,
const int ldb,
240 void gemm_cm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
241 const float alpha,
const float *a,
const int lda,
const float *b,
const int ldb,
245 void gemm_cm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
246 const double alpha,
const double *a,
const int lda,
const double *b,
const int ldb,
250 void gemm_cm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
251 const std::complex<float> alpha,
const std::complex<float> *a,
const int lda,
const std::complex<float> *b,
const int ldb,
255 void gemm_cm(
const char transa,
const char transb,
const int m,
const int n,
const int k,
256 const std::complex<double> alpha,
const std::complex<double> *a,
const int lda,
const std::complex<double> *b,
const int ldb,
264 void symm_cm(
const char side,
const char uplo,
const int m,
const int n,
265 const float alpha,
const float *a,
const int lda,
const float *b,
const int ldb,
269 void symm_cm(
const char side,
const char uplo,
const int m,
const int n,
270 const double alpha,
const double *a,
const int lda,
const double *b,
const int ldb,
274 void symm_cm(
const char side,
const char uplo,
const int m,
const int n,
275 const std::complex<float> alpha,
const std::complex<float> *a,
const int lda,
const std::complex<float> *b,
const int ldb,
279 void symm_cm(
const char side,
const char uplo,
const int m,
const int n,
280 const std::complex<double> alpha,
const std::complex<double> *a,
const int lda,
const std::complex<double> *b,
const int ldb,
287 void hemm_cm(
const char side,
const char uplo,
const int m,
const int n,
288 const float alpha,
const float *a,
const int lda,
const float *b,
const int ldb,
292 void hemm_cm(
const char side,
const char uplo,
const int m,
const int n,
293 const double alpha,
const double *a,
const int lda,
const double *b,
const int ldb,
297 void hemm_cm(
char side,
char uplo,
int m,
int n,
298 std::complex<float> alpha, std::complex<float> *a,
int lda, std::complex<float> *b,
int ldb,
302 void hemm_cm(
char side,
char uplo,
int m,
int n,
303 std::complex<double> alpha, std::complex<double> *a,
int lda, std::complex<double> *b,
int ldb,
308 void gemv(
const char trans,
const int m,
const int n,
309 const float alpha,
const float* A,
const int lda,
const float* X,
const int incx,
313 void gemv(
const char trans,
const int m,
const int n,
314 const double alpha,
const double* A,
const int lda,
const double* X,
const int incx,
318 void gemv(
const char trans,
const int m,
const int n,
319 const std::complex<float> alpha,
const std::complex<float> *A,
const int lda,
const std::complex<float> *X,
const int incx,
323 void gemv(
const char trans,
const int m,
const int n,
324 const std::complex<double> alpha,
const std::complex<double> *A,
const int lda,
const std::complex<double> *X,
const int incx,
347 template <
typename T>
351 template <
typename T>
371#include <cuda_runtime.h>
372#include "cublas_v2.h"
379 static cublasHandle_t cublas_handle =
nullptr;
381 void createGpuBlasHandle();
383 void destoryBLAShandle();
385 cublasOperation_t judge_trans(
bool is_complex,
const char& trans,
const char* name);
387 cublasSideMode_t judge_side(
const char& trans);
389 cublasFillMode_t judge_fill(
const char& trans);
400#define zgemm_ zgemm_i
401void zgemm_i(
const char *transa,
406 const std::complex<double> *alpha,
407 const std::complex<double> *a,
409 const std::complex<double> *b,
411 const std::complex<double> *beta,
412 std::complex<double> *c,
415#define zaxpy_ zaxpy_i
417 const std::complex<double> *alpha,
418 const std::complex<double> *X,
420 std::complex<double> *Y,
void chemm_(char *side, char *uplo, int *m, int *n, std::complex< float > *alpha, std::complex< float > *a, int *lda, std::complex< float > *b, int *ldb, std::complex< float > *beta, std::complex< float > *c, int *ldc)
double ddot_(const int *N, const double *X, const int *incX, const double *Y, const int *incY)
void dscal_(const int *N, const double *alpha, double *X, const int *incX)
void cscal_(const int *N, const std::complex< float > *alpha, std::complex< float > *X, const int *incX)
void zgerc_(const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *x, const int *incx, const std::complex< double > *y, const int *incy, std::complex< double > *a, const int *lda)
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc)
void dgemv_(const char *const transa, const int *const m, const int *const n, const double *const alpha, const double *const a, const int *const lda, const double *const x, const int *const incx, const double *const beta, double *const y, const int *const incy)
void sscal_(const int *N, const float *alpha, float *X, const int *incX)
void daxpy_(const int *N, const double *alpha, const double *X, const int *incX, double *Y, const int *incY)
void dsymm_(const char *side, const char *uplo, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void csymm_(const char *side, const char *uplo, const int *m, const int *n, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void dger_(const int *m, const int *n, const double *alpha, const double *x, const int *incx, const double *y, const int *incy, double *a, const int *lda)
void zscal_(const int *N, const std::complex< double > *alpha, std::complex< double > *X, const int *incX)
void dtrsm_(char *side, char *uplo, char *transa, char *diag, int *m, int *n, double *alpha, double *a, int *lda, double *b, int *ldb)
void zhemm_(char *side, char *uplo, int *m, int *n, std::complex< double > *alpha, std::complex< double > *a, int *lda, std::complex< double > *b, int *ldb, std::complex< double > *beta, std::complex< double > *c, int *ldc)
void zsymm_(const char *side, const char *uplo, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void dsyrk_(const char *uplo, const char *trans, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *beta, double *c, const int *ldc)
void zaxpy_(const int *N, const std::complex< double > *alpha, const std::complex< double > *X, const int *incX, std::complex< double > *Y, const int *incY)
void cgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void ztrsm_(char *side, char *uplo, char *transa, char *diag, int *m, int *n, std::complex< double > *alpha, std::complex< double > *a, int *lda, std::complex< double > *b, int *ldb)
void zcopy_(long const *n, const std::complex< double > *a, int const *incx, std::complex< double > *b, int const *incy)
void dcopy_(long const *n, const double *a, int const *incx, double *b, int const *incy)
float sdot_(const int *N, const float *X, const int *incX, const float *Y, const int *incY)
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void dsymv_(const char *uplo, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
double dznrm2_(const int *n, const std::complex< double > *X, const int *incX)
void caxpy_(const int *N, const std::complex< float > *alpha, const std::complex< float > *X, const int *incX, std::complex< float > *Y, const int *incY)
void ssymm_(const char *side, const char *uplo, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc)
float snrm2_(const int *n, const float *X, const int *incX)
void cgemv_(const char *trans, const int *m, const int *n, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *x, const int *incx, const std::complex< float > *beta, std::complex< float > *y, const int *incy)
void sgemv_(const char *const transa, const int *const m, const int *const n, const float *const alpha, const float *const a, const int *const lda, const float *const x, const int *const incx, const float *const beta, float *const y, const int *const incy)
void saxpy_(const int *N, const float *alpha, const float *X, const int *incX, float *Y, const int *incY)
double dnrm2_(const int *n, const double *X, const int *incX)
void zgemv_(const char *trans, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *x, const int *incx, const std::complex< double > *beta, std::complex< double > *y, const int *incy)
Definition blas_connector.h:147
static void copy(const long n, const double *a, const int incx, double *b, const int incy, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:325
static void vector_add_vector(const int &dim, std::complex< double > *result, const std::complex< double > *vector1, const double constant1, const std::complex< double > *vector2, const double constant2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void vector_div_vector(const int &dim, T *result, const T *vector1, const T *vector2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void axpy(const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:18
static void gemm_cm(const char transa, const char transb, const int m, const int n, const int k, const float alpha, const float *a, const int lda, const float *b, const int ldb, const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:190
static void gemv(const char trans, const int m, const int n, const float alpha, const float *A, const int lda, const float *X, const int incx, const float beta, float *Y, const int incy, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:501
static void vector_add_vector(const int &dim, double *result, const double *vector1, const double constant1, const double *vector2, const double constant2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void vector_add_vector(const int &dim, std::complex< float > *result, const std::complex< float > *vector1, const float constant1, const std::complex< float > *vector2, const float constant2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static float nrm2(const int n, const float *X, const int, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:271
static float dotc(const int n, const float *const X, const int incX, const float *const Y, const int incY, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:224
static void gemm(const char transa, const char transb, const int m, const int n, const int k, const float alpha, const float *a, const int lda, const float *b, const int ldb, const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:20
static void symm_cm(const char side, const char uplo, const int m, const int n, const float alpha, const float *a, const int lda, const float *b, const int ldb, const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:361
static void hemm_cm(const char side, const char uplo, const int m, const int n, const float alpha, const float *a, const int lda, const float *b, const int ldb, const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_matrix.cpp:445
static void vector_add_vector(const int &dim, float *result, const float *vector1, const float constant1, const float *vector2, const float constant2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void vector_mul_vector(const int &dim, T *result, const T *vector1, const T *vector2, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
static void scal(const int n, const float alpha, float *X, const int incX, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:80
static float dotu(const int n, const float *const X, const int incX, const float *const Y, const int incY, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:177
static float dot(const int n, const float *const X, const int incX, const float *const Y, const int incY, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:142
#define N
Definition exp.cpp:24
#define T
Definition exp.cpp:237
void zgemm_i(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
Definition gather_math_lib_info.cpp:16
void zaxpy_i(const int *N, const std::complex< double > *alpha, const std::complex< double > *X, const int *incX, std::complex< double > *Y, const int *incY)
Definition gather_math_lib_info.cpp:36
AbacusDevice_t
Definition types.h:12
@ CpuDevice
Definition types.h:14