1#ifndef BASE_MACROS_CUDA_H_
2#define BASE_MACROS_CUDA_H_
5#include <cuda_runtime.h>
7#include <thrust/complex.h>
9#define THREADS_PER_BLOCK 256
20 using type = thrust::complex<float>;
26 using type = thrust::complex<double>;
29static inline cublasOperation_t GetCublasOperation(
const char& trans)
31 cublasOperation_t cutrans = {};
34 cutrans = CUBLAS_OP_N;
36 else if (trans ==
'T')
38 cutrans = CUBLAS_OP_T;
40 else if (trans ==
'C')
42 cutrans = CUBLAS_OP_C;
84static inline cublasFillMode_t cublas_fill_mode(
const char& uplo)
86 if (uplo ==
'U' || uplo ==
'u')
87 return CUBLAS_FILL_MODE_UPPER;
88 else if (uplo ==
'L' || uplo ==
'l')
89 return CUBLAS_FILL_MODE_LOWER;
91 throw std::runtime_error(
"cublas_fill_mode: unknown uplo");
94static inline cublasDiagType_t cublas_diag_type(
const char& diag)
96 if (diag ==
'U' || diag ==
'u')
97 return CUBLAS_DIAG_UNIT;
98 else if (diag ==
'N' || diag ==
'n')
99 return CUBLAS_DIAG_NON_UNIT;
101 throw std::runtime_error(
"cublas_diag_type: unknown diag");
104static inline cusolverEigMode_t cublas_eig_mode(
const char& jobz)
106 if (jobz ==
'N' || jobz ==
'n')
107 return CUSOLVER_EIG_MODE_NOVECTOR;
108 else if (jobz ==
'V' || jobz ==
'v')
109 return CUSOLVER_EIG_MODE_VECTOR;
111 throw std::runtime_error(
"cublas_eig_mode: unknown diag");
114static inline cusolverEigType_t cublas_eig_type(
const int& itype)
117 return CUSOLVER_EIG_TYPE_1;
119 return CUSOLVER_EIG_TYPE_2;
121 throw std::runtime_error(
"cublas_eig_mode: unknown diag");
125static const char* cusolverGetErrorEnum(cusolverStatus_t error)
129 case CUSOLVER_STATUS_SUCCESS:
130 return "CUSOLVER_STATUS_SUCCESS";
131 case CUSOLVER_STATUS_NOT_INITIALIZED:
132 return "CUSOLVER_STATUS_NOT_INITIALIZED";
133 case CUSOLVER_STATUS_ALLOC_FAILED:
134 return "CUSOLVER_STATUS_ALLOC_FAILED";
135 case CUSOLVER_STATUS_INVALID_VALUE:
136 return "CUSOLVER_STATUS_INVALID_VALUE";
137 case CUSOLVER_STATUS_ARCH_MISMATCH:
138 return "CUSOLVER_STATUS_ARCH_MISMATCH";
139 case CUSOLVER_STATUS_MAPPING_ERROR:
140 return "CUSOLVER_STATUS_MAPPING_ERROR";
141 case CUSOLVER_STATUS_EXECUTION_FAILED:
142 return "CUSOLVER_STATUS_EXECUTION_FAILED";
143 case CUSOLVER_STATUS_INTERNAL_ERROR:
144 return "CUSOLVER_STATUS_INTERNAL_ERROR";
145 case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
146 return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
147 case CUSOLVER_STATUS_NOT_SUPPORTED:
148 return "CUSOLVER_STATUS_NOT_SUPPORTED ";
149 case CUSOLVER_STATUS_ZERO_PIVOT:
150 return "CUSOLVER_STATUS_ZERO_PIVOT";
151 case CUSOLVER_STATUS_INVALID_LICENSE:
152 return "CUSOLVER_STATUS_INVALID_LICENSE";
154 return "Unknown cusolverStatus_t message";
160 if (code != CUSOLVER_STATUS_SUCCESS)
162 fprintf(stderr,
" Unexpected cuSOLVER Error: %s %s %d\n", cusolverGetErrorEnum(code),
file, line);
168static const char* cublasGetErrorEnum(cublasStatus_t error)
172 case CUBLAS_STATUS_SUCCESS:
173 return "CUBLAS_STATUS_SUCCESS";
174 case CUBLAS_STATUS_NOT_INITIALIZED:
175 return "CUBLAS_STATUS_NOT_INITIALIZED";
176 case CUBLAS_STATUS_ALLOC_FAILED:
177 return "CUBLAS_STATUS_ALLOC_FAILED";
178 case CUBLAS_STATUS_INVALID_VALUE:
179 return "CUBLAS_STATUS_INVALID_VALUE";
180 case CUBLAS_STATUS_ARCH_MISMATCH:
181 return "CUBLAS_STATUS_ARCH_MISMATCH";
182 case CUBLAS_STATUS_MAPPING_ERROR:
183 return "CUBLAS_STATUS_MAPPING_ERROR";
184 case CUBLAS_STATUS_EXECUTION_FAILED:
185 return "CUBLAS_STATUS_EXECUTION_FAILED";
186 case CUBLAS_STATUS_INTERNAL_ERROR:
187 return "CUBLAS_STATUS_INTERNAL_ERROR";
195 if (res != CUBLAS_STATUS_SUCCESS)
197 fprintf(stderr,
" Unexpected cuBLAS Error: %s %s %d\n", cublasGetErrorEnum(res),
file, line);
202#define cusolverErrcheck(res) \
204 cusolverAssert((res), __FILE__, __LINE__); \
207#define cublasErrcheck(res) \
209 cublasAssert((res), __FILE__, __LINE__); \
213#define cudaErrcheck(res) \
215 if (res != cudaSuccess) \
217 fprintf(stderr, " Unexpected Device Error %s:%d: %s, %s\n", __FILE__, __LINE__, cudaGetErrorName(res), \
218 cudaGetErrorString(res)); \
224#define cudaCheckOnDebug() cudaErrcheck(cudaDeviceSynchronize())
226#define cudaCheckOnDebug()
void cublasAssert(cublasStatus_t res, const char *file, int line)
Definition cuda.h:193
void cusolverAssert(cusolverStatus_t code, const char *file, int line)
Definition cuda.h:158
std::complex< double > complex
Definition diago_cusolver.cpp:13
#define T
Definition exp.cpp:237
file(GLOB ATen_CORE_SRCS "*.cpp") set(ATen_CPU_SRCS $
Definition CMakeLists.txt:1
static constexpr cudaDataType cuda_data_type
Definition cuda.h:50
thrust::complex< double > type
Definition cuda.h:26
thrust::complex< float > type
Definition cuda.h:20
T type
Definition cuda.h:14