Skip to content

Predictor

txeo's Predictor class handles inference tasks using TensorFlow SavedModels. It loads models, performs predictions, and provides metadata about model inputs, outputs, and devices.

Constructors

Initialization with path to model

explicit Predictor(std::filesystem::path model_path);

Constructs a Predictor object from a TensorFlow SavedModel directory containing a .pb file.

Example (Python Model Freezing):

import tensorflow as tf

model = tf.saved_model.load("path/to/trained_model")
concrete_func = model.signatures["serving_default"]
frozen_func = tf.python.framework.convert_to_constants.convert_variables_to_constants_v2(concrete_func)
tf.io.write_graph(
    frozen_func.graph.as_graph_def(),
    "path/to/frozen_model",
    "frozen.pb",
    as_text=False
)

Methods

get_input_metadata

Returns input tensor metadata (names and shapes).

const TensorInfo &get_input_metadata() const noexcept;

get_output_metadata

Returns output tensor metadata (names and shapes).

const TensorInfo &get_output_metadata() const noexcept;

get_input_metadata_shape

Returns shape for a specified input tensor by name.

std::optional<txeo::TensorShape> get_input_metadata_shape(const std::string &name) const;

get_output_metadata_shape

Returns shape for a specified output tensor by name.

std::optional<txeo::TensorShape> get_output_metadata_shape(const std::string &name) const;

get_devices

Returns available compute devices.

std::vector<DeviceInfo> get_devices() const;

predict

Performs single input/output inference.

txeo::Tensor<T> predict(const txeo::Tensor<T> &input) const;

Example:

Tensor<float> input({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f});
auto output = predictor.predict(input);

predict_batch

Performs batch inference with multiple named inputs.

std::vector<txeo::Tensor<T>> predict_batch(const TensorIdent &inputs) const;

Example:

std::vector<std::pair<std::string, txeo::Tensor<float>>> inputs {
    {"image", image_tensor},
    {"metadata", meta_tensor}
};
auto results = predictor.predict_batch(inputs);

enable_xla

Enables or disables XLA (Accelerated Linear Algebra) compilation.

void enable_xla(bool enable);

Note: Prefer enabling XLA before the first inference call.


Structures

DeviceInfo

Member Description
name Device name
device_type Type of device (CPU/GPU)
memory_limit Memory limit in bytes

Exceptions

PredictorError

Exception thrown when predictor operations fail.

class PredictorError : public std::runtime_error;

For detailed API references, see individual method documentation at txeo::Predictor.