1#ifndef BASE_MACROS_ROCM_H_
2#define BASE_MACROS_ROCM_H_
5#include <hip/hip_runtime.h>
6#include <hipblas/hipblas.h>
7#include <hipsolver/hipsolver.h>
9#if defined(__HCC__) || defined(__HIP__)
10#include <thrust/complex.h>
13#define THREADS_PER_BLOCK 256
15#if defined(__HCC__) || defined(__HIP__)
25 using type = thrust::complex<float>;
31 using type = thrust::complex<double>;
35static inline hipblasOperation_t GetHipblasOperation(
const char& trans)
37 hipblasOperation_t hip_trans = {};
40 hip_trans = HIPBLAS_OP_N;
42 else if (trans ==
'T')
44 hip_trans = HIPBLAS_OP_T;
46 else if (trans ==
'C')
48 hip_trans = HIPBLAS_OP_C;
53 hip_trans = HIPBLAS_OP_N;
96static inline hipblasFillMode_t hipblas_fill_mode(
const char& uplo)
98 if (uplo ==
'U' || uplo ==
'u')
99 return HIPBLAS_FILL_MODE_UPPER;
100 else if (uplo ==
'L' || uplo ==
'l')
101 return HIPBLAS_FILL_MODE_LOWER;
103 throw std::runtime_error(
"hipblas_fill_mode: unknown uplo");
106static inline hipblasDiagType_t hipblas_diag_type(
const char& diag)
108 if (diag ==
'U' || diag ==
'u')
109 return HIPBLAS_DIAG_UNIT;
110 else if (diag ==
'N' || diag ==
'n')
111 return HIPBLAS_DIAG_NON_UNIT;
113 throw std::runtime_error(
"hipblas_diag_type: unknown diag");
116static inline hipsolverEigMode_t hipblas_eig_mode(
const char& jobz)
118 if (jobz ==
'N' || jobz ==
'n')
119 return HIPSOLVER_EIG_MODE_NOVECTOR;
120 else if (jobz ==
'V' || jobz ==
'v')
121 return HIPSOLVER_EIG_MODE_VECTOR;
123 throw std::runtime_error(
"hipblas_eig_mode: unknown diag");
126static inline hipsolverEigType_t hipblas_eig_type(
const int& itype)
129 return HIPSOLVER_EIG_TYPE_1;
131 return HIPSOLVER_EIG_TYPE_2;
133 throw std::runtime_error(
"hipblas_eig_mode: unknown diag");
136static inline hipsolverFillMode_t hipsolver_fill_mode(
const char& uplo)
138 if (uplo ==
'U' || uplo ==
'u')
139 return HIPSOLVER_FILL_MODE_UPPER;
140 else if (uplo ==
'L' || uplo ==
'l')
141 return HIPSOLVER_FILL_MODE_LOWER;
143 throw std::runtime_error(
"hipsolver_fill_mode: unknown uplo");
147static const char* hipsolverGetErrorEnum(hipsolverStatus_t error)
151 case HIPSOLVER_STATUS_SUCCESS:
152 return "HIPSOLVER_STATUS_SUCCESS";
153 case HIPSOLVER_STATUS_NOT_INITIALIZED:
154 return "HIPSOLVER_STATUS_NOT_INITIALIZED";
155 case HIPSOLVER_STATUS_ALLOC_FAILED:
156 return "HIPSOLVER_STATUS_ALLOC_FAILED";
157 case HIPSOLVER_STATUS_INVALID_VALUE:
158 return "HIPSOLVER_STATUS_INVALID_VALUE";
159 case HIPSOLVER_STATUS_ARCH_MISMATCH:
160 return "HIPSOLVER_STATUS_ARCH_MISMATCH";
161 case HIPSOLVER_STATUS_MAPPING_ERROR:
162 return "HIPSOLVER_STATUS_MAPPING_ERROR";
163 case HIPSOLVER_STATUS_EXECUTION_FAILED:
164 return "HIPSOLVER_STATUS_EXECUTION_FAILED";
165 case HIPSOLVER_STATUS_INTERNAL_ERROR:
166 return "HIPSOLVER_STATUS_INTERNAL_ERROR";
167 case HIPSOLVER_STATUS_NOT_SUPPORTED:
168 return "HIPSOLVER_STATUS_NOT_SUPPORTED ";
169 case HIPSOLVER_STATUS_INVALID_ENUM:
170 return "HIPSOLVER_STATUS_INVALID_ENUM";
172 return "Unknown hipsolverStatus_t message";
178 if (code != HIPSOLVER_STATUS_SUCCESS)
180 fprintf(stderr,
"hipSOLVER Assert: %s %s %d\n", hipsolverGetErrorEnum(code),
file, line);
187static const char* hipblasGetErrorEnum(hipblasStatus_t error)
191 case HIPBLAS_STATUS_SUCCESS:
192 return "HIPBLAS_STATUS_SUCCESS";
193 case HIPBLAS_STATUS_NOT_INITIALIZED:
194 return "HIPBLAS_STATUS_NOT_INITIALIZED";
195 case HIPBLAS_STATUS_ALLOC_FAILED:
196 return "HIPBLAS_STATUS_ALLOC_FAILED";
197 case HIPBLAS_STATUS_INVALID_VALUE:
198 return "HIPBLAS_STATUS_INVALID_VALUE";
199 case HIPBLAS_STATUS_ARCH_MISMATCH:
200 return "HIPBLAS_STATUS_ARCH_MISMATCH";
201 case HIPBLAS_STATUS_MAPPING_ERROR:
202 return "HIPBLAS_STATUS_MAPPING_ERROR";
203 case HIPBLAS_STATUS_EXECUTION_FAILED:
204 return "HIPBLAS_STATUS_EXECUTION_FAILED";
205 case HIPBLAS_STATUS_INTERNAL_ERROR:
206 return "HIPBLAS_STATUS_INTERNAL_ERROR";
214 if (code != HIPBLAS_STATUS_SUCCESS)
216 fprintf(stderr,
"Unexpected hipBLAS Error: %s %s %d\n", hipblasGetErrorEnum(code),
file, line);
222#define hipsolverErrcheck(res) \
224 hipsolverAssert((res), __FILE__, __LINE__); \
227#define hipblasErrcheck(res) \
229 hipblasAssert((res), __FILE__, __LINE__); \
233#define hipErrcheck(res) \
235 if (res != hipSuccess) \
237 fprintf(stderr, " Unexpected Device Error %s:%d: %s, %s\n", __FILE__, __LINE__, hipGetErrorName(res), \
238 hipGetErrorString(res)); \
244#define hipCheckOnDebug() hipErrcheck(hipDeviceSynchronize())
246#define hipCheckOnDebug()
std::complex< double > complex
Definition diago_cusolver.cpp:13
#define T
Definition exp.cpp:237
void hipblasAssert(hipblasStatus_t code, const char *file, int line, bool abort=true)
Definition rocm.h:212
void hipsolverAssert(hipsolverStatus_t code, const char *file, int line, bool abort=true)
Definition rocm.h:176
file(GLOB ATen_CORE_SRCS "*.cpp") set(ATen_CPU_SRCS $
Definition CMakeLists.txt:1
static constexpr hipDataType hip_data_type
Definition rocm.h:61
T type
Definition cuda.h:14