ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
cublas.h
Go to the documentation of this file.
1#ifndef BASE_THIRD_PARTY_CUBLAS_H_
2#define BASE_THIRD_PARTY_CUBLAS_H_
3
4#include <cuda_runtime.h>
5#include <cublas_v2.h>
6#include <base/macros/cuda.h>
7
8namespace container {
9namespace cuBlasConnector {
10
11static inline
12void copy(cublasHandle_t& handle, const int& n, const float *x, const int& incx, float *y, const int& incy)
13{
14 cublasErrcheck(cublasScopy(handle, n, x, incx, y, incy));
15}
16static inline
17void copy(cublasHandle_t& handle, const int& n, const double *x, const int& incx, double *y, const int& incy)
18{
19 cublasErrcheck(cublasDcopy(handle, n, x, incx, y, incy));
20}
21static inline
22void copy(cublasHandle_t& handle, const int& n, const std::complex<float> *x, const int& incx, std::complex<float> *y, const int& incy)
23{
24 cublasErrcheck(cublasCcopy(handle, n, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<cuComplex*>(y), incy));
25}
26static inline
27void copy(cublasHandle_t& handle, const int& n, const std::complex<double> *x, const int& incx, std::complex<double> *y, const int& incy)
28{
29 cublasErrcheck(cublasZcopy(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<cuDoubleComplex*>(y), incy));
30}
31
32static inline
33void nrm2(cublasHandle_t& handle, const int& n, const float *x, const int& incx, float* result)
34{
35 cublasErrcheck(cublasSnrm2(handle, n, x, incx, result));
36}
37static inline
38void nrm2(cublasHandle_t& handle, const int& n, const double *x, const int& incx, double* result)
39{
40 cublasErrcheck(cublasDnrm2(handle, n, x, incx, result));
41}
42static inline
43void nrm2(cublasHandle_t& handle, const int& n, const std::complex<float> *x, const int& incx, float* result)
44{
45 cublasErrcheck(cublasScnrm2(handle, n, reinterpret_cast<const cuComplex*>(x), incx, result));
46}
47static inline
48void nrm2(cublasHandle_t& handle, const int& n, const std::complex<double> *x, const int& incx, double* result)
49{
50 cublasErrcheck(cublasDznrm2(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), incx, result));
51}
52
53static inline
54void dot(cublasHandle_t& handle, const int& n, const float *x, const int& incx, const float *y, const int& incy, float* result)
55{
56 cublasErrcheck(cublasSdot(handle, n, x, incx, y, incy, result));
57}
58static inline
59void dot(cublasHandle_t& handle, const int& n, const double *x, const int& incx, const double *y, const int& incy, double* result)
60{
61 cublasErrcheck(cublasDdot(handle, n, x, incx, y, incy, result));
62}
63static inline
64void dot(cublasHandle_t& handle, const int& n, const std::complex<float> *x, const int& incx, const std::complex<float> *y, const int& incy, std::complex<float>* result)
65{
66 cublasErrcheck(cublasCdotc(handle, n, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(y), incy, reinterpret_cast<cuComplex*>(result)));
67}
68static inline
69void dot(cublasHandle_t& handle, const int& n, const std::complex<double> *x, const int& incx, const std::complex<double> *y, const int& incy, std::complex<double>* result)
70{
71 cublasErrcheck(cublasZdotc(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(y), incy, reinterpret_cast<cuDoubleComplex*>(result)));
72}
73
74static inline
75void axpy(cublasHandle_t& handle, const int& n, const float& alpha, const float *x, const int& incx, float *y, const int& incy)
76{
77 cublasErrcheck(cublasSaxpy(handle, n, &alpha, x, incx, y, incy));
78}
79static inline
80void axpy(cublasHandle_t& handle, const int& n, const double& alpha, const double *x, const int& incx, double *y, const int& incy)
81{
82 cublasErrcheck(cublasDaxpy(handle, n, &alpha, x, incx, y, incy));
83}
84static inline
85void axpy(cublasHandle_t& handle, const int& n, const std::complex<float>& alpha, const std::complex<float> *x, const int& incx, std::complex<float> *y, const int& incy)
86{
87 cublasErrcheck(cublasCaxpy(handle, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<cuComplex*>(y), incy));
88}
89static inline
90void axpy(cublasHandle_t& handle, const int& n, const std::complex<double>& alpha, const std::complex<double> *x, const int& incx, std::complex<double> *y, const int& incy)
91{
92 cublasErrcheck(cublasZaxpy(handle, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<cuDoubleComplex*>(y), incy));
93}
94
95static inline
96void scal(cublasHandle_t& handle, const int& n, const float& alpha, float *x, const int& incx)
97{
98 cublasErrcheck(cublasSscal(handle, n, &alpha, x, incx));
99}
100static inline
101void scal(cublasHandle_t& handle, const int& n, const double& alpha, double *x, const int& incx)
102{
103 cublasErrcheck(cublasDscal(handle, n, &alpha, x, incx));
104}
105static inline
106void scal(cublasHandle_t& handle, const int& n, const std::complex<float>& alpha, std::complex<float> *x, const int& incx)
107{
108 cublasErrcheck(cublasCscal(handle, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<cuComplex*>(x), incx));
109}
110static inline
111void scal(cublasHandle_t& handle, const int& n, const std::complex<double>& alpha, std::complex<double> *x, const int& incx)
112{
113 cublasErrcheck(cublasZscal(handle, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<cuDoubleComplex*>(x), incx));
114}
115
116static inline
117void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
118 const float& alpha, const float *A, const int& lda, const float *x, const int& incx,
119 const float& beta, float *y, const int& incy)
120{
121 cublasErrcheck(cublasSgemv(handle, GetCublasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
122}
123static inline
124void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
125 const double& alpha, const double *A, const int& lda, const double *x, const int& incx,
126 const double& beta, double *y, const int& incy)
127{
128 cublasErrcheck(cublasDgemv(handle, GetCublasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
129}
130static inline
131void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
132 const std::complex<float>& alpha, const std::complex<float> *A, const int& lda, const std::complex<float> *x, const int& incx,
133 const std::complex<float>& beta, std::complex<float> *y, const int& incy)
134{
135 cublasErrcheck(cublasCgemv(handle, GetCublasOperation(trans), m, n, reinterpret_cast<const cuComplex*>(&alpha),
136 reinterpret_cast<const cuComplex*>(A), lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta), reinterpret_cast<cuComplex*>(y), incy));
137}
138static inline
139void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
140 const std::complex<double>& alpha, const std::complex<double> *A, const int& lda, const std::complex<double> *x, const int& incx,
141 const std::complex<double>& beta, std::complex<double> *y, const int& incy)
142{
143 cublasErrcheck(cublasZgemv(handle, GetCublasOperation(trans), m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha),
144 reinterpret_cast<const cuDoubleComplex*>(A), lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta), reinterpret_cast<cuDoubleComplex*>(y), incy));
145}
146
147template <typename T>
148static inline
149void gemv_batched(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
150 const T& alpha, T** A, const int& lda, T** x, const int& incx,
151 const T& beta, T** y, const int& incy, const int& batch_size)
152{
153 for (int ii = 0; ii < batch_size; ++ii) {
154 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
155 cuBlasConnector::gemv(handle, trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy);
156 }
157}
158
159template <typename T>
160static inline
161void gemv_batched_strided(cublasHandle_t& handle, const char& transa, const int& m, const int& n,
162 const T& alpha, const T* A, const int& lda, const int& stride_a, const T* x, const int& incx, const int& stride_x,
163 const T& beta, T* y, const int& incy, const int& stride_y, const int& batch_size)
164{
165 for (int ii = 0; ii < batch_size; ii++) {
166 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
167 cuBlasConnector::gemv(handle, transa, m, n, alpha, A + ii * stride_a, lda, x + ii * stride_x, incx, beta, y + ii * stride_y, incy);
168 }
169}
170
171static inline
172void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
173 const float& alpha, const float* A, const int& lda, const float* B, const int& ldb,
174 const float& beta, float* C, const int& ldc)
175{
176 cublasErrcheck(cublasSgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
177 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
178}
179static inline
180void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
181 const double& alpha, const double* A, const int& lda, const double* B, const int& ldb,
182 const double& beta, double* C, const int& ldc)
183{
184 cublasErrcheck(cublasDgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
185 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
186}
187static inline
188void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
189 const std::complex<float>& alpha, const std::complex<float>* A, const int& lda, const std::complex<float>* B, const int& ldb,
190 const std::complex<float>& beta, std::complex<float>* C, const int& ldc)
191{
192 cublasErrcheck(cublasCgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
193 m, n, k,
194 reinterpret_cast<const cuComplex*>(&alpha),
195 reinterpret_cast<const cuComplex*>(A), lda,
196 reinterpret_cast<const cuComplex*>(B), ldb,
197 reinterpret_cast<const cuComplex*>(&beta),
198 reinterpret_cast<cuComplex*>(C), ldc));
199}
200static inline
201void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
202 const std::complex<double>& alpha, const std::complex<double>* A, const int& lda, const std::complex<double>* B, const int& ldb,
203 const std::complex<double>& beta, std::complex<double>* C, const int& ldc)
204{
205 cublasErrcheck(cublasZgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
206 m, n, k,
207 reinterpret_cast<const cuDoubleComplex*>(&alpha),
208 reinterpret_cast<const cuDoubleComplex*>(A), lda,
209 reinterpret_cast<const cuDoubleComplex*>(B), ldb,
210 reinterpret_cast<const cuDoubleComplex*>(&beta),
211 reinterpret_cast<cuDoubleComplex*>(C), ldc));
212}
213
214template <typename T>
215static inline
216T** allocate_(T** in, const int& batch_size)
217{
218 T** out = nullptr;
219 cudaErrcheck(cudaMalloc(reinterpret_cast<void **>(&out), sizeof(T*) * batch_size));
220 cudaErrcheck(cudaMemcpy(out, in, sizeof(T*) * batch_size, cudaMemcpyHostToDevice));
221 return out;
222}
223
224static inline
225void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
226 const float& alpha, float** A, const int& lda, float** B, const int& ldb,
227 const float& beta, float** C, const int& ldc, const int& batch_size)
228{
229 float** d_A = allocate_(A, batch_size);
230 float** d_B = allocate_(B, batch_size);
231 float** d_C = allocate_(C, batch_size);
232 cublasErrcheck(cublasSgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
233 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
234 cudaErrcheck(cudaFree(d_A));
235 cudaErrcheck(cudaFree(d_B));
236 cudaErrcheck(cudaFree(d_C));
237}
238static inline
239void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
240 const double& alpha, double** A, const int& lda, double** B, const int& ldb,
241 const double& beta, double** C, const int& ldc, const int& batch_size)
242{
243 double** d_A = allocate_(A, batch_size);
244 double** d_B = allocate_(B, batch_size);
245 double** d_C = allocate_(C, batch_size);
246 cublasErrcheck(cublasDgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
247 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
248 cudaErrcheck(cudaFree(d_A));
249 cudaErrcheck(cudaFree(d_B));
250 cudaErrcheck(cudaFree(d_C));
251}
252static inline
253void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
254 const std::complex<float>& alpha, std::complex<float>** A, const int& lda, std::complex<float>** B, const int& ldb,
255 const std::complex<float>& beta, std::complex<float>** C, const int& ldc, const int& batch_size)
256{
257 std::complex<float>** d_A = allocate_(A, batch_size);
258 std::complex<float>** d_B = allocate_(B, batch_size);
259 std::complex<float>** d_C = allocate_(C, batch_size);
260 cublasErrcheck(cublasCgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
261 m, n, k,
262 reinterpret_cast<const cuComplex*>(&alpha),
263 reinterpret_cast<cuComplex**>(d_A), lda,
264 reinterpret_cast<cuComplex**>(d_B), ldb,
265 reinterpret_cast<const cuComplex*>(&beta),
266 reinterpret_cast<cuComplex**>(d_C), ldc, batch_size));
267 cudaErrcheck(cudaFree(d_A));
268 cudaErrcheck(cudaFree(d_B));
269 cudaErrcheck(cudaFree(d_C));
270}
271static inline
272void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
273 const std::complex<double>& alpha, std::complex<double>** A, const int& lda, std::complex<double>** B, const int& ldb,
274 const std::complex<double>& beta, std::complex<double>** C, const int& ldc, const int& batch_size)
275{
276 std::complex<double>** d_A = allocate_(A, batch_size);
277 std::complex<double>** d_B = allocate_(B, batch_size);
278 std::complex<double>** d_C = allocate_(C, batch_size);
279 cublasErrcheck(cublasZgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
280 m, n, k,
281 reinterpret_cast<const cuDoubleComplex*>(&alpha),
282 reinterpret_cast<cuDoubleComplex**>(d_A), lda,
283 reinterpret_cast<cuDoubleComplex**>(d_B), ldb,
284 reinterpret_cast<const cuDoubleComplex*>(&beta),
285 reinterpret_cast<cuDoubleComplex**>(d_C), ldc, batch_size));
286 cudaErrcheck(cudaFree(d_A));
287 cudaErrcheck(cudaFree(d_B));
288 cudaErrcheck(cudaFree(d_C));
289}
290
291static inline
292void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
293 const float& alpha, const float* A, const int& lda, const int& stride_a, const float* B, const int& ldb, const int& stride_b,
294 const float& beta, float* C, const int& ldc, const int& stride_c, const int& batch_size)
295{
296 cublasErrcheck(cublasSgemmStridedBatched(
297 handle,
298 GetCublasOperation(transa),
299 GetCublasOperation(transb),
300 m, n, k,
301 &alpha,
302 A, lda, stride_a,
303 B, ldb, stride_b,
304 &beta,
305 C, ldc, stride_c,
306 batch_size));
307}
308static inline
309void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
310 const double& alpha, const double* A, const int& lda, const int& stride_a, const double* B, const int& ldb, const int& stride_b,
311 const double& beta, double* C, const int& ldc, const int& stride_c, const int& batch_size)
312{
313 cublasErrcheck(cublasDgemmStridedBatched(
314 handle,
315 GetCublasOperation(transa),
316 GetCublasOperation(transb),
317 m, n, k,
318 &alpha,
319 A, lda, stride_a,
320 B, ldb, stride_b,
321 &beta,
322 C, ldc, stride_c,
323 batch_size));
324}
325static inline
326void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
327 const std::complex<float>& alpha, const std::complex<float>* A, const int& lda, const int& stride_a, const std::complex<float>* B, const int& ldb, const int& stride_b,
328 const std::complex<float>& beta, std::complex<float>* C, const int& ldc, const int& stride_c, const int& batch_size)
329{
330 cublasErrcheck(cublasCgemmStridedBatched(
331 handle,
332 GetCublasOperation(transa),
333 GetCublasOperation(transb),
334 m, n, k,
335 reinterpret_cast<const cuComplex*>(&alpha),
336 reinterpret_cast<const cuComplex*>(A), lda, stride_a,
337 reinterpret_cast<const cuComplex*>(B), ldb, stride_b,
338 reinterpret_cast<const cuComplex*>(&beta),
339 reinterpret_cast<cuComplex*>(C), ldc, stride_c,
340 batch_size));
341}
342static inline
343void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
344 const std::complex<double>& alpha, const std::complex<double>* A, const int& lda, const int& stride_a, const std::complex<double>* B, const int& ldb, const int& stride_b,
345 const std::complex<double>& beta, std::complex<double>* C, const int& ldc, const int& stride_c, const int& batch_size)
346{
347 cublasErrcheck(cublasZgemmStridedBatched(
348 handle,
349 GetCublasOperation(transa),
350 GetCublasOperation(transb),
351 m, n, k,
352 reinterpret_cast<const cuDoubleComplex*>(&alpha),
353 reinterpret_cast<const cuDoubleComplex*>(A), lda, stride_a,
354 reinterpret_cast<const cuDoubleComplex*>(B), ldb, stride_b,
355 reinterpret_cast<const cuDoubleComplex*>(&beta),
356 reinterpret_cast<cuDoubleComplex*>(C), ldc, stride_c,
357 batch_size));
358}
359
360} // namespace cuBlasConnector
361} // namespace container
362
363#endif // BASE_THIRD_PARTY_CUBLAS_H_
#define cublasErrcheck(res)
Definition cuda.h:230
#define cudaErrcheck(res)
Definition cuda.h:236
#define T
Definition exp.cpp:237
Definition tensor.cpp:8