Skip to content

Commit

Permalink
Add more benchmarking documentation (#822)
Browse files Browse the repository at this point in the history
* Embedd the benchmarking README in the Sphinx documentation.

* Add the benchmark summary from the release to the Sphinx documentation.

* Add link to benchmark results to README.md

* Fix formatting issues in docs/conf.py and improve inline comment.

* Fix typing issue due to new SB3 version.

* Black fixes.

* Improve formatting in benchmarking README.md
  • Loading branch information
ernestum authored Dec 4, 2023
1 parent 1af5e4d commit 928d576
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 4 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Currently, we have implementations of the algorithms below. 'Discrete' and 'Cont

You can find [the documentation here](https://imitation.readthedocs.io/en/latest/).

You can read the latest benchmark results [here](https://imitation.readthedocs.io/en/latest/main-concepts/benchmark_summary.html).

## Installation

### Prerequisites
Expand Down
6 changes: 3 additions & 3 deletions benchmarking/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Benchmarking imitation
# Benchmarking `imitation`

The imitation library is benchmarked by running the algorithms BC, DAgger, AIRL and GAIL
on five different environments from the
Expand Down Expand Up @@ -120,7 +120,7 @@ python sacred_output_to_csv.py output/sacred > summary.csv

This generates a csv file like this:

```csv
```
algo, env, score, expert_score
gail, seals/Walker2d-v1, 2298.883520464286, 2502.8930135576925
gail, seals/Swimmer-v1, 287.33667667857145, 295.40472964423077
Expand Down Expand Up @@ -187,7 +187,7 @@ where:
If `your_runs_dir` contains runs for more than one algorithm, you will have to
disambiguate using the `--algo` option.

# Tuning Hyperparameters
## Tuning Hyperparameters

The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`.
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
Expand Down
17 changes: 17 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@


# -- Project information -----------------------------------------------------
import io
import os
import urllib.request
import zipfile
from importlib import metadata

project = "imitation"
Expand Down Expand Up @@ -132,3 +135,17 @@ def setup(app):
"autodoc-process-docstring",
no_namedtuple_attrib_docstring,
)


# -- Download the latest benchmark summary -------------------------------------
download_url = (
"https://github.com/HumanCompatibleAI/imitation/releases/latest/"
"download/benchmark_runs.zip"
)

# Download the benchmark data, extract the summary and place it in the documentation
with urllib.request.urlopen(download_url) as url:
with zipfile.ZipFile(io.BytesIO(url.read())) as z:
with z.open("benchmark_runs/summary.md") as f:
with open("main-concepts/benchmark_summary.md", "wb") as out:
out.write(f.read())
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ If you use ``imitation`` in your research project, please cite our paper to help
main-concepts/trajectories
main-concepts/reward_networks
main-concepts/variable_horizon
main-concepts/benchmarks
main-concepts/benchmark_summary


.. toctree::
Expand Down
1 change: 1 addition & 0 deletions docs/main-concepts/benchmarks.md
3 changes: 2 additions & 1 deletion src/imitation/algorithms/mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch as th
from seals import base_envs
from stable_baselines3.common import policies
from stable_baselines3.common import type_aliases as sb3_types

from imitation.algorithms import base
from imitation.data import rollout, types
Expand Down Expand Up @@ -196,7 +197,7 @@ def set_pi(self, pi: np.ndarray) -> None:
assert np.all(pi >= 0), "policy has negative probabilities"
self.pi = pi

def _predict(self, observation: th.Tensor, deterministic: bool = False):
def _predict(self, observation: sb3_types.PyTorchObs, deterministic: bool = False):
raise NotImplementedError("Should never be called as predict overridden.")

def forward(
Expand Down

0 comments on commit 928d576

Please sign in to comment.