diff --git a/src/models/train_eval_laat.py b/src/models/train_eval_laat.py index c3be11f..f4f9718 100644 --- a/src/models/train_eval_laat.py +++ b/src/models/train_eval_laat.py @@ -131,6 +131,10 @@ def train( except KeyboardInterrupt: print('*' * 20) print('Exiting from training early') + if not (MODEL_FOLDER / f"best_{model_save_fname}.pt").exists(): + logger.info(f"saving best model so far from current epoch...") + torch.save(model.state_dict(), f"{MODEL_FOLDER / f'best_{model_save_fname}.pt'}") + return evals return evals