Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep the best model for each replicate #24

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions study/manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import glob
import logging
import sqlite3
from copy import copy as shallow_copy
from itertools import chain
from types import NoneType
from typing import Optional

import json
import joblib
import pandas as pd
import numpy as np
import optuna
from optuna.samplers import TPESampler
Expand Down Expand Up @@ -63,6 +68,8 @@ def __init__(self, data_config: DataConfig, model_config: ModelConfig, study_con
self.db_connection : Optional[sqlite3.Connection] = None
self.db_cursor : Optional[sqlite3.Cursor] = None

self.model_output_path = None

""" Utils """
def init_logger(self):
"""Generates the logger for this study"""
Expand Down Expand Up @@ -187,6 +194,12 @@ def save_results(self, replicate_n, trial: optuna.Trial, objective_val, metrics)
else:
new_entry_components[p_label] = v

# Save model's params dict as JSON file
self.model_output_path = self.study_config.output_path.parent / self.data_config.label
os.makedirs(self.model_output_path, exist_ok=True)
with open(self.model_output_path / f'model_rep{replicate_n}_trial{trial.number}.json', 'w') as f:
json.dump(new_entry_components, f)

# Re-order the values so they can cleanly save into the dataset
ordered_values = [new_entry_components[k] for k in self.db_cols]

Expand Down Expand Up @@ -215,6 +228,39 @@ def run(self):
# Run a sub-study using this data
self.run_replicate(i, train_idx, test_idx, x, y, s)

self.keep_the_best_model_for_replicate(i)

def keep_the_best_model_for_replicate(self, i):
"""
Keep only the best model for given replicate and delete the rest
"""
# TODO: the db code below might be probably improved
con = self.db_connection
tables = pd.read_sql("SELECT * FROM sqlite_master", con=con).loc[:, 'name']
for t in tables:
# Pull the dataframe from the database
try:
df_rep = pd.read_sql(f"SELECT * FROM {t}", con=con)
except:
self.logger.debug(f"Failed to read table {t}, ignoring it")
continue
# Get the best model for given replicate
best_model_trial_num = df_rep.loc[df_rep['log_loss (test)'].idxmin()].trial
# Keep only the best model for given replicate (e.g., `model_rep0_trial1.pkl`) and delete the rest
model_pattern = f"model_rep{i}_trial*.pkl"
model_files = glob.glob(os.path.join(self.model_output_path, model_pattern))
# Keep only the best model and delete the rest
for model_file in model_files:
# Extract the trial number from the filename
trial_num = int(model_file.split("_trial")[-1].split(".pkl")[0])

if trial_num != best_model_trial_num:
os.remove(model_file) # Delete non-best models
os.remove(model_file.replace(".pkl", ".json"))
self.logger.debug(f"Deleted: {model_file}")
else:
self.logger.debug(f"Kept: {model_file}")

def prepare_run(self):
# Control for RNG before proceeding
init_seed = self.study_config.random_seed
Expand Down Expand Up @@ -298,6 +344,11 @@ def opt_func(trial: optuna.Trial):
# Save the metric values to the DB
self.save_results(rep, trial, objective_value, metric_dict)

# Save the model for this trial
joblib.dump(model_manager._model,
filename=self.model_output_path / f"model_rep{rep}_trial{trial.number}.pkl")
self.logger.debug(f"Saved model for replicate {rep} and trial {trial.number}")

# Return the objective function so Optuna can run optimization based on it
return objective_value

Expand Down