From e69dbd64c3c1595ccb9f2935f0557520bf418369 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Thu, 29 Feb 2024 20:01:39 +0100 Subject: [PATCH] More formatting fixes. --- src/imitation/scripts/config/tuning.py | 7 +- tuning/hp_search_spaces.py | 149 ++++++++++++++++++------- tuning/rerun_best_trial.py | 13 ++- tuning/tune.py | 10 +- 4 files changed, 130 insertions(+), 49 deletions(-) diff --git a/src/imitation/scripts/config/tuning.py b/src/imitation/scripts/config/tuning.py index faf0517cb..22d1d82fb 100644 --- a/src/imitation/scripts/config/tuning.py +++ b/src/imitation/scripts/config/tuning.py @@ -2,7 +2,6 @@ import ray.tune as tune import sacred -from torch import nn from imitation.algorithms import dagger as dagger_alg from imitation.scripts.parallel import parallel_ex @@ -200,11 +199,11 @@ def pc(): "config_updates": { "active_selection_oversampling": tune.randint(1, 11), "comparison_queue_size": tune.randint( - 1, 1001 + 1, 1001, ), # upper bound determined by total_comparisons=1000 "exploration_frac": tune.uniform(0.0, 0.5), "fragment_length": tune.randint( - 1, 1001 + 1, 1001, ), # trajectories are 1000 steps long "gatherer_kwargs": { "temperature": tune.uniform(0.0, 2.0), @@ -218,7 +217,7 @@ def pc(): "discount_factor": tune.uniform(0.95, 1.0), }, "query_schedule": tune.choice( - ["hyperbolic", "constant", "inverse_quadratic"] + ["hyperbolic", "constant", "inverse_quadratic",] ), "trajectory_generator_kwargs": { "switch_prob": tune.uniform(0.1, 1), diff --git a/tuning/hp_search_spaces.py b/tuning/hp_search_spaces.py index 85da0e3f7..5a4e7db1d 100644 --- a/tuning/hp_search_spaces.py +++ b/tuning/hp_search_spaces.py @@ -14,14 +14,13 @@ """ import dataclasses -from typing import Any, Callable, Dict, List, Mapping +from typing import Any, Callable, Dict, List, Mapping, Optional import optuna import sacred -import stable_baselines3 as sb3 import imitation.scripts.train_imitation -import imitation.scripts.train_preference_comparisons +import imitation.scripts.train_preference_comparisons as train_pc_script @dataclasses.dataclass @@ -42,10 +41,13 @@ class RunSacredAsTrial: suggest_config_updates: Callable[[optuna.Trial], Mapping[str, Any]] """Command name to pass to sacred.run.""" - command_name: str = None + command_name: Optional[str] = None def __call__( - self, trial: optuna.Trial, run_options: Dict, extra_named_configs: List[str] + self, + trial: optuna.Trial, + run_options: Dict, + extra_named_configs: List[str], ) -> float: """Run the sacred experiment and return the performance. @@ -53,8 +55,13 @@ def __call__( trial: The optuna trial to sample hyperparameters for. run_options: Options to pass to sacred.run(options=). extra_named_configs: Additional named configs to pass to sacred.run. - """ + Returns: + The performance of the trial. + + Raises: + RuntimeError: If the trial fails. + """ config_updates = self.suggest_config_updates(trial) named_configs = self.suggest_named_configs(trial) + extra_named_configs @@ -71,15 +78,16 @@ def __call__( ) if result.status != "COMPLETED": raise RuntimeError( - f"Trial failed with {result.fail_trace()} and status {result.status}." + f"Trial failed with {result.fail_trace()} and status {result.status}.", ) return result.result["imit_stats"]["monitor_return_mean"] -"""A mapping from algorithm names to functions that run the algorithm as an optuna trial.""" +"""A mapping from algorithm names to functions that run the algorithm as an optuna +trial.""" objectives_by_algo = dict( pc=RunSacredAsTrial( - sacred_ex=imitation.scripts.train_preference_comparisons.train_preference_comparisons_ex, + sacred_ex=train_pc_script.train_preference_comparisons_ex, suggest_named_configs=lambda _: ["reward.reward_ensemble"], suggest_config_updates=lambda trial: { "seed": trial.number, @@ -88,46 +96,69 @@ def __call__( "total_comparisons": 1000, "active_selection": True, "active_selection_oversampling": trial.suggest_int( - "active_selection_oversampling", 1, 11 + "active_selection_oversampling", + 1, + 11, ), "comparison_queue_size": trial.suggest_int( - "comparison_queue_size", 1, 1001 + "comparison_queue_size", + 1, + 1001, ), # upper bound determined by total_comparisons=1000 "exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5), "fragment_length": trial.suggest_int( - "fragment_length", 1, 1001 + "fragment_length", + 1, + 1001, ), # trajectories are 1000 steps long "gatherer_kwargs": { "temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0), "discount_factor": trial.suggest_float( - "gatherer_discount_factor", 0.95, 1.0 + "gatherer_discount_factor", + 0.95, + 1.0, ), "sample": trial.suggest_categorical("gatherer_sample", [True, False]), }, "initial_epoch_multiplier": trial.suggest_float( - "initial_epoch_multiplier", 1, 200.0 + "initial_epoch_multiplier", + 1, + 200.0, ), "initial_comparison_frac": trial.suggest_float( - "initial_comparison_frac", 0.01, 1.0 + "initial_comparison_frac", + 0.01, + 1.0, ), "num_iterations": trial.suggest_int("num_iterations", 1, 51), "preference_model_kwargs": { "noise_prob": trial.suggest_float( - "preference_model_noise_prob", 0.0, 0.1 + "preference_model_noise_prob", + 0.0, + 0.1, ), "discount_factor": trial.suggest_float( - "preference_model_discount_factor", 0.95, 1.0 + "preference_model_discount_factor", + 0.95, + 1.0, ), }, "query_schedule": trial.suggest_categorical( - "query_schedule", ["hyperbolic", "constant", "inverse_quadratic"] + "query_schedule", + [ + "hyperbolic", + "constant", + "inverse_quadratic", + ], ), "trajectory_generator_kwargs": { "switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1), "random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9), }, "transition_oversampling": trial.suggest_float( - "transition_oversampling", 0.9, 2.0 + "transition_oversampling", + 0.9, + 2.0, ), "reward_trainer_kwargs": { "epochs": trial.suggest_int("reward_trainer_epochs", 1, 11), @@ -135,14 +166,17 @@ def __call__( "rl": { "rl_kwargs": { "ent_coef": trial.suggest_float( - "rl_ent_coef", 1e-7, 1e-3, log=True + "rl_ent_coef", + 1e-7, + 1e-3, + log=True, ), }, }, }, ), pc_classic_control=RunSacredAsTrial( - sacred_ex=imitation.scripts.train_preference_comparisons.train_preference_comparisons_ex, + sacred_ex=train_pc_script.train_preference_comparisons_ex, suggest_named_configs=lambda _: ["reward.reward_ensemble"], suggest_config_updates=lambda trial: { "seed": trial.number, @@ -151,46 +185,69 @@ def __call__( "total_comparisons": 1000, "active_selection": True, "active_selection_oversampling": trial.suggest_int( - "active_selection_oversampling", 1, 11 + "active_selection_oversampling", + 1, + 11, ), "comparison_queue_size": trial.suggest_int( - "comparison_queue_size", 1, 1001 + "comparison_queue_size", + 1, + 1001, ), # upper bound determined by total_comparisons=1000 "exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5), "fragment_length": trial.suggest_int( - "fragment_length", 1, 201 + "fragment_length", + 1, + 201, ), # trajectories are 1000 steps long "gatherer_kwargs": { "temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0), "discount_factor": trial.suggest_float( - "gatherer_discount_factor", 0.95, 1.0 + "gatherer_discount_factor", + 0.95, + 1.0, ), "sample": trial.suggest_categorical("gatherer_sample", [True, False]), }, "initial_epoch_multiplier": trial.suggest_float( - "initial_epoch_multiplier", 1, 200.0 + "initial_epoch_multiplier", + 1, + 200.0, ), "initial_comparison_frac": trial.suggest_float( - "initial_comparison_frac", 0.01, 1.0 + "initial_comparison_frac", + 0.01, + 1.0, ), "num_iterations": trial.suggest_int("num_iterations", 1, 51), "preference_model_kwargs": { "noise_prob": trial.suggest_float( - "preference_model_noise_prob", 0.0, 0.1 + "preference_model_noise_prob", + 0.0, + 0.1, ), "discount_factor": trial.suggest_float( - "preference_model_discount_factor", 0.95, 1.0 + "preference_model_discount_factor", + 0.95, + 1.0, ), }, "query_schedule": trial.suggest_categorical( - "query_schedule", ["hyperbolic", "constant", "inverse_quadratic"] + "query_schedule", + [ + "hyperbolic", + "constant", + "inverse_quadratic", + ], ), "trajectory_generator_kwargs": { "switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1), "random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9), }, "transition_oversampling": trial.suggest_float( - "transition_oversampling", 0.9, 2.0 + "transition_oversampling", + 0.9, + 2.0, ), "reward_trainer_kwargs": { "epochs": trial.suggest_int("reward_trainer_epochs", 1, 11), @@ -198,7 +255,10 @@ def __call__( "rl": { "rl_kwargs": { "ent_coef": trial.suggest_float( - "rl_ent_coef", 1e-7, 1e-3, log=True + "rl_ent_coef", + 1e-7, + 1e-3, + log=True, ), }, }, @@ -217,11 +277,16 @@ def __call__( "rl": { "rl_kwargs": { "learning_rate": trial.suggest_float( - "learning_rate", 1e-6, 1e-2, log=True + "learning_rate", + 1e-6, + 1e-2, + log=True, ), "buffer_size": trial.suggest_int("buffer_size", 1000, 100000), "learning_starts": trial.suggest_int( - "learning_starts", 1000, 10000 + "learning_starts", + 1000, + 10000, ), "batch_size": trial.suggest_int("batch_size", 32, 128), "tau": trial.suggest_float("tau", 0.0, 1.0), @@ -229,16 +294,24 @@ def __call__( "train_freq": trial.suggest_int("train_freq", 1, 40), "gradient_steps": trial.suggest_int("gradient_steps", 1, 10), "target_update_interval": trial.suggest_int( - "target_update_interval", 1, 10000 + "target_update_interval", + 1, + 10000, ), "exploration_fraction": trial.suggest_float( - "exploration_fraction", 0.01, 0.5 + "exploration_fraction", + 0.01, + 0.5, ), "exploration_final_eps": trial.suggest_float( - "exploration_final_eps", 0.01, 1.0 + "exploration_final_eps", + 0.01, + 1.0, ), "exploration_initial_eps": trial.suggest_float( - "exploration_initial_eps", 0.01, 0.5 + "exploration_initial_eps", + 0.01, + 0.5, ), "max_grad_norm": trial.suggest_float("max_grad_norm", 0.1, 10.0), }, diff --git a/tuning/rerun_best_trial.py b/tuning/rerun_best_trial.py index d18ffb8ad..7b878a02e 100644 --- a/tuning/rerun_best_trial.py +++ b/tuning/rerun_best_trial.py @@ -1,7 +1,6 @@ """Script to re-run the best trials from a previous hyperparameter tuning run.""" import argparse import random -from typing import List, Tuple import hp_search_spaces import optuna @@ -11,7 +10,7 @@ def make_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Re-run the best trial from a previous tuning run.", - epilog=f"Example usage:\n" f"python rerun_best_trials.py tuning_run.json\n", + epilog="Example usage:\npython rerun_best_trials.py tuning_run.json\n", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( @@ -40,6 +39,12 @@ def infer_algo_name(study: optuna.Study) -> str: """Infer the algo name from the study name. Assumes that the study name is of the form "tuning_{algo}_with_{named_configs}". + + Args: + study: The optuna study. + + Returns: + algo name """ assert study.study_name.startswith("tuning_") assert "_with_" in study.study_name @@ -51,7 +56,7 @@ def main(): args = parser.parse_args() study: optuna.Study = optuna.load_study( storage=optuna.storages.JournalStorage( - optuna.storages.JournalFileStorage(args.journal_log) + optuna.storages.JournalFileStorage(args.journal_log), ), # in our case, we have one journal file per study so the study name can be # inferred @@ -73,7 +78,7 @@ def main(): ) if result.status != "COMPLETED": raise RuntimeError( - f"Trial failed with {result.fail_trace()} and status {result.status}." + f"Trial failed with {result.fail_trace()} and status {result.status}.", ) diff --git a/tuning/tune.py b/tuning/tune.py index 7a77d400d..4a2e710c5 100644 --- a/tuning/tune.py +++ b/tuning/tune.py @@ -14,7 +14,8 @@ def make_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Tune hyperparameters for imitation learning algorithms.", - epilog=f"Example usage:\n{example_usage}\n\nPossible named configs:\n{possible_named_configs}", + epilog=f"Example usage:\n{example_usage}\n\n" + f"Possible named configs:\n{possible_named_configs}", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( @@ -33,7 +34,10 @@ def make_parser() -> argparse.ArgumentParser: "Use this to select the environment to tune on.", ) parser.add_argument( - "--num_trials", type=int, default=100, help="Number of trials to run." + "--num_trials", + type=int, + default=100, + help="Number of trials to run.", ) parser.add_argument( "-j", @@ -49,7 +53,7 @@ def make_parser() -> argparse.ArgumentParser: def make_study(args: argparse.Namespace) -> optuna.Study: if args.journal_log is not None: storage = optuna.storages.JournalStorage( - optuna.storages.JournalFileStorage(args.journal_log) + optuna.storages.JournalFileStorage(args.journal_log), ) else: storage = None