ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
rocm.h
Go to the documentation of this file.
1#ifndef BASE_MACROS_ROCM_H_
2#define BASE_MACROS_ROCM_H_
3
4#include <complex>
5#include <hip/hip_runtime.h>
6#include <hipblas/hipblas.h>
7#include <hipsolver/hipsolver.h>
8
9#if defined(__HCC__) || defined(__HIP__)
10#include <thrust/complex.h>
11#endif // defined(__HCC__) || defined(__HIP__)
12
13#define THREADS_PER_BLOCK 256
14
15#if defined(__HCC__) || defined(__HIP__)
16template <typename T>
17struct GetTypeThrust
18{
19 using type = T;
20};
21
22template <>
23struct GetTypeThrust<std::complex<float>>
24{
25 using type = thrust::complex<float>;
26};
27
28template <>
29struct GetTypeThrust<std::complex<double>>
30{
31 using type = thrust::complex<double>;
32};
33#endif // defined(__HCC__) || defined(__HIP__)
34
35static inline hipblasOperation_t GetHipblasOperation(const char& trans)
36{
37 hipblasOperation_t hip_trans = {};
38 if (trans == 'N')
39 {
40 hip_trans = HIPBLAS_OP_N;
41 }
42 else if (trans == 'T')
43 {
44 hip_trans = HIPBLAS_OP_T;
45 }
46 else if (trans == 'C')
47 {
48 hip_trans = HIPBLAS_OP_C;
49 }
50 else
51 {
52 // Handle invalid input or provide a default behavior.
53 hip_trans = HIPBLAS_OP_N;
54 }
55 return hip_trans;
56}
57
58template <typename T>
60{
61 static constexpr hipDataType hip_data_type = HIP_R_32F;
62};
63
64// Specializations of GetTypeRocm for supported types.
65template <>
66struct GetTypeRocm<int>
67{
68 static constexpr hipDataType hip_data_type = HIP_R_32F;
69};
70template <>
71struct GetTypeRocm<float>
72{
73 static constexpr hipDataType hip_data_type = HIP_R_32F;
74};
75template <>
76struct GetTypeRocm<double>
77{
78 static constexpr hipDataType hip_data_type = HIP_R_64F;
79};
80template <>
81struct GetTypeRocm<int64_t>
82{
83 static constexpr hipDataType hip_data_type = HIP_R_64F;
84};
85template <>
86struct GetTypeRocm<std::complex<float>>
87{
88 static constexpr hipDataType hip_data_type = HIP_C_32F;
89};
90template <>
91struct GetTypeRocm<std::complex<double>>
92{
93 static constexpr hipDataType hip_data_type = HIP_C_64F;
94};
95
96static inline hipblasFillMode_t hipblas_fill_mode(const char& uplo)
97{
98 if (uplo == 'U' || uplo == 'u')
99 return HIPBLAS_FILL_MODE_UPPER;
100 else if (uplo == 'L' || uplo == 'l')
101 return HIPBLAS_FILL_MODE_LOWER;
102 else
103 throw std::runtime_error("hipblas_fill_mode: unknown uplo");
104}
105
106static inline hipblasDiagType_t hipblas_diag_type(const char& diag)
107{
108 if (diag == 'U' || diag == 'u')
109 return HIPBLAS_DIAG_UNIT;
110 else if (diag == 'N' || diag == 'n')
111 return HIPBLAS_DIAG_NON_UNIT;
112 else
113 throw std::runtime_error("hipblas_diag_type: unknown diag");
114}
115
116static inline hipsolverEigMode_t hipblas_eig_mode(const char& jobz)
117{
118 if (jobz == 'N' || jobz == 'n')
119 return HIPSOLVER_EIG_MODE_NOVECTOR;
120 else if (jobz == 'V' || jobz == 'v')
121 return HIPSOLVER_EIG_MODE_VECTOR;
122 else
123 throw std::runtime_error("hipblas_eig_mode: unknown diag");
124}
125
126static inline hipsolverEigType_t hipblas_eig_type(const int& itype)
127{
128 if (itype == 1)
129 return HIPSOLVER_EIG_TYPE_1;
130 else if (itype == 2)
131 return HIPSOLVER_EIG_TYPE_2;
132 else
133 throw std::runtime_error("hipblas_eig_mode: unknown diag");
134}
135
136static inline hipsolverFillMode_t hipsolver_fill_mode(const char& uplo)
137{
138 if (uplo == 'U' || uplo == 'u')
139 return HIPSOLVER_FILL_MODE_UPPER;
140 else if (uplo == 'L' || uplo == 'l')
141 return HIPSOLVER_FILL_MODE_LOWER;
142 else
143 throw std::runtime_error("hipsolver_fill_mode: unknown uplo");
144}
145
146// hipSOLVER API errors
147static const char* hipsolverGetErrorEnum(hipsolverStatus_t error)
148{
149 switch (error)
150 {
151 case HIPSOLVER_STATUS_SUCCESS:
152 return "HIPSOLVER_STATUS_SUCCESS";
153 case HIPSOLVER_STATUS_NOT_INITIALIZED:
154 return "HIPSOLVER_STATUS_NOT_INITIALIZED";
155 case HIPSOLVER_STATUS_ALLOC_FAILED:
156 return "HIPSOLVER_STATUS_ALLOC_FAILED";
157 case HIPSOLVER_STATUS_INVALID_VALUE:
158 return "HIPSOLVER_STATUS_INVALID_VALUE";
159 case HIPSOLVER_STATUS_ARCH_MISMATCH:
160 return "HIPSOLVER_STATUS_ARCH_MISMATCH";
161 case HIPSOLVER_STATUS_MAPPING_ERROR:
162 return "HIPSOLVER_STATUS_MAPPING_ERROR";
163 case HIPSOLVER_STATUS_EXECUTION_FAILED:
164 return "HIPSOLVER_STATUS_EXECUTION_FAILED";
165 case HIPSOLVER_STATUS_INTERNAL_ERROR:
166 return "HIPSOLVER_STATUS_INTERNAL_ERROR";
167 case HIPSOLVER_STATUS_NOT_SUPPORTED:
168 return "HIPSOLVER_STATUS_NOT_SUPPORTED ";
169 case HIPSOLVER_STATUS_INVALID_ENUM:
170 return "HIPSOLVER_STATUS_INVALID_ENUM";
171 default:
172 return "Unknown hipsolverStatus_t message";
173 }
174}
175
176inline void hipsolverAssert(hipsolverStatus_t code, const char* file, int line, bool abort = true)
177{
178 if (code != HIPSOLVER_STATUS_SUCCESS)
179 {
180 fprintf(stderr, "hipSOLVER Assert: %s %s %d\n", hipsolverGetErrorEnum(code), file, line);
181 if (abort)
182 exit(code);
183 }
184}
185
186// hipSOLVER API errors
187static const char* hipblasGetErrorEnum(hipblasStatus_t error)
188{
189 switch (error)
190 {
191 case HIPBLAS_STATUS_SUCCESS:
192 return "HIPBLAS_STATUS_SUCCESS";
193 case HIPBLAS_STATUS_NOT_INITIALIZED:
194 return "HIPBLAS_STATUS_NOT_INITIALIZED";
195 case HIPBLAS_STATUS_ALLOC_FAILED:
196 return "HIPBLAS_STATUS_ALLOC_FAILED";
197 case HIPBLAS_STATUS_INVALID_VALUE:
198 return "HIPBLAS_STATUS_INVALID_VALUE";
199 case HIPBLAS_STATUS_ARCH_MISMATCH:
200 return "HIPBLAS_STATUS_ARCH_MISMATCH";
201 case HIPBLAS_STATUS_MAPPING_ERROR:
202 return "HIPBLAS_STATUS_MAPPING_ERROR";
203 case HIPBLAS_STATUS_EXECUTION_FAILED:
204 return "HIPBLAS_STATUS_EXECUTION_FAILED";
205 case HIPBLAS_STATUS_INTERNAL_ERROR:
206 return "HIPBLAS_STATUS_INTERNAL_ERROR";
207 default:
208 return "Unknown";
209 }
210}
211
212inline void hipblasAssert(hipblasStatus_t code, const char* file, int line, bool abort = true)
213{
214 if (code != HIPBLAS_STATUS_SUCCESS)
215 {
216 fprintf(stderr, "Unexpected hipBLAS Error: %s %s %d\n", hipblasGetErrorEnum(code), file, line);
217 if (abort)
218 exit(code);
219 }
220}
221
222#define hipsolverErrcheck(res) \
223 { \
224 hipsolverAssert((res), __FILE__, __LINE__); \
225 }
226
227#define hipblasErrcheck(res) \
228 { \
229 hipblasAssert((res), __FILE__, __LINE__); \
230 }
231
232// ROCM API errors
233#define hipErrcheck(res) \
234 { \
235 if (res != hipSuccess) \
236 { \
237 fprintf(stderr, " Unexpected Device Error %s:%d: %s, %s\n", __FILE__, __LINE__, hipGetErrorName(res), \
238 hipGetErrorString(res)); \
239 exit(res); \
240 } \
241 }
242
243#ifdef __DEBUG
244#define hipCheckOnDebug() hipErrcheck(hipDeviceSynchronize())
245#else
246#define hipCheckOnDebug()
247#endif
248
249#endif // BASE_MACROS_ROCM_H_
std::complex< double > complex
Definition diago_cusolver.cpp:13
#define T
Definition exp.cpp:237
void hipblasAssert(hipblasStatus_t code, const char *file, int line, bool abort=true)
Definition rocm.h:212
void hipsolverAssert(hipsolverStatus_t code, const char *file, int line, bool abort=true)
Definition rocm.h:176
file(GLOB ATen_CORE_SRCS "*.cpp") set(ATen_CPU_SRCS $
Definition CMakeLists.txt:1
Definition rocm.h:60
static constexpr hipDataType hip_data_type
Definition rocm.h:61
Definition cuda.h:13
T type
Definition cuda.h:14