diff --git a/HISTORY.md b/HISTORY.md index a3752359d..e6a00869d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -16,12 +16,13 @@ - Removes cbar_orientaton and cbar_pad args for parallel_axes_plot - Add `rasterized` arg for heatmaps (#359) - Support 1D cvt_archive_heatmap ({pr}`362`) +- Add 3D plots for CVTArchive ({pr}`371`) #### Documentation - Use dask instead of multiprocessing for lunar lander tutorial ({pr}`346`) - pip install swig before gymnasium[box2d] in lunar lander tutorial ({pr}`346`) -- Fix lunar lander dependency issues ({pr}`366`, {pr}`377`) +- Fix lunar lander dependency issues ({pr}`366`, {pr}`367`) - Simplify DQD tutorial imports ({pr}`369`) #### Improvements diff --git a/ribs/visualize/__init__.py b/ribs/visualize/__init__.py index 0d4a4bde9..8e207721e 100644 --- a/ribs/visualize/__init__.py +++ b/ribs/visualize/__init__.py @@ -15,12 +15,14 @@ .. autosummary:: :toctree: - ribs.visualize.grid_archive_heatmap + ribs.visualize.cvt_archive_3d_plot ribs.visualize.cvt_archive_heatmap - ribs.visualize.sliding_boundaries_archive_heatmap + ribs.visualize.grid_archive_heatmap ribs.visualize.parallel_axes_plot ribs.visualize.qdax_repertoire_heatmap + ribs.visualize.sliding_boundaries_archive_heatmap """ +from ribs.visualize._cvt_archive_3d_plot import cvt_archive_3d_plot from ribs.visualize._cvt_archive_heatmap import cvt_archive_heatmap from ribs.visualize._grid_archive_heatmap import grid_archive_heatmap from ribs.visualize._parallel_axes_plot import parallel_axes_plot @@ -29,9 +31,10 @@ sliding_boundaries_archive_heatmap __all__ = [ - "grid_archive_heatmap", + "cvt_archive_3d_plot", "cvt_archive_heatmap", - "sliding_boundaries_archive_heatmap", + "grid_archive_heatmap", "parallel_axes_plot", "qdax_repertoire_heatmap", + "sliding_boundaries_archive_heatmap", ] diff --git a/ribs/visualize/_cvt_archive_3d_plot.py b/ribs/visualize/_cvt_archive_3d_plot.py new file mode 100644 index 000000000..2b96509f0 --- /dev/null +++ b/ribs/visualize/_cvt_archive_3d_plot.py @@ -0,0 +1,369 @@ +"""Provides cvt_archive_3d_plot.""" +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.cm import ScalarMappable +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +from scipy.spatial import Voronoi # pylint: disable=no-name-in-module + +from ribs.visualize._utils import (retrieve_cmap, set_cbar, + validate_heatmap_visual_args) + + +def cvt_archive_3d_plot( + archive, + ax=None, + *, + measure_order=None, + cmap="magma", + lw=0.5, + ec=(0.0, 0.0, 0.0, 0.1), + cell_alpha=1.0, + vmin=None, + vmax=None, + cbar="auto", + cbar_kwargs=None, + plot_elites=False, + elite_ms=100, + elite_alpha=0.5, + plot_centroids=False, + plot_samples=False, + ms=1, +): + """Plots a :class:`~ribs.archives.CVTArchive` with 3D measure space. + + This function relies on Matplotlib's `mplot3d + `_ toolkit. + By default, this function plots a 3D Voronoi diagram of the cells in the + archive and shades each cell based on its objective value. It is also + possible to plot a "wireframe" with only the cells' boundaries, along with a + dot inside each cell indicating its objective value. + + Depending on how many cells are in the archive, ``ms`` and ``lw`` may need + to be tuned. If there are too many cells, the Voronoi diagram and centroid + markers will make the entire image appear black. In that case, try turning + off the centroids with ``plot_centroids=False`` or even removing the lines + completely with ``lw=0``. + + Examples: + + .. plot:: + :context: close-figs + + 3D Plot with Solid Cells + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from ribs.archives import CVTArchive + >>> from ribs.visualize import cvt_archive_3d_plot + >>> # Populate the archive with the negative sphere function. + >>> archive = CVTArchive(solution_dim=2, + ... cells=500, + ... ranges=[(-2, 0), (-2, 0), (-2, 0)]) + >>> x = np.random.uniform(-2, 0, 5000) + >>> y = np.random.uniform(-2, 0, 5000) + >>> z = np.random.uniform(-2, 0, 5000) + >>> archive.add(solution_batch=np.stack((x, y), axis=1), + ... objective_batch=-(x**2 + y**2 + z**2), + ... measures_batch=np.stack((x, y, z), axis=1)) + >>> # Plot the archive. + >>> plt.figure(figsize=(8, 6)) + >>> cvt_archive_3d_plot(archive) + >>> plt.title("Negative sphere function with 3D measures") + >>> plt.show() + + .. plot:: + :context: close-figs + + 3D Plot with Translucent Cells + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from ribs.archives import CVTArchive + >>> from ribs.visualize import cvt_archive_3d_plot + >>> # Populate the archive with the negative sphere function. + >>> archive = CVTArchive(solution_dim=2, + ... cells=500, + ... ranges=[(-2, 0), (-2, 0), (-2, 0)]) + >>> x = np.random.uniform(-2, 0, 5000) + >>> y = np.random.uniform(-2, 0, 5000) + >>> z = np.random.uniform(-2, 0, 5000) + >>> archive.add(solution_batch=np.stack((x, y), axis=1), + ... objective_batch=-(x**2 + y**2 + z**2), + ... measures_batch=np.stack((x, y, z), axis=1)) + >>> # Plot the archive. + >>> plt.figure(figsize=(8, 6)) + >>> cvt_archive_3d_plot(archive, cell_alpha=0.1) + >>> plt.title("Negative sphere function with 3D measures") + >>> plt.show() + + .. plot:: + :context: close-figs + + 3D "Wireframe" (Shading Turned Off) + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from ribs.archives import CVTArchive + >>> from ribs.visualize import cvt_archive_3d_plot + >>> # Populate the archive with the negative sphere function. + >>> archive = CVTArchive(solution_dim=2, + ... cells=100, + ... ranges=[(-2, 0), (-2, 0), (-2, 0)]) + >>> x = np.random.uniform(-2, 0, 1000) + >>> y = np.random.uniform(-2, 0, 1000) + >>> z = np.random.uniform(-2, 0, 1000) + >>> archive.add(solution_batch=np.stack((x, y), axis=1), + ... objective_batch=-(x**2 + y**2 + z**2), + ... measures_batch=np.stack((x, y, z), axis=1)) + >>> # Plot the archive. + >>> plt.figure(figsize=(8, 6)) + >>> cvt_archive_3d_plot(archive, cell_alpha=0.0) + >>> plt.title("Negative sphere function with 3D measures") + >>> plt.show() + + .. plot:: + :context: close-figs + + 3D Wireframe with Elites as Scatter Plot + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from ribs.archives import CVTArchive + >>> from ribs.visualize import cvt_archive_3d_plot + >>> # Populate the archive with the negative sphere function. + >>> archive = CVTArchive(solution_dim=2, + ... cells=100, + ... ranges=[(-2, 0), (-2, 0), (-2, 0)]) + >>> x = np.random.uniform(-2, 0, 1000) + >>> y = np.random.uniform(-2, 0, 1000) + >>> z = np.random.uniform(-2, 0, 1000) + >>> archive.add(solution_batch=np.stack((x, y), axis=1), + ... objective_batch=-(x**2 + y**2 + z**2), + ... measures_batch=np.stack((x, y, z), axis=1)) + >>> # Plot the archive. + >>> plt.figure(figsize=(8, 6)) + >>> cvt_archive_3d_plot(archive, cell_alpha=0.0, plot_elites=True) + >>> plt.title("Negative sphere function with 3D measures") + >>> plt.show() + + Args: + archive (CVTArchive): A 3D :class:`~ribs.archives.CVTArchive`. + ax (matplotlib.axes.Axes): Axes on which to plot the heatmap. + If ``None``, we will create a new 3D axis. + measure_order (array-like of int): Specifies the axes order for plotting + the measures. By default, the first measure (measure 0) in the + archive appears on the x-axis, the second (measure 1) on y-axis, and + third (measure 2) on z-axis. This argument is an array of length 3 + that specifies which measure should appear on the x, y, and z axes. + For instance, [2, 1, 0] will put measure 2 on the x-axis, measure 1 + on the y-axis, and measure 0 on the z-axis. + cmap (str, list, matplotlib.colors.Colormap): The colormap to use when + plotting intensity. Either the name of a + :class:`~matplotlib.colors.Colormap`, a list of RGB or RGBA colors + (i.e. an :math:`N \\times 3` or :math:`N \\times 4` array), or a + :class:`~matplotlib.colors.Colormap` object. + lw (float): Line width when plotting the Voronoi diagram. + ec (matplotlib color): Edge color of the cells in the Voronoi diagram. + See `here + `_ for + more info on specifying colors in Matplotlib. + cell_alpha: Alpha value for the cell colors. Set to 1.0 for opaque + cells; set to 0.0 for fully transparent cells. + vmin (float): Minimum objective value to use in the plot. If ``None``, + the minimum objective value in the archive is used. + vmax (float): Maximum objective value to use in the plot. If ``None``, + the maximum objective value in the archive is used. + cbar ('auto', None, matplotlib.axes.Axes): By default, this is set to + ``'auto'`` which displays the colorbar on the archive's current + :class:`~matplotlib.axes.Axes`. If ``None``, then colorbar is not + displayed. If this is an :class:`~matplotlib.axes.Axes`, displays + the colorbar on the specified Axes. + cbar_kwargs (dict): Additional kwargs to pass to + :func:`~matplotlib.pyplot.colorbar`. + plot_elites (bool): If True, we will plot a scatter plot of the elites + in the archive. The elites will be colored according to their + objective value. + elite_ms (float): Marker size for plotting elites. + elite_alpha (float): Alpha value for plotting elites. + plot_centroids (bool): Whether to plot the cluster centroids. + plot_samples (bool): Whether to plot the samples used when generating + the clusters. + ms (float): Marker size for both centroids and samples. + + Raises: + ValueError: The archive's measure dimension must be 1D or 2D. + ValueError: ``measure_order`` is not a permutation of ``[0, 1, 2]``. + ValueError: ``plot_samples`` is passed in but the archive does not have + samples (e.g., due to using custom centroids during construction). + """ + # We don't have an aspect arg here so we just pass None. + validate_heatmap_visual_args( + None, cbar, archive.measure_dim, [3], + "This plot can only be made for a 3D CVTArchive") + + if plot_samples and archive.samples is None: + raise ValueError("Samples are not available for this archive, but " + "`plot_samples` was passed in.") + + # Try getting the colormap early in case it fails. + cmap = retrieve_cmap(cmap) + + # Retrieve data from archive. + df = archive.as_pandas() + objective_batch = df.objective_batch() + measures_batch = df.measures_batch() + lower_bounds = archive.lower_bounds + upper_bounds = archive.upper_bounds + centroids = archive.centroids + samples = archive.samples + + if measure_order is not None: + if sorted(measure_order) != [0, 1, 2]: + raise ValueError( + "measure_order should be a permutation of [0, 1, 2] but " + f"received {measure_order}") + measures_batch = measures_batch[:, measure_order] + lower_bounds = lower_bounds[measure_order] + upper_bounds = upper_bounds[measure_order] + centroids = centroids[:, measure_order] + samples = samples[:, measure_order] + + # Compute objective value range. + min_obj = np.min(objective_batch) if vmin is None else vmin + max_obj = np.max(objective_batch) if vmax is None else vmax + + # If the min and max are the same, we set a sensible default range. + if min_obj == max_obj: + min_obj, max_obj = min_obj - 0.01, max_obj + 0.01 + + # Default ax behavior. + if ax is None: + ax = plt.axes(projection="3d") + + ax.set_xlim(lower_bounds[0], upper_bounds[0]) + ax.set_ylim(lower_bounds[1], upper_bounds[1]) + ax.set_zlim(lower_bounds[2], upper_bounds[2]) + + # Create reflections of the points so that we can easily find the polygons + # at the edge of the Voronoi diagram. See here for the basic idea in 2D: + # https://stackoverflow.com/questions/28665491/getting-a-bounded-polygon-coordinates-from-voronoi-cells + # + # Note that this indeed results in us creating a Voronoi diagram with 7 + # times the cells we need. However, the Voronoi creation is still pretty + # fast. + # + # Note that the above StackOverflow approach proposes filtering the points + # after creating the Voronoi diagram by checking if they are outside the + # upper or lower bounds. We found that this approach works fine, but it + # requires subtracting an epsilon from the lower bounds and adding an + # epsilon to the upper bounds, to allow for some margin of error due to + # numerical stability. Otherwise, some of the edge polygons will be clipped. + # Below, we do not filter with this method; instead, we just check whether + # the point on each side of the ridge is one of the original centroids. + xmin, ymin, zmin = lower_bounds + xmax, ymax, zmax = upper_bounds + ( + xmin_reflec, + ymin_reflec, + zmin_reflec, + xmax_reflec, + ymax_reflec, + zmax_reflec, + ) = [centroids.copy() for _ in range(6)] + + xmin_reflec[:, 0] = xmin - (centroids[:, 0] - xmin) + ymin_reflec[:, 1] = ymin - (centroids[:, 1] - ymin) + zmin_reflec[:, 2] = zmin - (centroids[:, 2] - zmin) + xmax_reflec[:, 0] = xmax + (xmax - centroids[:, 0]) + ymax_reflec[:, 1] = ymax + (ymax - centroids[:, 1]) + zmax_reflec[:, 2] = zmax + (zmax - centroids[:, 2]) + + vor = Voronoi( + np.concatenate((centroids, xmin_reflec, ymin_reflec, zmin_reflec, + xmax_reflec, ymax_reflec, zmax_reflec))) + + # Collect the vertices of the ridges of each cell -- the boundary between + # two points in a Voronoi diagram is referred to as a ridge; in 3D, the + # ridge is a planar polygon; in 2D, the ridge is a line. + vertices = [] + objs = [] # Also record objective for each ridge so we can color it. + + # Map from centroid index to objective. + pt_to_obj = {elite.index: elite.objective for elite in archive} + + # The points in the Voronoi diagram are indexed by their placement in the + # input list. Above, when we called Voronoi, `centroids` were placed first, + # so the centroid points all have indices less than len(centroids). + max_centroid_idx = len(centroids) + + for ridge_points, ridge_vertices in zip(vor.ridge_points, + vor.ridge_vertices): + a, b = ridge_points + # Record the ridge. We are only interested in a ridge if it involves one + # of our centroid points, hence the check for max_idx. + # + # Note that we record the ridge twice if a and b are both valid points, + # so we end up plotting the same polygon twice. Unclear how to resolve + # this, but it seems to show up fine as is. + if a < max_centroid_idx: + vertices.append(vor.vertices[ridge_vertices]) + # NaN indicates the cell was not filled and thus had no objective. + objs.append(pt_to_obj.get(a, np.nan)) + if b < max_centroid_idx: + vertices.append(vor.vertices[ridge_vertices]) + objs.append(pt_to_obj.get(b, np.nan)) + + # Collect and normalize objs that need to be passed through cmap. + objs = np.asarray(objs) + cmap_idx = ~np.isnan(objs) + cmap_objs = objs[cmap_idx] + normalized_objs = np.clip( + (np.asarray(cmap_objs) - min_obj) / (max_obj - min_obj), 0.0, 1.0) + + # Create an array of facecolors in RGBA format that defaults to transparent + # white. + facecolors = np.full((len(objs), 4), [1.0, 1.0, 1.0, 0.0]) + + # Set colors based on objectives. Also set alpha, which is located in index + # 3 since this is RGBA format. + facecolors[cmap_idx] = cmap(normalized_objs) + facecolors[cmap_idx, 3] = cell_alpha + + ax.add_collection( + Poly3DCollection( + vertices, + edgecolor=[ec for _ in vertices], + facecolor=facecolors, + lw=lw, + )) + + if plot_elites: + ax.scatter(measures_batch[:, 0], + measures_batch[:, 1], + measures_batch[:, 2], + s=elite_ms, + c=objective_batch, + cmap=cmap, + vmin=vmin, + vmax=vmax, + lw=0.0, + alpha=elite_alpha) + if plot_samples: + ax.plot(samples[:, 0], + samples[:, 1], + samples[:, 2], + "o", + c="grey", + ms=ms) + if plot_centroids: + ax.plot(centroids[:, 0], + centroids[:, 1], + centroids[:, 2], + "o", + c="black", + ms=ms) + + # Create color bar. + mappable = ScalarMappable(cmap=cmap) + mappable.set_clim(min_obj, max_obj) + set_cbar(mappable, ax, cbar, cbar_kwargs) diff --git a/tests/visualize/args_test.py b/tests/visualize/args_test.py index 16d4fbbb7..11f5ed4ff 100644 --- a/tests/visualize/args_test.py +++ b/tests/visualize/args_test.py @@ -6,12 +6,13 @@ import pytest from ribs.archives import CVTArchive, GridArchive, SlidingBoundariesArchive -from ribs.visualize import (cvt_archive_heatmap, grid_archive_heatmap, +from ribs.visualize import (cvt_archive_3d_plot, cvt_archive_heatmap, + grid_archive_heatmap, sliding_boundaries_archive_heatmap) -@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding"]) -def test_heatmap_fails_on_unsupported_dims(archive_type): +@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding", "cvt_3d"]) +def test_fails_on_unsupported_dims(archive_type): archive = { "grid": lambda: GridArchive( @@ -26,6 +27,13 @@ def test_heatmap_fails_on_unsupported_dims(archive_type): "sliding": lambda: SlidingBoundariesArchive( solution_dim=2, dims=[20, 20, 20], ranges=[(-1, 1)] * 3), + "cvt_3d": + lambda: CVTArchive( + solution_dim=2, + cells=100, + ranges=[(-1, 1)] * 4, + samples=100, + ), }[archive_type]() with pytest.raises(ValueError): @@ -33,10 +41,11 @@ def test_heatmap_fails_on_unsupported_dims(archive_type): "grid": grid_archive_heatmap, "cvt": cvt_archive_heatmap, "sliding": sliding_boundaries_archive_heatmap, + "cvt_3d": cvt_archive_3d_plot, }[archive_type](archive) -@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding"]) +@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding", "cvt_3d"]) @pytest.mark.parametrize( "invalid_arg_cbar", ["None", 3.2, True, @@ -59,6 +68,13 @@ def test_heatmap_fails_on_invalid_cbar_option(archive_type, invalid_arg_cbar): dims=[20, 20, 20], ranges=[(-1, 1)] * 3, ), + "cvt_3d": + lambda: CVTArchive( + solution_dim=2, + cells=100, + ranges=[(-1, 1)] * 3, + samples=100, + ), }[archive_type]() with pytest.raises(ValueError): @@ -66,9 +82,12 @@ def test_heatmap_fails_on_invalid_cbar_option(archive_type, invalid_arg_cbar): "grid": grid_archive_heatmap, "cvt": cvt_archive_heatmap, "sliding": sliding_boundaries_archive_heatmap, + "cvt_3d": cvt_archive_3d_plot, }[archive_type](archive=archive, cbar=invalid_arg_cbar) +# Note: cvt_3d is not included because cvt_archive_3d_plot does not have an +# aspect parameter. @pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding"]) @pytest.mark.parametrize( "invalid_arg_aspect", diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/3d.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/3d.png new file mode 100644 index 000000000..833e7deb8 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/3d.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/3d_rect.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/3d_rect.png new file mode 100644 index 000000000..6030cba0e Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/3d_rect.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/3d_rect_reorder.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/3d_rect_reorder.png new file mode 100644 index 000000000..6f8bc8281 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/3d_rect_reorder.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/cell_alpha.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/cell_alpha.png new file mode 100644 index 000000000..e561c16b8 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/cell_alpha.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/coolwarm_cmap.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/coolwarm_cmap.png new file mode 100644 index 000000000..808043005 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/coolwarm_cmap.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/limits.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/limits.png new file mode 100644 index 000000000..4b6a64690 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/limits.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/listed_cmap.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/listed_cmap.png new file mode 100644 index 000000000..b2a9fee99 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/listed_cmap.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_centroids.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_centroids.png new file mode 100644 index 000000000..9c74c5403 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_centroids.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_elites.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_elites.png new file mode 100644 index 000000000..59c5b4b8e Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_elites.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_samples.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_samples.png new file mode 100644 index 000000000..9998fe7a6 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_samples.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/transparent.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/transparent.png new file mode 100644 index 000000000..933e82e2b Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/transparent.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/vmin_equals_vmax.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/vmin_equals_vmax.png new file mode 100644 index 000000000..edb6ab8d7 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/vmin_equals_vmax.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/voronoi_style.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/voronoi_style.png new file mode 100644 index 000000000..8086f695a Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/voronoi_style.png differ diff --git a/tests/visualize/cvt_archive_3d_plot_test.py b/tests/visualize/cvt_archive_3d_plot_test.py new file mode 100644 index 000000000..d91f8494d --- /dev/null +++ b/tests/visualize/cvt_archive_3d_plot_test.py @@ -0,0 +1,202 @@ +"""Tests for cvt_archive_3d_plot. + +See README.md for instructions on writing tests. +""" +import matplotlib.pyplot as plt +import numpy as np +import pytest +from matplotlib.testing.decorators import image_comparison + +from ribs.archives import CVTArchive +from ribs.visualize import cvt_archive_3d_plot + +from .conftest import add_uniform_sphere_3d + +# pylint: disable = redefined-outer-name + +# Tolerance for root mean square difference between the pixels of the images, +# where 255 is the max value. We have a pretty high tolerance for +# `cvt_archive_3d_plot` since 3D rendering tends to vary a bit. +CVT_IMAGE_TOLERANCE = 1.0 + +# +# Fixtures +# + + +@pytest.fixture(scope="module") +def cvt_archive_3d(): + """Deterministically-created CVTArchive.""" + ranges = np.array([(-1, 1), (-1, 1), (-1, 1)]) + archive = CVTArchive( + solution_dim=3, + cells=500, + ranges=ranges, + samples=10_000, + seed=42, + ) + add_uniform_sphere_3d(archive, *ranges) + return archive + + +@pytest.fixture(scope="module") +def cvt_archive_3d_rect(): + """Same as above, but the dimensions have different ranges.""" + ranges = [(-1, 1), (-2, 0), (1, 3)] + archive = CVTArchive( + solution_dim=3, + cells=500, + ranges=ranges, + samples=10_000, + seed=42, + ) + add_uniform_sphere_3d(archive, *ranges) + return archive + + +# +# Argument validation tests +# + + +def test_no_samples_error(): + # This archive has no samples since custom centroids were passed in. + archive = CVTArchive(solution_dim=2, + cells=2, + ranges=[(-1, 1), (-1, 1)], + custom_centroids=[[0, 0], [1, 1]]) + + # Thus, plotting samples on this archive should fail. + with pytest.raises(ValueError): + cvt_archive_3d_plot(archive, plot_samples=True) + + +# +# Tests on archive with (-1, 1) range. +# + + +@image_comparison(baseline_images=["3d"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_3d(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d) + + +@image_comparison(baseline_images=["3d"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_3d_custom_axis(cvt_archive_3d): + ax = plt.axes(projection="3d") + cvt_archive_3d_plot(cvt_archive_3d, ax=ax) + + +@image_comparison(baseline_images=["3d_rect"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_3d_rect(cvt_archive_3d_rect): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d_rect) + + +@image_comparison(baseline_images=["3d_rect_reorder"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_3d_rect_reorder(cvt_archive_3d_rect): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d_rect, measure_order=[1, 2, 0]) + + +@image_comparison(baseline_images=["limits"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_limits(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, vmin=-1.0, vmax=-0.5) + + +@image_comparison(baseline_images=["listed_cmap"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_listed_cmap(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, cmap=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + +@image_comparison(baseline_images=["coolwarm_cmap"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_coolwarm_cmap(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, cmap="coolwarm") + + +@image_comparison(baseline_images=["vmin_equals_vmax"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_vmin_equals_vmax(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, vmin=-0.95, vmax=-0.95) + + +@image_comparison(baseline_images=["plot_centroids"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_plot_centroids(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, plot_centroids=True, cell_alpha=0.1) + + +@image_comparison(baseline_images=["plot_samples"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_plot_samples(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, plot_samples=True, cell_alpha=0.1) + + +@image_comparison(baseline_images=["voronoi_style"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_voronoi_style(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, lw=3.0, ec="grey") + + +@image_comparison(baseline_images=["cell_alpha"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_cell_alpha(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, cell_alpha=0.1) + + +@image_comparison(baseline_images=["transparent"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_transparent(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, cell_alpha=0.0) + + +@image_comparison(baseline_images=["plot_elites"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_plot_elites(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + cvt_archive_3d_plot(cvt_archive_3d, cell_alpha=0.0, plot_elites=True) diff --git a/tests/visualize/cvt_archive_heatmap_test.py b/tests/visualize/cvt_archive_heatmap_test.py index 798b3619b..a1ee505cd 100644 --- a/tests/visualize/cvt_archive_heatmap_test.py +++ b/tests/visualize/cvt_archive_heatmap_test.py @@ -16,8 +16,8 @@ # pylint: disable = redefined-outer-name # Tolerance for root mean square difference between the pixels of the images, -# where 255 is the max value. We only have tolerance for `cvt_archive_heatmap` -# since it is a bit more finicky than the other plots. +# where 255 is the max value. We have tolerance for `cvt_archive_heatmap` since +# it is a bit more finicky than the other plots. CVT_IMAGE_TOLERANCE = 0.1 # @@ -95,7 +95,7 @@ def test_no_samples_error(): # Thus, plotting samples on this archive should fail. with pytest.raises(ValueError): - cvt_archive_heatmap(archive, lw=3.0, ec="grey", plot_samples=True) + cvt_archive_heatmap(archive, plot_samples=True) #