1#ifndef BASE_MACROS_CUDA_H_
2#define BASE_MACROS_CUDA_H_
5#include <cuda_runtime.h>
7#include <thrust/complex.h>
11#define THREADS_PER_BLOCK 256
22 using type = thrust::complex<float>;
28 using type = thrust::complex<double>;
31static inline cublasOperation_t GetCublasOperation(
const char& trans)
33 cublasOperation_t cutrans = {};
36 cutrans = CUBLAS_OP_N;
38 else if (trans ==
'T')
40 cutrans = CUBLAS_OP_T;
42 else if (trans ==
'C')
44 cutrans = CUBLAS_OP_C;
86static inline cublasFillMode_t cublas_fill_mode(
const char& uplo)
88 if (uplo ==
'U' || uplo ==
'u')
89 return CUBLAS_FILL_MODE_UPPER;
90 else if (uplo ==
'L' || uplo ==
'l')
91 return CUBLAS_FILL_MODE_LOWER;
93 throw std::runtime_error(
"cublas_fill_mode: unknown uplo");
96static inline cublasDiagType_t cublas_diag_type(
const char& diag)
98 if (diag ==
'U' || diag ==
'u')
99 return CUBLAS_DIAG_UNIT;
100 else if (diag ==
'N' || diag ==
'n')
101 return CUBLAS_DIAG_NON_UNIT;
103 throw std::runtime_error(
"cublas_diag_type: unknown diag");
106static inline cusolverEigMode_t cublas_eig_mode(
const char& jobz)
108 if (jobz ==
'N' || jobz ==
'n')
109 return CUSOLVER_EIG_MODE_NOVECTOR;
110 else if (jobz ==
'V' || jobz ==
'v')
111 return CUSOLVER_EIG_MODE_VECTOR;
113 throw std::runtime_error(
"cublas_eig_mode: unknown diag");
116static inline cusolverEigType_t cublas_eig_type(
const int& itype)
119 return CUSOLVER_EIG_TYPE_1;
121 return CUSOLVER_EIG_TYPE_2;
123 throw std::runtime_error(
"cublas_eig_mode: unknown diag");
137static inline cusolverEigRange_t cublas_eig_range(
const char& range)
139 if (range ==
'A' || range ==
'a')
140 return CUSOLVER_EIG_RANGE_ALL;
141 else if (range ==
'V' || range ==
'v')
142 return CUSOLVER_EIG_RANGE_V;
143 else if (range ==
'I' || range ==
'i')
144 return CUSOLVER_EIG_RANGE_I;
146 throw std::runtime_error(
"cublas_eig_range: unknown range '" + std::string(1, range) +
"'");
std::complex< double > complex
Definition diago_cusolver.cpp:15
#define T
Definition exp.cpp:237
static constexpr cudaDataType cuda_data_type
Definition cuda.h:52
thrust::complex< double > type
Definition cuda.h:28
thrust::complex< float > type
Definition cuda.h:22
T type
Definition cuda.h:16