TensorShape¶
The TensorShape
class defines the dimensional structure of tensors in the txeo library. It describes the dimensions and axes of tensors, serving as the foundation for tensor creation, manipulation, and indexing.
Overview¶
A TensorShape
represents the dimensions of a tensor:
- Axes: Positions labeled from zero, each associated with a tensor dimension.
- Dimensions: Size along each axis.
Examples:
- Scalar (no axes):
TensorShape()
- Vector:
{3}
- Matrix:
{3, 4}
- 3-dimensional tensor:
{2, 3, 4}
API Reference¶
Method | Description |
---|---|
axes_dims() |
Returns dimensions of each axis |
clone() |
Returns a deep copy of shape |
insert_axis(axis, dim) |
Inserts axis at specified position |
number_of_axes() |
Synonym for size() |
push_axis_back(dim) |
Adds an axis at the end |
remove_all_axes() |
Removes all axes |
remove_axis(axis) |
Removes specified axis |
set_dim(axis, dim) |
Changes size of specified axis |
size() |
Returns number of axes |
stride() |
Returns strides for efficient indexing |
Creating Tensor Shapes¶
Constructing from dimensions¶
txeo::TensorShape shape({2, 3, 4}); // Creates a shape with dimensions 2x3x4
Creating uniform dimensions¶
txeo::TensorShape shape(3, 5); // Shape with three axes, each of dimension 5
Common Operations¶
Accessing Dimensions¶
You can access the shape dimensions:
std::vector<int64_t> dims = shape.axes_dims();
Comparing shapes¶
if (shape1 == shape2) {
std::cout << "Shapes are equal.";
}
Manipulating Axes¶
Inserting an Axis¶
shape.insert_axis(1, 5); // Inserts a new axis at position 1 with dimension 5
Removing an Axis¶
shape.remove_axis(2); // Removes the axis at position 2
Changing a Dimension¶
shape.set_dim(0, 10); // Sets the first axis dimension size to 10
Removing All Axes¶
shape.remove_axis(0); // Removes axis at position 0
shape.remove_axis(1); // Removes axis at new position 1
or simply:
shape.remove_all_axes();
Stride¶
Tensor strides represent the memory step size for each dimension:
auto strides = shape.stride();
for (size_t s : strides)
std::cout << s << ' ';
Examples¶
Checking Shape Equality¶
#include <iostream>
#include "txeo/TensorShape.h"
int main() {
txeo::TensorShape shape1({3, 4});
txeo::TensorShape shape2({3, 4});
if (shape == shape2) {
std::cout << "Shapes match!" << std::endl;
}
}
Modifying a TensorShape¶
#include <iostream>
#include "txeo/TensorShape.h"
int main() {
txeo::TensorShape shape({2, 3});
shape.push_axis_back(4); // shape now (2,3,4)
shape.set_dim(1, 5); // Update dimension of axis 1 from 3 to 5
std::cout << "Updated Shape: " << shape << std::endl;
}
Exception Handling¶
All invalid operations throw a TensorShapeError
:
try {
shape.insert_axis(10, 2); // invalid operation
} catch (const txeo::TensorShapeError &e) {
std::cerr << "Error: " << e.what();
}
For detailed API references, see individual method documentation at txeo::TensorShape.