ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
cuda.h
Go to the documentation of this file.
1#ifndef BASE_MACROS_CUDA_H_
2#define BASE_MACROS_CUDA_H_
3
4#include <cublas_v2.h>
5#include <cuda_runtime.h>
6#include <cusolverDn.h>
7#include <thrust/complex.h>
8
9#define THREADS_PER_BLOCK 256
10
11template <typename T>
13{
14 using type = T;
15};
16
17template <>
18struct GetTypeThrust<std::complex<float>>
19{
20 using type = thrust::complex<float>;
21};
22
23template <>
24struct GetTypeThrust<std::complex<double>>
25{
26 using type = thrust::complex<double>;
27};
28
29static inline cublasOperation_t GetCublasOperation(const char& trans)
30{
31 cublasOperation_t cutrans = {};
32 if (trans == 'N')
33 {
34 cutrans = CUBLAS_OP_N;
35 }
36 else if (trans == 'T')
37 {
38 cutrans = CUBLAS_OP_T;
39 }
40 else if (trans == 'C')
41 {
42 cutrans = CUBLAS_OP_C;
43 }
44 return cutrans;
45}
46
47template <typename T>
49{
50 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_32F;
51};
52// Specializations of DataTypeToEnum for supported types.
53template <>
54struct GetTypeCuda<int>
55{
56 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_32I;
57};
58template <>
59struct GetTypeCuda<float>
60{
61 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_32F;
62};
63template <>
64struct GetTypeCuda<double>
65{
66 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_64F;
67};
68template <>
69struct GetTypeCuda<int64_t>
70{
71 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_R_64I;
72};
73template <>
74struct GetTypeCuda<std::complex<float>>
75{
76 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_C_32F;
77};
78template <>
79struct GetTypeCuda<std::complex<double>>
80{
81 static constexpr cudaDataType cuda_data_type = cudaDataType::CUDA_C_64F;
82};
83
84static inline cublasFillMode_t cublas_fill_mode(const char& uplo)
85{
86 if (uplo == 'U' || uplo == 'u')
87 return CUBLAS_FILL_MODE_UPPER;
88 else if (uplo == 'L' || uplo == 'l')
89 return CUBLAS_FILL_MODE_LOWER;
90 else
91 throw std::runtime_error("cublas_fill_mode: unknown uplo");
92}
93
94static inline cublasDiagType_t cublas_diag_type(const char& diag)
95{
96 if (diag == 'U' || diag == 'u')
97 return CUBLAS_DIAG_UNIT;
98 else if (diag == 'N' || diag == 'n')
99 return CUBLAS_DIAG_NON_UNIT;
100 else
101 throw std::runtime_error("cublas_diag_type: unknown diag");
102}
103
104static inline cusolverEigMode_t cublas_eig_mode(const char& jobz)
105{
106 if (jobz == 'N' || jobz == 'n')
107 return CUSOLVER_EIG_MODE_NOVECTOR;
108 else if (jobz == 'V' || jobz == 'v')
109 return CUSOLVER_EIG_MODE_VECTOR;
110 else
111 throw std::runtime_error("cublas_eig_mode: unknown diag");
112}
113
114static inline cusolverEigType_t cublas_eig_type(const int& itype)
115{
116 if (itype == 1)
117 return CUSOLVER_EIG_TYPE_1;
118 else if (itype == 2)
119 return CUSOLVER_EIG_TYPE_2;
120 else
121 throw std::runtime_error("cublas_eig_mode: unknown diag");
122}
123
124// cuSOLVER API errors
125static const char* cusolverGetErrorEnum(cusolverStatus_t error)
126{
127 switch (error)
128 {
129 case CUSOLVER_STATUS_SUCCESS:
130 return "CUSOLVER_STATUS_SUCCESS";
131 case CUSOLVER_STATUS_NOT_INITIALIZED:
132 return "CUSOLVER_STATUS_NOT_INITIALIZED";
133 case CUSOLVER_STATUS_ALLOC_FAILED:
134 return "CUSOLVER_STATUS_ALLOC_FAILED";
135 case CUSOLVER_STATUS_INVALID_VALUE:
136 return "CUSOLVER_STATUS_INVALID_VALUE";
137 case CUSOLVER_STATUS_ARCH_MISMATCH:
138 return "CUSOLVER_STATUS_ARCH_MISMATCH";
139 case CUSOLVER_STATUS_MAPPING_ERROR:
140 return "CUSOLVER_STATUS_MAPPING_ERROR";
141 case CUSOLVER_STATUS_EXECUTION_FAILED:
142 return "CUSOLVER_STATUS_EXECUTION_FAILED";
143 case CUSOLVER_STATUS_INTERNAL_ERROR:
144 return "CUSOLVER_STATUS_INTERNAL_ERROR";
145 case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
146 return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
147 case CUSOLVER_STATUS_NOT_SUPPORTED:
148 return "CUSOLVER_STATUS_NOT_SUPPORTED ";
149 case CUSOLVER_STATUS_ZERO_PIVOT:
150 return "CUSOLVER_STATUS_ZERO_PIVOT";
151 case CUSOLVER_STATUS_INVALID_LICENSE:
152 return "CUSOLVER_STATUS_INVALID_LICENSE";
153 default:
154 return "Unknown cusolverStatus_t message";
155 }
156}
157
158inline void cusolverAssert(cusolverStatus_t code, const char* file, int line)
159{
160 if (code != CUSOLVER_STATUS_SUCCESS)
161 {
162 fprintf(stderr, " Unexpected cuSOLVER Error: %s %s %d\n", cusolverGetErrorEnum(code), file, line);
163 exit(code);
164 }
165}
166
167// cuSOLVER API errors
168static const char* cublasGetErrorEnum(cublasStatus_t error)
169{
170 switch (error)
171 {
172 case CUBLAS_STATUS_SUCCESS:
173 return "CUBLAS_STATUS_SUCCESS";
174 case CUBLAS_STATUS_NOT_INITIALIZED:
175 return "CUBLAS_STATUS_NOT_INITIALIZED";
176 case CUBLAS_STATUS_ALLOC_FAILED:
177 return "CUBLAS_STATUS_ALLOC_FAILED";
178 case CUBLAS_STATUS_INVALID_VALUE:
179 return "CUBLAS_STATUS_INVALID_VALUE";
180 case CUBLAS_STATUS_ARCH_MISMATCH:
181 return "CUBLAS_STATUS_ARCH_MISMATCH";
182 case CUBLAS_STATUS_MAPPING_ERROR:
183 return "CUBLAS_STATUS_MAPPING_ERROR";
184 case CUBLAS_STATUS_EXECUTION_FAILED:
185 return "CUBLAS_STATUS_EXECUTION_FAILED";
186 case CUBLAS_STATUS_INTERNAL_ERROR:
187 return "CUBLAS_STATUS_INTERNAL_ERROR";
188 default:
189 return "Unknown";
190 }
191}
192
193inline void cublasAssert(cublasStatus_t res, const char* file, int line)
194{
195 if (res != CUBLAS_STATUS_SUCCESS)
196 {
197 fprintf(stderr, " Unexpected cuBLAS Error: %s %s %d\n", cublasGetErrorEnum(res), file, line);
198 exit(res);
199 }
200}
201
202#define cusolverErrcheck(res) \
203 { \
204 cusolverAssert((res), __FILE__, __LINE__); \
205 }
206
207#define cublasErrcheck(res) \
208 { \
209 cublasAssert((res), __FILE__, __LINE__); \
210 }
211
212// CUDA API errors
213#define cudaErrcheck(res) \
214 { \
215 if (res != cudaSuccess) \
216 { \
217 fprintf(stderr, " Unexpected Device Error %s:%d: %s, %s\n", __FILE__, __LINE__, cudaGetErrorName(res), \
218 cudaGetErrorString(res)); \
219 exit(res); \
220 } \
221 }
222
223#ifdef __DEBUG
224#define cudaCheckOnDebug() cudaErrcheck(cudaDeviceSynchronize())
225#else
226#define cudaCheckOnDebug()
227#endif
228
229#endif // BASE_MACROS_CUDA_H_
void cublasAssert(cublasStatus_t res, const char *file, int line)
Definition cuda.h:193
void cusolverAssert(cusolverStatus_t code, const char *file, int line)
Definition cuda.h:158
std::complex< double > complex
Definition diago_cusolver.cpp:13
#define T
Definition exp.cpp:237
file(GLOB ATen_CORE_SRCS "*.cpp") set(ATen_CPU_SRCS $
Definition CMakeLists.txt:1
Definition cuda.h:49
static constexpr cudaDataType cuda_data_type
Definition cuda.h:50
thrust::complex< double > type
Definition cuda.h:26
thrust::complex< float > type
Definition cuda.h:20
Definition cuda.h:13
T type
Definition cuda.h:14