Skip to content

Commit

Permalink
More updates
Browse files Browse the repository at this point in the history
  • Loading branch information
btjanaka committed Sep 13, 2023
1 parent b2fed61 commit 417b6c1
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 32 deletions.
43 changes: 11 additions & 32 deletions ribs/visualize/_cvt_archive_3d_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,16 @@ def cvt_archive_3d_plot(
*,
measure_order=None,
cmap="magma",
aspect=None,
lw=0.5,
ec=(0.0, 0.0, 0.0, 0.1),
cell_alpha=1.0,
vmin=None,
vmax=None,
cbar="auto",
cbar_kwargs=None,
# TODO
rasterized=False,
plot_centroids=False,
plot_samples=False,
ms=1,
pcm_kwargs=None,
):
"""Plots a :class:`~ribs.archives.CVTArchive` with 3D measure space.
Expand Down Expand Up @@ -88,15 +85,13 @@ def cvt_archive_3d_plot(
:class:`~matplotlib.colors.Colormap`, a list of RGB or RGBA colors
(i.e. an :math:`N \\times 3` or :math:`N \\times 4` array), or a
:class:`~matplotlib.colors.Colormap` object.
aspect ('auto', 'equal', float): The aspect ratio of the heatmap (i.e.
height/width). Defaults to ``'auto'`` for 2D and ``0.5`` for 1D.
``'equal'`` is the same as ``aspect=1``. See
:meth:`matplotlib.axes.Axes.set_aspect` for more info.
lw (float): Line width when plotting the Voronoi diagram.
ec (matplotlib color): Edge color of the cells in the Voronoi diagram.
See `here
<https://matplotlib.org/stable/tutorials/colors/colors.html>`_ for
more info on specifying colors in Matplotlib.
cell_alpha: Alpha value for the cell colors. Set to 1.0 for opaque
cells; set to 0.0 for fully transparent cells.
vmin (float): Minimum objective value to use in the plot. If ``None``,
the minimum objective value in the archive is used.
vmax (float): Maximum objective value to use in the plot. If ``None``,
Expand All @@ -108,41 +103,26 @@ def cvt_archive_3d_plot(
the colorbar on the specified Axes.
cbar_kwargs (dict): Additional kwargs to pass to
:func:`~matplotlib.pyplot.colorbar`.
# TODO
rasterized (bool): Whether to rasterize the heatmap. This can be useful
for saving to a vector format like PDF. Essentially, only the
heatmap will be converted to a raster graphic so that the archive
cells will not have to be individually rendered. Meanwhile, the
surrounding axes, particularly text labels, will remain in vector
format.
plot_centroids (bool): Whether to plot the cluster centroids.
plot_samples (bool): Whether to plot the samples used when generating
the clusters.
ms (float): Marker size for both centroids and samples.
pcm_kwargs (dict): Additional kwargs to pass to
:func:`~matplotlib.pyplot.pcolormesh`. Only applicable to 1D
heatmaps. linewidth and edgecolor are set with the ``lw`` and
``ec`` args.
Raises:
ValueError: The archive's measure dimension must be 1D or 2D.
ValueError: ``measure_order`` is not a permutation of ``[0, 1, 2]``.
ValueError: ``plot_samples`` is passed in but the archive does not have
samples (e.g., due to using custom centroids during construction).
"""
# We don't have an aspect arg here so we just pass None.
validate_heatmap_visual_args(
aspect, cbar, archive.measure_dim, [3],
None, cbar, archive.measure_dim, [3],
"This plot can only be made for a 3D CVTArchive")

if plot_samples and archive.samples is None:
raise ValueError("Samples are not available for this archive, but "
"`plot_samples` was passed in.")

if aspect is None:
aspect = "auto"

# Try getting the colormap early in case it fails.
cmap = retrieve_cmap(cmap)

