ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
Public Member Functions | Private Member Functions | Private Attributes | List of all members
container::TensorShape Class Reference

A class for representing the shape of a tensor. More...

#include <tensor_shape.h>

Collaboration diagram for container::TensorShape:

Public Member Functions

 TensorShape ()
 Default constructor.
 
 TensorShape (std::initializer_list< int64_t > dims)
 Constructor with an initializer list of integers.
 
 TensorShape (const std::vector< int64_t > &dims)
 Constructor with a vector of integers.
 
 TensorShape (const TensorShape &other)
 Copy constructor.
 
int64_t dim_size (int dim) const
 Get the size of a dimension in the tensor.
 
const std::vector< int64_t > & dims () const
 Get all dimension sizes in the tensor.
 
const std::vector< int64_t > & strides () const
 
unsigned int ndim () const
 Get the ndim of the tensor.
 
void set_dim_size (int dim, int64_t size)
 Modify the size of a dimension in the tensor.
 
void add_dim (int64_t size)
 Add a new dimension to the tensor.
 
void remove_dim (int dim)
 Remove a dimension from the tensor.
 
int64_t NumElements () const
 Returns the total number of elements in the shape.
 
bool operator== (const TensorShape &other) const
 Overload the == operator to compare two TensorShape objects.
 
bool operator!= (const TensorShape &other) const
 Overload the != operator to compare two TensorShape objects.
 

Private Member Functions

std::vector< int64_t > get_strides_ (const std::vector< int64_t > &dim)
 Compute the strides of the tensor.
 

Private Attributes

std::vector< int64_t > dims_ = {}
 
std::vector< int64_t > strides_ = {}
 

Detailed Description

A class for representing the shape of a tensor.

Constructor & Destructor Documentation

◆ TensorShape() [1/4]

container::TensorShape::TensorShape ( )

Default constructor.

◆ TensorShape() [2/4]

container::TensorShape::TensorShape ( std::initializer_list< int64_t >  dims)

Constructor with an initializer list of integers.

Parameters
dimsAn initializer list of integers representing the dimensions of the tensor.

◆ TensorShape() [3/4]

container::TensorShape::TensorShape ( const std::vector< int64_t > &  dims)

Constructor with a vector of integers.

Parameters
dimsA vector of integers representing the dimensions of the tensor.

◆ TensorShape() [4/4]

container::TensorShape::TensorShape ( const TensorShape other)

Copy constructor.

Parameters
otherThe TensorShape object to be copied.
Here is the call graph for this function:

Member Function Documentation

◆ add_dim()

void container::TensorShape::add_dim ( int64_t  size)

Add a new dimension to the tensor.

Parameters
sizeThe size of the new dimension.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ dim_size()

int64_t container::TensorShape::dim_size ( int  dim) const

Get the size of a dimension in the tensor.

Parameters
dimThe index of the dimension.
Returns
The size of the specified dimension.
Here is the caller graph for this function:

◆ dims()

const std::vector< int64_t > & container::TensorShape::dims ( ) const

Get all dimension sizes in the tensor.

Returns
A const reference to the vector of dimension sizes.
Here is the caller graph for this function:

◆ get_strides_()

std::vector< int64_t > container::TensorShape::get_strides_ ( const std::vector< int64_t > &  dim)
private

Compute the strides of the tensor.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ ndim()

unsigned int container::TensorShape::ndim ( ) const

Get the ndim of the tensor.

Returns
The number of dimensions in the tensor.
Here is the caller graph for this function:

◆ NumElements()

int64_t container::TensorShape::NumElements ( ) const

Returns the total number of elements in the shape.

Returns
int64_t The number of elements.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ operator!=()

bool container::TensorShape::operator!= ( const TensorShape other) const

Overload the != operator to compare two TensorShape objects.

Parameters
otherThe other TensorShape object to be compared.
Returns
True if the two objects have different dimensions, false otherwise.

◆ operator==()

bool container::TensorShape::operator== ( const TensorShape other) const

Overload the == operator to compare two TensorShape objects.

Parameters
otherThe other TensorShape object to be compared.
Returns
True if the two objects have the same dimensions, false otherwise.

◆ remove_dim()

void container::TensorShape::remove_dim ( int  dim)

Remove a dimension from the tensor.

Parameters
dimThe index of the dimension to be removed.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ set_dim_size()

void container::TensorShape::set_dim_size ( int  dim,
int64_t  size 
)

Modify the size of a dimension in the tensor.

Parameters
dimThe index of the dimension to be modified.
sizeThe new size of the dimension.
Here is the call graph for this function:
Here is the caller graph for this function:

◆ strides()

const std::vector< int64_t > & container::TensorShape::strides ( ) const
Here is the caller graph for this function:

Member Data Documentation

◆ dims_

std::vector<int64_t> container::TensorShape::dims_ = {}
private

◆ strides_

std::vector<int64_t> container::TensorShape::strides_ = {}
private

The documentation for this class was generated from the following files: