1#ifndef ATEN_CORE_TENSOR_H_
2#define ATEN_CORE_TENSOR_H_
107 template <
typename T>
112 this->data<T>(), values.begin(), this->NumElements()))
134 const TensorShape&
shape()
const;
163 template <
typename T>
172 std::cerr <<
"Tensor data type does not match requested type." << std::endl;
199 return sizeof(float);
201 return sizeof(int32_t);
203 return sizeof(int64_t);
205 return sizeof(double);
207 return sizeof(std::complex<float>);
209 return sizeof(std::complex<double>);
211 std::cerr <<
"Unsupported data type!" << std::endl;
230 template <
typename DEVICE>
242 output.data<T_>(), this->data<T_>(), this->NumElements()))
268 template <
typename DEVICE,
typename T>
270 if (num_elements == -1) {
274 "The number of elements of the input data must match the number of elements of the tensor.")
278 this->data<T_>(),
data, num_elements))
288 template <
typename T>
298 output.data<
T>(), this->data<T_>(), this->NumElements()))
336 Tensor slice(
const std::vector<int>&
start,
const std::vector<int>& size)
const;
373 template <
typename T,
typename... Indices>
376 throw std::invalid_argument(
"Incorrect number of indices.");
383 return *
reinterpret_cast<T*
>(data<T>() + linearIndex);
397 template <
typename T>
400 throw std::invalid_argument(
"Invalid call, inner_most_ptr only support tensor rank <= 2!");
403 throw std::invalid_argument(
"Invalid index, index of the inner-most must less than the inner-most shape size!");
406 return data<T>() + index;
499 template <
typename T,
size_t N,
typename index_t =
int64_t>
503 "Accessor is used to access the data of a tensor with rank > 0, for scalars use *data<T>()");
506 "The rank of the tensor must match the rank of the accessor.")
510 template<
typename T,
size_t N,
typename index_t =
int>
532 explicit operator bool()
const {
576 template <
typename... Indices>
579 size_t linearIndex = 0;
580 size_t indexArray[] = {
static_cast<size_t>(indices)... };
582 for (
int ii =
static_cast<int>(
shape_.
ndim()) - 1; ii >= 0; --ii) {
583 linearIndex += indexArray[ii] * stride;
621std::ostream&
operator<<(std::ostream& os,
const Tensor& tensor);
An abstract base class for memory allocators.
Definition allocator.h:17
bool unref() const
Decreases the reference count by one.
Definition refcount.cpp:13
void ref() const
Increases the reference count by one.
Definition refcount.cpp:9
Definition tensor_accessor.h:68
Interface to access the raw ref-counted data buffer.
Definition tensor_buffer.h:13
T * base() const
Reinterpret the buffer as an array of type T.
Definition tensor_buffer.h:94
A class for representing the shape of a tensor.
Definition tensor_shape.h:13
int64_t dim_size(int dim) const
Get the size of a dimension in the tensor.
Definition tensor_shape.cpp:31
const std::vector< int64_t > & dims() const
Get all dimension sizes in the tensor.
Definition tensor_shape.cpp:36
int64_t NumElements() const
Returns the total number of elements in the shape.
Definition tensor_shape.cpp:51
unsigned int ndim() const
Get the ndim of the tensor.
Definition tensor_shape.cpp:46
A multi-dimensional array of elements of a single data type.
Definition tensor.h:32
Tensor()
Creates a 1-dimentional, 0-element float tensor.
Definition tensor.cpp:10
Tensor cast() const
Method to transform data from a given tensor object to the output tensor with a given data type.
Definition tensor.h:289
Tensor to_device() const
Method to transform data from a given tensor object to the output tensor with a given device type.
Definition tensor.h:231
void * data() const
Get a pointer to the data buffer of the tensor.
Definition tensor.cpp:73
static base::core::Allocator * GetAllocator(DeviceType device)
Get the Allocator object according to the given device type.
Definition tensor.cpp:79
size_t calculateLinearIndex(Indices... indices) const
Calculates the linear index corresponding to the given indices.
Definition tensor.h:577
T & get_value(Indices... indices) const
Get the element at the specified indices.
Definition tensor.h:374
T * inner_most_ptr(const int &index) const
Get the pointer to the specified row.
Definition tensor.h:398
int64_t NumElements() const
Get the total number of elements in the tensor.
Definition tensor.cpp:70
void sync(const Tensor &rhs)
Synchronize the current Tensor with another Tensor.
Definition tensor.cpp:296
void zero()
Set all elements in current tensor object to zero.
Definition tensor.cpp:97
bool AllocateFrom(const Tensor &other, const TensorShape &shape)
Copies data from another Tensor with memory allocation and specified shape.
Definition tensor.cpp:286
TensorAccessor< T, N, index_t > accessor() &&=delete
DeviceType device_type() const
Get the data type of the tensor.
Definition tensor.cpp:64
void reshape(TensorShape shape)
Set all elements in current tensor object to zero.
Definition tensor.cpp:103
bool operator==(const Tensor &other) const
Equality comparison operator for tensors.
Definition tensor.cpp:253
void resize(const TensorShape &new_shape)
Resize the tensor to the new shape.
Definition tensor.cpp:207
TensorBuffer * buffer_
The TensorBuffer object that holds the data of the tensor.
Definition tensor.h:562
TensorAccessor< T, N, index_t > accessor() const &
Accessor function for a multi-dimensional tensor.
Definition tensor.h:500
const TensorBuffer & buffer() const
Get the TensorBuffer object that holds the data of the tensor.
Definition tensor.cpp:76
DataType data_type() const
Get the data type of the tensor.
Definition tensor.cpp:61
Tensor slice(const std::vector< int > &start, const std::vector< int > &size) const
Return a new Tensor slice starting at the specified indices with the given size.
Definition tensor.cpp:147
static size_t SizeOfType(DataType data_type)
Returns the size of a single element for a given data type.
Definition tensor.h:196
void set_value(T value)
Definition tensor.h:537
Tensor shaped(const TensorShape &shape) const
Set all elements in current tensor object to zero.
Definition tensor.cpp:139
void copy_from_device(const T *data, int64_t num_elements=-1)
Copies data from a given device to the current tensor object.
Definition tensor.h:269
Tensor(std::initializer_list< T > values, DeviceType device=DeviceType::CpuDevice)
Constructor for the Tensor class using an initializer list of values.
Definition tensor.h:108
DataType data_type_
The data type of the tensor.
Definition tensor.h:547
Tensor & operator=(const Tensor &other)
Assignment operator overload for the Tensor class.
Definition tensor.cpp:220
Tensor operator[](const int &index) const
Access a sub-Tensor based on an index.
Definition tensor.cpp:312
~Tensor()
Definition tensor.cpp:55
TensorShape shape_
The shape of the tensor.
Definition tensor.h:557
DeviceType device_
The device type of the tensor.
Definition tensor.h:552
bool CopyFrom(const Tensor &other)
Copy the data from another tensor into this tensor.
Definition tensor.cpp:273
T * data() const
Get a typed pointer to the data buffer of the tensor.
Definition tensor.h:164
const TensorShape & shape() const
Get the shape of the tensor.
Definition tensor.cpp:67
void CopyFromInternal(const Tensor &other, const TensorShape &shape)
Definition tensor.h:591
#define N
Definition exp.cpp:24
#define T
Definition exp.cpp:237
#define REQUIRES_OK(expr,...)
Definition macros.h:60
#define TEMPLATE_CZ_2(TYPE_ENUM, DEVICE_ENUM,...)
Definition macros.h:227
#define TEMPLATE_ALL_2(TYPE_ENUM, DEVICE_ENUM,...)
Definition macros.h:214
DataType
Enumeration of data types for tensors. The DataType enum lists the supported data types for tensors....
Definition tensor_types.h:50
@ DT_COMPLEX
32-bit complex */
@ DT_INT64
64-bit integer */
@ DT_FLOAT
Single-precision floating point */.
@ DT_INT
32-bit integer */
@ DT_DOUBLE
Double-precision floating point */.
DeviceType
The type of memory used by an allocator.
Definition tensor_types.h:73
@ CpuDevice
Memory type is CPU.
std::ostream & operator<<(std::ostream &os, const Tensor &tensor)
Overloaded operator<< for the Tensor class.
Definition tensor.cpp:329
Template struct for mapping a DataType to its corresponding enum value.
Definition tensor_types.h:194
Template struct for mapping a Device Type to its corresponding enum value.
Definition tensor_types.h:158
Casts memory between devices.
Definition memory.h:107
A functor to set memory to a constant value.
Definition memory.h:37
Synchronizes memory between devices.
Definition memory.h:58
This file contains the definition of the DataType enum class.
iclock::time_point start
Definition test_partition.cpp:22