5#ifndef ATEN_CORE_TENSOR_TYPES_H_
6#define ATEN_CORE_TENSOR_TYPES_H_
18#include <unordered_map>
19#include <initializer_list>
23#if defined(__CUDACC__)
25#elif defined(__HIPCC__)
31template <
typename T,
int Accuracy>
32static inline bool element_compare(
T& a,
T& b) {
34 return (a == b) || (std::norm(a - b) < 1e-7);
36 else if (Accuracy <= 8) {
37 return (a == b) || (std::norm(a - b) < 1e-15);
136 using type = base_device::DEVICE_CPU;
141 using type = base_device::DEVICE_GPU;
223#if defined(__CUDACC__) || defined(__HIPCC__)
230struct DataTypeToEnum<thrust::
complex<double>> {
std::complex< double > complex
Definition diago_cusolver.cpp:13
#define T
Definition exp.cpp:237
DataType
Enumeration of data types for tensors. The DataType enum lists the supported data types for tensors....
Definition tensor_types.h:50
@ DT_COMPLEX
32-bit complex */
@ DT_INT64
64-bit integer */
@ DT_FLOAT
Single-precision floating point */.
@ DT_INT
32-bit integer */
@ DT_DOUBLE
Double-precision floating point */.
@ DT_INVALID
Invalid data type */.
DeviceType
The type of memory used by an allocator.
Definition tensor_types.h:73
@ UnKnown
Memory type is unknown.
@ CpuDevice
Memory type is CPU.
@ GpuDevice
Memory type is GPU(CUDA or ROCm).
std::ostream & operator<<(std::ostream &os, const Tensor &tensor)
Overloaded operator<< for the Tensor class.
Definition tensor.cpp:329
base_device::DEVICE_CPU type
Definition tensor_types.h:136
base_device::DEVICE_GPU type
Definition tensor_types.h:141
Definition tensor_types.h:130
T type
Definition tensor_types.h:131
A tag type for identifying CPU and GPU devices.
Definition tensor_types.h:68
Definition tensor_types.h:69
Template struct for mapping a DataType to its corresponding enum value.
Definition tensor_types.h:194
static constexpr DataType value
Definition tensor_types.h:195
Template struct for mapping a Device Type to its corresponding enum value.
Definition tensor_types.h:158
static constexpr DeviceType value
Definition tensor_types.h:159
double type
Definition tensor_types.h:109
float type
Definition tensor_types.h:99
Template struct to determine the return type based on the input type.
Definition tensor_types.h:88
T type
Definition tensor_types.h:89
Definition tensor_types.h:113
T type
Definition tensor_types.h:114