9namespace cuBlasConnector {
12void copy(cublasHandle_t& handle,
const int& n,
const float *x,
const int& incx,
float *y,
const int& incy)
17void copy(cublasHandle_t& handle,
const int& n,
const double *x,
const int& incx,
double *y,
const int& incy)
22void copy(cublasHandle_t& handle,
const int& n,
const std::complex<float> *x,
const int& incx, std::complex<float> *y,
const int& incy)
24 cublasErrcheck(cublasCcopy(handle, n,
reinterpret_cast<const cuComplex*
>(x), incx,
reinterpret_cast<cuComplex*
>(y), incy));
27void copy(cublasHandle_t& handle,
const int& n,
const std::complex<double> *x,
const int& incx, std::complex<double> *y,
const int& incy)
29 cublasErrcheck(cublasZcopy(handle, n,
reinterpret_cast<const cuDoubleComplex*
>(x), incx,
reinterpret_cast<cuDoubleComplex*
>(y), incy));
33void nrm2(cublasHandle_t& handle,
const int& n,
const float *x,
const int& incx,
float* result)
38void nrm2(cublasHandle_t& handle,
const int& n,
const double *x,
const int& incx,
double* result)
43void nrm2(cublasHandle_t& handle,
const int& n,
const std::complex<float> *x,
const int& incx,
float* result)
45 cublasErrcheck(cublasScnrm2(handle, n,
reinterpret_cast<const cuComplex*
>(x), incx, result));
48void nrm2(cublasHandle_t& handle,
const int& n,
const std::complex<double> *x,
const int& incx,
double* result)
50 cublasErrcheck(cublasDznrm2(handle, n,
reinterpret_cast<const cuDoubleComplex*
>(x), incx, result));
54void dot(cublasHandle_t& handle,
const int& n,
const float *x,
const int& incx,
const float *y,
const int& incy,
float* result)
59void dot(cublasHandle_t& handle,
const int& n,
const double *x,
const int& incx,
const double *y,
const int& incy,
double* result)
64void dot(cublasHandle_t& handle,
const int& n,
const std::complex<float> *x,
const int& incx,
const std::complex<float> *y,
const int& incy, std::complex<float>* result)
66 cublasErrcheck(cublasCdotc(handle, n,
reinterpret_cast<const cuComplex*
>(x), incx,
reinterpret_cast<const cuComplex*
>(y), incy,
reinterpret_cast<cuComplex*
>(result)));
69void dot(cublasHandle_t& handle,
const int& n,
const std::complex<double> *x,
const int& incx,
const std::complex<double> *y,
const int& incy, std::complex<double>* result)
71 cublasErrcheck(cublasZdotc(handle, n,
reinterpret_cast<const cuDoubleComplex*
>(x), incx,
reinterpret_cast<const cuDoubleComplex*
>(y), incy,
reinterpret_cast<cuDoubleComplex*
>(result)));
75void axpy(cublasHandle_t& handle,
const int& n,
const float& alpha,
const float *x,
const int& incx,
float *y,
const int& incy)
80void axpy(cublasHandle_t& handle,
const int& n,
const double& alpha,
const double *x,
const int& incx,
double *y,
const int& incy)
85void axpy(cublasHandle_t& handle,
const int& n,
const std::complex<float>& alpha,
const std::complex<float> *x,
const int& incx, std::complex<float> *y,
const int& incy)
87 cublasErrcheck(cublasCaxpy(handle, n,
reinterpret_cast<const cuComplex*
>(&alpha),
reinterpret_cast<const cuComplex*
>(x), incx,
reinterpret_cast<cuComplex*
>(y), incy));
90void axpy(cublasHandle_t& handle,
const int& n,
const std::complex<double>& alpha,
const std::complex<double> *x,
const int& incx, std::complex<double> *y,
const int& incy)
92 cublasErrcheck(cublasZaxpy(handle, n,
reinterpret_cast<const cuDoubleComplex*
>(&alpha),
reinterpret_cast<const cuDoubleComplex*
>(x), incx,
reinterpret_cast<cuDoubleComplex*
>(y), incy));
96void scal(cublasHandle_t& handle,
const int& n,
const float& alpha,
float *x,
const int& incx)
101void scal(cublasHandle_t& handle,
const int& n,
const double& alpha,
double *x,
const int& incx)
106void scal(cublasHandle_t& handle,
const int& n,
const std::complex<float>& alpha, std::complex<float> *x,
const int& incx)
108 cublasErrcheck(cublasCscal(handle, n,
reinterpret_cast<const cuComplex*
>(&alpha),
reinterpret_cast<cuComplex*
>(x), incx));
111void scal(cublasHandle_t& handle,
const int& n,
const std::complex<double>& alpha, std::complex<double> *x,
const int& incx)
113 cublasErrcheck(cublasZscal(handle, n,
reinterpret_cast<const cuDoubleComplex*
>(&alpha),
reinterpret_cast<cuDoubleComplex*
>(x), incx));
117void gemv(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
118 const float& alpha,
const float *A,
const int& lda,
const float *x,
const int& incx,
119 const float& beta,
float *y,
const int& incy)
121 cublasErrcheck(cublasSgemv(handle, GetCublasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
124void gemv(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
125 const double& alpha,
const double *A,
const int& lda,
const double *x,
const int& incx,
126 const double& beta,
double *y,
const int& incy)
128 cublasErrcheck(cublasDgemv(handle, GetCublasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
131void gemv(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
132 const std::complex<float>& alpha,
const std::complex<float> *A,
const int& lda,
const std::complex<float> *x,
const int& incx,
133 const std::complex<float>& beta, std::complex<float> *y,
const int& incy)
135 cublasErrcheck(cublasCgemv(handle, GetCublasOperation(trans), m, n,
reinterpret_cast<const cuComplex*
>(&alpha),
136 reinterpret_cast<const cuComplex*
>(A), lda,
reinterpret_cast<const cuComplex*
>(x), incx,
reinterpret_cast<const cuComplex*
>(&beta),
reinterpret_cast<cuComplex*
>(y), incy));
139void gemv(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
140 const std::complex<double>& alpha,
const std::complex<double> *A,
const int& lda,
const std::complex<double> *x,
const int& incx,
141 const std::complex<double>& beta, std::complex<double> *y,
const int& incy)
143 cublasErrcheck(cublasZgemv(handle, GetCublasOperation(trans), m, n,
reinterpret_cast<const cuDoubleComplex*
>(&alpha),
144 reinterpret_cast<const cuDoubleComplex*
>(A), lda,
reinterpret_cast<const cuDoubleComplex*
>(x), incx,
reinterpret_cast<const cuDoubleComplex*
>(&beta),
reinterpret_cast<cuDoubleComplex*
>(y), incy));
149void gemv_batched(cublasHandle_t& handle,
const char& trans,
const int& m,
const int& n,
150 const T& alpha,
T** A,
const int& lda,
T** x,
const int& incx,
151 const T& beta,
T** y,
const int& incy,
const int& batch_size)
153 for (
int ii = 0; ii < batch_size; ++ii) {
155 cuBlasConnector::gemv(handle, trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy);
161void gemv_batched_strided(cublasHandle_t& handle,
const char& transa,
const int& m,
const int& n,
162 const T& alpha,
const T* A,
const int& lda,
const int& stride_a,
const T* x,
const int& incx,
const int& stride_x,
163 const T& beta,
T* y,
const int& incy,
const int& stride_y,
const int& batch_size)
165 for (
int ii = 0; ii < batch_size; ii++) {
167 cuBlasConnector::gemv(handle, transa, m, n, alpha, A + ii * stride_a, lda, x + ii * stride_x, incx, beta, y + ii * stride_y, incy);
172void gemm(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
173 const float& alpha,
const float* A,
const int& lda,
const float* B,
const int& ldb,
174 const float& beta,
float* C,
const int& ldc)
176 cublasErrcheck(cublasSgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
177 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
180void gemm(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
181 const double& alpha,
const double* A,
const int& lda,
const double* B,
const int& ldb,
182 const double& beta,
double* C,
const int& ldc)
184 cublasErrcheck(cublasDgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
185 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
188void gemm(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
189 const std::complex<float>& alpha,
const std::complex<float>* A,
const int& lda,
const std::complex<float>* B,
const int& ldb,
190 const std::complex<float>& beta, std::complex<float>* C,
const int& ldc)
192 cublasErrcheck(cublasCgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
194 reinterpret_cast<const cuComplex*
>(&alpha),
195 reinterpret_cast<const cuComplex*
>(A), lda,
196 reinterpret_cast<const cuComplex*
>(B), ldb,
197 reinterpret_cast<const cuComplex*
>(&beta),
198 reinterpret_cast<cuComplex*
>(C), ldc));
201void gemm(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
202 const std::complex<double>& alpha,
const std::complex<double>* A,
const int& lda,
const std::complex<double>* B,
const int& ldb,
203 const std::complex<double>& beta, std::complex<double>* C,
const int& ldc)
205 cublasErrcheck(cublasZgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
207 reinterpret_cast<const cuDoubleComplex*
>(&alpha),
208 reinterpret_cast<const cuDoubleComplex*
>(A), lda,
209 reinterpret_cast<const cuDoubleComplex*
>(B), ldb,
210 reinterpret_cast<const cuDoubleComplex*
>(&beta),
211 reinterpret_cast<cuDoubleComplex*
>(C), ldc));
216T** allocate_(
T** in,
const int& batch_size)
219 cudaErrcheck(cudaMalloc(
reinterpret_cast<void **
>(&out),
sizeof(
T*) * batch_size));
220 cudaErrcheck(cudaMemcpy(out, in,
sizeof(
T*) * batch_size, cudaMemcpyHostToDevice));
225void gemm_batched(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
226 const float& alpha,
float** A,
const int& lda,
float** B,
const int& ldb,
227 const float& beta,
float** C,
const int& ldc,
const int& batch_size)
229 float** d_A = allocate_(A, batch_size);
230 float** d_B = allocate_(B, batch_size);
231 float** d_C = allocate_(C, batch_size);
232 cublasErrcheck(cublasSgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
233 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
239void gemm_batched(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
240 const double& alpha,
double** A,
const int& lda,
double** B,
const int& ldb,
241 const double& beta,
double** C,
const int& ldc,
const int& batch_size)
243 double** d_A = allocate_(A, batch_size);
244 double** d_B = allocate_(B, batch_size);
245 double** d_C = allocate_(C, batch_size);
246 cublasErrcheck(cublasDgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
247 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
253void gemm_batched(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
254 const std::complex<float>& alpha, std::complex<float>** A,
const int& lda, std::complex<float>** B,
const int& ldb,
255 const std::complex<float>& beta, std::complex<float>** C,
const int& ldc,
const int& batch_size)
257 std::complex<float>** d_A = allocate_(A, batch_size);
258 std::complex<float>** d_B = allocate_(B, batch_size);
259 std::complex<float>** d_C = allocate_(C, batch_size);
260 cublasErrcheck(cublasCgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
262 reinterpret_cast<const cuComplex*
>(&alpha),
263 reinterpret_cast<cuComplex**
>(d_A), lda,
264 reinterpret_cast<cuComplex**
>(d_B), ldb,
265 reinterpret_cast<const cuComplex*
>(&beta),
266 reinterpret_cast<cuComplex**
>(d_C), ldc, batch_size));
272void gemm_batched(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
273 const std::complex<double>& alpha, std::complex<double>** A,
const int& lda, std::complex<double>** B,
const int& ldb,
274 const std::complex<double>& beta, std::complex<double>** C,
const int& ldc,
const int& batch_size)
276 std::complex<double>** d_A = allocate_(A, batch_size);
277 std::complex<double>** d_B = allocate_(B, batch_size);
278 std::complex<double>** d_C = allocate_(C, batch_size);
279 cublasErrcheck(cublasZgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
281 reinterpret_cast<const cuDoubleComplex*
>(&alpha),
282 reinterpret_cast<cuDoubleComplex**
>(d_A), lda,
283 reinterpret_cast<cuDoubleComplex**
>(d_B), ldb,
284 reinterpret_cast<const cuDoubleComplex*
>(&beta),
285 reinterpret_cast<cuDoubleComplex**
>(d_C), ldc, batch_size));
292void gemm_batched_strided(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
293 const float& alpha,
const float* A,
const int& lda,
const int& stride_a,
const float* B,
const int& ldb,
const int& stride_b,
294 const float& beta,
float* C,
const int& ldc,
const int& stride_c,
const int& batch_size)
298 GetCublasOperation(transa),
299 GetCublasOperation(transb),
309void gemm_batched_strided(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
310 const double& alpha,
const double* A,
const int& lda,
const int& stride_a,
const double* B,
const int& ldb,
const int& stride_b,
311 const double& beta,
double* C,
const int& ldc,
const int& stride_c,
const int& batch_size)
315 GetCublasOperation(transa),
316 GetCublasOperation(transb),
326void gemm_batched_strided(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
327 const std::complex<float>& alpha,
const std::complex<float>* A,
const int& lda,
const int& stride_a,
const std::complex<float>* B,
const int& ldb,
const int& stride_b,
328 const std::complex<float>& beta, std::complex<float>* C,
const int& ldc,
const int& stride_c,
const int& batch_size)
332 GetCublasOperation(transa),
333 GetCublasOperation(transb),
335 reinterpret_cast<const cuComplex*
>(&alpha),
336 reinterpret_cast<const cuComplex*
>(A), lda, stride_a,
337 reinterpret_cast<const cuComplex*
>(B), ldb, stride_b,
338 reinterpret_cast<const cuComplex*
>(&beta),
339 reinterpret_cast<cuComplex*
>(C), ldc, stride_c,
343void gemm_batched_strided(cublasHandle_t& handle,
const char& transa,
const char& transb,
const int& m,
const int& n,
const int& k,
344 const std::complex<double>& alpha,
const std::complex<double>* A,
const int& lda,
const int& stride_a,
const std::complex<double>* B,
const int& ldb,
const int& stride_b,
345 const std::complex<double>& beta, std::complex<double>* C,
const int& ldc,
const int& stride_c,
const int& batch_size)
349 GetCublasOperation(transa),
350 GetCublasOperation(transb),
352 reinterpret_cast<const cuDoubleComplex*
>(&alpha),
353 reinterpret_cast<const cuDoubleComplex*
>(A), lda, stride_a,
354 reinterpret_cast<const cuDoubleComplex*
>(B), ldb, stride_b,
355 reinterpret_cast<const cuDoubleComplex*
>(&beta),
356 reinterpret_cast<cuDoubleComplex*
>(C), ldc, stride_c,