Skip to content

Commit

Permalink
feat: Evaluation script
Browse files Browse the repository at this point in the history
  • Loading branch information
becktepe committed Jun 4, 2024
1 parent 73ba82d commit 7acaa6b
Show file tree
Hide file tree
Showing 53 changed files with 982 additions and 108 deletions.
11 changes: 9 additions & 2 deletions arlbench/autorl/autorl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
import numpy as np
import pandas as pd

from arlbench.core.algorithms import (DQN, PPO, SAC, Algorithm, AlgorithmState,
TrainResult, TrainReturnT)
from arlbench.core.algorithms import (
DQN,
PPO,
SAC,
Algorithm,
AlgorithmState,
TrainResult,
TrainReturnT,
)
from arlbench.core.environments import make_env
from arlbench.utils import config_space_to_gymnasium_space

Expand Down
5 changes: 3 additions & 2 deletions arlbench/autorl/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
from flashbax.buffers.prioritised_trajectory_buffer import \
PrioritisedTrajectoryBufferState
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBufferState,
)
from flashbax.buffers.sum_tree import SumTreeState
from flashbax.vault import Vault
from flax.core.frozen_dict import FrozenDict
Expand Down
35 changes: 27 additions & 8 deletions arlbench/core/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
from collections.abc import Callable
from typing import Optional, Union

from flashbax.buffers.prioritised_trajectory_buffer import \
PrioritisedTrajectoryBufferState
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBufferState,
)

from .algorithm import Algorithm
from .dqn import (DQN, DQNMetrics, DQNRunnerState, DQNState, DQNTrainingResult,
DQNTrainReturnT)
from .ppo import (PPO, PPOMetrics, PPORunnerState, PPOState, PPOTrainingResult,
PPOTrainReturnT)
from .sac import (SAC, SACMetrics, SACRunnerState, SACState, SACTrainingResult,
SACTrainReturnT)
from .dqn import (
DQN,
DQNMetrics,
DQNRunnerState,
DQNState,
DQNTrainingResult,
DQNTrainReturnT,
)
from .ppo import (
PPO,
PPOMetrics,
PPORunnerState,
PPOState,
PPOTrainingResult,
PPOTrainReturnT,
)
from .sac import (
SAC,
SACMetrics,
SACRunnerState,
SACState,
SACTrainingResult,
SACTrainReturnT,
)

TrainResult = Union[DQNTrainingResult, PPOTrainingResult, SACTrainingResult]
TrainMetrics = Union[DQNMetrics, PPOMetrics, SACMetrics]
Expand Down
9 changes: 6 additions & 3 deletions arlbench/core/algorithms/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from flashbax import utils
from flashbax.buffers import sum_tree
from flashbax.buffers.prioritised_trajectory_buffer import (
Experience, PrioritisedTrajectoryBufferSample,
PrioritisedTrajectoryBufferState, _get_sample_trajectories,
get_invalid_indices)
Experience,
PrioritisedTrajectoryBufferSample,
PrioritisedTrajectoryBufferState,
_get_sample_trajectories,
get_invalid_indices,
)
from flashbax.buffers.trajectory_buffer import calculate_uniform_item_indices

if TYPE_CHECKING:
Expand Down
10 changes: 8 additions & 2 deletions arlbench/core/algorithms/dqn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .dqn import (DQN, DQNMetrics, DQNRunnerState, DQNState, DQNTrainingResult,
DQNTrainReturnT)
from .dqn import (
DQN,
DQNMetrics,
DQNRunnerState,
DQNState,
DQNTrainingResult,
DQNTrainReturnT,
)