Expand Down Expand Up @@ -182,9 +162,6 @@ def cvt_archive_3d_plot(
ax.set_ylim(lower_bounds[1], upper_bounds[1])
ax.set_zlim(lower_bounds[2], upper_bounds[2])

# TODO: aspect?
ax.set_aspect(aspect)

# Create reflections of the points so that we can easily find the polygons
# at the edge of the Voronoi diagram. See here for the basic idea in 2D:
# https://stackoverflow.com/questions/28665491/getting-a-bounded-polygon-coordinates-from-voronoi-cells
Expand Down Expand Up @@ -261,19 +238,21 @@ def cvt_archive_3d_plot(
normalized_objs = np.clip(
(np.asarray(cmap_objs) - min_obj) / (max_obj - min_obj), 0.0, 1.0)

# TODO: facecolor alpha

# Create an array of facecolors that defaults to transparent white.
# Create an array of facecolors in RGBA format that defaults to transparent
# white.
facecolors = np.full((len(objs), 4), [1.0, 1.0, 1.0, 0.0])

# Set colors based on objectives. Also set alpha, which is located in index
# 3 since this is RGBA format.
facecolors[cmap_idx] = cmap(normalized_objs)
facecolors[cmap_idx, 3] = cell_alpha

ax.add_collection(
Poly3DCollection(
vertices,
edgecolor=[ec for _ in vertices],
facecolor=facecolors,
lw=lw,
rasterized=rasterized,
))

# Plot the sample points and centroids.
Expand Down
174 changes: 174 additions & 0 deletions tests/visualize/cvt_archive_3d_plot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Tests for cvt_archive_3d_plot.
See README.md for instructions on writing tests.
"""
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.testing.decorators import image_comparison

from ribs.archives import CVTArchive
from ribs.visualize import cvt_archive_3d_plot

from .conftest import add_uniform_sphere_3d

# pylint: disable = redefined-outer-name

# Tolerance for root mean square difference between the pixels of the images,
# where 255 is the max value. We have a pretty high tolerance for
# `cvt_archive_3d_plot` since 3D rendering tends to vary a bit.
CVT_IMAGE_TOLERANCE = 1.0

#
# Fixtures
#


@pytest.fixture(scope="module")
def cvt_archive_3d():
"""Deterministically-created CVTArchive."""
archive = CVTArchive(
solution_dim=3,
cells=500,
ranges=np.array([(-1, 1), (-1, 1), (-1, 1)]),
samples=10_000,
seed=42,
)
add_uniform_sphere_3d(archive, (-1, 1), (-1, 1), (-1, 1))
return archive


#
# Argument validation tests
#


def test_no_samples_error():
# This archive has no samples since custom centroids were passed in.
archive = CVTArchive(solution_dim=2,
cells=2,
ranges=[(-1, 1), (-1, 1)],
custom_centroids=[[0, 0], [1, 1]])

# Thus, plotting samples on this archive should fail.
with pytest.raises(ValueError):
cvt_archive_3d_plot(archive, plot_samples=True)


#
# Tests on archive with (-1, 1) range.
#


@image_comparison(baseline_images=["3d"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_3d(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d)


@image_comparison(baseline_images=["3d"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_3d_custom_axis(cvt_archive_3d):
ax = plt.axes(projection="3d")
cvt_archive_3d_plot(cvt_archive_3d, ax=ax)


# @image_comparison(baseline_images=["2d_long"],
# remove_text=False,
# extensions=["png"],
# tol=CVT_IMAGE_TOLERANCE)
# def test_2d_long(cvt_archive_2d_long):
# plt.figure(figsize=(8, 6))
# cvt_archive_heatmap(cvt_archive_2d_long)

# @image_comparison(baseline_images=["2d_long_square"],
# remove_text=False,
# extensions=["png"],
# tol=CVT_IMAGE_TOLERANCE)
# def test_2d_long_square(cvt_archive_2d_long):
# plt.figure(figsize=(8, 6))
# cvt_archive_heatmap(cvt_archive_2d_long, aspect="equal")

# @image_comparison(baseline_images=["2d_long_transpose"],
# remove_text=False,
# extensions=["png"],
# tol=CVT_IMAGE_TOLERANCE)
# def test_2d_long_transpose(cvt_archive_2d_long):
# plt.figure(figsize=(8, 6))
# cvt_archive_heatmap(cvt_archive_2d_long, transpose_measures=True)

# @image_comparison(baseline_images=["limits"],
# remove_text=False,
# extensions=["png"],
# tol=CVT_IMAGE_TOLERANCE)
# def test_limits(cvt_archive_2d):
# plt.figure(figsize=(8, 6))
# cvt_archive_heatmap(cvt_archive_2d, vmin=-1.0, vmax=-0.5)


@image_comparison(baseline_images=["listed_cmap"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_listed_cmap(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, cmap=[[1, 0, 0], [0, 1, 0], [0, 0, 1]])


@image_comparison(baseline_images=["coolwarm_cmap"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_coolwarm_cmap(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, cmap="coolwarm")


@image_comparison(baseline_images=["vmin_equals_vmax"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_vmin_equals_vmax(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, vmin=-0.95, vmax=-0.95)


@image_comparison(baseline_images=["plot_centroids"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_plot_centroids(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, plot_centroids=True, cell_alpha=0.1)


@image_comparison(baseline_images=["plot_samples"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_plot_samples(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, plot_samples=True, cell_alpha=0.1)


@image_comparison(baseline_images=["voronoi_style"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_voronoi_style(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, lw=3.0, ec="grey")


@image_comparison(baseline_images=["cell_alpha"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cell_alpha(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, cell_alpha=0.1)

0 comments on commit 417b6c1

Please sign in to comment.