Represents a k-fold cross-validation utility.
More...
#include <kfold.hpp>
Represents a k-fold cross-validation utility.
Definition at line 48 of file kfold.hpp.
◆ Kfold()
gpmp::ml::Kfold::Kfold |
( |
int |
k | ) |
|
Constructor for the Kfold class.
- Parameters
-
k | The number of folds for cross-validation |
Definition at line 39 of file kfold.cpp.
◆ ~Kfold()
gpmp::ml::Kfold::~Kfold |
( |
| ) |
|
◆ perform_cross_validation()
void gpmp::ml::Kfold::perform_cross_validation |
( |
std::function< void(int, int)> |
train_and_test_func | ) |
|
Performs k-fold cross-validation using the provided training and testing function.
- Parameters
-
train_and_test_func | The function that trains and tests the model on given folds The function should take two integers as parameters: the index of the training fold and the index of the testing instance |
- Exceptions
-
std::logic_error | if indices are not split before performing cross-validation |
Definition at line 68 of file kfold.cpp.
71 throw std::logic_error(
"Indices not split. Call split_indices before "
72 "performing cross-validation.");
75 for (
int i = 0; i <
k; ++i) {
78 for (
int j = 0; j <
k; ++j) {
81 train_and_test_func(j, index);
std::vector< std::vector< int > > fold_indices
◆ split_indices()
void gpmp::ml::Kfold::split_indices |
( |
int |
num_instances | ) |
|
Splits the indices of instances into k folds for cross-validation.
- Parameters
-
num_instances | The total number of instances |
- Exceptions
-
std::invalid_argument | if k or num_instances is invalid |
Definition at line 45 of file kfold.cpp.
46 if (
k <= 1 || num_instances <= 0) {
47 throw std::invalid_argument(
48 "Invalid value of k or number of instances.");
56 int fold_size = num_instances /
k;
57 int remainder = num_instances %
k;
59 for (
int i = 0; i <
k; ++i) {
60 int fold_length = fold_size + (i < remainder ? 1 : 0);
61 std::vector<int> indices(fold_length);
62 std::iota(indices.begin(), indices.end(), start_index);
64 start_index += fold_length;
◆ fold_indices
std::vector<std::vector<int> > gpmp::ml::Kfold::fold_indices |
|
private |
< The number of folds for cross-validation The indices of instances grouped into k folds
Definition at line 82 of file kfold.hpp.
The documentation for this class was generated from the following files: