Skip to content

Commit

Permalink
Add hyper parameters for SQIL.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Feb 26, 2024
1 parent 7311d1c commit 5769fc6
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tuning/hp_search_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import optuna
import sacred
import stable_baselines3 as sb3

import imitation.scripts.train_imitation
import imitation.scripts.train_preference_comparisons


Expand All @@ -33,12 +35,16 @@ class RunSacredAsTrial:
"""The sacred experiment to run."""
sacred_ex: sacred.Experiment


"""A function that returns a list of named configs to pass to sacred.run."""
suggest_named_configs: Callable[[optuna.Trial], List[str]]

"""A function that returns a dict of config updates to pass to sacred.run."""
suggest_config_updates: Callable[[optuna.Trial], Mapping[str, Any]]

"""Command name to pass to sacred.run."""
command_name: str = None

def __call__(
self,
trial: optuna.Trial,
Expand All @@ -61,6 +67,7 @@ def __call__(

experiment: sacred.Experiment = self.sacred_ex
result = experiment.run(
command_name=self.command_name,
config_updates=config_updates,
named_configs=named_configs,
options=run_options,
Expand Down Expand Up @@ -115,4 +122,41 @@ def __call__(
},
},
),
sqil=RunSacredAsTrial(
sacred_ex=imitation.scripts.train_imitation.train_imitation_ex,
command_name="sqil",
suggest_named_configs=lambda _: [],
suggest_config_updates=lambda trial: {
"seed": trial.number,
"demonstrations": {
"n_expert_demos": 100,
"source": "generated",
},
"sqil": {
"total_timesteps": 1e6,
"train_kwargs": {

}
},
"rl": {
"rl_cls": sb3.DQN,
"rl_kwargs": {
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-2, log=True),
"buffer_size": trial.suggest_int("buffer_size", 1000, 100000),
"learning_starts": trial.suggest_int("learning_starts", 1000, 10000),
"batch_size": trial.suggest_int("batch_size", 32, 128),
"tau": trial.suggest_float("tau", 0., 1.),
"gamma": trial.suggest_float("gamma", 0.9, 0.999),
"train_freq": trial.suggest_int("train_freq", 1, 40),
"gradient_steps": trial.suggest_int("gradient_steps", 1, 10),
"target_update_interval": trial.suggest_int("target_update_interval", 1, 10000),
"exploration_fraction": trial.suggest_float("exploration_fraction", 0.01, 0.5),
"exploration_final_eps": trial.suggest_float("exploration_final_eps", 0.01, 1.0),
"exploration_initial_eps": trial.suggest_float("exploration_initial_eps", 0.01, 0.5),
"max_grad_norm": trial.suggest_float("max_grad_norm", 0.1, 10.0),

},
},
},
),
)

0 comments on commit 5769fc6

Please sign in to comment.