46 if (k <= 1 || num_instances <= 0) {
47 throw std::invalid_argument(
48 "Invalid value of k or number of instances.");
53 fold_indices.resize(k);
56 int fold_size = num_instances / k;
57 int remainder = num_instances % k;
59 for (
int i = 0; i < k; ++i) {
60 int fold_length = fold_size + (i < remainder ? 1 : 0);
61 std::vector<int> indices(fold_length);
62 std::iota(indices.begin(), indices.end(), start_index);
63 fold_indices[i] = indices;
64 start_index += fold_length;
69 std::function<
void(
int,
int)> train_and_test_func) {
70 if (fold_indices.empty()) {
71 throw std::logic_error(
"Indices not split. Call split_indices before "
72 "performing cross-validation.");
75 for (
int i = 0; i < k; ++i) {
77 std::vector<int> validation_indices = fold_indices[i];
78 for (
int j = 0; j < k; ++j) {
80 for (
int index : fold_indices[j]) {
81 train_and_test_func(j, index);
void split_indices(int num_instances)
Splits the indices of instances into k folds for cross-validation.
~Kfold()
Destructor for the Kfold class.
Kfold(int k)
Constructor for the Kfold class.
void perform_cross_validation(std::function< void(int, int)> train_and_test_func)
Performs k-fold cross-validation using the provided training and testing function.