Skip to content

OlsGDTrainer

Overview

txeo::OlsGDTrainer is a concrete implementation of the txeo::Trainer<T> abstract class. It performs Ordinary Least Squares (OLS) linear regression using Gradient Descent.

Features

  • Implements gradient descent for linear regression
  • Supports learning rate tuning, convergence tolerance, and early stopping
  • Optionally uses the Barzilai-Borwein method for adaptive learning rate
  • Access to learned weight/bias matrix

Template Parameter

  • T: Floating-point type (e.g., float, double)

Example Usage

// Create training data (y = 2x + 1)
txeo::Matrix<double> X({{1.0}, {2.0}, {3.0}});
txeo::Matrix<double> y({{3.0}, {5.0}, {7.0}});

OlsGDTrainer<double> trainer(txeo::DataTable<double>(X, y));
trainer.set_tolerance(1e-5);
trainer.fit(1000, LossFunc::MSE, 10);

if (trainer.is_converged()) {
    auto weights = trainer.weight_bias();
    std::cout << "Model: y = " << weights(0,0) << "x + " << weights(1,0) << std::endl;

    txeo::Matrix<double> test_input(1,1,{4.0});
    auto prediction = trainer.predict(test_input);
    std::cout << "Prediction for x=4: " << prediction(0,0) << std::endl;
}

Constructors

Trainer(const txeo::DataTable<T> &data)

Creates a trainer using a data table object.

txeo::Trainer(const txeo::DataTable<T> &data);

Public Methods

predict(input)

Performs prediction on new input data.

txeo::Tensor<T> predict(const txeo::Tensor<T>& input);

learning_rate()

Returns the current learning rate.

T learning_rate() const;

set_learning_rate(value)

Sets the learning rate used in training.

void set_learning_rate(T value);

enable_variable_lr() / disable_variable_lr()

Toggles the use of the Barzilai-Borwein adaptive learning rate.

void enable_variable_lr();
void disable_variable_lr();

weight_bias()

Returns the model weight-bias matrix.

const txeo::Matrix<T>& weight_bias() const;

tolerance() / set_tolerance(value)

Gets or sets the convergence tolerance.

T tolerance() const;
void set_tolerance(const T& value);

is_converged()

Checks if convergence was reached during training.

bool is_converged() const;

min_loss()

Returns the minimum loss encountered during training.

T min_loss() const;

Exceptions

OlsGDTrainerError

Exception type used for runtime errors within the trainer.

class OlsGDTrainerError : public std::runtime_error;

Inheritance

  • Inherits from: txeo::Trainer<T>
  • Implements:
  • predict()
  • train()

For detailed API references, see individual method documentation at txeo::OlsGDTrainer.