1#ifndef BASE_MACROS_CUDA_H_
2#define BASE_MACROS_CUDA_H_
5#include <cuda_runtime.h>
7#include <thrust/complex.h>
9#define THREADS_PER_BLOCK 256
20 using type = thrust::complex<float>;
26 using type = thrust::complex<double>;
29static inline cublasOperation_t GetCublasOperation(
const char& trans)
31 cublasOperation_t cutrans = {};
34 cutrans = CUBLAS_OP_N;
36 else if (trans ==
'T')
38 cutrans = CUBLAS_OP_T;
40 else if (trans ==
'C')
42 cutrans = CUBLAS_OP_C;
84static inline cublasFillMode_t cublas_fill_mode(
const char& uplo)
86 if (uplo ==
'U' || uplo ==
'u')
87 return CUBLAS_FILL_MODE_UPPER;
88 else if (uplo ==
'L' || uplo ==
'l')
89 return CUBLAS_FILL_MODE_LOWER;
91 throw std::runtime_error(
"cublas_fill_mode: unknown uplo");
94static inline cublasDiagType_t cublas_diag_type(
const char& diag)
96 if (diag ==
'U' || diag ==
'u')
97 return CUBLAS_DIAG_UNIT;
98 else if (diag ==
'N' || diag ==
'n')
99 return CUBLAS_DIAG_NON_UNIT;
101 throw std::runtime_error(
"cublas_diag_type: unknown diag");
104static inline cusolverEigMode_t cublas_eig_mode(
const char& jobz)
106 if (jobz ==
'N' || jobz ==
'n')
107 return CUSOLVER_EIG_MODE_NOVECTOR;
108 else if (jobz ==
'V' || jobz ==
'v')
109 return CUSOLVER_EIG_MODE_VECTOR;
111 throw std::runtime_error(
"cublas_eig_mode: unknown diag");
114static inline cusolverEigType_t cublas_eig_type(
const int& itype)
117 return CUSOLVER_EIG_TYPE_1;
119 return CUSOLVER_EIG_TYPE_2;
121 throw std::runtime_error(
"cublas_eig_mode: unknown diag");
135static inline cusolverEigRange_t cublas_eig_range(
const char& range)
137 if (range ==
'A' || range ==
'a')
138 return CUSOLVER_EIG_RANGE_ALL;
139 else if (range ==
'V' || range ==
'v')
140 return CUSOLVER_EIG_RANGE_V;
141 else if (range ==
'I' || range ==
'i')
142 return CUSOLVER_EIG_RANGE_I;
144 throw std::runtime_error(
"cublas_eig_range: unknown range '" + std::string(1, range) +
"'");
148static const char* cusolverGetErrorEnum(cusolverStatus_t error)
152 case CUSOLVER_STATUS_SUCCESS:
153 return "CUSOLVER_STATUS_SUCCESS";
154 case CUSOLVER_STATUS_NOT_INITIALIZED:
155 return "CUSOLVER_STATUS_NOT_INITIALIZED";
156 case CUSOLVER_STATUS_ALLOC_FAILED:
157 return "CUSOLVER_STATUS_ALLOC_FAILED";
158 case CUSOLVER_STATUS_INVALID_VALUE:
159 return "CUSOLVER_STATUS_INVALID_VALUE";
160 case CUSOLVER_STATUS_ARCH_MISMATCH:
161 return "CUSOLVER_STATUS_ARCH_MISMATCH";
162 case CUSOLVER_STATUS_MAPPING_ERROR:
163 return "CUSOLVER_STATUS_MAPPING_ERROR";
164 case CUSOLVER_STATUS_EXECUTION_FAILED:
165 return "CUSOLVER_STATUS_EXECUTION_FAILED";
166 case CUSOLVER_STATUS_INTERNAL_ERROR:
167 return "CUSOLVER_STATUS_INTERNAL_ERROR";
168 case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
169 return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
170 case CUSOLVER_STATUS_NOT_SUPPORTED:
171 return "CUSOLVER_STATUS_NOT_SUPPORTED ";
172 case CUSOLVER_STATUS_ZERO_PIVOT:
173 return "CUSOLVER_STATUS_ZERO_PIVOT";
174 case CUSOLVER_STATUS_INVALID_LICENSE:
175 return "CUSOLVER_STATUS_INVALID_LICENSE";
177 return "Unknown cusolverStatus_t message";
183 if (code != CUSOLVER_STATUS_SUCCESS)
185 fprintf(stderr,
" Unexpected cuSOLVER Error: %s %s %d\n", cusolverGetErrorEnum(code),
file, line);
191static const char* cublasGetErrorEnum(cublasStatus_t error)
195 case CUBLAS_STATUS_SUCCESS:
196 return "CUBLAS_STATUS_SUCCESS";
197 case CUBLAS_STATUS_NOT_INITIALIZED:
198 return "CUBLAS_STATUS_NOT_INITIALIZED";
199 case CUBLAS_STATUS_ALLOC_FAILED:
200 return "CUBLAS_STATUS_ALLOC_FAILED";
201 case CUBLAS_STATUS_INVALID_VALUE:
202 return "CUBLAS_STATUS_INVALID_VALUE";
203 case CUBLAS_STATUS_ARCH_MISMATCH:
204 return "CUBLAS_STATUS_ARCH_MISMATCH";
205 case CUBLAS_STATUS_MAPPING_ERROR:
206 return "CUBLAS_STATUS_MAPPING_ERROR";
207 case CUBLAS_STATUS_EXECUTION_FAILED:
208 return "CUBLAS_STATUS_EXECUTION_FAILED";
209 case CUBLAS_STATUS_INTERNAL_ERROR:
210 return "CUBLAS_STATUS_INTERNAL_ERROR";
218 if (res != CUBLAS_STATUS_SUCCESS)
220 fprintf(stderr,
" Unexpected cuBLAS Error: %s %s %d\n", cublasGetErrorEnum(res),
file, line);
225#define cusolverErrcheck(res) \
227 cusolverAssert((res), __FILE__, __LINE__); \
230#define cublasErrcheck(res) \
232 cublasAssert((res), __FILE__, __LINE__); \
236#define cudaErrcheck(res) \
238 if (res != cudaSuccess) \
240 fprintf(stderr, " Unexpected Device Error %s:%d: %s, %s\n", __FILE__, __LINE__, cudaGetErrorName(res), \
241 cudaGetErrorString(res)); \
247#define cudaCheckOnDebug() cudaErrcheck(cudaDeviceSynchronize())
249#define cudaCheckOnDebug()
void cublasAssert(cublasStatus_t res, const char *file, int line)
Definition cuda.h:216
void cusolverAssert(cusolverStatus_t code, const char *file, int line)
Definition cuda.h:181
std::complex< double > complex
Definition diago_cusolver.cpp:13
#define T
Definition exp.cpp:237
file(GLOB ATen_CORE_SRCS "*.cpp") set(ATen_CPU_SRCS $
Definition CMakeLists.txt:1
static constexpr cudaDataType cuda_data_type
Definition cuda.h:50
thrust::complex< double > type
Definition cuda.h:26
thrust::complex< float > type
Definition cuda.h:20
T type
Definition cuda.h:14