1#ifndef ATEN_CORE_TENSOR_ACCESSOR_H_
2#define ATEN_CORE_TENSOR_ACCESSOR_H_
16#if defined(__CUDACC__) || defined(__HIPCC__)
18struct RestrictPtrTraits {
19 using PtrType =
T* __restrict__;
23template <
typename T,
size_t N,
typename index_t = int64_t,
24 template <
typename U>
class PtrTraits = DefaultPtrTraits>
28 using PtrType =
typename PtrTraits<T>::PtrType;
66template <
typename T,
size_t N,
typename index_t = int64_t,
70 using PtrType =
typename PtrTraits<T>::PtrType;
84template <
typename T,
typename index_t,
85 template <
typename U>
class PtrTraits>
88 using PtrType =
typename PtrTraits<T>::PtrType;
Definition array_ref.h:12
Definition tensor_accessor.h:25
AT_HOST_DEVICE PtrType data()
Definition tensor_accessor.h:52
T * data_
Definition tensor_accessor.h:61
AT_HOST_DEVICE const PtrType data() const
Definition tensor_accessor.h:56
AT_HOST_DEVICE index_t size(index_t idx) const
Definition tensor_accessor.h:48
AT_HOST int_array_ref strides() const
Definition tensor_accessor.h:40
const index_t * strides_
Definition tensor_accessor.h:63
typename PtrTraits< T >::PtrType PtrType
Definition tensor_accessor.h:28
AT_HOST_DEVICE index_t stride(index_t idx) const
Definition tensor_accessor.h:44
AT_HOST_DEVICE TensorAccessorBase(PtrType data, const index_t *sizes, const index_t *strides)
Definition tensor_accessor.h:30
AT_HOST int_array_ref sizes() const
Definition tensor_accessor.h:36
const index_t * sizes_
Definition tensor_accessor.h:62
typename PtrTraits< T >::PtrType PtrType
Definition tensor_accessor.h:88
AT_HOST_DEVICE const T & operator[](index_t idx) const
Definition tensor_accessor.h:96
AT_HOST_DEVICE TensorAccessor(T *data, const index_t *sizes, const index_t *strides)
Definition tensor_accessor.h:89
AT_HOST_DEVICE T & operator[](index_t idx)
Definition tensor_accessor.h:92
Definition tensor_accessor.h:68
AT_HOST_DEVICE TensorAccessor< T, N - 1, index_t, PtrTraits > operator[](index_t idx)
Definition tensor_accessor.h:75
typename PtrTraits< T >::PtrType PtrType
Definition tensor_accessor.h:70
AT_HOST_DEVICE TensorAccessor(PtrType data, const index_t *sizes, const index_t *strides)
Definition tensor_accessor.h:72
AT_HOST_DEVICE const TensorAccessor< T, N - 1, index_t, PtrTraits > operator[](index_t idx) const
Definition tensor_accessor.h:79
#define N
Definition exp.cpp:24
#define T
Definition exp.cpp:237
#define AT_HOST_DEVICE
Definition macros.h:40
#define AT_HOST
Definition macros.h:38
Definition tensor_accessor.h:12
T * PtrType
Definition tensor_accessor.h:13