TensorShape¶
The TensorShape
class defines the dimensional structure of tensors in the txeo library. It describes the dimensions and axes of tensors, serving as the foundation for tensor creation, manipulation, and indexing.
Overview¶
A TensorShape
represents the dimensions of a tensor:
- Axes: Positions labeled from zero, each associated with a tensor dimension.
- Dimensions: Size along each axis.
Examples:
- Scalar (no axes):
TensorShape()
- Vector:
{3}
- Matrix:
{3, 4}
- 3-dimensional tensor:
{2, 3, 4}
API Reference¶
Method | Description |
---|---|
size() |
Returns number of axes |
number_of_axes() |
Synonym for size() |
axes_dims() |
Returns dimensions of each axis |
stride() |
Returns strides for efficient indexing |
set_dim(axis, dim) |
Changes size of specified axis |
insert_axis(axis, dim) |
Inserts axis at specified position |
push_axis_back(dim) |
Adds an axis at the end |
remove_axis(axis) |
Removes specified axis |
remove_all_axes() |
Removes all axes |
clone() |
Returns a deep copy of shape |
Creating Tensor Shapes¶
Constructing from dimensions¶
txeo::TensorShape shape({2, 3, 4}); // Creates a shape with dimensions 2x3x4
Creating uniform dimensions¶
txeo::TensorShape shape(3, 5); // Shape with three axes, each of dimension 5
Common Operations¶
Accessing Dimensions¶
You can access the shape dimensions:
std::vector<int64_t> dims = shape.axes_dims();
Comparing shapes¶
if (shape1 == shape2) {
std::cout << "Shapes are equal.";
}
Manipulating Axes¶
Inserting an Axis¶
shape.insert_axis(1, 5); // Inserts a new axis at position 1 with dimension 5
Removing an Axis¶
shape.remove_axis(2); // Removes the axis at position 2
Changing a Dimension¶
shape.set_dim(0, 10); // Sets the first axis dimension size to 10
Removing All Axes¶
shape.remove_axis(0); // Removes axis at position 0
shape.remove_axis(1); // Removes axis at new position 1
or simply:
shape.remove_all_axes();
Stride¶
Tensor strides represent the memory step size for each dimension:
auto strides = shape.stride();
for (size_t s : strides)
std::cout << s << ' ';
Examples¶
Checking Shape Equality¶
#include <iostream>
#include "txeo/TensorShape.h"
int main() {
txeo::TensorShape shape1({3, 4});
txeo::TensorShape shape2({3, 4});
if (shape == shape2) {
std::cout << "Shapes match!" << std::endl;
}
}
Modifying a TensorShape¶
#include <iostream>
#include "txeo/TensorShape.h"
int main() {
txeo::TensorShape shape({2, 3});
shape.push_axis_back(4); // shape now (2,3,4)
shape.set_dim(1, 5); // Update dimension of axis 1 from 3 to 5
std::cout << "Updated Shape: " << shape << std::endl;
}
Exception Handling¶
All invalid operations throw a TensorShapeError
:
try {
shape.insert_axis(10, 2); // invalid operation
} catch (const txeo::TensorShapeError &e) {
std::cerr << "Error: " << e.what();
}
For detailed API references, see individual method documentation at txeo::TensorShape.