Skip to content

Trainer

Overview

The txeo::Trainer class is an abstract base class that provides the interface for training machine learning models in txeo. It handles training/evaluation data, common training parameters, and the training lifecycle.

Derived classes must implement the predict() and train() methods.


Features

  • Abstract base class with pure virtual methods.
  • Manages training and evaluation datasets.
  • Supports early stopping.
  • Tracks whether the model has been trained.

Template Parameter

  • T: The numeric type used in tensors (e.g., float, double).

Constructors

Trainer(x_train, y_train, x_eval, y_eval)

Initializes the trainer with a data table object.

Trainer(const txeo::DataTable<T> &data);

Public Methods

fit(epochs, metric)

Trains the model for a fixed number of epochs.

void fit(size_t epochs, txeo::LossFunc metric);

fit(epochs, metric, patience)

Trains the model with early stopping.

void fit(size_t epochs, txeo::LossFunc metric, size_t patience);

fit(epochs, metric, patience, type)

Trains the model with early stopping and feature normalization.

void fit(size_t epochs, txeo::LossFunc metric, size_t patience, txeo::NormalizationType type);

predict(input)

Pure virtual method to generate predictions from a trained model. Must be implemented in derived classes.

txeo::Tensor<T> predict(const txeo::Tensor<T>& input) = 0;

compute_test_loss(txeo::LossFunc metric) const

Computes the loss of the trained model for test data.

virtual T compute_test_loss(txeo::LossFunc metric) const;

is_trained()

Returns true if the model has been trained.

bool is_trained() const;

data_table()

Returns std::DataTable object owned by the trainer.

const txeo::DataTable<T> &data_table() const;

enable_feature_norm()

Enables normalization of feature data (input).

void enable_feature_norm(txeo::NormalizationType type);

disable_feature_norm()

Disable normalization of feature data (input).

void disable_feature_norm();

Exceptions

TrainerError

Exception type thrown by Trainer operations.

class TrainerError : public std::runtime_error;

Example Usage

class MyTrainer : public txeo::Trainer<float> {
  public:
    using txeo::Trainer<float>::Trainer;

    txeo::Tensor<float> predict(const txeo::Tensor<float>& input) override {
        // your prediction logic
    }

  protected:
    void train(size_t epochs, txeo::LossFunc loss_func) override {
        // your training logic
    }
};

For detailed API references, see individual method documentation at txeo::Predictor.