Skip to content

Commit

Permalink
more example stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed May 16, 2024
1 parent 243c25e commit 5d46bac
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ on:
- development

env:
package-name: "ARLBench"
package-name: "arlbench"
test-dir: tests
extra-requires: "[dev,envpool]" # "" for no extra_requires

Expand Down
2 changes: 0 additions & 2 deletions examples/configs/base.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
defaults:
- _self_
- /cluster: local
- /algorithm: dqn
- /environment: cc_cartpole
- search_space: dqn
- /experiments: cc_cartpole_dqn

hydra:
run:
Expand Down
5 changes: 2 additions & 3 deletions examples/configs/random_search.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ defaults:

hydra:
sweeper:
n_trials: 256
n_trials: 16
search_space: ${search_space}
sweeper_kwargs:
max_parallelization: 0.1 # hence, only 25 jobs per batch
job_array_size_limit: 100
max_parallelization: 1 # run all of it at once
run:
dir: results/sobol/${algorithm}_${autorl.env_name}/${autorl.seed}
sweep:
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/smac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defaults:

hydra:
sweeper:
n_trials: 50
n_trials: 16
budget_variable: autorl.n_total_timesteps
search_space: ${search_space}
sweeper_kwargs:
Expand Down
13 changes: 3 additions & 10 deletions examples/run_arlbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import warnings
warnings.filterwarnings("ignore")
import csv
import logging
import sys
Expand All @@ -10,23 +12,17 @@
import hydra
import jax
from arlbench.arlbench import run_arlbench
from codecarbon import track_emissions
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig


@hydra.main(version_base=None, config_path="configs", config_name="base")
@track_emissions(offline=True, country_iso_code="DEU")
def execute(cfg: DictConfig):
"""Helper function for nice logging and error handling."""
logging.basicConfig(
filename="job.log", format="%(asctime)s %(message)s", filemode="w"
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.info("Logging configured")
logger.info(f"JAX devices: {jax.devices()}")
logger.info(f"JAX device count: {jax.local_device_count()}")
logger.info(f"JAX default backend: {jax.default_backend()}")

if cfg.jax_enable_x64:
logger.info("Enabling x64 support for JAX.")
Expand All @@ -40,8 +36,6 @@ def execute(cfg: DictConfig):

def run(cfg: DictConfig, logger: logging.Logger):
"""Console script for arlbench."""
logger.info("Starting run with config:")
logger.info(str(OmegaConf.to_yaml(cfg)))

# check if file done exists and if so, return
try:
Expand All @@ -52,7 +46,6 @@ def run(cfg: DictConfig, logger: logging.Logger):
csvreader = csv.reader(pf)
performance = next(csvreader)
performance = float(performance[0])
logger.info(f"Returning performance {performance}.")
return performance
except FileNotFoundError:
pass
Expand Down
4 changes: 4 additions & 0 deletions examples/run_reactive_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def run(cfg: DictConfig, logger: logging.Logger):
grad_norm, _ = statistics["grad_info"]

# If grad norm doesn't change much, spike the learning rate
if last_grad_norm is not None:
print(i)
print(abs(grad_norm - last_grad_norm))
if last_grad_norm is not None and abs(grad_norm - last_grad_norm) < tolerance:
last_lr = cfg.hp_config.learning_rate
cfg.hp_config.learning_rate *= 10
Expand All @@ -46,6 +49,7 @@ def run(cfg: DictConfig, logger: logging.Logger):
cfg.hp_config.learning_rate = last_lr
spiked = False
logger.info(f"Resetting learning rate to {cfg.hp_config.learning_rate}")
last_grad_norm = grad_norm
logger.info(f"Training finished with a total reward of {objectives['reward_mean']}")


Expand Down

0 comments on commit 5d46bac

Please sign in to comment.