Skip to content

Commit

Permalink
Merge pull request #37 from Az-r-ow/bce
Browse files Browse the repository at this point in the history
BCE fixes and various changes
  • Loading branch information
Az-r-ow authored May 17, 2024
2 parents 05755c5 + c9eff7c commit 09dd99d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/NeuralNet/losses/BCE.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,25 @@ namespace NeuralNet {
class BCE : public Loss {
public:
static double cmpLoss(const Eigen::MatrixXd &o, const Eigen::MatrixXd &y) {
Eigen::MatrixXd loss = -(y.array() * o.array().log() +
(1.0 - y.array()) * (1.0 - o.array()).log());
double threshold = 1.0e-5;
Eigen::MatrixXd oTrim = trim(o, threshold);
Eigen::MatrixXd yTrim = trim(y, threshold);

Eigen::MatrixXd loss =
-(yTrim.array() * oTrim.array().log() +
(1.0 - yTrim.array()) * (1.0 - oTrim.array()).log());

if (loss.array().isNaN().any())
throw std::runtime_error(
"NaN value encountered. Inputs might be too big");

return loss.sum();
}

static Eigen::MatrixXd cmpLossGrad(const Eigen::MatrixXd &yHat,
const Eigen::MatrixXd &y) {
return (yHat.array() - y.array()) / (yHat.array() * (1.0 - y.array()));
return (yHat.array() - y.array()) /
((yHat.array() * (1.0 - yHat.array())) + 1e-9);
}
};

Expand Down
13 changes: 13 additions & 0 deletions src/NeuralNet/utils/Functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,19 @@ static Eigen::MatrixXd hardmax(const Eigen::MatrixXd &mat) {
return hardmaxMatrix;
}

/**
* @brief round the number < to the given threshold to 0
*
* @param logits Matrix of doubles
* @param threshold a double (default: 0.01)
*
* @return the same matrix with the values < threshold = 0
*/
static Eigen::MatrixXd trim(const Eigen::MatrixXd &logits,
double threshold = 0.01) {
return (logits.array() < threshold).select(0, logits);
}

/* SIGNAL HANDLING */
static void signalHandler(int signum) {
std::cout << "Interrupt signal (" << signum << ") received.\n";
Expand Down
38 changes: 38 additions & 0 deletions tests/test-losses.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <activations/activations.hpp>
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>
#include <losses/losses.hpp>
Expand All @@ -14,4 +15,41 @@ TEST_CASE("Testing Binary Cross-Entropy loss with random values", "[losses]") {
double loss = BCE::cmpLoss(o, y);

REQUIRE(loss >= 0);
}

TEST_CASE("Testing Binary Cross-Entropy derivation with pre-calculated values",
"[losses]") {
Eigen::MatrixXd o(2, 2);
Eigen::MatrixXd y(2, 2);

o << 0.5, 0.5, 0.2, 0.8;
y << 0, 1, 0, 1;

Eigen::MatrixXd grad = BCE::cmpLossGrad(o, y);

Eigen::MatrixXd exp(2, 2);

exp << 2.0, -2.0, 1.25, -1.25;

CHECK_MATRIX_APPROX(grad, exp, EPSILON);
}

TEST_CASE("Testing Binary Cross-Entropy with softmax activation", "[losses]") {
Eigen::MatrixXd i = Eigen::MatrixXd::Random(2, 2);
Eigen::MatrixXd y = Eigen::MatrixXd::Zero(2, 2);

y(0, 0) = 1;
y(1, 1) = 1;

Eigen::MatrixXd prob = Softmax::activate(i);

double loss = BCE::cmpLoss(prob, y);

CHECK(loss >= 0);

Eigen::MatrixXd grad = BCE::cmpLossGrad(prob, y);

bool hasNaN = grad.array().isNaN().any();

CHECK_FALSE(hasNaN);
}

0 comments on commit 09dd99d

Please sign in to comment.