__all__ = [
"DQN",
Expand Down
22 changes: 15 additions & 7 deletions arlbench/core/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,31 @@
import jax.numpy as jnp
import numpy as np
import optax
from ConfigSpace import (Categorical, Configuration, ConfigurationSpace,
EqualsCondition, Float, Integer)
from ConfigSpace import (
Categorical,
Configuration,
ConfigurationSpace,
EqualsCondition,
Float,
Integer,
)
from flax.training.train_state import TrainState

from arlbench.core import running_statistics
from arlbench.core.algorithms.algorithm import Algorithm
from arlbench.core.algorithms.buffers import uniform_sample
from arlbench.core.algorithms.common import TimeStep
from arlbench.core.algorithms.prioritised_item_buffer import \
make_prioritised_item_buffer
from arlbench.core.algorithms.prioritised_item_buffer import (
make_prioritised_item_buffer,
)

from .models import CNNQ, MLPQ

if TYPE_CHECKING:
import chex
from flashbax.buffers.prioritised_trajectory_buffer import \
PrioritisedTrajectoryBufferState
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBufferState,
)
from flax.core.frozen_dict import FrozenDict

from arlbench.core.environments import Environment
Expand Down Expand Up @@ -780,7 +788,7 @@ def do_update(
buffer_state (PrioritisedTrajectoryBufferState): Buffer state.
Returns:
tuple[chex.PRNGKey, DQNTrainState, PrioritisedTrajectoryBufferState, DQNMetrics]:
tuple[chex.PRNGKey, DQNTrainState, PrioritisedTrajectoryBufferState, DQNMetrics]:
Random number generator key, training state, buffer state, and metrics.
"""

Expand Down
10 changes: 8 additions & 2 deletions arlbench/core/algorithms/ppo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .ppo import (PPO, PPOMetrics, PPORunnerState, PPOState, PPOTrainingResult,
PPOTrainReturnT)
from .ppo import (
PPO,
PPOMetrics,
PPORunnerState,
PPOState,
PPOTrainingResult,
PPOTrainReturnT,
)

__all__ = [
"PPO",
Expand Down
13 changes: 6 additions & 7 deletions arlbench/core/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import jax.numpy as jnp
import numpy as np
import optax
from ConfigSpace import (Categorical, Configuration, ConfigurationSpace, Float,
Integer)
from ConfigSpace import Categorical, Configuration, ConfigurationSpace, Float, Integer
from flax.training.train_state import TrainState

from arlbench.core import running_statistics
Expand Down Expand Up @@ -646,13 +645,13 @@ def _update_epoch(
"""One epoch of network updates using minibatches of the current transition batch.
Args:
update_state (tuple[PPOTrainState, Transition, jnp.ndarray, jnp.ndarray, chex.PRNGKey]):
update_state (tuple[PPOTrainState, Transition, jnp.ndarray, jnp.ndarray, chex.PRNGKey]):
(train_state, transition_batch, advantages, targets, rng) Current update state.
_ (None): Unused parameter (required for jax.lax.scan).
Returns:
tuple[tuple[PPOTrainState, Transition, jnp.ndarray, jnp.ndarray, chex.PRNGKey],
tuple[tuple | None, tuple | None]]: Tuple of (train_state, transition_batch,
tuple[tuple | None, tuple | None]]: Tuple of (train_state, transition_batch,
advantages, targets, rng) and (loss, grads) if tracked.
"""
train_state, traj_batch, advantages, targets, rng = update_state
Expand Down Expand Up @@ -715,11 +714,11 @@ def _update_minibatch(
Args:
train_state (PPOTrainState): PPO training state.
batch_info (tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]):
batch_info (tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]):
Minibatch of transitions, advantages and targets.
Returns:
tuple[PPOTrainState, tuple[tuple | None, tuple | None]]:
tuple[PPOTrainState, tuple[tuple | None, tuple | None]]:
Tuple of PPO train state and (loss, grads) if tracked.
"""
traj_batch, advantages, targets = batch_info
Expand Down Expand Up @@ -749,7 +748,7 @@ def _loss_fn(
targets (jnp.ndarray): Targets.
Returns:
tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
Tuple of (total_loss, (value_loss, actor_loss, entropy)).
"""
# Rerun network
Expand Down
9 changes: 6 additions & 3 deletions arlbench/core/algorithms/prioritised_item_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
from flashbax.buffers.item_buffer import validate_item_buffer_args
from flashbax.buffers.prioritised_flat_buffer import validate_priority_exponent
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBuffer, PrioritisedTrajectoryBufferSample,
PrioritisedTrajectoryBufferState, make_prioritised_trajectory_buffer,
validate_device)
PrioritisedTrajectoryBuffer,
PrioritisedTrajectoryBufferSample,
PrioritisedTrajectoryBufferState,
make_prioritised_trajectory_buffer,
validate_device,
)
from flashbax.utils import add_dim_to_args

if TYPE_CHECKING:
Expand Down
10 changes: 8 additions & 2 deletions arlbench/core/algorithms/sac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .sac import (SAC, SACMetrics, SACRunnerState, SACState, SACTrainingResult,
SACTrainReturnT)
from .sac import (
SAC,
SACMetrics,
SACRunnerState,
SACState,
SACTrainingResult,
SACTrainReturnT,
)

__all__ = [
"SAC",
Expand Down
54 changes: 34 additions & 20 deletions arlbench/core/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,38 @@
import jax.numpy as jnp
import numpy as np
import optax
from ConfigSpace import (Categorical, Configuration, ConfigurationSpace,
EqualsCondition, Float, Integer)
from ConfigSpace import (
Categorical,
Configuration,
ConfigurationSpace,
EqualsCondition,
Float,
Integer,
)
from flax.training.train_state import TrainState

from arlbench.core import running_statistics
from arlbench.core.algorithms.algorithm import Algorithm
from arlbench.core.algorithms.buffers import uniform_sample
from arlbench.core.algorithms.common import TimeStep
from arlbench.core.algorithms.prioritised_item_buffer import \
make_prioritised_item_buffer

from .models import (AlphaCoef, SACCNNActor, SACCNNCritic, SACMLPActor,
SACMLPCritic, SACVectorCritic)
from arlbench.core.algorithms.prioritised_item_buffer import (
make_prioritised_item_buffer,
)

from .models import (
AlphaCoef,
SACCNNActor,
SACCNNCritic,
SACMLPActor,
SACMLPCritic,
SACVectorCritic,
)

if TYPE_CHECKING:
import chex
from flashbax.buffers.prioritised_trajectory_buffer import \
PrioritisedTrajectoryBufferState
from flashbax.buffers.prioritised_trajectory_buffer import (
PrioritisedTrajectoryBufferState,
)
from flax.core.frozen_dict import FrozenDict

from arlbench.core.environments import Environment
Expand Down Expand Up @@ -134,14 +148,14 @@ def __init__(
Args:
hpo_config (Configuration): Hyperparameter configuration.
env (Environment | AutoRLWrapper): Training environment.
eval_env (Environment | AutoRLWrapper | None, optional): Evaluation environent
eval_env (Environment | AutoRLWrapper | None, optional): Evaluation environent
(otherwise training environment is used for evaluation). Defaults to None.
cnn_policy (bool, optional): Use CNN network architecture. Defaults to False.
nas_config (Configuration | None, optional): Neural architecture
nas_config (Configuration | None, optional): Neural architecture
configuration. Defaults to None.
track_trajectories (bool, optional): Track metrics such as loss and gradients
track_trajectories (bool, optional): Track metrics such as loss and gradients
during training. Defaults to False.
track_metrics (bool, optional): Track trajectories during training.
track_metrics (bool, optional): Track trajectories during training.
Defaults to False.
"""
if nas_config is None:
Expand Down Expand Up @@ -610,7 +624,7 @@ def update_critic(
rng (chex.PRNGKey): Random number generator key.
Returns:
tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]:
tuple[SACTrainState, jnp.ndarray, jnp.ndarray, FrozenDict, chex.PRNGKey]:
Updated training state and metrics.
"""
rng, action_rng = jax.random.split(rng, 2)
Expand Down Expand Up @@ -690,7 +704,7 @@ def actor_loss(
alpha_params (FrozenDict): Alpha network parameters.
Returns:
tuple[jnp.ndarray, jnp.ndarray]: Update training state and metrics.
tuple[jnp.ndarray, jnp.ndarray]: Update training state and metrics.
"""
pi = self.actor_network.apply(actor_params, experience.last_obs)
actor_actions, log_prob = pi.sample_and_log_prob(seed=action_rng)
Expand Down Expand Up @@ -756,7 +770,7 @@ def _update_step(
_ (None): Unused parameter.
Returns:
tuple[ tuple[SACRunnerState, PrioritisedTrajectoryBufferState],
tuple[ tuple[SACRunnerState, PrioritisedTrajectoryBufferState],
tuple[SACMetrics | None, Transition | None], ]: Updated training state and metrics.
"""

Expand Down Expand Up @@ -785,7 +799,7 @@ def do_update(
buffer_state (PrioritisedTrajectoryBufferState): Buffer state.
Returns:
tuple[ chex.PRNGKey, SACTrainState, SACTrainState, SACTrainState,
tuple[ chex.PRNGKey, SACTrainState, SACTrainState, SACTrainState,
PrioritisedTrajectoryBufferState, SACMetrics]: Updated training states and metrics.
"""

Expand All @@ -811,12 +825,12 @@ def gradient_step(
"""Perform a gradient update step.
Args:
carry (tuple[ chex.PRNGKey, SACTrainState, SACTrainState,
carry (tuple[ chex.PRNGKey, SACTrainState, SACTrainState,
SACTrainState, PrioritisedTrajectoryBufferState, ]): Carry for jax.lax.scan():
_ (None): Unused parameter.
Returns:
tuple[ tuple[ chex.PRNGKey, SACTrainState, SACTrainState, SACTrainState,
tuple[ tuple[ chex.PRNGKey, SACTrainState, SACTrainState, SACTrainState,
PrioritisedTrajectoryBufferState, ], SACMetrics, ]: Updated training states and metrics.
"""
(
Expand Down Expand Up @@ -944,7 +958,7 @@ def dont_update(
buffer_state (PrioritisedTrajectoryBufferState): Buffer state.
Returns:
tuple[ chex.PRNGKey, SACTrainState, SACTrainState, SACTrainState,
tuple[ chex.PRNGKey, SACTrainState, SACTrainState, SACTrainState,
PrioritisedTrajectoryBufferState, SACMetrics]: Input training states and metrics.
"""
single_loss = jnp.array(
Expand Down
11 changes: 8 additions & 3 deletions arlbench/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from .common import (config_space_to_gymnasium_space, config_space_to_yaml,
gymnasium_space_to_gymnax_space, recursive_concat,
save_defaults_to_yaml, tuple_concat)
from .common import (
config_space_to_gymnasium_space,
config_space_to_yaml,
gymnasium_space_to_gymnax_space,
recursive_concat,
save_defaults_to_yaml,
tuple_concat,
)

__all__ = [
"config_space_to_gymnasium_space",
Expand Down
2 changes: 1 addition & 1 deletion docs/basic_usage/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ The following objectives are available at the moment:
- reward_mean: the mean evaluation reward across a number of evaluation episodes
- reward_std: the standard deviation of the evaluation rewards across a number of evaluation episodes
- runtime: the runtime of the training process
- emissions: the CO2 emissions of the training process
- emissions: the CO2 emissions of the training process, tracked using `CodeCarbon <https://github.com/mlco2/codecarbon>`_ (which does not currently support ARM)
Binary file modified docs/images/structure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/subsets.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 7acaa6b

Please sign in to comment.