Skip to content

Commit

Permalink
More formatting fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Feb 29, 2024
1 parent 60fc75a commit e69dbd6
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 49 deletions.
7 changes: 3 additions & 4 deletions src/imitation/scripts/config/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ray.tune as tune
import sacred
from torch import nn

from imitation.algorithms import dagger as dagger_alg
from imitation.scripts.parallel import parallel_ex
Expand Down Expand Up @@ -200,11 +199,11 @@ def pc():
"config_updates": {
"active_selection_oversampling": tune.randint(1, 11),
"comparison_queue_size": tune.randint(
1, 1001
1, 1001,
), # upper bound determined by total_comparisons=1000
"exploration_frac": tune.uniform(0.0, 0.5),
"fragment_length": tune.randint(
1, 1001
1, 1001,
), # trajectories are 1000 steps long
"gatherer_kwargs": {
"temperature": tune.uniform(0.0, 2.0),
Expand All @@ -218,7 +217,7 @@ def pc():
"discount_factor": tune.uniform(0.95, 1.0),
},
"query_schedule": tune.choice(
["hyperbolic", "constant", "inverse_quadratic"]
["hyperbolic", "constant", "inverse_quadratic",]
),
"trajectory_generator_kwargs": {
"switch_prob": tune.uniform(0.1, 1),
Expand Down
149 changes: 111 additions & 38 deletions tuning/hp_search_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
"""

import dataclasses
from typing import Any, Callable, Dict, List, Mapping
from typing import Any, Callable, Dict, List, Mapping, Optional

import optuna
import sacred
import stable_baselines3 as sb3

import imitation.scripts.train_imitation
import imitation.scripts.train_preference_comparisons
import imitation.scripts.train_preference_comparisons as train_pc_script


@dataclasses.dataclass
Expand All @@ -42,19 +41,27 @@ class RunSacredAsTrial:
suggest_config_updates: Callable[[optuna.Trial], Mapping[str, Any]]

"""Command name to pass to sacred.run."""
command_name: str = None
command_name: Optional[str] = None

def __call__(
self, trial: optuna.Trial, run_options: Dict, extra_named_configs: List[str]
self,
trial: optuna.Trial,
run_options: Dict,
extra_named_configs: List[str],
) -> float:
"""Run the sacred experiment and return the performance.
Args:
trial: The optuna trial to sample hyperparameters for.
run_options: Options to pass to sacred.run(options=).
extra_named_configs: Additional named configs to pass to sacred.run.
"""
Returns:
The performance of the trial.
Raises:
RuntimeError: If the trial fails.
"""
config_updates = self.suggest_config_updates(trial)
named_configs = self.suggest_named_configs(trial) + extra_named_configs

Expand All @@ -71,15 +78,16 @@ def __call__(
)
if result.status != "COMPLETED":
raise RuntimeError(
f"Trial failed with {result.fail_trace()} and status {result.status}."
f"Trial failed with {result.fail_trace()} and status {result.status}.",
)
return result.result["imit_stats"]["monitor_return_mean"]


"""A mapping from algorithm names to functions that run the algorithm as an optuna trial."""
"""A mapping from algorithm names to functions that run the algorithm as an optuna
trial."""
objectives_by_algo = dict(
pc=RunSacredAsTrial(
sacred_ex=imitation.scripts.train_preference_comparisons.train_preference_comparisons_ex,
sacred_ex=train_pc_script.train_preference_comparisons_ex,
suggest_named_configs=lambda _: ["reward.reward_ensemble"],
suggest_config_updates=lambda trial: {
"seed": trial.number,
Expand All @@ -88,61 +96,87 @@ def __call__(
"total_comparisons": 1000,
"active_selection": True,
"active_selection_oversampling": trial.suggest_int(
"active_selection_oversampling", 1, 11
"active_selection_oversampling",
1,
11,
),
"comparison_queue_size": trial.suggest_int(
"comparison_queue_size", 1, 1001
"comparison_queue_size",
1,
1001,
), # upper bound determined by total_comparisons=1000
"exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5),
"fragment_length": trial.suggest_int(
"fragment_length", 1, 1001
"fragment_length",
1,
1001,
), # trajectories are 1000 steps long
"gatherer_kwargs": {
"temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0),
"discount_factor": trial.suggest_float(
"gatherer_discount_factor", 0.95, 1.0
"gatherer_discount_factor",
0.95,
1.0,
),
"sample": trial.suggest_categorical("gatherer_sample", [True, False]),
},
"initial_epoch_multiplier": trial.suggest_float(
"initial_epoch_multiplier", 1, 200.0
"initial_epoch_multiplier",
1,
200.0,
),
"initial_comparison_frac": trial.suggest_float(
"initial_comparison_frac", 0.01, 1.0
"initial_comparison_frac",
0.01,
1.0,
),
"num_iterations": trial.suggest_int("num_iterations", 1, 51),
"preference_model_kwargs": {
"noise_prob": trial.suggest_float(
"preference_model_noise_prob", 0.0, 0.1
"preference_model_noise_prob",
0.0,
0.1,
),
"discount_factor": trial.suggest_float(
"preference_model_discount_factor", 0.95, 1.0
"preference_model_discount_factor",
0.95,
1.0,
),
},
"query_schedule": trial.suggest_categorical(
"query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]
"query_schedule",
[
"hyperbolic",
"constant",
"inverse_quadratic",
],
),
"trajectory_generator_kwargs": {
"switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1),
"random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9),
},
"transition_oversampling": trial.suggest_float(
"transition_oversampling", 0.9, 2.0
"transition_oversampling",
0.9,
2.0,
),
"reward_trainer_kwargs": {
"epochs": trial.suggest_int("reward_trainer_epochs", 1, 11),
},
"rl": {
"rl_kwargs": {
"ent_coef": trial.suggest_float(
"rl_ent_coef", 1e-7, 1e-3, log=True
"rl_ent_coef",
1e-7,
1e-3,
log=True,
),
},
},
},
),
pc_classic_control=RunSacredAsTrial(
sacred_ex=imitation.scripts.train_preference_comparisons.train_preference_comparisons_ex,
sacred_ex=train_pc_script.train_preference_comparisons_ex,
suggest_named_configs=lambda _: ["reward.reward_ensemble"],
suggest_config_updates=lambda trial: {
"seed": trial.number,
Expand All @@ -151,54 +185,80 @@ def __call__(
"total_comparisons": 1000,
"active_selection": True,
"active_selection_oversampling": trial.suggest_int(
"active_selection_oversampling", 1, 11
"active_selection_oversampling",
1,
11,
),
"comparison_queue_size": trial.suggest_int(
"comparison_queue_size", 1, 1001
"comparison_queue_size",
1,
1001,
), # upper bound determined by total_comparisons=1000
"exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5),
"fragment_length": trial.suggest_int(
"fragment_length", 1, 201
"fragment_length",
1,
201,
), # trajectories are 1000 steps long
"gatherer_kwargs": {
"temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0),
"discount_factor": trial.suggest_float(
"gatherer_discount_factor", 0.95, 1.0
"gatherer_discount_factor",
0.95,
1.0,
),
"sample": trial.suggest_categorical("gatherer_sample", [True, False]),
},
"initial_epoch_multiplier": trial.suggest_float(
"initial_epoch_multiplier", 1, 200.0
"initial_epoch_multiplier",
1,
200.0,
),
"initial_comparison_frac": trial.suggest_float(
"initial_comparison_frac", 0.01, 1.0
"initial_comparison_frac",
0.01,
1.0,
),
"num_iterations": trial.suggest_int("num_iterations", 1, 51),
"preference_model_kwargs": {
"noise_prob": trial.suggest_float(
"preference_model_noise_prob", 0.0, 0.1
"preference_model_noise_prob",
0.0,
0.1,
),
"discount_factor": trial.suggest_float(
"preference_model_discount_factor", 0.95, 1.0
"preference_model_discount_factor",
0.95,
1.0,
),
},
"query_schedule": trial.suggest_categorical(
"query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]
"query_schedule",
[
"hyperbolic",
"constant",
"inverse_quadratic",
],
),
"trajectory_generator_kwargs": {
"switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1),
"random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9),
},
"transition_oversampling": trial.suggest_float(
"transition_oversampling", 0.9, 2.0
"transition_oversampling",
0.9,
2.0,
),
"reward_trainer_kwargs": {
"epochs": trial.suggest_int("reward_trainer_epochs", 1, 11),
},
"rl": {
"rl_kwargs": {
"ent_coef": trial.suggest_float(
"rl_ent_coef", 1e-7, 1e-3, log=True
"rl_ent_coef",
1e-7,
1e-3,
log=True,
),
},
},
Expand All @@ -217,28 +277,41 @@ def __call__(
"rl": {
"rl_kwargs": {
"learning_rate": trial.suggest_float(
"learning_rate", 1e-6, 1e-2, log=True
"learning_rate",
1e-6,
1e-2,
log=True,
),
"buffer_size": trial.suggest_int("buffer_size", 1000, 100000),
"learning_starts": trial.suggest_int(
"learning_starts", 1000, 10000
"learning_starts",
1000,
10000,
),
"batch_size": trial.suggest_int("batch_size", 32, 128),
"tau": trial.suggest_float("tau", 0.0, 1.0),
"gamma": trial.suggest_float("gamma", 0.9, 0.999),
"train_freq": trial.suggest_int("train_freq", 1, 40),
"gradient_steps": trial.suggest_int("gradient_steps", 1, 10),
"target_update_interval": trial.suggest_int(
"target_update_interval", 1, 10000
"target_update_interval",
1,
10000,
),
"exploration_fraction": trial.suggest_float(
"exploration_fraction", 0.01, 0.5
"exploration_fraction",
0.01,
0.5,
),
"exploration_final_eps": trial.suggest_float(
"exploration_final_eps", 0.01, 1.0
"exploration_final_eps",
0.01,
1.0,
),
"exploration_initial_eps": trial.suggest_float(
"exploration_initial_eps", 0.01, 0.5
"exploration_initial_eps",
0.01,
0.5,
),
"max_grad_norm": trial.suggest_float("max_grad_norm", 0.1, 10.0),
},
Expand Down
13 changes: 9 additions & 4 deletions tuning/rerun_best_trial.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Script to re-run the best trials from a previous hyperparameter tuning run."""
import argparse
import random
from typing import List, Tuple

import hp_search_spaces
import optuna
Expand All @@ -11,7 +10,7 @@
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Re-run the best trial from a previous tuning run.",
epilog=f"Example usage:\n" f"python rerun_best_trials.py tuning_run.json\n",
epilog="Example usage:\npython rerun_best_trials.py tuning_run.json\n",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
Expand Down Expand Up @@ -40,6 +39,12 @@ def infer_algo_name(study: optuna.Study) -> str:
"""Infer the algo name from the study name.
Assumes that the study name is of the form "tuning_{algo}_with_{named_configs}".
Args:
study: The optuna study.
Returns:
algo name
"""
assert study.study_name.startswith("tuning_")
assert "_with_" in study.study_name
Expand All @@ -51,7 +56,7 @@ def main():
args = parser.parse_args()
study: optuna.Study = optuna.load_study(
storage=optuna.storages.JournalStorage(
optuna.storages.JournalFileStorage(args.journal_log)
optuna.storages.JournalFileStorage(args.journal_log),
),
# in our case, we have one journal file per study so the study name can be
# inferred
Expand All @@ -73,7 +78,7 @@ def main():
)
if result.status != "COMPLETED":
raise RuntimeError(
f"Trial failed with {result.fail_trace()} and status {result.status}."
f"Trial failed with {result.fail_trace()} and status {result.status}.",
)


Expand Down
Loading

0 comments on commit e69dbd6

Please sign in to comment.