Skip to content

Commit

Permalink
Added loss functions
Browse files Browse the repository at this point in the history
  • Loading branch information
spirosmaggioros committed Sep 3, 2024
1 parent bbbbef4 commit 6d9a755
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 0 deletions.
96 changes: 96 additions & 0 deletions src/machine_learning/loss_functions/losses.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#ifndef MEAN_SQUARED_ERROR_H
#define MEAN_SQUARED_ERROR_H

#ifdef __cplusplus
#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
#include <cassert>
#include <numbers>
#endif

/**
* @brief losses namespace that contains a couple of useful losses in machine learning
*
*/
namespace losses {

/**
* @brief sigmoid activation function
* @param x: double, the input parameter
* @return double: the sigmoid output
*/
double sigmoid(double x) {
return 1.0 / (1.0 + exp(-x));
}

/**
* @brief mean squared error function
* @param y: vector, the original labels
* @param y_hat: vector, the predicted labels
* @return double: the mean squared error
*/
double mean_squared_error(std::vector<double> const& y, std::vector<double> const& y_hat) {
assert(y.size() == y_hat.size());
size_t n = y.size();
double mse = 0.0;
for(size_t i = 0; i<n; i++) {
mse += std::powf(y[i] - y_hat[i], 2);
}
return mse / double(n);
}

/**
* @brief root mean squared error function
* @param y: vector, the original labels
* @param y_hat: vector, the predicted labels
* @return double: the root mean squared error
*/
double root_mean_squared_error(std::vector<double> const& y, std::vector<double> const& y_hat) {
return std::sqrt(mean_squared_error(y, y_hat));
}

/**
* @brief mean absolute error function
* @param y: vector, the original labels
* @param y_hat: vector, the predicted labels
* @return double: the mean absolute error
*/
double mean_absolute_error(std::vector<double> const& y, std::vector<double> const& y_hat) {
assert(y.size() == y_hat.size());
size_t n = y.size();
double mae = 0.0;
for(size_t i = 0; i<n; i++) {
mae += std::abs(y[i] - y_hat[i]);
}
return mae / double(n);
}

/**
* @brief binary crossentropy loss function
* @param y: vector, the original labels
* @param y_hat: vector, the predicted labels
* @return double: the binary crossentropy loss
*/
double binary_crossentropy_loss(std::vector<double> const& y, std::vector<double> const& y_hat) {
assert(y.size() == y_hat.size());
for(auto & x : y) {
assert(x == 0.0 || x == 1.0);
}

size_t n = y.size();
double bce = 0.0, eps = 1e-15;
for(size_t i = 0; i<n; i++) {
double prob = sigmoid(y_hat[i]);
double clipped_y_hat = std::clamp(prob, eps, 1 - eps);
bce += (y[i]*log(clipped_y_hat) + (1-y[i])*log(1 - clipped_y_hat));
}
return -bce / double(n);
}

}


#endif

29 changes: 29 additions & 0 deletions tests/machine_learning/loss_functions/losses.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#define CATCH_CONFIG_MAIN
#include "../../../src/machine_learning/loss_functions/losses.h"
#include "../../../third_party/catch.hpp"

using namespace losses;

TEST_CASE("Testing mean squared error") {

std::vector<double> v1 { 1.23, 4.25, 4.4, 1.231, 5.567 };
std::vector<double> v2 { 4.56, 4.123, 1.234, 6.432, 5.555 };

REQUIRE(mean_squared_error(v1, v2) == Approx(9.6358263926).epsilon(1e-6));
}

TEST_CASE("Testing mean absolute error") {
std::vector<double> v1 { 1.23, 4.25, 4.4, 1.231, 5.567 };
std::vector<double> v2 { 4.56, 4.123, 1.234, 6.432, 5.555 };

REQUIRE(mean_absolute_error(v1, v2) == Approx(2.36720).epsilon(1e-6));
}

TEST_CASE("Testing binary crossentropy loss") {
std::vector<double> v1 { 0., 0., 0., 0., 1., 0., 0., 0., 0., 1. };
std::vector<double> v2 { 1.8055, 0.9193, -0.2527, 1.0489, 0.5396, -1.2046, -0.9479, 0.8274,
-0.0548, -0.1902 };

REQUIRE(binary_crossentropy_loss(v1, v2) == Approx(0.8834657559).epsilon(1e-6));
}

0 comments on commit 6d9a755

Please sign in to comment.