Line data Source code
1 : /************************************************************************* 2 : * 3 : * Project 4 : * _____ _____ __ __ _____ 5 : * / ____| __ \| \/ | __ \ 6 : * ___ _ __ ___ _ __ | | __| |__) | \ / | |__) | 7 : * / _ \| '_ \ / _ \ '_ \| | |_ | ___/| |\/| | ___/ 8 : *| (_) | |_) | __/ | | | |__| | | | | | | | 9 : * \___/| .__/ \___|_| |_|\_____|_| |_| |_|_| 10 : * | | 11 : * |_| 12 : * 13 : * Copyright (C) Akiel Aries, <akiel@akiel.org>, et al. 14 : * 15 : * This software is licensed as described in the file LICENSE, which 16 : * you should have received as part of this distribution. The terms 17 : * among other details are referenced in the official documentation 18 : * seen here : https://akielaries.github.io/openGPMP/ along with 19 : * important files seen in this project. 20 : * 21 : * You may opt to use, copy, modify, merge, publish, distribute 22 : * and/or sell copies of the Software, and permit persons to whom 23 : * the Software is furnished to do so, under the terms of the 24 : * LICENSE file. As this is an Open Source effort, all implementations 25 : * must be of the same methodology. 26 : * 27 : * 28 : * 29 : * This software is distributed on an AS IS basis, WITHOUT 30 : * WARRANTY OF ANY KIND, either express or implied. 31 : * 32 : ************************************************************************/ 33 : #include <algorithm> 34 : #include <functional> 35 : #include <numeric> 36 : #include <openGPMP/ml/kfold.hpp> 37 : #include <stdexcept> 38 : 39 0 : gpmp::ml::Kfold::Kfold(int fold) : k(fold) { 40 0 : } 41 : 42 0 : gpmp::ml::Kfold::~Kfold() { 43 0 : } 44 : 45 0 : void gpmp::ml::Kfold::split_indices(int num_instances) { 46 0 : if (k <= 1 || num_instances <= 0) { 47 0 : throw std::invalid_argument( 48 0 : "Invalid value of k or number of instances."); 49 : } 50 : 51 : // Initialize fold_indices with k empty vectors 52 0 : fold_indices.clear(); 53 0 : fold_indices.resize(k); 54 : 55 : // Assign indices to each fold 56 0 : int fold_size = num_instances / k; 57 0 : int remainder = num_instances % k; 58 0 : int start_index = 0; 59 0 : for (int i = 0; i < k; ++i) { 60 0 : int fold_length = fold_size + (i < remainder ? 1 : 0); 61 0 : std::vector<int> indices(fold_length); 62 0 : std::iota(indices.begin(), indices.end(), start_index); 63 0 : fold_indices[i] = indices; 64 0 : start_index += fold_length; 65 0 : } 66 0 : } 67 : 68 0 : void gpmp::ml::Kfold::perform_cross_validation( 69 : std::function<void(int, int)> train_and_test_func) { 70 0 : if (fold_indices.empty()) { 71 0 : throw std::logic_error("Indices not split. Call split_indices before " 72 0 : "performing cross-validation."); 73 : } 74 : 75 0 : for (int i = 0; i < k; ++i) { 76 : // Use fold i as the validation set, train on all other folds 77 0 : std::vector<int> validation_indices = fold_indices[i]; 78 0 : for (int j = 0; j < k; ++j) { 79 0 : if (j != i) { 80 0 : for (int index : fold_indices[j]) { 81 0 : train_and_test_func(j, index); 82 : } 83 : } 84 : } 85 0 : } 86 0 : }