Skip to content

Commit

Permalink
Merge pull request #36 from Az-r-ow/bce
Browse files Browse the repository at this point in the history
BCE
  • Loading branch information
Az-r-ow authored May 14, 2024
2 parents 2212d31 + 5bf80d9 commit 47c3bdb
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/NeuralNet/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ void Network::setLoss(LOSS loss) {
this->cmpLoss = MCE::cmpLoss;
this->cmpLossGrad = MCE::cmpLossGrad;
break;
case LOSS::BCE:
this->cmpLoss = BCE::cmpLoss;
this->cmpLossGrad = BCE::cmpLossGrad;
default:
assert(false && "Loss not defined");
break;
Expand Down
23 changes: 23 additions & 0 deletions src/NeuralNet/losses/BCE.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "Loss.hpp"

namespace NeuralNet {
/**
* Binary Cross-Entropy
*/
class BCE : public Loss {
public:
static double cmpLoss(const Eigen::MatrixXd &o, const Eigen::MatrixXd &y) {
Eigen::MatrixXd loss =
-(y.array() * o.array().log() + (1.0 - y.array()).log());
return loss.sum();
}

static Eigen::MatrixXd cmpLossGrad(const Eigen::MatrixXd &yHat,
const Eigen::MatrixXd &y) {
return (yHat.array() - y.array()) / (yHat.array() * (1.0 - y.array()));
}
};

} // namespace NeuralNet
4 changes: 2 additions & 2 deletions src/NeuralNet/losses/Loss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ class Loss {
/**
* @brief This function computes the loss gradient w.r.t the outputs
*
* @param o The outputs from the output layer
* @param yHat The outputs from the output layer
* @param y The labels (expected vals)
*
* @return The current iteration's gradient
*/
static Eigen::MatrixXd cmpLossGrad(const Eigen::MatrixXd &o,
static Eigen::MatrixXd cmpLossGrad(const Eigen::MatrixXd &yHat,
const Eigen::MatrixXd &y);
};
} // namespace NeuralNet
3 changes: 2 additions & 1 deletion src/NeuralNet/losses/losses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@

#pragma once

#include "MCE.hpp"
#include "BCE.hpp" // Binary Cross-Entropy
#include "MCE.hpp" // Multiclass Cross-Entropy
#include "Quadratic.hpp"
3 changes: 2 additions & 1 deletion src/NeuralNet/utils/Enums.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ enum class WEIGHT_INIT {

enum class LOSS {
MCE, // Multi-class Cross Entropy
QUADRATIC
QUADRATIC,
BCE // Binary Cross-Entropy
};
} // namespace NeuralNet
3 changes: 2 additions & 1 deletion src/bindings/NeuralNetPy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ PYBIND11_MODULE(NeuralNetPy, m) {

py::enum_<LOSS>(m, "LOSS")
.value("QUADRATIC", LOSS::QUADRATIC)
.value("MCE", LOSS::MCE);
.value("MCE", LOSS::MCE)
.value("BCE", LOSS::BCE);

py::module optimizers_m = m.def_submodule("optimizers", R"pbdoc(
Optimizers
Expand Down
3 changes: 2 additions & 1 deletion tests/test-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <callbacks/Callback.hpp>
#include <callbacks/EarlyStopping.hpp>
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers.hpp>
#include <utils/Variants.hpp>
#include <vector>

Expand All @@ -11,7 +12,7 @@ TEST_CASE(
"EarlyStopping callback throws exception when the metric is not found",
"[callback]") {
std::shared_ptr<Callback> earlyStopping =
std::make_shared<EarlyStopping>("LOSS", 0.1);
std::make_shared<EarlyStopping>("NOT_A_METRIC", 0.1);

Network network;

Expand Down
2 changes: 1 addition & 1 deletion tests/test-network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ SCENARIO("The network updates the weights and biases as pre-calculated") {

Network checkpoint;

Model::load_from_file("checkpoint-0.bin", checkpoint);
Model::load_from_file("N9NeuralNet7NetworkE-checkpoint-0.bin", checkpoint);

REQUIRE(checkpoint.getNumLayers() == network.getNumLayers());

Expand Down

0 comments on commit 47c3bdb

Please sign in to comment.