ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
ndarray.h
Go to the documentation of this file.
1
12#ifndef NDARRAY_H
13#define NDARRAY_H
14
15#include <vector>
16#include <cassert>
17#include <iostream>
18#include <algorithm>
19#include <type_traits>
20#include <numeric>
21// for heterogeneous computing, we can use ATen::Tensor
22//#include "./module_container/ATen/tensor.h"
23
29template<typename T>
31{
32 // align with STL container implementation, there are several functions compulsory to be implemented
33 // constructor: default, copy, move, initializer_list
34 // operator: =, ==, !=, <, <=, >, >=
35 // iterator: begin, cbegin, end, cend
36 // capacity: size, empty, max_size, reserve, shrink_to_fit
37 // element access: [], at, front, back, data
38 // modifiers: clear, insert, erase, push_back, pop_back, resize, swap
39 // allocator: get_allocator
40public:
41 // constructors
46 NDArray()= delete;
47 // initializer_list constructor
48 NDArray(std::initializer_list<size_t> il) : shape_(il), data_(std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<size_t>())) {}
49 NDArray(std::initializer_list<int> il) : shape_(il.begin(), il.end()), data_(std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<size_t>())) {}
50 // variadic template constructor, (delegate constructor)
51 template<typename... Args> NDArray(const size_t idx, Args... args) : NDArray({idx, static_cast<size_t>(args)...}) {}
52 template<typename... Args> NDArray(const int& idx, Args... args) : NDArray({idx, static_cast<int>(args)...}) {} // not happy with this because size_t can have larger range
53 // copy constructor
54 NDArray(const NDArray& other) : data_(other.data_), shape_(other.shape_) {}
55 // move constructor
56 NDArray(NDArray&& other) : data_(std::move(other.data_)), shape_(std::move(other.shape_)) {}
57
58 // destructor
60
61 // operators
68 NDArray& operator=(const NDArray& other)
69 {
70 if (this != &other)
71 {
72 data_ = other.data_;
73 shape_ = other.shape_;
74 }
75 return *this;
76 }
81 {
82 if (this != &other)
83 {
84 data_ = std::move(other.data_);
85 shape_ = std::move(other.shape_);
86 }
87 return *this;
88 }
89
97 bool operator==(const NDArray& other) const { return data_ == other.data_ && shape_ == other.shape_; }
105 bool operator!=(const NDArray& other) const { return !(*this == other); }
106 // other operators are not generally supported
107
108 // element access
116 template<typename... Args> T& at(const size_t idx, Args... args) { return data_[index(idx, args...)]; }
117 template<typename... Args> const T& at(const size_t idx, Args... args) const { return data_[index(idx, args...)]; }
125 template<typename... Args> T& operator()(const size_t idx, Args... args) { return data_[index(idx, args...)]; }
126 template<typename... Args> const T& operator()(const size_t idx, Args... args) const { return data_[index(idx, args...)]; }
127 // front
128 T& front() { return data_.front(); }
129 const T& front() const { return data_.front(); }
130 // back
131 T& back() { return data_.back(); }
132 const T& back() const { return data_.back(); }
133 // data
134 T* data() { return data_.data(); }
135 const T* data() const { return data_.data(); }
136
137 // iterators
138 // iterators on the whole data
139 T* begin() { return data_.data(); }
140 T* end() { return data_.data() + data_.size(); }
141 const T* cbegin() const { return data_.data(); }
142 const T* cend() const { return data_.data() + data_.size(); }
143 // iterators on different dimensions
144
145 // capacity
146 // size
147 size_t size() const { return data_.size(); }
148 size_t size(const size_t& dim) const { return shape_.at(dim); }
149 // empty
150 bool empty() const { return data_.empty(); }
151 // multi-dimensional specific
152 // shape
153 const std::vector<size_t>& shape() const { return shape_; }
154 // reshape
155 template<typename... Args>
156 void reshape(Args... args)
157 {
158 // DEVELP WARNING: what if arg = -2? :)
159 // save args into a vector
160 //std::vector<int64_t> dims = {static_cast<int64_t>(args)...};
161 std::vector<int64_t> dims = {args...};
162 // assert number of -1 in dims is at most 1
163 // -1 is not type-safe!!!
164 size_t count = std::count_if(dims.begin(), dims.end(), [](size_t i) { return i == -1; });
165 assert(count <= 1);
166 // if there is -1, calculate the size
167 if (count == 1)
168 {
169 size_t size = 1;
170 for (size_t i = 0; i < dims.size(); ++i)
171 {
172 if (dims[i] != -1)
173 {
174 size *= dims[i];
175 }
176 }
177 size_t idx = std::find(dims.begin(), dims.end(), -1) - dims.begin();
178 dims[idx] = data_.size() / size;
179 }
180 // calculate the size
181 size_t size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
182 // assert size is the same
183 assert(size == data_.size());
184 // assign dims to shape_
185 std::copy(dims.begin(), dims.end(), shape_.begin());
186 }
187
188 // interface to ATen::Tensor, but constraint to int, double, float, std::complex<float>, std::complex<double>
194 // std::enable_if<
195 // std::is_same<T, int>::value
196 // || std::is_same<T, double>::value
197 // || std::is_same<T, float>::value
198 // || std::is_same<T, std::complex<float>>::value
199 // || std::is_same<T, std::complex<double>>::value, container::Tensor
200 // >::type to_tensor() const
201 // {
202 // container::TensorShape shape(shape_);
203 // container::Tensor result = container::Tensor(container::DataTypeToEnum<T>::value, shape);
204 // std::memcpy(result.data<T>(), data_.data(), data_.size() * sizeof(T));
205 // return result;
206 // }
207 template<typename... Args>
208 size_t index(const size_t idx, Args... args) const
209 {
210 assert(sizeof...(args) == shape_.size() - 1); // assert the indices are the same as the shape
211 size_t indices[] = {idx, static_cast<size_t>(args)...};
212 size_t index = 0;
213 for (size_t i = 0; i < shape_.size(); ++i)
214 {
215 index += indices[i] * std::accumulate(shape_.begin() + i + 1, shape_.end(), 1, std::multiplies<size_t>());
216 }
217 assert(index < data_.size()); // assert the index is within the data
218 return index;
219 }
220private:
221 std::vector<size_t> shape_;
222 // for GPU-compatible data container, will be replaced by raw pointer
223 std::vector<T> data_;
224};
225
226#endif // NDARRAY_H
under the restriction of C++11, a simple alternative to std::vector<T> + std::mdspan....
Definition ndarray.h:31
const T & front() const
Definition ndarray.h:129
const T & operator()(const size_t idx, Args... args) const
Definition ndarray.h:126
NDArray(const NDArray &other)
Definition ndarray.h:54
NDArray(const int &idx, Args... args)
Definition ndarray.h:52
NDArray()=delete
Construct a new NDArray object.
T & front()
Definition ndarray.h:128
T * data()
Definition ndarray.h:134
std::vector< T > data_
Definition ndarray.h:223
T * end()
Definition ndarray.h:140
NDArray(std::initializer_list< size_t > il)
Definition ndarray.h:48
NDArray(std::initializer_list< int > il)
Definition ndarray.h:49
std::vector< size_t > shape_
Definition ndarray.h:221
T & operator()(const size_t idx, Args... args)
[] operator
Definition ndarray.h:125
bool operator!=(const NDArray &other) const
!= operator
Definition ndarray.h:105
const T & back() const
Definition ndarray.h:132
NDArray & operator=(NDArray &&other)
= operator, move assignment
Definition ndarray.h:80
size_t index(const size_t idx, Args... args) const
SFINAE (Substitution Failure Is Not An Error) to_tensor function, only if T is int,...
Definition ndarray.h:208
T & back()
Definition ndarray.h:131
const T * cend() const
Definition ndarray.h:142
T * begin()
Definition ndarray.h:139
bool empty() const
Definition ndarray.h:150
const T & at(const size_t idx, Args... args) const
Definition ndarray.h:117
void reshape(Args... args)
Definition ndarray.h:156
const std::vector< size_t > & shape() const
Definition ndarray.h:153
NDArray(const size_t idx, Args... args)
Definition ndarray.h:51
size_t size() const
Definition ndarray.h:147
NDArray & operator=(const NDArray &other)
= operator, copy assignment
Definition ndarray.h:68
NDArray(NDArray &&other)
Definition ndarray.h:56
const T * cbegin() const
Definition ndarray.h:141
const T * data() const
Definition ndarray.h:135
bool operator==(const NDArray &other) const
== operator
Definition ndarray.h:97
~NDArray()
Definition ndarray.h:59
T & at(const size_t idx, Args... args)
at function
Definition ndarray.h:116
size_t size(const size_t &dim) const
Definition ndarray.h:148
#define T
Definition exp.cpp:237