ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
tensor_shape.h
Go to the documentation of this file.
1#ifndef ATEN_CORE_TENSOR_SHAPE_H_
2#define ATEN_CORE_TENSOR_SHAPE_H_
3
4#include <vector>
5#include <iostream>
6#include <initializer_list>
7
8namespace container {
9
14public:
19
24 TensorShape(std::initializer_list<int64_t> dims);
25
30 TensorShape(const std::vector<int64_t>& dims);
31
36 TensorShape(const TensorShape& other);
37
43 int64_t dim_size(int dim) const;
44
49 const std::vector<int64_t>& dims() const;
50
51 const std::vector<int64_t>& strides() const;
52
57 unsigned int ndim() const;
58
64 void set_dim_size(int dim, int64_t size);
65
70 void add_dim(int64_t size);
71
76 void remove_dim(int dim);
77
83 int64_t NumElements() const;
84
90 bool operator==(const TensorShape& other) const;
91
97 bool operator!=(const TensorShape& other) const;
98
99private:
100 std::vector<int64_t> dims_ = {}; // Save dimension sizes of the tensor
101 // Note: strides are not always equals to the dimension sizes.
102 // The strides specifies the number of elements to step in each dimension when traversing a tensor.
103 // There could be some sparse region in the tensor, and the strides will be larger than the dimension sizes.
104 // For example, given a 2D tensor with shape [3, 4],
105 // and the strides could be [6, 1] if the actual data is stored in a 1D array with size 18 [3, 6].
106 // The strides could also be [12, 3] if the actual data is stored in a 1D array with size 36 [3, 12].
107 // strides can only be modified by the TensorMap object.
108 std::vector<int64_t> strides_ = {}; // Save dimension strides of the tensor
109
113 std::vector<int64_t> get_strides_(const std::vector<int64_t>& dim);
114};
115
122std::ostream& operator<<(std::ostream& os, const TensorShape& shape);
123
124} // container
125
126#endif // ATEN_CORE_TENSOR_SHAPE_H_
A class for representing the shape of a tensor.
Definition tensor_shape.h:13
int64_t dim_size(int dim) const
Get the size of a dimension in the tensor.
Definition tensor_shape.cpp:31
void set_dim_size(int dim, int64_t size)
Modify the size of a dimension in the tensor.
Definition tensor_shape.cpp:60
std::vector< int64_t > strides_
Definition tensor_shape.h:108
std::vector< int64_t > get_strides_(const std::vector< int64_t > &dim)
Compute the strides of the tensor.
Definition tensor_shape.cpp:90
void add_dim(int64_t size)
Add a new dimension to the tensor.
Definition tensor_shape.cpp:66
const std::vector< int64_t > & dims() const
Get all dimension sizes in the tensor.
Definition tensor_shape.cpp:36
bool operator==(const TensorShape &other) const
Overload the == operator to compare two TensorShape objects.
Definition tensor_shape.cpp:81
void remove_dim(int dim)
Remove a dimension from the tensor.
Definition tensor_shape.cpp:72
int64_t NumElements() const
Returns the total number of elements in the shape.
Definition tensor_shape.cpp:51
TensorShape()
Default constructor.
Definition tensor_shape.cpp:16
unsigned int ndim() const
Get the ndim of the tensor.
Definition tensor_shape.cpp:46
std::vector< int64_t > dims_
Definition tensor_shape.h:100
bool operator!=(const TensorShape &other) const
Overload the != operator to compare two TensorShape objects.
Definition tensor_shape.cpp:86
const std::vector< int64_t > & strides() const
Definition tensor_shape.cpp:41
Definition tensor.cpp:8
std::ostream & operator<<(std::ostream &os, const Tensor &tensor)
Overloaded operator<< for the Tensor class.
Definition tensor.cpp:329