37 #include <unordered_map>
46 const std::vector<std::vector<double>> &_training_data,
47 const std::vector<int> &_labels) {
48 if (_training_data.size() != _labels.size() || _training_data.empty()) {
49 throw std::invalid_argument(
"Invalid training data or labels.");
52 this->training_data = _training_data;
53 this->labels = _labels;
57 if (training_data.empty() || labels.empty()) {
58 throw std::logic_error(
59 "Model not trained. Call train() before predict.");
62 if (input_vector.size() != training_data[0].size()) {
63 throw std::invalid_argument(
"Invalid input vector size.");
67 if (k <= 0 ||
static_cast<size_t>(k) > training_data.size()) {
68 throw std::invalid_argument(
"Invalid value of k.");
72 std::vector<std::pair<double, int>> distances;
73 for (
size_t i = 0; i < training_data.size(); ++i) {
75 calculateEuclideanDistance(input_vector, training_data[i]);
76 distances.emplace_back(distance, labels[i]);
83 [](
const std::pair<double, int> &a,
const std::pair<double, int> &b) {
84 return a.first < b.first;
88 std::unordered_map<int, int> label_counts;
89 for (
int i = 0; i < k; ++i) {
90 label_counts[distances[i].second]++;
95 int predicted_label = -1;
96 for (
const auto &entry : label_counts) {
97 if (entry.second > max_votes) {
98 max_votes = entry.second;
99 predicted_label = entry.first;
103 return predicted_label;
108 const std::vector<double> &vec2) {
109 double distance = 0.0;
110 for (
size_t i = 0; i < vec1.size(); ++i) {
111 double diff = vec1[i] - vec2[i];
112 distance += diff * diff;
114 return std::sqrt(distance);
int predict(const std::vector< double > &input_vector, int k)
Predicts the label of a given input vector using KNN algorithm.
double calculateEuclideanDistance(const std::vector< double > &vec1, const std::vector< double > &vec2)
Calculates the Euclidean distance between two vectors.
void train(const std::vector< std::vector< double >> &training_data, const std::vector< int > &labels)
Trains the KNN model with the given training data and labels.
~KNN()
Destructor for the KNN class.
KNN()
Constructor for the KNN class.
K-Nearest Neighbors Clustering.