Skip to content

Commit

Permalink
Merge pull request #29 from Az-r-ow/callbacks
Browse files Browse the repository at this point in the history
fix: wrapping up callbacks on early exit training
  • Loading branch information
Az-r-ow authored Mar 20, 2024
2 parents 6a64ef4 + f8217cf commit 8c05817
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/NeuralNet/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ double Network::train(std::vector<std::vector<double>> inputs,
try {
return onlineTraining(inputs, labels, epochs, callbacks);
} catch (const std::exception &e) {
trainingCheckpoint("onTrainEnd", callbacks);
std::cerr << "Training Interrupted : " << e.what() << '\n';
return loss;
}
Expand All @@ -68,6 +69,7 @@ double Network::train(std::vector<std::vector<std::vector<double>>> inputs,
try {
return onlineTraining(inputs, labels, epochs, callbacks);
} catch (const std::exception &e) {
trainingCheckpoint("onTrainEnd", callbacks); // wrap up callbacks
std::cerr << "Training Interrupted : " << e.what() << '\n';
return loss;
}
Expand All @@ -81,6 +83,7 @@ double Network::train(
try {
return this->trainer(trainingData, epochs, callbacks);
} catch (const std::exception &e) {
trainingCheckpoint("onTrainEnd", callbacks);
std::cerr << "Training Interrupted : " << e.what() << '\n';
return loss;
}
Expand All @@ -94,6 +97,7 @@ double Network::train(
try {
return this->trainer(trainingData, epochs, callbacks);
} catch (const std::exception &e) {
trainingCheckpoint("onTrainEnd", callbacks);
std::cerr << "Training Interrupted : " << e.what() << '\n';
return loss;
}
Expand Down
3 changes: 2 additions & 1 deletion src/NeuralNet/Network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ class Network : public Model {
* @param epochs An integer specifying the number of times the training
* algorithm should iterate over the dataset.
* @param callbacks A vector of `Callback` that will be called during training
* stages * @return A double value that represents the average loss of the
* stages
* @return A double value that represents the average loss of the
* training process. This can be used to gauge the effectiveness of the
* process.
*
Expand Down

0 comments on commit 8c05817

Please sign in to comment.