From 952228471b729e4a6780317e9053bf20bac1a199 Mon Sep 17 00:00:00 2001 From: Mirko Bronzi Date: Tue, 10 Sep 2024 10:38:43 -0400 Subject: [PATCH 1/2] removed hardcoded params for best checkpoint --- amlrt_project/train.py | 7 ++++--- examples/config.yaml | 6 ++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/amlrt_project/train.py b/amlrt_project/train.py index 8579478..47ebc75 100644 --- a/amlrt_project/train.py +++ b/amlrt_project/train.py @@ -161,14 +161,15 @@ def train_impl(model, datamodule, output, hyper_params, use_progress_bar, check_and_log_hp(['max_epoch'], hyper_params) best_model_path = os.path.join(output, BEST_MODEL_NAME) + best_checkpoint_params = hyper_params['best_checkpoint'] best_checkpoint_callback = ModelCheckpoint( dirpath=best_model_path, filename='model', save_top_k=1, verbose=use_progress_bar, - monitor="val_loss", - mode="min", - every_n_epochs=1, + monitor=best_checkpoint_params['metric'], + mode=best_checkpoint_params['mode'], + every_n_epochs=best_checkpoint_params['every_n_epochs'] ) last_model_path = os.path.join(output, LAST_MODEL_NAME) diff --git a/examples/config.yaml b/examples/config.yaml index c3fcf64..6359ae1 100644 --- a/examples/config.yaml +++ b/examples/config.yaml @@ -27,3 +27,9 @@ early_stopping: metric: val_loss mode: min patience: 3 + +# best checkpoint params +best_checkpoint: + metric: val_loss + mode: min + every_n_epochs: 1 \ No newline at end of file From 728d1de891d5c49a2b0e30cc484c803eb265bf82 Mon Sep 17 00:00:00 2001 From: Mirko Bronzi Date: Tue, 10 Sep 2024 10:48:18 -0400 Subject: [PATCH 2/2] centralized the definition of the metric/mode --- examples/config.yaml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/config.yaml b/examples/config.yaml index 6359ae1..97556dc 100644 --- a/examples/config.yaml +++ b/examples/config.yaml @@ -22,14 +22,20 @@ hidden_dim: 256 num_classes: 10 architecture: simple_mlp +# here wew centralize the metric and the mode to use in both early stopping and +# best checkpoint selection. If instead you want to use different metric/mode, +# remove this section and define them directly in the early_stopping / best_checkpoint blocks. +metric_to_use: 'val_loss' +mode_to_use: 'min' + # early stopping early_stopping: - metric: val_loss - mode: min + metric: ${metric_to_use} + mode: ${mode_to_use} patience: 3 # best checkpoint params best_checkpoint: - metric: val_loss - mode: min + metric: ${metric_to_use} + mode: ${mode_to_use} every_n_epochs: 1 \ No newline at end of file