-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bbbbef4
commit 6d9a755
Showing
2 changed files
with
125 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#ifndef MEAN_SQUARED_ERROR_H | ||
#define MEAN_SQUARED_ERROR_H | ||
|
||
#ifdef __cplusplus | ||
#include <iostream> | ||
#include <algorithm> | ||
#include <vector> | ||
#include <cmath> | ||
#include <cassert> | ||
#include <numbers> | ||
#endif | ||
|
||
/** | ||
* @brief losses namespace that contains a couple of useful losses in machine learning | ||
* | ||
*/ | ||
namespace losses { | ||
|
||
/** | ||
* @brief sigmoid activation function | ||
* @param x: double, the input parameter | ||
* @return double: the sigmoid output | ||
*/ | ||
double sigmoid(double x) { | ||
return 1.0 / (1.0 + exp(-x)); | ||
} | ||
|
||
/** | ||
* @brief mean squared error function | ||
* @param y: vector, the original labels | ||
* @param y_hat: vector, the predicted labels | ||
* @return double: the mean squared error | ||
*/ | ||
double mean_squared_error(std::vector<double> const& y, std::vector<double> const& y_hat) { | ||
assert(y.size() == y_hat.size()); | ||
size_t n = y.size(); | ||
double mse = 0.0; | ||
for(size_t i = 0; i<n; i++) { | ||
mse += std::powf(y[i] - y_hat[i], 2); | ||
} | ||
return mse / double(n); | ||
} | ||
|
||
/** | ||
* @brief root mean squared error function | ||
* @param y: vector, the original labels | ||
* @param y_hat: vector, the predicted labels | ||
* @return double: the root mean squared error | ||
*/ | ||
double root_mean_squared_error(std::vector<double> const& y, std::vector<double> const& y_hat) { | ||
return std::sqrt(mean_squared_error(y, y_hat)); | ||
} | ||
|
||
/** | ||
* @brief mean absolute error function | ||
* @param y: vector, the original labels | ||
* @param y_hat: vector, the predicted labels | ||
* @return double: the mean absolute error | ||
*/ | ||
double mean_absolute_error(std::vector<double> const& y, std::vector<double> const& y_hat) { | ||
assert(y.size() == y_hat.size()); | ||
size_t n = y.size(); | ||
double mae = 0.0; | ||
for(size_t i = 0; i<n; i++) { | ||
mae += std::abs(y[i] - y_hat[i]); | ||
} | ||
return mae / double(n); | ||
} | ||
|
||
/** | ||
* @brief binary crossentropy loss function | ||
* @param y: vector, the original labels | ||
* @param y_hat: vector, the predicted labels | ||
* @return double: the binary crossentropy loss | ||
*/ | ||
double binary_crossentropy_loss(std::vector<double> const& y, std::vector<double> const& y_hat) { | ||
assert(y.size() == y_hat.size()); | ||
for(auto & x : y) { | ||
assert(x == 0.0 || x == 1.0); | ||
} | ||
|
||
size_t n = y.size(); | ||
double bce = 0.0, eps = 1e-15; | ||
for(size_t i = 0; i<n; i++) { | ||
double prob = sigmoid(y_hat[i]); | ||
double clipped_y_hat = std::clamp(prob, eps, 1 - eps); | ||
bce += (y[i]*log(clipped_y_hat) + (1-y[i])*log(1 - clipped_y_hat)); | ||
} | ||
return -bce / double(n); | ||
} | ||
|
||
} | ||
|
||
|
||
#endif | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#define CATCH_CONFIG_MAIN | ||
#include "../../../src/machine_learning/loss_functions/losses.h" | ||
#include "../../../third_party/catch.hpp" | ||
|
||
using namespace losses; | ||
|
||
TEST_CASE("Testing mean squared error") { | ||
|
||
std::vector<double> v1 { 1.23, 4.25, 4.4, 1.231, 5.567 }; | ||
std::vector<double> v2 { 4.56, 4.123, 1.234, 6.432, 5.555 }; | ||
|
||
REQUIRE(mean_squared_error(v1, v2) == Approx(9.6358263926).epsilon(1e-6)); | ||
} | ||
|
||
TEST_CASE("Testing mean absolute error") { | ||
std::vector<double> v1 { 1.23, 4.25, 4.4, 1.231, 5.567 }; | ||
std::vector<double> v2 { 4.56, 4.123, 1.234, 6.432, 5.555 }; | ||
|
||
REQUIRE(mean_absolute_error(v1, v2) == Approx(2.36720).epsilon(1e-6)); | ||
} | ||
|
||
TEST_CASE("Testing binary crossentropy loss") { | ||
std::vector<double> v1 { 0., 0., 0., 0., 1., 0., 0., 0., 0., 1. }; | ||
std::vector<double> v2 { 1.8055, 0.9193, -0.2527, 1.0489, 0.5396, -1.2046, -0.9479, 0.8274, | ||
-0.0548, -0.1902 }; | ||
|
||
REQUIRE(binary_crossentropy_loss(v1, v2) == Approx(0.8834657559).epsilon(1e-6)); | ||
} | ||
|