From 5769fc6ac2c1f70484a7a0e183755d2bab9bd433 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 26 Feb 2024 17:22:19 +0100 Subject: [PATCH] Add hyper parameters for SQIL. --- tuning/hp_search_spaces.py | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tuning/hp_search_spaces.py b/tuning/hp_search_spaces.py index aabd8e056..4f5a7ee9d 100644 --- a/tuning/hp_search_spaces.py +++ b/tuning/hp_search_spaces.py @@ -18,7 +18,9 @@ import optuna import sacred +import stable_baselines3 as sb3 +import imitation.scripts.train_imitation import imitation.scripts.train_preference_comparisons @@ -33,12 +35,16 @@ class RunSacredAsTrial: """The sacred experiment to run.""" sacred_ex: sacred.Experiment + """A function that returns a list of named configs to pass to sacred.run.""" suggest_named_configs: Callable[[optuna.Trial], List[str]] """A function that returns a dict of config updates to pass to sacred.run.""" suggest_config_updates: Callable[[optuna.Trial], Mapping[str, Any]] + """Command name to pass to sacred.run.""" + command_name: str = None + def __call__( self, trial: optuna.Trial, @@ -61,6 +67,7 @@ def __call__( experiment: sacred.Experiment = self.sacred_ex result = experiment.run( + command_name=self.command_name, config_updates=config_updates, named_configs=named_configs, options=run_options, @@ -115,4 +122,41 @@ def __call__( }, }, ), + sqil=RunSacredAsTrial( + sacred_ex=imitation.scripts.train_imitation.train_imitation_ex, + command_name="sqil", + suggest_named_configs=lambda _: [], + suggest_config_updates=lambda trial: { + "seed": trial.number, + "demonstrations": { + "n_expert_demos": 100, + "source": "generated", + }, + "sqil": { + "total_timesteps": 1e6, + "train_kwargs": { + + } + }, + "rl": { + "rl_cls": sb3.DQN, + "rl_kwargs": { + "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-2, log=True), + "buffer_size": trial.suggest_int("buffer_size", 1000, 100000), + "learning_starts": trial.suggest_int("learning_starts", 1000, 10000), + "batch_size": trial.suggest_int("batch_size", 32, 128), + "tau": trial.suggest_float("tau", 0., 1.), + "gamma": trial.suggest_float("gamma", 0.9, 0.999), + "train_freq": trial.suggest_int("train_freq", 1, 40), + "gradient_steps": trial.suggest_int("gradient_steps", 1, 10), + "target_update_interval": trial.suggest_int("target_update_interval", 1, 10000), + "exploration_fraction": trial.suggest_float("exploration_fraction", 0.01, 0.5), + "exploration_final_eps": trial.suggest_float("exploration_final_eps", 0.01, 1.0), + "exploration_initial_eps": trial.suggest_float("exploration_initial_eps", 0.01, 0.5), + "max_grad_norm": trial.suggest_float("max_grad_norm", 0.1, 10.0), + + }, + }, + }, + ), ) \ No newline at end of file