9namespace cuBlasConnector {
12void dot(cublasHandle_t& handle,
const int& n,
const float *x,
const int& incx,
const float *y,
const int& incy,
float* result)
17void dot(cublasHandle_t& handle,
const int& n,
const double *x,
const int& incx,
const double *y,
const int& incy,
double* result)
22void dot(cublasHandle_t& handle,
const int& n,
const std::complex<float> *x,
const int& incx,
const std::complex<float> *y,
const int& incy, std::complex<float>* result)
24 cublasErrcheck(cublasCdotc(handle, n,
reinterpret_cast<const cuComplex*
>(x), incx,
reinterpret_cast<const cuComplex*
>(y), incy,
reinterpret_cast<cuComplex*
>(result)));
27void dot(cublasHandle_t& handle,
const int& n,
const std::complex<double> *x,
const int& incx,
const std::complex<double> *y,
const int& incy, std::complex<double>* result)
29 cublasErrcheck(cublasZdotc(handle, n,
reinterpret_cast<const cuDoubleComplex*
>(x), incx,
reinterpret_cast<const cuDoubleComplex*
>(y), incy,
reinterpret_cast<cuDoubleComplex*
>(result)));
33void axpy(cublasHandle_t& handle,
const int& n,
const float& alpha,
const float *x,
const int& incx,
float *y,
const int& incy)
38void axpy(cublasHandle_t& handle,
const int& n,
const double& alpha,
const double *x,
const int& incx,
double *y,
const int& incy)
43void axpy(cublasHandle_t& handle,
const int& n,
const std::complex<float>& alpha,
const std::complex<float> *x,
const int& incx, std::complex<float> *y,
const int& incy)
45 cublasErrcheck(cublasCaxpy(handle, n,
reinterpret_cast<const cuComplex*
>(&alpha),
reinterpret_cast<const cuComplex*
>(x), incx,
reinterpret_cast<cuComplex*
>(y), incy));
48void axpy(cublasHandle_t& handle,
const int& n,
const std::complex<double>& alpha,
const std::complex<double> *x,
const int& incx, std::complex<double> *y,
const int& incy)
50 cublasErrcheck(cublasZaxpy(handle, n,
reinterpret_cast<const cuDoubleComplex*
>(&alpha),
reinterpret_cast<const cuDoubleComplex*
>(x), incx,
reinterpret_cast<cuDoubleComplex*
>(y), incy));
54void scal(cublasHandle_t& handle,
const int& n,
const float& alpha,
float *x,
const int& incx)
59void scal(cublasHandle_t& handle,
const int& n,
const double& alpha,
double *x,
const int& incx)
64void scal(cublasHandle_t& handle,
const int& n,
const std::complex<float>& alpha, std::complex<float> *x,
const int& incx)
66 cublasErrcheck(cublasCscal(handle, n,
reinterpret_cast<const cuComplex*
>(&alpha),
reinterpret_cast<cuComplex*
>(x), incx));
69void scal(cublasHandle_t& handle,
const int& n,
const std::complex<double>& alpha, std::complex<double> *x,
const int& incx)
71 cublasErrcheck(cublasZscal(handle, n,
reinterpret_cast<const cuDoubleComplex*
>(&alpha),
reinterpret_cast<cuDoubleComplex*
>(x), incx));
75void gemv(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
76 const float& alpha,
const float *A,
const int& lda,
const float *x,
const int& incx,
77 const float& beta,
float *y,
const int& incy)
79 cublasErrcheck(cublasSgemv(handle, GetCublasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
82void gemv(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
83 const double& alpha,
const double *A,
const int& lda,
const double *x,
const int& incx,
84 const double& beta,
double *y,
const int& incy)
86 cublasErrcheck(cublasDgemv(handle, GetCublasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
89void gemv(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
90 const std::complex<float>& alpha,
const std::complex<float> *A,
const int& lda,
const std::complex<float> *x,
const int& incx,
91 const std::complex<float>& beta, std::complex<float> *y,
const int& incy)
93 cublasErrcheck(cublasCgemv(handle, GetCublasOperation(trans), m, n,
reinterpret_cast<const cuComplex*
>(&alpha),
94 reinterpret_cast<const cuComplex*
>(A), lda,
reinterpret_cast<const cuComplex*
>(x), incx,
reinterpret_cast<const cuComplex*
>(&beta),
reinterpret_cast<cuComplex*
>(y), incy));
97void gemv(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
98 const std::complex<double>& alpha,
const std::complex<double> *A,
const int& lda,
const std::complex<double> *x,
const int& incx,
99 const std::complex<double>& beta, std::complex<double> *y,
const int& incy)
101 cublasErrcheck(cublasZgemv(handle, GetCublasOperation(trans), m, n,
reinterpret_cast<const cuDoubleComplex*
>(&alpha),
102 reinterpret_cast<const cuDoubleComplex*
>(A), lda,
reinterpret_cast<const cuDoubleComplex*
>(x), incx,
reinterpret_cast<const cuDoubleComplex*
>(&beta),
reinterpret_cast<cuDoubleComplex*
>(y), incy));
107void gemv_batched(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
108 const T& alpha,
T** A,
const int& lda,
T** x,
const int& incx,
109 const T& beta,
T** y,
const int& incy,
const int& batch_size)
111 for (
int ii = 0; ii < batch_size; ++ii) {
113 cuBlasConnector::gemv(handle, trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy);
119void gemv_batched_strided(cublasHandle_t& handle,
const char& transa,
const int& m,
const int& n,
120 const T& alpha,
const T* A,
const int& lda,
const int& stride_a,
const T* x,
const int& incx,
const int& stride_x,
121 const T& beta,
T* y,
const int& incy,
const int& stride_y,
const int& batch_size)
123 for (
int ii = 0; ii < batch_size; ii++) {
125 cuBlasConnector::gemv(handle, transa, m, n, alpha, A + ii * stride_a, lda, x + ii * stride_x, incx, beta, y + ii * stride_y, incy);
130void gemm(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
131 const float& alpha,
const float* A,
const int& lda,
const float* B,
const int& ldb,
132 const float& beta,
float* C,
const int& ldc)
134 cublasErrcheck(cublasSgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
135 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
138void gemm(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
139 const double& alpha,
const double* A,
const int& lda,
const double* B,
const int& ldb,
140 const double& beta,
double* C,
const int& ldc)
142 cublasErrcheck(cublasDgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
143 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
146void gemm(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
147 const std::complex<float>& alpha,
const std::complex<float>* A,
const int& lda,
const std::complex<float>* B,
const int& ldb,
148 const std::complex<float>& beta, std::complex<float>* C,
const int& ldc)
150 cublasErrcheck(cublasCgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
152 reinterpret_cast<const cuComplex*
>(&alpha),
153 reinterpret_cast<const cuComplex*
>(A), lda,
154 reinterpret_cast<const cuComplex*
>(B), ldb,
155 reinterpret_cast<const cuComplex*
>(&beta),
156 reinterpret_cast<cuComplex*
>(C), ldc));
159void gemm(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
160 const std::complex<double>& alpha,
const std::complex<double>* A,
const int& lda,
const std::complex<double>* B,
const int& ldb,
161 const std::complex<double>& beta, std::complex<double>* C,
const int& ldc)
163 cublasErrcheck(cublasZgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
165 reinterpret_cast<const cuDoubleComplex*
>(&alpha),
166 reinterpret_cast<const cuDoubleComplex*
>(A), lda,
167 reinterpret_cast<const cuDoubleComplex*
>(B), ldb,
168 reinterpret_cast<const cuDoubleComplex*
>(&beta),
169 reinterpret_cast<cuDoubleComplex*
>(C), ldc));
174T** allocate_(
T** in,
const int& batch_size)
177 cudaErrcheck(cudaMalloc(
reinterpret_cast<void **
>(&out),
sizeof(
T*) * batch_size));
178 cudaErrcheck(cudaMemcpy(out, in,
sizeof(
T*) * batch_size, cudaMemcpyHostToDevice));
183void gemm_batched(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
184 const float& alpha,
float** A,
const int& lda,
float** B,
const int& ldb,
185 const float& beta,
float** C,
const int& ldc,
const int& batch_size)
187 float** d_A = allocate_(A, batch_size);
188 float** d_B = allocate_(B, batch_size);
189 float** d_C = allocate_(C, batch_size);
190 cublasErrcheck(cublasSgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
191 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
197void gemm_batched(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
198 const double& alpha,
double** A,
const int& lda,
double** B,
const int& ldb,
199 const double& beta,
double** C,
const int& ldc,
const int& batch_size)
201 double** d_A = allocate_(A, batch_size);
202 double** d_B = allocate_(B, batch_size);
203 double** d_C = allocate_(C, batch_size);
204 cublasErrcheck(cublasDgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
205 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
211void gemm_batched(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
212 const std::complex<float>& alpha, std::complex<float>** A,
const int& lda, std::complex<float>** B,
const int& ldb,
213 const std::complex<float>& beta, std::complex<float>** C,
const int& ldc,
const int& batch_size)
215 std::complex<float>** d_A = allocate_(A, batch_size);
216 std::complex<float>** d_B = allocate_(B, batch_size);
217 std::complex<float>** d_C = allocate_(C, batch_size);
218 cublasErrcheck(cublasCgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
220 reinterpret_cast<const cuComplex*
>(&alpha),
221 reinterpret_cast<cuComplex**
>(d_A), lda,
222 reinterpret_cast<cuComplex**
>(d_B), ldb,
223 reinterpret_cast<const cuComplex*
>(&beta),
224 reinterpret_cast<cuComplex**
>(d_C), ldc, batch_size));
230void gemm_batched(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
231 const std::complex<double>& alpha, std::complex<double>** A,
const int& lda, std::complex<double>** B,
const int& ldb,
232 const std::complex<double>& beta, std::complex<double>** C,
const int& ldc,
const int& batch_size)
234 std::complex<double>** d_A = allocate_(A, batch_size);
235 std::complex<double>** d_B = allocate_(B, batch_size);
236 std::complex<double>** d_C = allocate_(C, batch_size);
237 cublasErrcheck(cublasZgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
239 reinterpret_cast<const cuDoubleComplex*
>(&alpha),
240 reinterpret_cast<cuDoubleComplex**
>(d_A), lda,
241 reinterpret_cast<cuDoubleComplex**
>(d_B), ldb,
242 reinterpret_cast<const cuDoubleComplex*
>(&beta),
243 reinterpret_cast<cuDoubleComplex**
>(d_C), ldc, batch_size));
250void gemm_batched_strided(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
251 const float& alpha,
const float* A,
const int& lda,
const int& stride_a,
const float* B,
const int& ldb,
const int& stride_b,
252 const float& beta,
float* C,
const int& ldc,
const int& stride_c,
const int& batch_size)
256 GetCublasOperation(transa),
257 GetCublasOperation(transb),
267void gemm_batched_strided(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
268 const double& alpha,
const double* A,
const int& lda,
const int& stride_a,
const double* B,
const int& ldb,
const int& stride_b,
269 const double& beta,
double* C,
const int& ldc,
const int& stride_c,
const int& batch_size)
273 GetCublasOperation(transa),
274 GetCublasOperation(transb),
284void gemm_batched_strided(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
285 const std::complex<float>& alpha,
const std::complex<float>* A,
const int& lda,
const int& stride_a,
const std::complex<float>* B,
const int& ldb,
const int& stride_b,
286 const std::complex<float>& beta, std::complex<float>* C,
const int& ldc,
const int& stride_c,
const int& batch_size)
290 GetCublasOperation(transa),
291 GetCublasOperation(transb),
293 reinterpret_cast<const cuComplex*
>(&alpha),
294 reinterpret_cast<const cuComplex*
>(A), lda, stride_a,
295 reinterpret_cast<const cuComplex*
>(B), ldb, stride_b,
296 reinterpret_cast<const cuComplex*
>(&beta),
297 reinterpret_cast<cuComplex*
>(C), ldc, stride_c,
301void gemm_batched_strided(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
302 const std::complex<double>& alpha,
const std::complex<double>* A,
const int& lda,
const int& stride_a,
const std::complex<double>* B,
const int& ldb,
const int& stride_b,
303 const std::complex<double>& beta, std::complex<double>* C,
const int& ldc,
const int& stride_c,
const int& batch_size)
307 GetCublasOperation(transa),
308 GetCublasOperation(transb),
310 reinterpret_cast<const cuDoubleComplex*
>(&alpha),
311 reinterpret_cast<const cuDoubleComplex*
>(A), lda, stride_a,
312 reinterpret_cast<const cuDoubleComplex*
>(B), ldb, stride_b,
313 reinterpret_cast<const cuDoubleComplex*
>(&beta),
314 reinterpret_cast<cuDoubleComplex*
>(C), ldc, stride_c,