Skip to content

Commit

Permalink
Add 3D plots for CVTArchive (#371)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

This PR introduces a visualization function that plots CVTArchive in 3D.
There are a couple of variations, including:

**Solid cells:**


![3d](https://github.com/icaros-usc/pyribs/assets/38124174/0f6e4247-b054-49c0-b90f-79c7365ca9a7)

**Translucent cells:**


![cell_alpha](https://github.com/icaros-usc/pyribs/assets/38124174/957ede30-8269-410a-adac-e4aa1c463fab)

**Wireframe (transparent cells):**


![transparent](https://github.com/icaros-usc/pyribs/assets/38124174/1ec70793-1d93-456b-8726-2865f551c464)

**Wireframe with scatterplot:**


![plot_elites](https://github.com/icaros-usc/pyribs/assets/38124174/a101cefd-7387-4114-83c9-26efba29fd34)

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Prototype different plots
- [x] Decide on API for activating different variations
- [x] Tests
- [x] Documentation examples

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka authored Sep 14, 2023
1 parent 8efa0f2 commit 251a9b0
Show file tree
Hide file tree
Showing 19 changed files with 606 additions and 12 deletions.
3 changes: 2 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
- Removes cbar_orientaton and cbar_pad args for parallel_axes_plot
- Add `rasterized` arg for heatmaps (#359)
- Support 1D cvt_archive_heatmap ({pr}`362`)
- Add 3D plots for CVTArchive ({pr}`371`)

#### Documentation

- Use dask instead of multiprocessing for lunar lander tutorial ({pr}`346`)
- pip install swig before gymnasium[box2d] in lunar lander tutorial ({pr}`346`)
- Fix lunar lander dependency issues ({pr}`366`, {pr}`377`)
- Fix lunar lander dependency issues ({pr}`366`, {pr}`367`)
- Simplify DQD tutorial imports ({pr}`369`)

#### Improvements
Expand Down
11 changes: 7 additions & 4 deletions ribs/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
.. autosummary::
:toctree:
ribs.visualize.grid_archive_heatmap
ribs.visualize.cvt_archive_3d_plot
ribs.visualize.cvt_archive_heatmap
ribs.visualize.sliding_boundaries_archive_heatmap
ribs.visualize.grid_archive_heatmap
ribs.visualize.parallel_axes_plot
ribs.visualize.qdax_repertoire_heatmap
ribs.visualize.sliding_boundaries_archive_heatmap
"""
from ribs.visualize._cvt_archive_3d_plot import cvt_archive_3d_plot
from ribs.visualize._cvt_archive_heatmap import cvt_archive_heatmap
from ribs.visualize._grid_archive_heatmap import grid_archive_heatmap
from ribs.visualize._parallel_axes_plot import parallel_axes_plot
Expand All @@ -29,9 +31,10 @@
sliding_boundaries_archive_heatmap

__all__ = [
"grid_archive_heatmap",
"cvt_archive_3d_plot",
"cvt_archive_heatmap",
"sliding_boundaries_archive_heatmap",
"grid_archive_heatmap",
"parallel_axes_plot",
"qdax_repertoire_heatmap",
"sliding_boundaries_archive_heatmap",
]
369 changes: 369 additions & 0 deletions ribs/visualize/_cvt_archive_3d_plot.py

Large diffs are not rendered by default.

27 changes: 23 additions & 4 deletions tests/visualize/args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import pytest

from ribs.archives import CVTArchive, GridArchive, SlidingBoundariesArchive
from ribs.visualize import (cvt_archive_heatmap, grid_archive_heatmap,
from ribs.visualize import (cvt_archive_3d_plot, cvt_archive_heatmap,
grid_archive_heatmap,
sliding_boundaries_archive_heatmap)


@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding"])
def test_heatmap_fails_on_unsupported_dims(archive_type):
@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding", "cvt_3d"])
def test_fails_on_unsupported_dims(archive_type):
archive = {
"grid":
lambda: GridArchive(
Expand All @@ -26,17 +27,25 @@ def test_heatmap_fails_on_unsupported_dims(archive_type):
"sliding":
lambda: SlidingBoundariesArchive(
solution_dim=2, dims=[20, 20, 20], ranges=[(-1, 1)] * 3),
"cvt_3d":
lambda: CVTArchive(
solution_dim=2,
cells=100,
ranges=[(-1, 1)] * 4,
samples=100,
),
}[archive_type]()

with pytest.raises(ValueError):
{
"grid": grid_archive_heatmap,
"cvt": cvt_archive_heatmap,
"sliding": sliding_boundaries_archive_heatmap,
"cvt_3d": cvt_archive_3d_plot,
}[archive_type](archive)


@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding"])
@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding", "cvt_3d"])
@pytest.mark.parametrize(
"invalid_arg_cbar",
["None", 3.2, True,
Expand All @@ -59,16 +68,26 @@ def test_heatmap_fails_on_invalid_cbar_option(archive_type, invalid_arg_cbar):
dims=[20, 20, 20],
ranges=[(-1, 1)] * 3,
),
"cvt_3d":
lambda: CVTArchive(
solution_dim=2,
cells=100,
ranges=[(-1, 1)] * 3,
samples=100,
),
}[archive_type]()

with pytest.raises(ValueError):
{
"grid": grid_archive_heatmap,
"cvt": cvt_archive_heatmap,
"sliding": sliding_boundaries_archive_heatmap,
"cvt_3d": cvt_archive_3d_plot,
}[archive_type](archive=archive, cbar=invalid_arg_cbar)


# Note: cvt_3d is not included because cvt_archive_3d_plot does not have an
# aspect parameter.
@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding"])
@pytest.mark.parametrize(
"invalid_arg_aspect",
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
202 changes: 202 additions & 0 deletions tests/visualize/cvt_archive_3d_plot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""Tests for cvt_archive_3d_plot.
See README.md for instructions on writing tests.
"""
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.testing.decorators import image_comparison

from ribs.archives import CVTArchive
from ribs.visualize import cvt_archive_3d_plot

from .conftest import add_uniform_sphere_3d

# pylint: disable = redefined-outer-name

# Tolerance for root mean square difference between the pixels of the images,
# where 255 is the max value. We have a pretty high tolerance for
# `cvt_archive_3d_plot` since 3D rendering tends to vary a bit.
CVT_IMAGE_TOLERANCE = 1.0

#
# Fixtures
#


@pytest.fixture(scope="module")
def cvt_archive_3d():
"""Deterministically-created CVTArchive."""
ranges = np.array([(-1, 1), (-1, 1), (-1, 1)])
archive = CVTArchive(
solution_dim=3,
cells=500,
ranges=ranges,
samples=10_000,
seed=42,
)
add_uniform_sphere_3d(archive, *ranges)
return archive


@pytest.fixture(scope="module")
def cvt_archive_3d_rect():
"""Same as above, but the dimensions have different ranges."""
ranges = [(-1, 1), (-2, 0), (1, 3)]
archive = CVTArchive(
solution_dim=3,
cells=500,
ranges=ranges,
samples=10_000,
seed=42,
)
add_uniform_sphere_3d(archive, *ranges)
return archive


#
# Argument validation tests
#


def test_no_samples_error():
# This archive has no samples since custom centroids were passed in.
archive = CVTArchive(solution_dim=2,
cells=2,
ranges=[(-1, 1), (-1, 1)],
custom_centroids=[[0, 0], [1, 1]])

# Thus, plotting samples on this archive should fail.
with pytest.raises(ValueError):
cvt_archive_3d_plot(archive, plot_samples=True)


#
# Tests on archive with (-1, 1) range.
#


@image_comparison(baseline_images=["3d"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_3d(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d)


@image_comparison(baseline_images=["3d"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_3d_custom_axis(cvt_archive_3d):
ax = plt.axes(projection="3d")
cvt_archive_3d_plot(cvt_archive_3d, ax=ax)


@image_comparison(baseline_images=["3d_rect"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_3d_rect(cvt_archive_3d_rect):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d_rect)


@image_comparison(baseline_images=["3d_rect_reorder"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_3d_rect_reorder(cvt_archive_3d_rect):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d_rect, measure_order=[1, 2, 0])


@image_comparison(baseline_images=["limits"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_limits(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, vmin=-1.0, vmax=-0.5)


@image_comparison(baseline_images=["listed_cmap"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_listed_cmap(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, cmap=[[1, 0, 0], [0, 1, 0], [0, 0, 1]])


@image_comparison(baseline_images=["coolwarm_cmap"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_coolwarm_cmap(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, cmap="coolwarm")


@image_comparison(baseline_images=["vmin_equals_vmax"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_vmin_equals_vmax(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, vmin=-0.95, vmax=-0.95)


@image_comparison(baseline_images=["plot_centroids"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_plot_centroids(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, plot_centroids=True, cell_alpha=0.1)


@image_comparison(baseline_images=["plot_samples"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_plot_samples(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, plot_samples=True, cell_alpha=0.1)


@image_comparison(baseline_images=["voronoi_style"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_voronoi_style(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, lw=3.0, ec="grey")


@image_comparison(baseline_images=["cell_alpha"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cell_alpha(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, cell_alpha=0.1)


@image_comparison(baseline_images=["transparent"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_transparent(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, cell_alpha=0.0)


@image_comparison(baseline_images=["plot_elites"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_plot_elites(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, cell_alpha=0.0, plot_elites=True)
6 changes: 3 additions & 3 deletions tests/visualize/cvt_archive_heatmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# pylint: disable = redefined-outer-name

# Tolerance for root mean square difference between the pixels of the images,
# where 255 is the max value. We only have tolerance for `cvt_archive_heatmap`
# since it is a bit more finicky than the other plots.
# where 255 is the max value. We have tolerance for `cvt_archive_heatmap` since
# it is a bit more finicky than the other plots.
CVT_IMAGE_TOLERANCE = 0.1

#
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_no_samples_error():

# Thus, plotting samples on this archive should fail.
with pytest.raises(ValueError):
cvt_archive_heatmap(archive, lw=3.0, ec="grey", plot_samples=True)
cvt_archive_heatmap(archive, plot_samples=True)


#
Expand Down

0 comments on commit 251a9b0

Please sign in to comment.