ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
cuda.h
Go to the documentation of this file.
1#ifndef BASE_MACROS_CUDA_H_
2#define BASE_MACROS_CUDA_H_
3
4#include <cublas_v2.h>
5#include <cuda_runtime.h>
6#include <cusolverDn.h>
7#include <thrust/complex.h>
8
9#define THREADS_PER_BLOCK 256
10
11template <typename T>
13{
14 using type = T;
15};
16
17template <>
18struct GetTypeThrust<std::complex<float>>
19{
20 using type = thrust::complex<float>;
21};
22
23template <>
24struct GetTypeThrust<std::complex<double>>
25{
26 using type = thrust::complex<double>;
27};
28
29static inline cublasOperation_t GetCublasOperation(const char& trans)
30{
31 cublasOperation_t cutrans = {};
32 if (trans == 'N')
33 {
34 cutrans = CUBLAS_OP_N;
35 }
36 else if (trans == 'T')
37 {
38 cutrans = CUBLAS_OP_T;
39 }
40 else if (trans == 'C')
41 {
42 cutrans = CUBLAS_OP_C;
43 }
44 return cutrans;
45}
46
47template <typename T>
49{
50 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_32F;
51};
52// Specializations of DataTypeToEnum for supported types.
53template <>
54struct GetTypeCuda<int>
55{
56 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_32I;
57};
58template <>
59struct GetTypeCuda<float>
60{
61 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_32F;
62};
63template <>
64struct GetTypeCuda<double>
65{
66 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_64F;
67};
68template <>
69struct GetTypeCuda<int64_t>
70{
71 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_64I;
72};
73template <>
74struct GetTypeCuda<std::complex<float>>
75{
76 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_C_32F;
77};
78template <>
79struct GetTypeCuda<std::complex<double>>
80{
81 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_C_64F;
82};
83
84static inline cublasFillMode_t cublas_fill_mode(const char& uplo)
85{
86 if (uplo == 'U' || uplo == 'u')
87 return CUBLAS_FILL_MODE_UPPER;
88 else if (uplo == 'L' || uplo == 'l')
89 return CUBLAS_FILL_MODE_LOWER;
90 else
91 throw std::runtime_error("cublas_fill_mode: unknown uplo");
92}
93
94static inline cublasDiagType_t cublas_diag_type(const char& diag)
95{
96 if (diag == 'U' || diag == 'u')
97 return CUBLAS_DIAG_UNIT;
98 else if (diag == 'N' || diag == 'n')
99 return CUBLAS_DIAG_NON_UNIT;
100 else
101 throw std::runtime_error("cublas_diag_type: unknown diag");
102}
103
104static inline cusolverEigMode_t cublas_eig_mode(const char& jobz)
105{
106 if (jobz == 'N' || jobz == 'n')
107 return CUSOLVER_EIG_MODE_NOVECTOR;
108 else if (jobz == 'V' || jobz == 'v')
109 return CUSOLVER_EIG_MODE_VECTOR;
110 else
111 throw std::runtime_error("cublas_eig_mode: unknown diag");
112}
113
114static inline cusolverEigType_t cublas_eig_type(const int& itype)
115{
116 if (itype == 1)
117 return CUSOLVER_EIG_TYPE_1;
118 else if (itype == 2)
119 return CUSOLVER_EIG_TYPE_2;
120 else
121 throw std::runtime_error("cublas_eig_mode: unknown diag");
122}
123
135static inline cusolverEigRange_t cublas_eig_range(const char& range)
136{
137 if (range == 'A' || range == 'a')
138 return CUSOLVER_EIG_RANGE_ALL;
139 else if (range == 'V' || range == 'v')
140 return CUSOLVER_EIG_RANGE_V;
141 else if (range == 'I' || range == 'i')
142 return CUSOLVER_EIG_RANGE_I;
143 else
144 throw std::runtime_error("cublas_eig_range: unknown range '" + std::string(1, range) + "'");
145}
146
147// cuSOLVER API errors
148static const char* cusolverGetErrorEnum(cusolverStatus_t error)
149{
150 switch (error)
151 {
152 case CUSOLVER_STATUS_SUCCESS:
153 return "CUSOLVER_STATUS_SUCCESS";
154 case CUSOLVER_STATUS_NOT_INITIALIZED:
155 return "CUSOLVER_STATUS_NOT_INITIALIZED";
156 case CUSOLVER_STATUS_ALLOC_FAILED:
157 return "CUSOLVER_STATUS_ALLOC_FAILED";
158 case CUSOLVER_STATUS_INVALID_VALUE:
159 return "CUSOLVER_STATUS_INVALID_VALUE";
160 case CUSOLVER_STATUS_ARCH_MISMATCH:
161 return "CUSOLVER_STATUS_ARCH_MISMATCH";
162 case CUSOLVER_STATUS_MAPPING_ERROR:
163 return "CUSOLVER_STATUS_MAPPING_ERROR";
164 case CUSOLVER_STATUS_EXECUTION_FAILED:
165 return "CUSOLVER_STATUS_EXECUTION_FAILED";
166 case CUSOLVER_STATUS_INTERNAL_ERROR:
167 return "CUSOLVER_STATUS_INTERNAL_ERROR";
168 case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
169 return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
170 case CUSOLVER_STATUS_NOT_SUPPORTED:
171 return "CUSOLVER_STATUS_NOT_SUPPORTED ";
172 case CUSOLVER_STATUS_ZERO_PIVOT:
173 return "CUSOLVER_STATUS_ZERO_PIVOT";
174 case CUSOLVER_STATUS_INVALID_LICENSE:
175 return "CUSOLVER_STATUS_INVALID_LICENSE";
176 default:
177 return "Unknown cusolverStatus_t message";
178 }
179}
180
181inline void cusolverAssert(cusolverStatus_t code, const char* file, int line)
182{
183 if (code != CUSOLVER_STATUS_SUCCESS)
184 {
185 fprintf(stderr, " Unexpected cuSOLVER Error: %s %s %d\n", cusolverGetErrorEnum(code), file, line);
186 exit(code);
187 }
188}
189
190// cuSOLVER API errors
191static const char* cublasGetErrorEnum(cublasStatus_t error)
192{
193 switch (error)
194 {
195 case CUBLAS_STATUS_SUCCESS:
196 return "CUBLAS_STATUS_SUCCESS";
197 case CUBLAS_STATUS_NOT_INITIALIZED:
198 return "CUBLAS_STATUS_NOT_INITIALIZED";
199 case CUBLAS_STATUS_ALLOC_FAILED:
200 return "CUBLAS_STATUS_ALLOC_FAILED";
201 case CUBLAS_STATUS_INVALID_VALUE:
202 return "CUBLAS_STATUS_INVALID_VALUE";
203 case CUBLAS_STATUS_ARCH_MISMATCH:
204 return "CUBLAS_STATUS_ARCH_MISMATCH";
205 case CUBLAS_STATUS_MAPPING_ERROR:
206 return "CUBLAS_STATUS_MAPPING_ERROR";
207 case CUBLAS_STATUS_EXECUTION_FAILED:
208 return "CUBLAS_STATUS_EXECUTION_FAILED";
209 case CUBLAS_STATUS_INTERNAL_ERROR:
210 return "CUBLAS_STATUS_INTERNAL_ERROR";
211 default:
212 return "Unknown";
213 }
214}
215
216inline void cublasAssert(cublasStatus_t res, const char* file, int line)
217{
218 if (res != CUBLAS_STATUS_SUCCESS)
219 {
220 fprintf(stderr, " Unexpected cuBLAS Error: %s %s %d\n", cublasGetErrorEnum(res), file, line);
221 exit(res);
222 }
223}
224
225#define cusolverErrcheck(res) \
226 { \
227 cusolverAssert((res), __FILE__, __LINE__); \
228 }
229
230#define cublasErrcheck(res) \
231 { \
232 cublasAssert((res), __FILE__, __LINE__); \
233 }
234
235// CUDA API errors
236#define cudaErrcheck(res) \
237 { \
238 if (res != cudaSuccess) \
239 { \
240 fprintf(stderr, " Unexpected Device Error %s:%d: %s, %s\n", __FILE__, __LINE__, cudaGetErrorName(res), \
241 cudaGetErrorString(res)); \
242 exit(res); \
243 } \
244 }
245
246#ifdef __DEBUG
247#define cudaCheckOnDebug() cudaErrcheck(cudaDeviceSynchronize())
248#else
249#define cudaCheckOnDebug()
250#endif
251
252#endif // BASE_MACROS_CUDA_H_
void cublasAssert(cublasStatus_t res, const char *file, int line)
Definition cuda.h:216
void cusolverAssert(cusolverStatus_t code, const char *file, int line)
Definition cuda.h:181
std::complex< double > complex
Definition diago_cusolver.cpp:13
#define T
Definition exp.cpp:237
file(GLOB ATen_CORE_SRCS "*.cpp") set(ATen_CPU_SRCS $
Definition CMakeLists.txt:1
Definition cuda.h:49
static constexpr cudaDataType cuda_data_type
Definition cuda.h:50
thrust::complex< double > type
Definition cuda.h:26
thrust::complex< float > type
Definition cuda.h:20
Definition cuda.h:13
T type
Definition cuda.h:14