ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
tensor_utils.h
Go to the documentation of this file.
1#ifndef ATEN_CORE_TENSOR_UTILS_H_
2#define ATEN_CORE_TENSOR_UTILS_H_
3
4#include <ATen/core/tensor.h>
6
7namespace container {
8
19__inline__
20std::string removeTrailingZeros(std::string str) {
21 int i = static_cast<int>(str.length()) - 1;
22 while (i >= 0 && str[i] == '0') {
23 i--;
24 }
25 if (i == -1) {
26 return "0";
27 }
28 return str.substr(0, i + 1);
29}
30
47template<typename T>
48__inline__
50 const T* arr,
51 int size,
52 int& integer_count,
53 int& fraction_count)
54{
55 integer_count = 0;
56 fraction_count = 0;
57
58 for (int i = 0; i < size; i++) {
59 int digits = 0;
60 if (arr[i] < 0) {
61 digits = log10(-arr[i]) + 1;
62 if (digits + 1 > integer_count) {
63 integer_count = digits + 1;
64 }
65 }
66 else {
67 digits = log10(arr[i]) + 1;
68 if (digits > integer_count) {
69 integer_count = digits;
70 }
71 }
72
73 T fraction = arr[i] - std::floor(arr[i]);
74 if (fraction == 0) {
75 continue;
76 }
77 std::string str = removeTrailingZeros(std::to_string(fraction));
78 digits = str.length() - str.find('.');
79 if (digits > fraction_count) {
80 fraction_count = digits;
81 }
82 }
83
84 return integer_count + fraction_count;
85}
86
102template<typename T>
103__inline__
105 const std::complex<T>* arr,
106 int size,
107 int& integer_count,
108 int& fraction_count)
109{
110 return _get_digit_places<T>(reinterpret_cast<const T*>(arr), size * 2, integer_count, fraction_count);
111}
112
124template <typename T>
125__inline__
127 std::ostream& os,
128 const T data,
129 const int& digit_width,
130 const int& fraction_count)
131{
132 os << std::setw(digit_width) \
133 << std::setprecision(fraction_count) << std::fixed << data;
134}
135
147template <typename T>
148__inline__
150 std::ostream& os,
151 const std::complex<T> data,
152 const int& digit_width,
153 const int& fraction_count)
154{
155 // Write the real and imaginary parts of the complex value to the output stream
156 // with the specified formatting.
157 os << "{";
158 os << std::setw(digit_width) \
159 << std::setprecision(fraction_count) << std::fixed
160 << data.real();
161 os << ", ";
162 os << std::setw(digit_width) \
163 << std::setprecision(fraction_count) << std::fixed
164 << data.imag();
165 os << "}";
166}
167
177template <>
178__inline__
180 std::ostream& os,
181 const int data,
182 const int& digit_width,
183 const int& fraction_count)
184{
185 os << std::setw(digit_width - 1) \
186 << std::setprecision(fraction_count) << std::fixed << data;
187}
188
203template <typename T>
204__inline__
206 std::ostream& os,
207 const T * data,
208 const TensorShape& shape,
209 const int64_t& num_elements)
210{
211 int integer_count = 0, fraction_count = 0;
212 int digit_width = _get_digit_places(data, num_elements, integer_count, fraction_count) + 1;
213 if (shape.ndim() == 1) {
214 os << "[";
215 for (int i = 0; i < num_elements; ++i) {
216 _output_wrapper(os, data[i], digit_width, fraction_count);
217 if (i != num_elements - 1) {
218 os << ",";
219 }
220 }
221 os << "]";
222 }
223 else if (shape.ndim() == 2) {
224 os << "[";
225 for (int i = 0; i < shape.dim_size(0); ++i) {
226 if (i != 0) os << " ";
227 os << "[";
228 for (int j = 0; j < shape.dim_size(1); ++j) {
229 _output_wrapper(os, data[i * shape.dim_size(1) + j], digit_width, fraction_count);
230 if (j != shape.dim_size(1) - 1) {
231 os << ", ";
232 }
233 }
234 os << "]";
235 if (i != shape.dim_size(0) - 1) os << ",\n";
236 }
237 os << "]";
238 }
239 else if (shape.ndim() == 3) {
240 os << "[";
241 for (int i = 0; i < shape.dim_size(0); ++i) {
242 if (i != 0) os << " ";
243 os << "[";
244 for (int j = 0; j < shape.dim_size(1); ++j) {
245 if (j != 0) os << " ";
246 os << "[";
247 for (int k = 0; k < shape.dim_size(2); ++k) {
248 _output_wrapper(os, data[i * shape.dim_size(1) * shape.dim_size(2) + j * shape.dim_size(2) + k], digit_width, fraction_count);
249 if (k != shape.dim_size(2) - 1) {
250 os << ", ";
251 }
252 }
253 os << "]";
254 if (j != shape.dim_size(1) - 1) os << ",\n";
255 }
256 os << "]";
257 if (i != shape.dim_size(0) - 1) os << ",\n\n";
258 }
259 os << "]";
260 }
261 else {
262 for (int64_t i = 0; i < num_elements; ++i) {
263 _output_wrapper(os, data[i], 0, 0);
264 if (i < num_elements - 1) {
265 os << ", ";
266 }
267 }
268 }
269}
270
271template <typename T>
273 if (tensor.device_type() == DeviceType::CpuDevice) {
274 return reinterpret_cast<T*>(tensor.data())[0];
275 }
276 else {
277 T result = 0;
278 TEMPLATE_ALL_2(tensor.data_type(), tensor.device_type(),
280 &result, reinterpret_cast<T*>(tensor.data()), 1))
281 return result;
282 }
283}
284
285} // namespace container
286
287#endif // ATEN_CORE_TENSOR_UTILS_H_
A class for representing the shape of a tensor.
Definition tensor_shape.h:13
int64_t dim_size(int dim) const
Get the size of a dimension in the tensor.
Definition tensor_shape.cpp:31
unsigned int ndim() const
Get the ndim of the tensor.
Definition tensor_shape.cpp:46
A multi-dimensional array of elements of a single data type.
Definition tensor.h:32
void * data() const
Get a pointer to the data buffer of the tensor.
Definition tensor.cpp:73
DeviceType device_type() const
Get the data type of the tensor.
Definition tensor.cpp:64
DataType data_type() const
Get the data type of the tensor.
Definition tensor.cpp:61
#define T
Definition exp.cpp:237
#define TEMPLATE_ALL_2(TYPE_ENUM, DEVICE_ENUM,...)
Definition macros.h:214
Definition tensor.cpp:8
T extract(const container::Tensor &tensor)
Definition tensor_utils.h:272
@ CpuDevice
Memory type is CPU.
__inline__ void _internal_output(std::ostream &os, const T *data, const TensorShape &shape, const int64_t &num_elements)
Outputs tensor data to a given output stream.
Definition tensor_utils.h:205
__inline__ std::string removeTrailingZeros(std::string str)
Removes trailing zeros from a string.
Definition tensor_utils.h:20
__inline__ void _output_wrapper(std::ostream &os, const T data, const int &digit_width, const int &fraction_count)
Output wrapper for a data value with given formatting parameters.
Definition tensor_utils.h:126
__inline__ int _get_digit_places(const T *arr, int size, int &integer_count, int &fraction_count)
Calculates the length of the longest integer and fractional part of an array of numbers.
Definition tensor_utils.h:49
Synchronizes memory between devices.
Definition memory.h:58