From 84f73a5fb2cdd4078847a7b795eb051990b2037b Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Fri, 5 Jul 2024 18:16:06 +0100 Subject: [PATCH] Optuna main script --- sleap_sweep/main.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 sleap_sweep/main.py diff --git a/sleap_sweep/main.py b/sleap_sweep/main.py new file mode 100644 index 0000000..98c1a20 --- /dev/null +++ b/sleap_sweep/main.py @@ -0,0 +1,42 @@ +import optuna +import sleap + +from sleap_sweep.train.config import create_cfg + + +def objective(trial: optuna.Trial) -> float: + # define parameters to optimise + initial_learning_rate_suggest = trial.suggest_float( + "initial_learning_rate", 1e-5, 1e-2, log=True + ) # initially: initial_learning_rate= 1e-04 + + # create config with selected params + cfg = create_cfg({"initial_learning_rate": initial_learning_rate_suggest}) + + # create a SLEAP Trainer for that config + trainer = sleap.nn.training.Trainer.from_config(cfg) + + # train model + trainer.setup() # is this needed? + trainer.train() + + # return validation metric to optimise + val_metrics = sleap.load_metrics(cfg.outputs.run_name, split="val") + val_metric_optim = 0.5 * ( + val_metrics["vis.precision"] + val_metrics["vis.recall"] + ) + + return val_metric_optim + + +def main(): + study = optuna.create_study() + + # The optimization finishes after evaluating 1000 times or 3 seconds. + study.optimize(objective, n_trials=1000, timeout=3) + + print(f"Best params is {study.best_params} with value {study.best_value}") + + +if __name__ == "__main__": + main()