ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
tensor_accessor.h
Go to the documentation of this file.
1#ifndef ATEN_CORE_TENSOR_ACCESSOR_H_
2#define ATEN_CORE_TENSOR_ACCESSOR_H_
3
4#include <cstddef> // Include the <cstddef> header file to define size_t
5#include <cstdint>
8
9namespace container {
10
11template <typename T>
13 using PtrType = T*;
14};
15
16#if defined(__CUDACC__) || defined(__HIPCC__)
17template <typename T>
18struct RestrictPtrTraits {
19 using PtrType = T* __restrict__;
20};
21#endif
22
23template <typename T, size_t N, typename index_t = int64_t,
24 template <typename U> class PtrTraits = DefaultPtrTraits>
26 public:
27
28 using PtrType = typename PtrTraits<T>::PtrType;
29
32 const index_t* sizes,
33 const index_t* strides)
35
37 return {sizes_, N};
38 }
39
41 return {strides_, N};
42 }
43
44 AT_HOST_DEVICE index_t stride(index_t idx) const {
45 return strides_[idx];
46 }
47
48 AT_HOST_DEVICE index_t size(index_t idx) const {
49 return sizes_[idx];
50 }
51
53 return data_;
54 }
55
56 AT_HOST_DEVICE const PtrType data() const {
57 return data_;
58 }
59
60 protected:
62 const index_t* sizes_;
63 const index_t* strides_;
64};
65
66template <typename T, size_t N, typename index_t = int64_t,
67 template <typename U> class PtrTraits = DefaultPtrTraits>
68class TensorAccessor : public TensorAccessorBase<T, N, index_t, PtrTraits> {
69 public:
70 using PtrType = typename PtrTraits<T>::PtrType;
71
72 AT_HOST_DEVICE TensorAccessor(PtrType data, const index_t* sizes, const index_t* strides)
73 : TensorAccessorBase<T, N, index_t, PtrTraits>(data, sizes, strides) {}
74
75 AT_HOST_DEVICE TensorAccessor<T, N - 1, index_t, PtrTraits> operator[](index_t idx) {
76 return TensorAccessor<T, N - 1, index_t, PtrTraits>(this->data_ + idx * this->strides_[0], this->sizes_ + 1, this->strides_ + 1);
77 }
78
79 AT_HOST_DEVICE const TensorAccessor<T, N - 1, index_t, PtrTraits> operator[](index_t idx) const {
80 return TensorAccessor<T, N - 1, index_t, PtrTraits>(this->data_ + idx * this->strides_[0], this->sizes_ + 1, this->strides_ + 1);
81 }
82};
83
84template <typename T, typename index_t,
85 template <typename U> class PtrTraits>
86class TensorAccessor<T, 1, index_t, PtrTraits> : public TensorAccessorBase<T, 1, index_t, PtrTraits> {
87 public:
88 using PtrType = typename PtrTraits<T>::PtrType;
89 AT_HOST_DEVICE TensorAccessor(T* data, const index_t* sizes, const index_t* strides)
90 : TensorAccessorBase<T, 1, index_t, PtrTraits>(data, sizes, strides) {}
91
92 AT_HOST_DEVICE T& operator[](index_t idx) {
93 return this->data_[idx * this->strides_[0]];
94 }
95
96 AT_HOST_DEVICE const T& operator[](index_t idx) const {
97 return this->data_[idx * this->strides_[0]];
98 }
99};
100
101
102} // namespace container
103
104#endif // ATEN_CORE_TENSOR_ACCESSOR_H_
Definition array_ref.h:12
Definition tensor_accessor.h:25
AT_HOST_DEVICE PtrType data()
Definition tensor_accessor.h:52
T * data_
Definition tensor_accessor.h:61
AT_HOST_DEVICE const PtrType data() const
Definition tensor_accessor.h:56
AT_HOST_DEVICE index_t size(index_t idx) const
Definition tensor_accessor.h:48
AT_HOST int_array_ref strides() const
Definition tensor_accessor.h:40
const index_t * strides_
Definition tensor_accessor.h:63
typename PtrTraits< T >::PtrType PtrType
Definition tensor_accessor.h:28
AT_HOST_DEVICE index_t stride(index_t idx) const
Definition tensor_accessor.h:44
AT_HOST_DEVICE TensorAccessorBase(PtrType data, const index_t *sizes, const index_t *strides)
Definition tensor_accessor.h:30
AT_HOST int_array_ref sizes() const
Definition tensor_accessor.h:36
const index_t * sizes_
Definition tensor_accessor.h:62
typename PtrTraits< T >::PtrType PtrType
Definition tensor_accessor.h:88
AT_HOST_DEVICE const T & operator[](index_t idx) const
Definition tensor_accessor.h:96
AT_HOST_DEVICE TensorAccessor(T *data, const index_t *sizes, const index_t *strides)
Definition tensor_accessor.h:89
AT_HOST_DEVICE T & operator[](index_t idx)
Definition tensor_accessor.h:92
Definition tensor_accessor.h:68
AT_HOST_DEVICE TensorAccessor< T, N - 1, index_t, PtrTraits > operator[](index_t idx)
Definition tensor_accessor.h:75
typename PtrTraits< T >::PtrType PtrType
Definition tensor_accessor.h:70
AT_HOST_DEVICE TensorAccessor(PtrType data, const index_t *sizes, const index_t *strides)
Definition tensor_accessor.h:72
AT_HOST_DEVICE const TensorAccessor< T, N - 1, index_t, PtrTraits > operator[](index_t idx) const
Definition tensor_accessor.h:79
#define N
Definition exp.cpp:24
#define T
Definition exp.cpp:237
#define AT_HOST_DEVICE
Definition macros.h:40
#define AT_HOST
Definition macros.h:38
Definition tensor.cpp:8
Definition tensor_accessor.h:12
T * PtrType
Definition tensor_accessor.h:13