ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
hipblas.h
Go to the documentation of this file.
1#ifndef BASE_THIRD_PARTY_HIPBLAS_H_
2#define BASE_THIRD_PARTY_HIPBLAS_H_
3
4#include <hip/hip_runtime.h>
5#include <hipblas/hipblas.h>
6#include <base/macros/rocm.h>
7
8namespace container {
9namespace hipBlasConnector {
10
11static inline
12void dot(hipblasHandle_t& handle, const int& n, const float *x, const int& incx, const float *y, const int& incy, float* result)
13{
14 hipblasErrcheck(hipblasSdot(handle, n, x, incx, y, incy, result));
15}
16static inline
17void dot(hipblasHandle_t& handle, const int& n, const double *x, const int& incx, const double *y, const int& incy, double* result)
18{
19 hipblasErrcheck(hipblasDdot(handle, n, x, incx, y, incy, result));
20}
21static inline
22void dot(hipblasHandle_t& handle, const int& n, const std::complex<float> *x, const int& incx, const std::complex<float> *y, const int& incy, std::complex<float>* result)
23{
24 hipblasErrcheck(hipblasCdotc(handle, n, reinterpret_cast<const hipblasComplex*>(x), incx, reinterpret_cast<const hipblasComplex*>(y), incy, reinterpret_cast<hipblasComplex*>(result)));
25}
26static inline
27void dot(hipblasHandle_t& handle, const int& n, const std::complex<double> *x, const int& incx, const std::complex<double> *y, const int& incy, std::complex<double>* result)
28{
29 hipblasErrcheck(hipblasZdotc(handle, n, reinterpret_cast<const hipblasDoubleComplex*>(x), incx, reinterpret_cast<const hipblasDoubleComplex*>(y), incy, reinterpret_cast<hipblasDoubleComplex*>(result)));
30}
31
32static inline
33void axpy(hipblasHandle_t& handle, const int& n, const float& alpha, const float *x, const int& incx, float *y, const int& incy)
34{
35 hipblasErrcheck(hipblasSaxpy(handle, n, &alpha, x, incx, y, incy));
36}
37static inline
38void axpy(hipblasHandle_t& handle, const int& n, const double& alpha, const double *x, const int& incx, double *y, const int& incy)
39{
40 hipblasErrcheck(hipblasDaxpy(handle, n, &alpha, x, incx, y, incy));
41}
42static inline
43void axpy(hipblasHandle_t& handle, const int& n, const std::complex<float>& alpha, const std::complex<float> *x, const int& incx, std::complex<float> *y, const int& incy)
44{
45 hipblasErrcheck(hipblasCaxpy(handle, n, reinterpret_cast<const hipblasComplex*>(&alpha), reinterpret_cast<const hipblasComplex*>(x), incx, reinterpret_cast<hipblasComplex*>(y), incy));
46}
47static inline
48void axpy(hipblasHandle_t& handle, const int& n, const std::complex<double>& alpha, const std::complex<double> *x, const int& incx, std::complex<double> *y, const int& incy)
49{
50 hipblasErrcheck(hipblasZaxpy(handle, n, reinterpret_cast<const hipblasDoubleComplex*>(&alpha), reinterpret_cast<const hipblasDoubleComplex*>(x), incx, reinterpret_cast<hipblasDoubleComplex*>(y), incy));
51}
52
53static inline
54void scal(hipblasHandle_t& handle, const int& n, const float& alpha, float *x, const int& incx)
55{
56 hipblasErrcheck(hipblasSscal(handle, n, &alpha, x, incx));
57}
58static inline
59void scal(hipblasHandle_t& handle, const int& n, const double& alpha, double *x, const int& incx)
60{
61 hipblasErrcheck(hipblasDscal(handle, n, &alpha, x, incx));
62}
63static inline
64void scal(hipblasHandle_t& handle, const int& n, const std::complex<float>& alpha, std::complex<float> *x, const int& incx)
65{
66 hipblasErrcheck(hipblasCscal(handle, n, reinterpret_cast<const hipblasComplex*>(&alpha), reinterpret_cast<hipblasComplex*>(x), incx));
67}
68static inline
69void scal(hipblasHandle_t& handle, const int& n, const std::complex<double>& alpha, std::complex<double> *x, const int& incx)
70{
71 hipblasErrcheck(hipblasZscal(handle, n, reinterpret_cast<const hipblasDoubleComplex*>(&alpha), reinterpret_cast<hipblasDoubleComplex*>(x), incx));
72}
73
74static inline
75void gemv(hipblasHandle_t& handle, const char& trans, const int& m, const int& n,
76 const float& alpha, const float *A, const int& lda, const float *x, const int& incx,
77 const float& beta, float *y, const int& incy)
78{
79 hipblasErrcheck(hipblasSgemv(handle, GetHipblasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
80}
81static inline
82void gemv(hipblasHandle_t& handle, const char& trans, const int& m, const int& n,
83 const double& alpha, const double *A, const int& lda, const double *x, const int& incx,
84 const double& beta, double *y, const int& incy)
85{
86 hipblasErrcheck(hipblasDgemv(handle, GetHipblasOperation(trans), m, n, &alpha, A, lda, x, incx, &beta, y, incy));
87}
88static inline
89void gemv(hipblasHandle_t& handle, const char& trans, const int& m, const int& n,
90 const std::complex<float>& alpha, const std::complex<float> *A, const int& lda, const std::complex<float> *x, const int& incx,
91 const std::complex<float>& beta, std::complex<float> *y, const int& incy)
92{
93 hipblasErrcheck(hipblasCgemv(handle, GetHipblasOperation(trans), m, n, reinterpret_cast<const hipblasComplex*>(&alpha),
94 reinterpret_cast<const hipblasComplex*>(A), lda, reinterpret_cast<const hipblasComplex*>(x), incx, reinterpret_cast<const hipblasComplex*>(&beta), reinterpret_cast<hipblasComplex*>(y), incy));
95}
96static inline
97void gemv(hipblasHandle_t& handle, const char& trans, const int& m, const int& n,
98 const std::complex<double>& alpha, const std::complex<double> *A, const int& lda, const std::complex<double> *x, const int& incx,
99 const std::complex<double>& beta, std::complex<double> *y, const int& incy)
100{
101 hipblasErrcheck(hipblasZgemv(handle, GetHipblasOperation(trans), m, n, reinterpret_cast<const hipblasDoubleComplex*>(&alpha),
102 reinterpret_cast<const hipblasDoubleComplex*>(A), lda, reinterpret_cast<const hipblasDoubleComplex*>(x), incx, reinterpret_cast<const hipblasDoubleComplex*>(&beta), reinterpret_cast<hipblasDoubleComplex*>(y), incy));
103}
104
105template <typename T>
106static inline
107void gemv_batched(hipblasHandle_t& handle, const char& trans, const int& m, const int& n,
108 const T& alpha, T** A, const int& lda, T** x, const int& incx,
109 const T& beta, T** y, const int& incy, const int& batch_size)
110{
111 for (int ii = 0; ii < batch_size; ++ii) {
112 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
113 hipBlasConnector::gemv(handle, trans, m, n, alpha, A[ii], lda, x[ii], incy, beta, y[ii], incy);
114 }
115}
116
117template <typename T>
118static inline
119void gemv_batched_strided(hipblasHandle_t& handle, const char& transa, const int& m, const int& n,
120 const T& alpha, const T* A, const int& lda, const int& stride_a, const T* x, const int& incx, const int& stride_x,
121 const T& beta, T* y, const int& incy, const int& stride_y, const int& batch_size)
122{
123 for (int ii = 0; ii < batch_size; ii++) {
124 // Call the single GEMV for each pair of matrix A[ii] and vector x[ii]
125 hipBlasConnector::gemv(handle, transa, m, n, alpha, A + ii * stride_a, lda, x + ii * stride_x, incx, beta, y + ii * stride_y, incy);
126 }
127}
128
129static inline
130void gemm(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
131 const float& alpha, const float* A, const int& lda, const float* B, const int& ldb,
132 const float& beta, float* C, const int& ldc)
133{
134 hipblasErrcheck(hipblasSgemm(handle, GetHipblasOperation(transa), GetHipblasOperation(transb),
135 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
136}
137static inline
138void gemm(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
139 const double& alpha, const double* A, const int& lda, const double* B, const int& ldb,
140 const double& beta, double* C, const int& ldc)
141{
142 hipblasErrcheck(hipblasDgemm(handle, GetHipblasOperation(transa), GetHipblasOperation(transb),
143 m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc));
144}
145static inline
146void gemm(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
147 const std::complex<float>& alpha, const std::complex<float>* A, const int& lda, const std::complex<float>* B, const int& ldb,
148 const std::complex<float>& beta, std::complex<float>* C, const int& ldc)
149{
150 hipblasErrcheck(hipblasCgemm(handle, GetHipblasOperation(transa), GetHipblasOperation(transb),
151 m, n, k,
152 reinterpret_cast<const hipblasComplex*>(&alpha),
153 reinterpret_cast<const hipblasComplex*>(A), lda,
154 reinterpret_cast<const hipblasComplex*>(B), ldb,
155 reinterpret_cast<const hipblasComplex*>(&beta),
156 reinterpret_cast<hipblasComplex*>(C), ldc));
157}
158static inline
159void gemm(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
160 const std::complex<double>& alpha, const std::complex<double>* A, const int& lda, const std::complex<double>* B, const int& ldb,
161 const std::complex<double>& beta, std::complex<double>* C, const int& ldc)
162{
163 hipblasErrcheck(hipblasZgemm(handle, GetHipblasOperation(transa), GetHipblasOperation(transb),
164 m, n, k,
165 reinterpret_cast<const hipblasDoubleComplex*>(&alpha),
166 reinterpret_cast<const hipblasDoubleComplex*>(A), lda,
167 reinterpret_cast<const hipblasDoubleComplex*>(B), ldb,
168 reinterpret_cast<const hipblasDoubleComplex*>(&beta),
169 reinterpret_cast<hipblasDoubleComplex*>(C), ldc));
170}
171
172template <typename T>
173static inline
174T** allocate_(T** in, const int& batch_size)
175{
176 T** out = nullptr;
177 hipErrcheck(hipMalloc(reinterpret_cast<void **>(&out), sizeof(T*) * batch_size));
178 hipErrcheck(hipMemcpy(out, in, sizeof(T*) * batch_size, hipMemcpyHostToDevice));
179 return out;
180}
181
182static inline
183void gemm_batched(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
184 const float& alpha, float** A, const int& lda, float** B, const int& ldb,
185 const float& beta, float** C, const int& ldc, const int& batch_size)
186{
187 float** d_A = allocate_(A, batch_size);
188 float** d_B = allocate_(B, batch_size);
189 float** d_C = allocate_(C, batch_size);
190 hipblasErrcheck(hipblasSgemmBatched(handle, GetHipblasOperation(transa), GetHipblasOperation(transb),
191 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
192 hipErrcheck(hipFree(d_A));
193 hipErrcheck(hipFree(d_B));
194 hipErrcheck(hipFree(d_C));
195}
196static inline
197void gemm_batched(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
198 const double& alpha, double** A, const int& lda, double** B, const int& ldb,
199 const double& beta, double** C, const int& ldc, const int& batch_size)
200{
201 double** d_A = allocate_(A, batch_size);
202 double** d_B = allocate_(B, batch_size);
203 double** d_C = allocate_(C, batch_size);
204 hipblasErrcheck(hipblasDgemmBatched(handle, GetHipblasOperation(transa), GetHipblasOperation(transb),
205 m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc, batch_size));
206 hipErrcheck(hipFree(d_A));
207 hipErrcheck(hipFree(d_B));
208 hipErrcheck(hipFree(d_C));
209}
210static inline
211void gemm_batched(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
212 const std::complex<float>& alpha, std::complex<float>** A, const int& lda, std::complex<float>** B, const int& ldb,
213 const std::complex<float>& beta, std::complex<float>** C, const int& ldc, const int& batch_size)
214{
215 std::complex<float>** d_A = allocate_(A, batch_size);
216 std::complex<float>** d_B = allocate_(B, batch_size);
217 std::complex<float>** d_C = allocate_(C, batch_size);
218 hipblasErrcheck(hipblasCgemmBatched(handle, GetHipblasOperation(transa), GetHipblasOperation(transb),
219 m, n, k,
220 reinterpret_cast<const hipblasComplex*>(&alpha),
221 reinterpret_cast<hipblasComplex**>(d_A), lda,
222 reinterpret_cast<hipblasComplex**>(d_B), ldb,
223 reinterpret_cast<const hipblasComplex*>(&beta),
224 reinterpret_cast<hipblasComplex**>(d_C), ldc, batch_size));
225 hipErrcheck(hipFree(d_A));
226 hipErrcheck(hipFree(d_B));
227 hipErrcheck(hipFree(d_C));
228}
229static inline
230void gemm_batched(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
231 const std::complex<double>& alpha, std::complex<double>** A, const int& lda, std::complex<double>** B, const int& ldb,
232 const std::complex<double>& beta, std::complex<double>** C, const int& ldc, const int& batch_size)
233{
234 std::complex<double>** d_A = allocate_(A, batch_size);
235 std::complex<double>** d_B = allocate_(B, batch_size);
236 std::complex<double>** d_C = allocate_(C, batch_size);
237 hipblasErrcheck(hipblasZgemmBatched(handle, GetHipblasOperation(transa), GetHipblasOperation(transb),
238 m, n, k,
239 reinterpret_cast<const hipblasDoubleComplex*>(&alpha),
240 reinterpret_cast<hipblasDoubleComplex**>(d_A), lda,
241 reinterpret_cast<hipblasDoubleComplex**>(d_B), ldb,
242 reinterpret_cast<const hipblasDoubleComplex*>(&beta),
243 reinterpret_cast<hipblasDoubleComplex**>(d_C), ldc, batch_size));
244 hipErrcheck(hipFree(d_A));
245 hipErrcheck(hipFree(d_B));
246 hipErrcheck(hipFree(d_C));
247}
248
249static inline
250void gemm_batched_strided(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
251 const float& alpha, const float* A, const int& lda, const int& stride_a, const float* B, const int& ldb, const int& stride_b,
252 const float& beta, float* C, const int& ldc, const int& stride_c, const int& batch_size)
253{
254 hipblasErrcheck(hipblasSgemmStridedBatched(
255 handle,
256 GetHipblasOperation(transa),
257 GetHipblasOperation(transb),
258 m, n, k,
259 &alpha,
260 A, lda, stride_a,
261 B, ldb, stride_b,
262 &beta,
263 C, ldc, stride_c,
264 batch_size));
265}
266static inline
267void gemm_batched_strided(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
268 const double& alpha, const double* A, const int& lda, const int& stride_a, const double* B, const int& ldb, const int& stride_b,
269 const double& beta, double* C, const int& ldc, const int& stride_c, const int& batch_size)
270{
271 hipblasErrcheck(hipblasDgemmStridedBatched(
272 handle,
273 GetHipblasOperation(transa),
274 GetHipblasOperation(transb),
275 m, n, k,
276 &alpha,
277 A, lda, stride_a,
278 B, ldb, stride_b,
279 &beta,
280 C, ldc, stride_c,
281 batch_size));
282}
283static inline
284void gemm_batched_strided(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
285 const std::complex<float>& alpha, const std::complex<float>* A, const int& lda, const int& stride_a, const std::complex<float>* B, const int& ldb, const int& stride_b,
286 const std::complex<float>& beta, std::complex<float>* C, const int& ldc, const int& stride_c, const int& batch_size)
287{
288 hipblasErrcheck(hipblasCgemmStridedBatched(
289 handle,
290 GetHipblasOperation(transa),
291 GetHipblasOperation(transb),
292 m, n, k,
293 reinterpret_cast<const hipblasComplex*>(&alpha),
294 reinterpret_cast<const hipblasComplex*>(A), lda, stride_a,
295 reinterpret_cast<const hipblasComplex*>(B), ldb, stride_b,
296 reinterpret_cast<const hipblasComplex*>(&beta),
297 reinterpret_cast<hipblasComplex*>(C), ldc, stride_c,
298 batch_size));
299}
300static inline
301void gemm_batched_strided(hipblasHandle_t& handle, const char& transa, const char& transb, const int& m, const int& n, const int& k,
302 const std::complex<double>& alpha, const std::complex<double>* A, const int& lda, const int& stride_a, const std::complex<double>* B, const int& ldb, const int& stride_b,
303 const std::complex<double>& beta, std::complex<double>* C, const int& ldc, const int& stride_c, const int& batch_size)
304{
305 hipblasErrcheck(hipblasZgemmStridedBatched(
306 handle,
307 GetHipblasOperation(transa),
308 GetHipblasOperation(transb),
309 m, n, k,
310 reinterpret_cast<const hipblasDoubleComplex*>(&alpha),
311 reinterpret_cast<const hipblasDoubleComplex*>(A), lda, stride_a,
312 reinterpret_cast<const hipblasDoubleComplex*>(B), ldb, stride_b,
313 reinterpret_cast<const hipblasDoubleComplex*>(&beta),
314 reinterpret_cast<hipblasDoubleComplex*>(C), ldc, stride_c,
315 batch_size));
316}
317
318} // namespace hipBlasConnector
319} // namespace container
320
321#endif // BASE_THIRD_PARTY_HIPBLAS_H_
#define T
Definition exp.cpp:237
Definition tensor.cpp:8
#define hipblasErrcheck(res)
Definition rocm.h:227
#define hipErrcheck(res)
Definition rocm.h:233