Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix solution retrieval in lunar lander eval #457

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# History

## 0.7.1 (Forthcoming)

### Changelog

#### Bugs

- Fix solution retrieval in lunar lander eval ({pr}`457`)

## 0.7.0

To learn about this release, see our page on What's New in v0.7.0:
Expand Down
7 changes: 4 additions & 3 deletions examples/lunar_lander.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
import tqdm
from dask.distributed import Client, LocalCluster

from ribs.archives import GridArchive
from ribs.archives import ArchiveDataFrame, GridArchive
from ribs.emitters import EvolutionStrategyEmitter
from ribs.schedulers import Scheduler
from ribs.visualize import grid_archive_heatmap
Expand Down Expand Up @@ -318,7 +318,8 @@ def run_evaluation(outdir, env_seed):
retrieve the archive and save videos.
env_seed (int): Seed for the environment.
"""
df = pd.read_csv(outdir / "archive.csv")
df = ArchiveDataFrame(pd.read_csv(outdir / "archive.csv"))
solutions = df.get_field("solution")
indices = np.random.permutation(len(df))[:10]

# Use a single env so that all the videos go to the same directory.
Expand All @@ -330,7 +331,7 @@ def run_evaluation(outdir, env_seed):
)

for idx in indices:
model = np.array(df.loc[idx, "solution_0":])
model = solutions[idx]
reward, impact_x_pos, impact_y_vel = simulate(model, env_seed,
video_env)
print(f"=== Index {idx} ===\n"
Expand Down
Loading