66 void fit(
const std::vector<std::vector<double>> &X_train,
67 const std::vector<std::vector<double>> &y_train,
75 std::vector<double>
predict(
const std::vector<double> &input_vector);
142 const std::vector<std::vector<double>> &y);
150 double compute_loss(
const std::vector<std::vector<double>> &X,
151 const std::vector<std::vector<double>> &y);
160 const std::vector<std::vector<double>> &y,
161 double learning_rate);
Bayesian Neural Network class.
std::vector< double > hidden_biases
Biases for the hidden layer.
void fit(const std::vector< std::vector< double >> &X_train, const std::vector< std::vector< double >> &y_train, int epochs=1000)
Train the Bayesian Neural Network.
void update_weights(const std::vector< std::vector< double >> &X, const std::vector< std::vector< double >> &y, double learning_rate)
Update weights using stochastic gradient descent.
int hidden_size
Number of hidden units in the network.
int input_size
Number of input features.
BNN(int input_size, int hidden_size, int output_size, double prior_variance=1.0)
Constructor for the BNN class.
double log_likelihood(const std::vector< std::vector< double >> &X, const std::vector< std::vector< double >> &y)
Compute the log-likelihood of the data.
double compute_loss(const std::vector< std::vector< double >> &X, const std::vector< std::vector< double >> &y)
Compute the negative log posterior (loss function)
double activation_function(double x)
Activation function for the hidden layer.
std::mt19937 rng
Mersenne Twister random number generator.
double prior_log_likelihood()
Compute the log-likelihood of the prior distribution.
std::vector< std::vector< double > > input_to_hidden_weights
Weights from input to hidden layer.
double prior_variance
Variance for the prior distribution.
std::vector< std::vector< double > > hidden_to_output_weights
Weights from hidden to output layer.
int output_size
Number of output units in the network.
std::vector< double > output_biases
Biases for the output layer.
std::vector< double > predict(const std::vector< double > &input_vector)
Predict the output for a given input.