openGPMP
Open Source Mathematics Package
Public Member Functions | Private Attributes | List of all members
gpmp::ml::Kfold Class Reference

Represents a k-fold cross-validation utility. More...

#include <kfold.hpp>

Public Member Functions

 Kfold (int k)
 Constructor for the Kfold class. More...
 
 ~Kfold ()
 Destructor for the Kfold class. More...
 
void split_indices (int num_instances)
 Splits the indices of instances into k folds for cross-validation. More...
 
void perform_cross_validation (std::function< void(int, int)> train_and_test_func)
 Performs k-fold cross-validation using the provided training and testing function. More...
 

Private Attributes

int k
 
std::vector< std::vector< int > > fold_indices
 

Detailed Description

Represents a k-fold cross-validation utility.

Definition at line 48 of file kfold.hpp.

Constructor & Destructor Documentation

◆ Kfold()

gpmp::ml::Kfold::Kfold ( int  k)

Constructor for the Kfold class.

Parameters
kThe number of folds for cross-validation

Definition at line 39 of file kfold.cpp.

39  : k(fold) {
40 }

◆ ~Kfold()

gpmp::ml::Kfold::~Kfold ( )

Destructor for the Kfold class.

Definition at line 42 of file kfold.cpp.

42  {
43 }

Member Function Documentation

◆ perform_cross_validation()

void gpmp::ml::Kfold::perform_cross_validation ( std::function< void(int, int)>  train_and_test_func)

Performs k-fold cross-validation using the provided training and testing function.

Parameters
train_and_test_funcThe function that trains and tests the model on given folds The function should take two integers as parameters: the index of the training fold and the index of the testing instance
Exceptions
std::logic_errorif indices are not split before performing cross-validation

Definition at line 68 of file kfold.cpp.

69  {
70  if (fold_indices.empty()) {
71  throw std::logic_error("Indices not split. Call split_indices before "
72  "performing cross-validation.");
73  }
74 
75  for (int i = 0; i < k; ++i) {
76  // Use fold i as the validation set, train on all other folds
77  std::vector<int> validation_indices = fold_indices[i];
78  for (int j = 0; j < k; ++j) {
79  if (j != i) {
80  for (int index : fold_indices[j]) {
81  train_and_test_func(j, index);
82  }
83  }
84  }
85  }
86 }
std::vector< std::vector< int > > fold_indices
Definition: kfold.hpp:84

◆ split_indices()

void gpmp::ml::Kfold::split_indices ( int  num_instances)

Splits the indices of instances into k folds for cross-validation.

Parameters
num_instancesThe total number of instances
Exceptions
std::invalid_argumentif k or num_instances is invalid

Definition at line 45 of file kfold.cpp.

45  {
46  if (k <= 1 || num_instances <= 0) {
47  throw std::invalid_argument(
48  "Invalid value of k or number of instances.");
49  }
50 
51  // Initialize fold_indices with k empty vectors
52  fold_indices.clear();
53  fold_indices.resize(k);
54 
55  // Assign indices to each fold
56  int fold_size = num_instances / k;
57  int remainder = num_instances % k;
58  int start_index = 0;
59  for (int i = 0; i < k; ++i) {
60  int fold_length = fold_size + (i < remainder ? 1 : 0);
61  std::vector<int> indices(fold_length);
62  std::iota(indices.begin(), indices.end(), start_index);
63  fold_indices[i] = indices;
64  start_index += fold_length;
65  }
66 }

Member Data Documentation

◆ fold_indices

std::vector<std::vector<int> > gpmp::ml::Kfold::fold_indices
private

Definition at line 84 of file kfold.hpp.

◆ k

int gpmp::ml::Kfold::k
private

< The number of folds for cross-validation The indices of instances grouped into k folds

Definition at line 82 of file kfold.hpp.


The documentation for this class was generated from the following files: