diff --git a/cellbender/remove_background/train.py b/cellbender/remove_background/train.py index 8d77808..bf7bed6 100644 --- a/cellbender/remove_background/train.py +++ b/cellbender/remove_background/train.py @@ -152,7 +152,6 @@ def run_training(model: RemoveBackgroundPyroModel, # Initialize train and tests ELBO with empty lists. train_elbo = [] - test_elbo = [] lr = [] epoch_checkpoint_freq = 1000 # a large number... it will be recalculated @@ -212,16 +211,15 @@ def run_training(model: RemoveBackgroundPyroModel, if epoch % test_freq == 0: model.eval() total_epoch_loss_test = evaluate_epoch(svi, test_loader) - test_elbo.append(-total_epoch_loss_test) model.loss['test']['epoch'].append(epoch) model.loss['test']['elbo'].append(-total_epoch_loss_test) logger.info("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) # Check whether test ELBO has spiked beyond specified conditions. - if (epoch_elbo_fail_fraction is not None) and (len(test_elbo) > 2): - current_diff = max(0., test_elbo[-2] - test_elbo[-1]) - overall_diff = np.abs(test_elbo[-2] - test_elbo[0]) + if (epoch_elbo_fail_fraction is not None) and (len(model.loss['test']['elbo']) > 2): + current_diff = max(0., model.loss['test']['elbo'][-2] - model.loss['test']['elbo'][-1]) + overall_diff = np.abs(model.loss['test']['elbo'][-2] - model.loss['test']['elbo'][0]) fractional_spike = current_diff / overall_diff if fractional_spike > epoch_elbo_fail_fraction: raise ElboException( @@ -245,15 +243,15 @@ def run_training(model: RemoveBackgroundPyroModel, # Check on the final test ELBO to see if it meets criteria. if final_elbo_fail_fraction is not None: - best_test_elbo = max(test_elbo) - if test_elbo[-1] < best_test_elbo: - final_best_diff = best_test_elbo - test_elbo[-1] - initial_best_diff = best_test_elbo - test_elbo[0] + best_test_elbo = max(model.loss['test']['elbo']) + if model.loss['test']['elbo'][-1] < best_test_elbo: + final_best_diff = best_test_elbo - model.loss['test']['elbo'][-1] + initial_best_diff = best_test_elbo - model.loss['test']['elbo'][0] if (final_best_diff / initial_best_diff) > final_elbo_fail_fraction: raise ElboException( - f'Training failed because final test loss {test_elbo[-1]:.2f} ' + f"Training failed because final test loss {model.loss['test']['elbo'][-1]:.2f} " f'is not sufficiently close to best test loss {best_test_elbo:.2f}, ' - f'compared to the initial test loss {test_elbo[0]:.2f}. ' + f"compared to the initial test loss {model.loss['test']['elbo'][0]:.2f}. " f'Fractional difference is {final_best_diff / initial_best_diff:.2f}, ' f'which is > specified final_elbo_fail_fraction {final_elbo_fail_fraction:.2f}' ) @@ -284,14 +282,14 @@ def run_training(model: RemoveBackgroundPyroModel, logger.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) # Check final ELBO meets conditions. - if (final_elbo_fail_fraction is not None) and (len(test_elbo) > 1): - best_test_elbo = max(test_elbo) - if -test_elbo[-1] >= -best_test_elbo * (1 + final_elbo_fail_fraction): - raise ElboException(f'Training failed because final test loss ({-test_elbo[-1]:.4f}) ' + if (final_elbo_fail_fraction is not None) and (len(model.loss['test']['elbo']) > 1): + best_test_elbo = max(model.loss['test']['elbo']) + if -model.loss['test']['elbo'][-1] >= -best_test_elbo * (1 + final_elbo_fail_fraction): + raise ElboException(f"Training failed because final test loss ({-model.loss['test']['elbo'][-1]:.4f}) " f'exceeds best test loss ({-best_test_elbo:.4f}) by >= ' f'{100 * final_elbo_fail_fraction:.1f}%') # Free up all the GPU memory we can once training is complete. torch.cuda.empty_cache() - return train_elbo, test_elbo + return train_elbo, model.loss['test']['elbo']