diff --git a/HISTORY.md b/HISTORY.md index e93c7ddd7..09c06c537 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -10,6 +10,7 @@ - Add visualization of QDax repertoires (#353) - Improve cvt_archive_heatmap flexibility (#354) - Speed up 2D cvt_archive_heatmap by order of magnitude (#355) +- Clip Voronoi regions in cvt_archive_heatmap (#356) #### Documentation diff --git a/docs/conf.py b/docs/conf.py index 15de98353..fc47c1985 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -256,4 +256,5 @@ "scipy": ("https://docs.scipy.org/doc/scipy/", None), "sklearn": ("https://scikit-learn.org/stable/", None), "qdax": ("https://qdax.readthedocs.io/en/latest/", None), + "shapely": ("https://shapely.readthedocs.io/en/stable/", None), } diff --git a/pinned_reqs/extras_visualize.txt b/pinned_reqs/extras_visualize.txt index fa0a7c0af..e4e8448bc 100644 --- a/pinned_reqs/extras_visualize.txt +++ b/pinned_reqs/extras_visualize.txt @@ -12,3 +12,4 @@ scikit-learn==1.3.0 scipy==1.10.1 threadpoolctl==3.0.0 matplotlib==3.7.2 +shapely==2.0.1 diff --git a/ribs/visualize.py b/ribs/visualize.py index 809e7693f..9e23d3d07 100644 --- a/ribs/visualize.py +++ b/ribs/visualize.py @@ -23,6 +23,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import shapely from matplotlib.cm import ScalarMappable from scipy.spatial import Voronoi # pylint: disable=no-name-in-module @@ -270,6 +271,7 @@ def cvt_archive_heatmap(archive, vmax=None, cbar="auto", cbar_kwargs=None, + clip=False, plot_centroids=False, plot_samples=False, ms=1): @@ -343,6 +345,15 @@ def cvt_archive_heatmap(archive, the colorbar on the specified Axes. cbar_kwargs (dict): Additional kwargs to pass to :func:`~matplotlib.pyplot.colorbar`. + clip (bool, shapely.Polygon): Clip the heatmap cells to a given polygon. + By default, we draw the cells along the outer edges of the heatmap + as polygons that extend beyond the archive bounds, but these + polygons are hidden because we set the axis limits to be the archive + bounds. Passing `clip=True` will clip the heatmap such that these + "outer edge" polygons are within the archive bounds. An arbitrary + polygon can also be passed in to clip the heatmap to a custom shape. + See `#356 `_ for more + info. plot_centroids (bool): Whether to plot the cluster centroids. plot_samples (bool): Whether to plot the samples used when generating the clusters. @@ -376,6 +387,11 @@ def cvt_archive_heatmap(archive, upper_bounds = np.flip(upper_bounds) centroids = np.flip(centroids, axis=1) + # If clip is on, make it default to an archive bounding box. + if clip and not isinstance(clip, shapely.Polygon): + clip = shapely.box(lower_bounds[0], lower_bounds[1], upper_bounds[0], + upper_bounds[1]) + if plot_samples: samples = archive.samples if transpose_measures: @@ -443,18 +459,36 @@ def cvt_archive_heatmap(archive, if -1 in region or len(region) == 0: continue - if objective is None: - # Transparent white (RGBA format) -- this ensures that if a figure - # is saved with a transparent background, the empty cells will also - # be transparent. - facecolors.append(np.array([1.0, 1.0, 1.0, 0.0])) - facecolor_cmap_mask.append(False) + if clip: + # Clip the cell vertices to the polygon. Clipping may cause some + # cells to split into two or more polygons, especially if the clip + # polygon has holes. + polygon = shapely.Polygon(vor.vertices[region]) + intersection = polygon.intersection(clip) + if isinstance(intersection, shapely.MultiPolygon): + for polygon in intersection.geoms: + vertices.append(polygon.exterior.coords) + n_splits = len(intersection.geoms) + else: + # The intersection is a single Polygon. + vertices.append(intersection.exterior.coords) + n_splits = 1 else: - facecolors.append(np.empty(4)) - facecolor_cmap_mask.append(True) - facecolor_objs.append(objective) - - vertices.append(vor.vertices[region]) + vertices.append(vor.vertices[region]) + n_splits = 1 + + # Repeat values for each split. + for _ in range(n_splits): + if objective is None: + # Transparent white (RGBA format) -- this ensures that if a + # figure is saved with a transparent background, the empty cells + # will also be transparent. + facecolors.append(np.array([1.0, 1.0, 1.0, 0.0])) + facecolor_cmap_mask.append(False) + else: + facecolors.append(np.empty(4)) + facecolor_cmap_mask.append(True) + facecolor_objs.append(objective) # Compute facecolors from the cmap. We first normalize the objectives and # clip them to [0, 1]. diff --git a/setup.py b/setup.py index 9d8e1e86a..efa6f03e5 100644 --- a/setup.py +++ b/setup.py @@ -3,10 +3,10 @@ from setuptools import find_packages, setup -with open("README.md") as readme_file: +with open("README.md", encoding="utf-8") as readme_file: readme = readme_file.read() -with open("HISTORY.md") as history_file: +with open("HISTORY.md", encoding="utf-8") as history_file: history = history_file.read() # NOTE: Update pinned_reqs whenever install_requires or extras_require changes. @@ -24,12 +24,16 @@ ] extras_require = { - "visualize": ["matplotlib>=3.0.0",], + "visualize": [ + "matplotlib>=3.0.0", + "shapely>=2.0.0", + ], # All dependencies except for dev. Don't worry if there are duplicate # dependencies, since setuptools automatically handles duplicates. "all": [ ### visualize ### "matplotlib>=3.0.0", + "shapely>=2.0.0", ], "dev": [ "pip>=20.3", diff --git a/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_clip.png b/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_clip.png new file mode 100644 index 000000000..4db4550e3 Binary files /dev/null and b/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_clip.png differ diff --git a/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_clip_polygon.png b/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_clip_polygon.png new file mode 100644 index 000000000..9ea02525a Binary files /dev/null and b/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_clip_polygon.png differ diff --git a/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_clip_polygon_with_hole.png b/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_clip_polygon_with_hole.png new file mode 100644 index 000000000..5009fb475 Binary files /dev/null and b/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_clip_polygon_with_hole.png differ diff --git a/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_noclip.png b/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_noclip.png new file mode 100644 index 000000000..6bfd32d8d Binary files /dev/null and b/tests/visualize/baseline_images/visualize_test/cvt_archive_heatmap_noclip.png differ diff --git a/tests/visualize/visualize_test.py b/tests/visualize/visualize_test.py index 1c8bd0b3a..11d95a764 100644 --- a/tests/visualize/visualize_test.py +++ b/tests/visualize/visualize_test.py @@ -16,6 +16,7 @@ import matplotlib.pyplot as plt import numpy as np import pytest +import shapely from matplotlib.testing.decorators import image_comparison from ribs.archives import CVTArchive, GridArchive, SlidingBoundariesArchive @@ -673,6 +674,99 @@ def test_cvt_archive_heatmap_voronoi_style(cvt_archive): cvt_archive_heatmap(cvt_archive, lw=3.0, ec="grey") +# +# cvt_archive_heatmap clip tests +# + + +@image_comparison(baseline_images=["cvt_archive_heatmap_noclip"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_cvt_archive_heatmap_noclip(cvt_archive): + plt.figure(figsize=(8, 6)) + cvt_archive_heatmap(cvt_archive, clip=False) + plt.xlim(-1.5, 1.5) + plt.ylim(-1.5, 1.5) + + +@image_comparison(baseline_images=["cvt_archive_heatmap_clip"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_cvt_archive_heatmap_clip(cvt_archive): + plt.figure(figsize=(8, 6)) + cvt_archive_heatmap(cvt_archive, clip=True) + plt.xlim(-1.5, 1.5) + plt.ylim(-1.5, 1.5) + + +@image_comparison(baseline_images=["cvt_archive_heatmap_clip_polygon"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_cvt_archive_heatmap_clip_polygon(cvt_archive): + plt.figure(figsize=(8, 6)) + cvt_archive_heatmap( + cvt_archive, + clip=shapely.Polygon(shell=np.array([ + [-0.75, -0.375], + [-0.75, 0.375], + [-0.375, 0.75], + [0.375, 0.75], + [0.75, 0.375], + [0.75, -0.375], + [0.375, -0.75], + [-0.375, -0.75], + ]),), + ) + plt.xlim(-1.5, 1.5) + plt.ylim(-1.5, 1.5) + + +@image_comparison( + baseline_images=["cvt_archive_heatmap_clip_polygon_with_hole"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_cvt_archive_heatmap_clip_polygon_with_hole(cvt_archive): + """This test will force some cells to be split in two.""" + plt.figure(figsize=(8, 6)) + cvt_archive_heatmap( + cvt_archive, + clip=shapely.Polygon( + shell=np.array([ + [-0.75, -0.375], + [-0.75, 0.375], + [-0.375, 0.75], + [0.375, 0.75], + [0.75, 0.375], + [0.75, -0.375], + [0.375, -0.75], + [-0.375, -0.75], + ]), + holes=[ + # Two holes that split some cells into two parts, and some cells + # into three parts. + np.array([ + [-0.5, 0], + [-0.5, 0.05], + [0.5, 0.05], + [0.5, 0], + ]), + np.array([ + [-0.5, 0.125], + [-0.5, 0.175], + [0.5, 0.175], + [0.5, 0.125], + ]), + ], + ), + ) + plt.xlim(-1.5, 1.5) + plt.ylim(-1.5, 1.5) + + # # Parallel coordinate plot test #