openGPMP
Open Source Mathematics Package
kfold.cpp
Go to the documentation of this file.
1 /*************************************************************************
2  *
3  * Project
4  * _____ _____ __ __ _____
5  * / ____| __ \| \/ | __ \
6  * ___ _ __ ___ _ __ | | __| |__) | \ / | |__) |
7  * / _ \| '_ \ / _ \ '_ \| | |_ | ___/| |\/| | ___/
8  *| (_) | |_) | __/ | | | |__| | | | | | | |
9  * \___/| .__/ \___|_| |_|\_____|_| |_| |_|_|
10  * | |
11  * |_|
12  *
13  * Copyright (C) Akiel Aries, <akiel@akiel.org>, et al.
14  *
15  * This software is licensed as described in the file LICENSE, which
16  * you should have received as part of this distribution. The terms
17  * among other details are referenced in the official documentation
18  * seen here : https://akielaries.github.io/openGPMP/ along with
19  * important files seen in this project.
20  *
21  * You may opt to use, copy, modify, merge, publish, distribute
22  * and/or sell copies of the Software, and permit persons to whom
23  * the Software is furnished to do so, under the terms of the
24  * LICENSE file. As this is an Open Source effort, all implementations
25  * must be of the same methodology.
26  *
27  *
28  *
29  * This software is distributed on an AS IS basis, WITHOUT
30  * WARRANTY OF ANY KIND, either express or implied.
31  *
32  ************************************************************************/
33 #include <algorithm>
34 #include <functional>
35 #include <numeric>
36 #include <openGPMP/ml/kfold.hpp>
37 #include <stdexcept>
38 
39 gpmp::ml::Kfold::Kfold(int fold) : k(fold) {
40 }
41 
43 }
44 
45 void gpmp::ml::Kfold::split_indices(int num_instances) {
46  if (k <= 1 || num_instances <= 0) {
47  throw std::invalid_argument(
48  "Invalid value of k or number of instances.");
49  }
50 
51  // Initialize fold_indices with k empty vectors
52  fold_indices.clear();
53  fold_indices.resize(k);
54 
55  // Assign indices to each fold
56  int fold_size = num_instances / k;
57  int remainder = num_instances % k;
58  int start_index = 0;
59  for (int i = 0; i < k; ++i) {
60  int fold_length = fold_size + (i < remainder ? 1 : 0);
61  std::vector<int> indices(fold_length);
62  std::iota(indices.begin(), indices.end(), start_index);
63  fold_indices[i] = indices;
64  start_index += fold_length;
65  }
66 }
67 
69  std::function<void(int, int)> train_and_test_func) {
70  if (fold_indices.empty()) {
71  throw std::logic_error("Indices not split. Call split_indices before "
72  "performing cross-validation.");
73  }
74 
75  for (int i = 0; i < k; ++i) {
76  // Use fold i as the validation set, train on all other folds
77  std::vector<int> validation_indices = fold_indices[i];
78  for (int j = 0; j < k; ++j) {
79  if (j != i) {
80  for (int index : fold_indices[j]) {
81  train_and_test_func(j, index);
82  }
83  }
84  }
85  }
86 }
void split_indices(int num_instances)
Splits the indices of instances into k folds for cross-validation.
Definition: kfold.cpp:45
~Kfold()
Destructor for the Kfold class.
Definition: kfold.cpp:42
Kfold(int k)
Constructor for the Kfold class.
Definition: kfold.cpp:39
void perform_cross_validation(std::function< void(int, int)> train_and_test_func)
Performs k-fold cross-validation using the provided training and testing function.
Definition: kfold.cpp:68