diff --git a/delira/training/base_experiment.py b/delira/training/base_experiment.py index ebe02bbe..e672d77b 100644 --- a/delira/training/base_experiment.py +++ b/delira/training/base_experiment.py @@ -2,7 +2,6 @@ import logging import pickle import os -from datetime import datetime import warnings import copy @@ -17,6 +16,7 @@ from delira.models import AbstractNetwork from delira.utils import DeliraConfig +from delira.training.utils import generate_save_path from delira.training.base_trainer import BaseNetworkTrainer from delira.training.predictor import Predictor @@ -50,6 +50,7 @@ def __init__(self, checkpoint_freq=1, trainer_cls=BaseNetworkTrainer, predictor_cls=Predictor, + unique_name=True, **kwargs): """ @@ -87,6 +88,9 @@ def __init__(self, the trainer class to use for training the model predictor_cls : subclass of :class:`Predictor` the predictor class to use for testing the model + unique_name : boolean + if the name is not unique an experiment with the same + name will be continued **kwargs : additional keyword arguments @@ -109,12 +113,11 @@ def __init__(self, if save_path is None: save_path = os.path.abspath(".") - self.save_path = os.path.join(save_path, name, - str(datetime.now().strftime( - "%y-%m-%d_%H-%M-%S"))) - - if os.path.isdir(self.save_path): - logger.warning("Save Path %s already exists") + if unique_name: + self.save_path = generate_save_path( + os.path.join(save_path, name)) + else: + self.save_path = os.path.join(save_path, name) os.makedirs(self.save_path, exist_ok=True) diff --git a/delira/training/utils.py b/delira/training/utils.py index 23dd02eb..5d608c11 100644 --- a/delira/training/utils.py +++ b/delira/training/utils.py @@ -1,5 +1,7 @@ import collections import numpy as np +import os +from datetime import datetime def recursively_convert_elements(element, check_type, conversion_fn): @@ -98,3 +100,19 @@ def convert_to_numpy_identity(*args, **kwargs): _correct_zero_shape) return args, kwargs + + +def generate_save_path(save_path): + i = 0 + now = datetime.now() + date_str = '{}_{:02d}_{:02d}_'.format( + now.year, now.month, now.day) + while True: + new_path = os.path.join(save_path, '{}{:03d}'.format(date_str, i)) + i += 1 + if not os.path.isdir(new_path): + break + if i: + print('Save path is a duplicate and got changed to {}' + .format(new_path)) + return new_path