From 7048c364f9898d3e5d0987ae4c10b6faa316d576 Mon Sep 17 00:00:00 2001 From: Bryon Tjanaka <38124174+btjanaka@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:46:33 -0700 Subject: [PATCH] Speed up 2D cvt_archive_heatmap by order of magnitude (#355) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Currently, cvt_archive_heatmap plots individual polygons via ax.fill . We can speed this up by instead using a [`PolyCollection`](https://matplotlib.org/stable/api/collections_api.html#matplotlib.collections.PolyCollection) to add all the polygons at once. This is similar to using a `PatchCollection` as shown here: https://matplotlib.org/stable/gallery/shapes_and_collections/patch_collection.html. Benchmark for plotting a `CVTArchive` with 10,000 cells: - Before: 14.9 sec - After: 0.6 sec I used the following code to benchmark the implementation: ```python """Driver for cvt heatmap experiments.""" import time import fire import matplotlib.pyplot as plt import numpy as np from ribs.archives import CVTArchive from ribs.visualize import cvt_archive_heatmap def main(n_cells=10000): """Creates the archive and plots it.""" np.random.seed(42) archive = CVTArchive( solution_dim=3, cells=n_cells, ranges=[(-1, 1), (-1, 1)], custom_centroids=np.random.uniform(-1, 1, (n_cells, 2)), ) archive.add( np.random.uniform(-1, 1, (20000, 3)), np.random.standard_normal(20000), np.random.uniform(-1, 1, (20000, 2)), ) plt.figure(figsize=(8, 6)) start_time = time.time() cvt_archive_heatmap(archive) print("Plot time", time.time() - start_time) plt.savefig("cvt.png") if __name__ == "__main__": fire.Fire(main) ``` ## TODO - [x] Speed up 2D polygon plotting by using matplotlib PolyCollection — note that I initially used a PatchCollection with individual Polygon patches, but PolyCollection is much faster because we do not have to construct the individual `Polygon` patches in Python. - [x] Compute facecolors in a batch instead of individually - [x] Fix test errors — it seems the images changed slightly due to the new implementation, so we now allow a slight tolerance for cvt heatmap images ## Questions ## Status - [x] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [x] I have formatted my code using `yapf` - [x] I have tested my code by running `pytest` - [x] I have linted my code with `pylint` - [x] I have added a one-line description of my change to the changelog in `HISTORY.md` - [x] This PR is ready to go --- HISTORY.md | 1 + ribs/visualize.py | 69 +++++++++++++++------ tests/visualize/visualize_test.py | 41 ++++++++---- tests/visualize_qdax/visualize_qdax_test.py | 3 +- 4 files changed, 81 insertions(+), 33 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index b8b1bfed8..e93c7ddd7 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -9,6 +9,7 @@ - Drop Python 3.7 support and upgrade dependencies (#350) - Add visualization of QDax repertoires (#353) - Improve cvt_archive_heatmap flexibility (#354) +- Speed up 2D cvt_archive_heatmap by order of magnitude (#355) #### Documentation diff --git a/ribs/visualize.py b/ribs/visualize.py index 41960c5c9..809e7693f 100644 --- a/ribs/visualize.py +++ b/ribs/visualize.py @@ -422,27 +422,56 @@ def cvt_archive_heatmap(archive, if min_obj == max_obj: min_obj, max_obj = min_obj - 0.01, max_obj + 0.01 - # Shade the regions. - # - # Note: by default, the first region will be an empty list -- see: - # https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.Voronoi.html - # However, this empty region is ignored by ax.fill since `polygon` is also - # an empty list in this case. + # Vertices of all cells. + vertices = [] + # The facecolor of each cell. Shape (n_regions, 4) for RGBA format, but we + # do not know n_regions in advance. + facecolors = [] + # Boolean array indicating which of the facecolors needs to be computed with + # the cmap. The other colors correspond to empty cells. Shape (n_regions,) + facecolor_cmap_mask = [] + # The objective corresponding to the regions which must be passed through + # the cmap. Shape (sum(facecolor_cmap_mask),) + facecolor_objs = [] + + # Cycle through the regions to set up polygon vertices and facecolors. for region, objective in zip(vor.regions, region_obj): - # This check is O(n), but n is typically small, and creating - # `polygon` is also O(n) anyway. - if -1 not in region: - if objective is None: - # Transparent white (RGBA format) -- this ensures that if a - # figure is saved with a transparent background, the empty cells - # will also be transparent. - color = (1.0, 1.0, 1.0, 0.0) - else: - normalized_obj = np.clip( - (objective - min_obj) / (max_obj - min_obj), 0.0, 1.0) - color = cmap(normalized_obj) - polygon = vor.vertices[region] - ax.fill(*zip(*polygon), color=color, ec=ec, lw=lw) + # Checking for -1 is O(n), but n is typically small. + # + # We check length since the first region is an empty list by default: + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.Voronoi.html + if -1 in region or len(region) == 0: + continue + + if objective is None: + # Transparent white (RGBA format) -- this ensures that if a figure + # is saved with a transparent background, the empty cells will also + # be transparent. + facecolors.append(np.array([1.0, 1.0, 1.0, 0.0])) + facecolor_cmap_mask.append(False) + else: + facecolors.append(np.empty(4)) + facecolor_cmap_mask.append(True) + facecolor_objs.append(objective) + + vertices.append(vor.vertices[region]) + + # Compute facecolors from the cmap. We first normalize the objectives and + # clip them to [0, 1]. + normalized_objs = np.clip( + (np.asarray(facecolor_objs) - min_obj) / (max_obj - min_obj), 0.0, 1.0) + facecolors = np.asarray(facecolors) + facecolors[facecolor_cmap_mask] = cmap(normalized_objs) + + # Plot the collection on the axes. Note that this is faster than plotting + # each polygon individually with ax.fill(). + ax.add_collection( + matplotlib.collections.PolyCollection( + vertices, + edgecolors=ec, + facecolors=facecolors, + linewidths=lw, + )) # Create a colorbar. mappable = ScalarMappable(cmap=cmap) diff --git a/tests/visualize/visualize_test.py b/tests/visualize/visualize_test.py index b4fc509b2..1c8bd0b3a 100644 --- a/tests/visualize/visualize_test.py +++ b/tests/visualize/visualize_test.py @@ -25,6 +25,11 @@ # pylint: disable = redefined-outer-name +# Tolerance for root mean square difference between the pixels of the images, +# where 255 is the max value. We only have tolerance for `cvt_archive_heatmap` +# since it is a bit more finicky than the other plots. +CVT_IMAGE_TOLERANCE = 0.1 + @pytest.fixture(autouse=True) def clean_matplotlib(): @@ -379,7 +384,8 @@ def test_heatmap_archive__grid_custom_cbar_axis(grid_archive): @image_comparison(baseline_images=["cvt_archive_heatmap"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_heatmap_archive__cvt(cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(cvt_archive) @@ -403,7 +409,8 @@ def test_heatmap_with_custom_axis__grid(grid_archive): @image_comparison(baseline_images=["cvt_archive_heatmap"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_heatmap_with_custom_axis__cvt(cvt_archive): _, ax = plt.subplots(figsize=(8, 6)) cvt_archive_heatmap(cvt_archive, ax=ax) @@ -427,7 +434,8 @@ def test_heatmap_long__grid(long_grid_archive): @image_comparison(baseline_images=["cvt_archive_heatmap_long"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_heatmap_long__cvt(long_cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(long_cvt_archive) @@ -451,7 +459,8 @@ def test_heatmap_long_square__grid(long_grid_archive): @image_comparison(baseline_images=["cvt_archive_heatmap_long_square"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_heatmap_long_square__cvt(long_cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(long_cvt_archive, aspect="equal") @@ -475,7 +484,8 @@ def test_heatmap_long_transpose__grid(long_grid_archive): @image_comparison(baseline_images=["cvt_archive_heatmap_long_transpose"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_heatmap_long_transpose__cvt(long_cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(long_cvt_archive, transpose_measures=True) @@ -502,7 +512,8 @@ def test_heatmap_with_limits__grid(grid_archive): @image_comparison(baseline_images=["cvt_archive_heatmap_with_limits"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_heatmap_with_limits__cvt(cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(cvt_archive, vmin=-1.0, vmax=-0.5) @@ -527,7 +538,8 @@ def test_heatmap_listed_cmap__grid(grid_archive): @image_comparison(baseline_images=["cvt_archive_heatmap_with_listed_cmap"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_heatmap_listed_cmap__cvt(cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(cvt_archive, cmap=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]) @@ -553,7 +565,8 @@ def test_heatmap_coolwarm_cmap__grid(grid_archive): @image_comparison(baseline_images=["cvt_archive_heatmap_with_coolwarm_cmap"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_heatmap_coolwarm_cmap__cvt(cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(cvt_archive, cmap="coolwarm") @@ -614,7 +627,8 @@ def test_sliding_archive_mismatch_xy_with_boundaries(): @image_comparison(baseline_images=["cvt_archive_heatmap_vmin_equals_vmax"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_cvt_archive_heatmap_vmin_equals_vmax(cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(cvt_archive, vmin=-0.5, vmax=-0.5) @@ -622,7 +636,8 @@ def test_cvt_archive_heatmap_vmin_equals_vmax(cvt_archive): @image_comparison(baseline_images=["cvt_archive_heatmap_with_centroids"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_cvt_archive_heatmap_with_centroids(cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(cvt_archive, plot_centroids=True) @@ -630,7 +645,8 @@ def test_cvt_archive_heatmap_with_centroids(cvt_archive): @image_comparison(baseline_images=["cvt_archive_heatmap_with_samples"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_cvt_archive_heatmap_with_samples(cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(cvt_archive, plot_samples=True) @@ -650,7 +666,8 @@ def test_cvt_archive_heatmap_no_samples_error(): @image_comparison(baseline_images=["cvt_archive_heatmap_voronoi_style"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) def test_cvt_archive_heatmap_voronoi_style(cvt_archive): plt.figure(figsize=(8, 6)) cvt_archive_heatmap(cvt_archive, lw=3.0, ec="grey") diff --git a/tests/visualize_qdax/visualize_qdax_test.py b/tests/visualize_qdax/visualize_qdax_test.py index ac13c7c3a..e1a9fe4a3 100644 --- a/tests/visualize_qdax/visualize_qdax_test.py +++ b/tests/visualize_qdax/visualize_qdax_test.py @@ -28,7 +28,8 @@ def clean_matplotlib(): @image_comparison(baseline_images=["qdax_repertoire_heatmap"], remove_text=False, - extensions=["png"]) + extensions=["png"], + tol=0.1) # See CVT_IMAGE_TOLERANCE in visualize_test.py def test_qdax_repertoire_heatmap(): plt.figure(figsize=(8, 6))