ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
cublas.h
Go to the documentation of this file.
1#ifndef BASE_THIRD_PARTY_CUBLAS_H_
2#define BASE_THIRD_PARTY_CUBLAS_H_
3
4#include <cuda_runtime.h>
5#include <cublas_v2.h>
6#include <base/macros/cuda.h>
7
8namespace container {
9namespace cuBlasConnector {
10
11static inline
12void dot(cublasHandle_t& handle, const int& n, const float *x, const int& incx, const float *y, const int& incy, float* result)
13{
14 cublasErrcheck(cublasSdot(handle, n, x, incx, y, incy, result));
15}
16static inline
17void dot(cublasHandle_t& handle, const int& n, const double *x, const int& incx, const double *y, const int& incy, double* result)
18{
19 cublasErrcheck(cublasDdot(handle, n, x, incx, y, incy, result));
20}
21static inline
22void dot(cublasHandle_t& handle, const int& n, const std::complex<float> *x, const int& incx, const std::complex<float> *y, const int& incy, std::complex<float>* result)
23{
24 cublasErrcheck(cublasCdotc(handle, n, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(y), incy, reinterpret_cast<cuComplex*>(result)));
25}
26static inline
27void dot(cublasHandle_t& handle, const int& n, const std::complex<double> *x, const int& incx, const std::complex<double> *y, const int& incy, std::complex<double>* result)
28{
29 cublasErrcheck(cublasZdotc(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(y), incy, reinterpret_cast<cuDoubleComplex*>(result)));
30}
31
32static inline
33void axpy(cublasHandle_t& handle, const int& n, const float& alpha, const float *x, const int& incx, float *y, const int& incy)
34{
35 cublasErrcheck(cublasSaxpy(handle, n, &alpha, x, incx, y, incy));
36}
37static inline
38void axpy(cublasHandle_t& handle, const int& n, const double& alpha, const double *x, const int& incx, double *y, const int& incy)
39{
40 cublasErrcheck(cublasDaxpy(handle, n, &alpha, x, incx, y, incy));
41}
42static inline
43void axpy(cublasHandle_t& handle, const int& n, const std::complex<float>& alpha, const std::complex<float> *x, const int& incx, std::complex<float> *y, const int& incy)
44{
45 cublasErrcheck(cublasCaxpy(handle, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<cuComplex*>(y), incy));
46}
47static inline
48void axpy(cublasHandle_t& handle, const int& n, const std::complex<double>& alpha, const std::complex<double> *x, const int& incx, std::complex<double> *y, const int& incy)
49{
50 cublasErrcheck(cublasZaxpy(handle, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<cuDoubleComplex*>(y), incy));
51}
52
53static inline
54void scal(cublasHandle_t& handle, const int& n, const float& alpha, float *x, const int& incx)
55{
56 cublasErrcheck(cublasSscal(handle, n, &alpha, x, incx));
57}
58static inline
59void scal(cublasHandle_t& handle, const int& n, const double& alpha, double *x, const int& incx)
60{
61 cublasErrcheck(cublasDscal(handle, n, &alpha, x, incx));
62}
63static inline
64void scal(cublasHandle_t& handle, const int& n, const std::complex<float>& alpha, std::complex<float> *x, const int& incx)
65{
66 cublasErrcheck(cublasCscal(handle, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<cuComplex*>(x), incx));
67}
68static inline
69void scal(cublasHandle_t& handle, const int& n, const std::complex<double>& alpha, std::complex<double> *x, const int& incx)
70{
71 cublasErrcheck(cublasZscal(handle, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<cuDoubleComplex*>(x), incx));
72}
73
74static inline
75void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
76 const float& alpha, const float *A, const int& lda, const float *x, const int& incx,
77 const float& beta, float *y, const int& incy)
78{
79 cublasErrcheck(cublasSgemv(handle, GetCublasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
80}
81static inline
82void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
83 const double& alpha, const double *A, const int& lda, const double *x, const int& incx,
84 const double& beta, double *y, const int& incy)
85{
86 cublasErrcheck(cublasDgemv(handle, GetCublasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
87}
88static inline
89void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
90 const std::complex<float>& alpha, const std::complex<float> *A, const int& lda, const std::complex<float> *x, const int& incx,
91 const std::complex<float>& beta, std::complex<float> *y, const int& incy)
92{
93 cublasErrcheck(cublasCgemv(handle, GetCublasOperation(trans), m, n, reinterpret_cast<const cuComplex*>(&alpha),
94 reinterpret_cast<const cuComplex*>(A), lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta), reinterpret_cast<cuComplex*>(y), incy));
95}
96static inline
97void gemv(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
98 const std::complex<double>& alpha, const std::complex<double> *A, const int& lda, const std::complex<double> *x, const int& incx,
99 const std::complex<double>& beta, std::complex<double> *y, const int& incy)
100{
101 cublasErrcheck(cublasZgemv(handle, GetCublasOperation(trans), m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha),
102 reinterpret_cast<const cuDoubleComplex*>(A), lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta), reinterpret_cast<cuDoubleComplex*>(y), incy));
103}
104
105template <typename T>
106static inline
107void gemv_batched(cublasHandle_t& handle, const char& trans, const int& m, const int& n,
108 const T& alpha, T** A, const int& lda, T** x, const int& incx,
109 const T& beta, T** y, const int& incy, const int& batch_size)
110{
111 for (int ii = 0; ii < batch_size; ++ii) {
112 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
113 cuBlasConnector::gemv(handle, trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy);
114 }
115}
116
117template <typename T>
118static inline
119void gemv_batched_strided(cublasHandle_t& handle, const char& transa, const int& m, const int& n,
120 const T& alpha, const T* A, const int& lda, const int& stride_a, const T* x, const int& incx, const int& stride_x,
121 const T& beta, T* y, const int& incy, const int& stride_y, const int& batch_size)
122{
123 for (int ii = 0; ii < batch_size; ii++) {
124 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
125 cuBlasConnector::gemv(handle, transa, m, n, alpha, A + ii * stride_a, lda, x + ii * stride_x, incx, beta, y + ii * stride_y, incy);
126 }
127}
128
129static inline
130void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
131 const float& alpha, const float* A, const int& lda, const float* B, const int& ldb,
132 const float& beta, float* C, const int& ldc)
133{
134 cublasErrcheck(cublasSgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
135 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
136}
137static inline
138void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
139 const double& alpha, const double* A, const int& lda, const double* B, const int& ldb,
140 const double& beta, double* C, const int& ldc)
141{
142 cublasErrcheck(cublasDgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
143 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
144}
145static inline
146void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
147 const std::complex<float>& alpha, const std::complex<float>* A, const int& lda, const std::complex<float>* B, const int& ldb,
148 const std::complex<float>& beta, std::complex<float>* C, const int& ldc)
149{
150 cublasErrcheck(cublasCgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
151 m, n, k,
152 reinterpret_cast<const cuComplex*>(&alpha),
153 reinterpret_cast<const cuComplex*>(A), lda,
154 reinterpret_cast<const cuComplex*>(B), ldb,
155 reinterpret_cast<const cuComplex*>(&beta),
156 reinterpret_cast<cuComplex*>(C), ldc));
157}
158static inline
159void gemm(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
160 const std::complex<double>& alpha, const std::complex<double>* A, const int& lda, const std::complex<double>* B, const int& ldb,
161 const std::complex<double>& beta, std::complex<double>* C, const int& ldc)
162{
163 cublasErrcheck(cublasZgemm(handle, GetCublasOperation(transa), GetCublasOperation(transb),
164 m, n, k,
165 reinterpret_cast<const cuDoubleComplex*>(&alpha),
166 reinterpret_cast<const cuDoubleComplex*>(A), lda,
167 reinterpret_cast<const cuDoubleComplex*>(B), ldb,
168 reinterpret_cast<const cuDoubleComplex*>(&beta),
169 reinterpret_cast<cuDoubleComplex*>(C), ldc));
170}
171
172template <typename T>
173static inline
174T** allocate_(T** in, const int& batch_size)
175{
176 T** out = nullptr;
177 cudaErrcheck(cudaMalloc(reinterpret_cast<void **>(&out), sizeof(T*) * batch_size));
178 cudaErrcheck(cudaMemcpy(out, in, sizeof(T*) * batch_size, cudaMemcpyHostToDevice));
179 return out;
180}
181
182static inline
183void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
184 const float& alpha, float** A, const int& lda, float** B, const int& ldb,
185 const float& beta, float** C, const int& ldc, const int& batch_size)
186{
187 float** d_A = allocate_(A, batch_size);
188 float** d_B = allocate_(B, batch_size);
189 float** d_C = allocate_(C, batch_size);
190 cublasErrcheck(cublasSgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
191 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
192 cudaErrcheck(cudaFree(d_A));
193 cudaErrcheck(cudaFree(d_B));
194 cudaErrcheck(cudaFree(d_C));
195}
196static inline
197void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
198 const double& alpha, double** A, const int& lda, double** B, const int& ldb,
199 const double& beta, double** C, const int& ldc, const int& batch_size)
200{
201 double** d_A = allocate_(A, batch_size);
202 double** d_B = allocate_(B, batch_size);
203 double** d_C = allocate_(C, batch_size);
204 cublasErrcheck(cublasDgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
205 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
206 cudaErrcheck(cudaFree(d_A));
207 cudaErrcheck(cudaFree(d_B));
208 cudaErrcheck(cudaFree(d_C));
209}
210static inline
211void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
212 const std::complex<float>& alpha, std::complex<float>** A, const int& lda, std::complex<float>** B, const int& ldb,
213 const std::complex<float>& beta, std::complex<float>** C, const int& ldc, const int& batch_size)
214{
215 std::complex<float>** d_A = allocate_(A, batch_size);
216 std::complex<float>** d_B = allocate_(B, batch_size);
217 std::complex<float>** d_C = allocate_(C, batch_size);
218 cublasErrcheck(cublasCgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
219 m, n, k,
220 reinterpret_cast<const cuComplex*>(&alpha),
221 reinterpret_cast<cuComplex**>(d_A), lda,
222 reinterpret_cast<cuComplex**>(d_B), ldb,
223 reinterpret_cast<const cuComplex*>(&beta),
224 reinterpret_cast<cuComplex**>(d_C), ldc, batch_size));
225 cudaErrcheck(cudaFree(d_A));
226 cudaErrcheck(cudaFree(d_B));
227 cudaErrcheck(cudaFree(d_C));
228}
229static inline
230void gemm_batched(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
231 const std::complex<double>& alpha, std::complex<double>** A, const int& lda, std::complex<double>** B, const int& ldb,
232 const std::complex<double>& beta, std::complex<double>** C, const int& ldc, const int& batch_size)
233{
234 std::complex<double>** d_A = allocate_(A, batch_size);
235 std::complex<double>** d_B = allocate_(B, batch_size);
236 std::complex<double>** d_C = allocate_(C, batch_size);
237 cublasErrcheck(cublasZgemmBatched(handle, GetCublasOperation(transa), GetCublasOperation(transb),
238 m, n, k,
239 reinterpret_cast<const cuDoubleComplex*>(&alpha),
240 reinterpret_cast<cuDoubleComplex**>(d_A), lda,
241 reinterpret_cast<cuDoubleComplex**>(d_B), ldb,
242 reinterpret_cast<const cuDoubleComplex*>(&beta),
243 reinterpret_cast<cuDoubleComplex**>(d_C), ldc, batch_size));
244 cudaErrcheck(cudaFree(d_A));
245 cudaErrcheck(cudaFree(d_B));
246 cudaErrcheck(cudaFree(d_C));
247}
248
249static inline
250void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
251 const float& alpha, const float* A, const int& lda, const int& stride_a, const float* B, const int& ldb, const int& stride_b,
252 const float& beta, float* C, const int& ldc, const int& stride_c, const int& batch_size)
253{
254 cublasErrcheck(cublasSgemmStridedBatched(
255 handle,
256 GetCublasOperation(transa),
257 GetCublasOperation(transb),
258 m, n, k,
259 &alpha,
260 A, lda, stride_a,
261 B, ldb, stride_b,
262 &beta,
263 C, ldc, stride_c,
264 batch_size));
265}
266static inline
267void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
268 const double& alpha, const double* A, const int& lda, const int& stride_a, const double* B, const int& ldb, const int& stride_b,
269 const double& beta, double* C, const int& ldc, const int& stride_c, const int& batch_size)
270{
271 cublasErrcheck(cublasDgemmStridedBatched(
272 handle,
273 GetCublasOperation(transa),
274 GetCublasOperation(transb),
275 m, n, k,
276 &alpha,
277 A, lda, stride_a,
278 B, ldb, stride_b,
279 &beta,
280 C, ldc, stride_c,
281 batch_size));
282}
283static inline
284void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
285 const std::complex<float>& alpha, const std::complex<float>* A, const int& lda, const int& stride_a, const std::complex<float>* B, const int& ldb, const int& stride_b,
286 const std::complex<float>& beta, std::complex<float>* C, const int& ldc, const int& stride_c, const int& batch_size)
287{
288 cublasErrcheck(cublasCgemmStridedBatched(
289 handle,
290 GetCublasOperation(transa),
291 GetCublasOperation(transb),
292 m, n, k,
293 reinterpret_cast<const cuComplex*>(&alpha),
294 reinterpret_cast<const cuComplex*>(A), lda, stride_a,
295 reinterpret_cast<const cuComplex*>(B), ldb, stride_b,
296 reinterpret_cast<const cuComplex*>(&beta),
297 reinterpret_cast<cuComplex*>(C), ldc, stride_c,
298 batch_size));
299}
300static inline
301void gemm_batched_strided(cublasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
302 const std::complex<double>& alpha, const std::complex<double>* A, const int& lda, const int& stride_a, const std::complex<double>* B, const int& ldb, const int& stride_b,
303 const std::complex<double>& beta, std::complex<double>* C, const int& ldc, const int& stride_c, const int& batch_size)
304{
305 cublasErrcheck(cublasZgemmStridedBatched(
306 handle,
307 GetCublasOperation(transa),
308 GetCublasOperation(transb),
309 m, n, k,
310 reinterpret_cast<const cuDoubleComplex*>(&alpha),
311 reinterpret_cast<const cuDoubleComplex*>(A), lda, stride_a,
312 reinterpret_cast<const cuDoubleComplex*>(B), ldb, stride_b,
313 reinterpret_cast<const cuDoubleComplex*>(&beta),
314 reinterpret_cast<cuDoubleComplex*>(C), ldc, stride_c,
315 batch_size));
316}
317
318} // namespace cuBlasConnector
319} // namespace container
320
321#endif // BASE_THIRD_PARTY_CUBLAS_H_
#define cublasErrcheck(res)
Definition cuda.h:207
#define cudaErrcheck(res)
Definition cuda.h:213
#define T
Definition exp.cpp:237
Definition tensor.cpp:8