diff --git a/README.md b/README.md index 4bec2e89d..d359871a2 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ Currently, we have implementations of the algorithms below. 'Discrete' and 'Cont You can find [the documentation here](https://imitation.readthedocs.io/en/latest/). +You can read the latest benchmark results [here](https://imitation.readthedocs.io/en/latest/main-concepts/benchmark_summary.html). + ## Installation ### Prerequisites diff --git a/benchmarking/README.md b/benchmarking/README.md index e98c72cdd..5566a684c 100644 --- a/benchmarking/README.md +++ b/benchmarking/README.md @@ -1,4 +1,4 @@ -# Benchmarking imitation +# Benchmarking `imitation` The imitation library is benchmarked by running the algorithms BC, DAgger, AIRL and GAIL on five different environments from the @@ -120,7 +120,7 @@ python sacred_output_to_csv.py output/sacred > summary.csv This generates a csv file like this: -```csv +``` algo, env, score, expert_score gail, seals/Walker2d-v1, 2298.883520464286, 2502.8930135576925 gail, seals/Swimmer-v1, 287.33667667857145, 295.40472964423077 @@ -187,7 +187,7 @@ where: If `your_runs_dir` contains runs for more than one algorithm, you will have to disambiguate using the `--algo` option. -# Tuning Hyperparameters +## Tuning Hyperparameters The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`. The benchmarking hyperparameter configs were generated by tuning the hyperparameters using diff --git a/docs/conf.py b/docs/conf.py index 60b018e30..852f8de72 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,7 +18,10 @@ # -- Project information ----------------------------------------------------- +import io import os +import urllib.request +import zipfile from importlib import metadata project = "imitation" @@ -132,3 +135,17 @@ def setup(app): "autodoc-process-docstring", no_namedtuple_attrib_docstring, ) + + +# -- Download the latest benchmark summary ------------------------------------- +download_url = ( + "https://github.com/HumanCompatibleAI/imitation/releases/latest/" + "download/benchmark_runs.zip" +) + +# Download the benchmark data, extract the summary and place it in the documentation +with urllib.request.urlopen(download_url) as url: + with zipfile.ZipFile(io.BytesIO(url.read())) as z: + with z.open("benchmark_runs/summary.md") as f: + with open("main-concepts/benchmark_summary.md", "wb") as out: + out.write(f.read()) diff --git a/docs/index.rst b/docs/index.rst index 73cfa4974..3b3a9e1be 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,6 +57,8 @@ If you use ``imitation`` in your research project, please cite our paper to help main-concepts/trajectories main-concepts/reward_networks main-concepts/variable_horizon + main-concepts/benchmarks + main-concepts/benchmark_summary .. toctree:: diff --git a/docs/main-concepts/benchmarks.md b/docs/main-concepts/benchmarks.md new file mode 120000 index 000000000..c01e09764 --- /dev/null +++ b/docs/main-concepts/benchmarks.md @@ -0,0 +1 @@ +../../benchmarking/README.md \ No newline at end of file diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index 21ebd8642..26ecdc5cc 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -26,6 +26,7 @@ import torch as th from seals import base_envs from stable_baselines3.common import policies +from stable_baselines3.common import type_aliases as sb3_types from imitation.algorithms import base from imitation.data import rollout, types @@ -196,7 +197,7 @@ def set_pi(self, pi: np.ndarray) -> None: assert np.all(pi >= 0), "policy has negative probabilities" self.pi = pi - def _predict(self, observation: th.Tensor, deterministic: bool = False): + def _predict(self, observation: sb3_types.PyTorchObs, deterministic: bool = False): raise NotImplementedError("Should never be called as predict overridden.") def forward(