Skip to content

Commit

Permalink
Merge pull request #127 from mila-iqia/not_hard_coding_best_checkpoin…
Browse files Browse the repository at this point in the history
…t_values

removed hardcoded params for best checkpoint
  • Loading branch information
mirkobronzi authored Sep 10, 2024
2 parents bae84ec + 728d1de commit 5ff13e5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
7 changes: 4 additions & 3 deletions amlrt_project/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,15 @@ def train_impl(model, datamodule, output, hyper_params, use_progress_bar,
check_and_log_hp(['max_epoch'], hyper_params)

best_model_path = os.path.join(output, BEST_MODEL_NAME)
best_checkpoint_params = hyper_params['best_checkpoint']
best_checkpoint_callback = ModelCheckpoint(
dirpath=best_model_path,
filename='model',
save_top_k=1,
verbose=use_progress_bar,
monitor="val_loss",
mode="min",
every_n_epochs=1,
monitor=best_checkpoint_params['metric'],
mode=best_checkpoint_params['mode'],
every_n_epochs=best_checkpoint_params['every_n_epochs']
)

last_model_path = os.path.join(output, LAST_MODEL_NAME)
Expand Down
16 changes: 14 additions & 2 deletions examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,20 @@ hidden_dim: 256
num_classes: 10
architecture: simple_mlp

# here wew centralize the metric and the mode to use in both early stopping and
# best checkpoint selection. If instead you want to use different metric/mode,
# remove this section and define them directly in the early_stopping / best_checkpoint blocks.
metric_to_use: 'val_loss'
mode_to_use: 'min'

# early stopping
early_stopping:
metric: val_loss
mode: min
metric: ${metric_to_use}
mode: ${mode_to_use}
patience: 3

# best checkpoint params
best_checkpoint:
metric: ${metric_to_use}
mode: ${mode_to_use}
every_n_epochs: 1

0 comments on commit 5ff13e5

Please sign in to comment.