1#ifndef ATEN_CORE_TENSOR_SHAPE_H_
2#define ATEN_CORE_TENSOR_SHAPE_H_
6#include <initializer_list>
49 const std::vector<int64_t>&
dims()
const;
51 const std::vector<int64_t>&
strides()
const;
57 unsigned int ndim()
const;
113 std::vector<int64_t>
get_strides_(
const std::vector<int64_t>& dim);
122std::ostream&
operator<<(std::ostream& os,
const TensorShape& shape);
A class for representing the shape of a tensor.
Definition tensor_shape.h:13
int64_t dim_size(int dim) const
Get the size of a dimension in the tensor.
Definition tensor_shape.cpp:31
void set_dim_size(int dim, int64_t size)
Modify the size of a dimension in the tensor.
Definition tensor_shape.cpp:60
std::vector< int64_t > strides_
Definition tensor_shape.h:108
std::vector< int64_t > get_strides_(const std::vector< int64_t > &dim)
Compute the strides of the tensor.
Definition tensor_shape.cpp:90
void add_dim(int64_t size)
Add a new dimension to the tensor.
Definition tensor_shape.cpp:66
const std::vector< int64_t > & dims() const
Get all dimension sizes in the tensor.
Definition tensor_shape.cpp:36
bool operator==(const TensorShape &other) const
Overload the == operator to compare two TensorShape objects.
Definition tensor_shape.cpp:81
void remove_dim(int dim)
Remove a dimension from the tensor.
Definition tensor_shape.cpp:72
int64_t NumElements() const
Returns the total number of elements in the shape.
Definition tensor_shape.cpp:51
TensorShape()
Default constructor.
Definition tensor_shape.cpp:16
unsigned int ndim() const
Get the ndim of the tensor.
Definition tensor_shape.cpp:46
std::vector< int64_t > dims_
Definition tensor_shape.h:100
bool operator!=(const TensorShape &other) const
Overload the != operator to compare two TensorShape objects.
Definition tensor_shape.cpp:86
const std::vector< int64_t > & strides() const
Definition tensor_shape.cpp:41
std::ostream & operator<<(std::ostream &os, const Tensor &tensor)
Overloaded operator<< for the Tensor class.
Definition tensor.cpp:329