diff --git a/study/manager.py b/study/manager.py index bf9ef0a..6930cb8 100644 --- a/study/manager.py +++ b/study/manager.py @@ -1,3 +1,5 @@ +import os +import glob import logging import sqlite3 from copy import copy as shallow_copy @@ -5,6 +7,9 @@ from types import NoneType from typing import Optional +import json +import joblib +import pandas as pd import numpy as np import optuna from optuna.samplers import TPESampler @@ -63,6 +68,8 @@ def __init__(self, data_config: DataConfig, model_config: ModelConfig, study_con self.db_connection : Optional[sqlite3.Connection] = None self.db_cursor : Optional[sqlite3.Cursor] = None + self.model_output_path = None + """ Utils """ def init_logger(self): """Generates the logger for this study""" @@ -187,6 +194,12 @@ def save_results(self, replicate_n, trial: optuna.Trial, objective_val, metrics) else: new_entry_components[p_label] = v + # Save model's params dict as JSON file + self.model_output_path = self.study_config.output_path.parent / self.data_config.label + os.makedirs(self.model_output_path, exist_ok=True) + with open(self.model_output_path / f'model_rep{replicate_n}_trial{trial.number}.json', 'w') as f: + json.dump(new_entry_components, f) + # Re-order the values so they can cleanly save into the dataset ordered_values = [new_entry_components[k] for k in self.db_cols] @@ -215,6 +228,39 @@ def run(self): # Run a sub-study using this data self.run_replicate(i, train_idx, test_idx, x, y, s) + self.keep_the_best_model_for_replicate(i) + + def keep_the_best_model_for_replicate(self, i): + """ + Keep only the best model for given replicate and delete the rest + """ + # TODO: the db code below might be probably improved + con = self.db_connection + tables = pd.read_sql("SELECT * FROM sqlite_master", con=con).loc[:, 'name'] + for t in tables: + # Pull the dataframe from the database + try: + df_rep = pd.read_sql(f"SELECT * FROM {t}", con=con) + except: + self.logger.debug(f"Failed to read table {t}, ignoring it") + continue + # Get the best model for given replicate + best_model_trial_num = df_rep.loc[df_rep['log_loss (test)'].idxmin()].trial + # Keep only the best model for given replicate (e.g., `model_rep0_trial1.pkl`) and delete the rest + model_pattern = f"model_rep{i}_trial*.pkl" + model_files = glob.glob(os.path.join(self.model_output_path, model_pattern)) + # Keep only the best model and delete the rest + for model_file in model_files: + # Extract the trial number from the filename + trial_num = int(model_file.split("_trial")[-1].split(".pkl")[0]) + + if trial_num != best_model_trial_num: + os.remove(model_file) # Delete non-best models + os.remove(model_file.replace(".pkl", ".json")) + self.logger.debug(f"Deleted: {model_file}") + else: + self.logger.debug(f"Kept: {model_file}") + def prepare_run(self): # Control for RNG before proceeding init_seed = self.study_config.random_seed @@ -298,6 +344,11 @@ def opt_func(trial: optuna.Trial): # Save the metric values to the DB self.save_results(rep, trial, objective_value, metric_dict) + # Save the model for this trial + joblib.dump(model_manager._model, + filename=self.model_output_path / f"model_rep{rep}_trial{trial.number}.pkl") + self.logger.debug(f"Saved model for replicate {rep} and trial {trial.number}") + # Return the objective function so Optuna can run optimization based on it return objective_value