diff --git a/HISTORY.md b/HISTORY.md index 4af2ad497..fe2469b70 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -17,7 +17,8 @@ - Add `rasterized` arg for heatmaps (#359) - Support 1D cvt_archive_heatmap ({pr}`362`) - Add 3D plots for CVTArchive ({pr}`371`) -- Add visualization of 3D QDax repertoires ({pr}`372`) +- Add visualization of 3D QDax repertoires ({pr}`373`) +- Enable plotting custom data in visualizations ({pr}`374`) #### Documentation diff --git a/ribs/visualize/_cvt_archive_3d_plot.py b/ribs/visualize/_cvt_archive_3d_plot.py index 2b96509f0..31e2dddfa 100644 --- a/ribs/visualize/_cvt_archive_3d_plot.py +++ b/ribs/visualize/_cvt_archive_3d_plot.py @@ -5,7 +5,7 @@ from mpl_toolkits.mplot3d.art3d import Poly3DCollection from scipy.spatial import Voronoi # pylint: disable=no-name-in-module -from ribs.visualize._utils import (retrieve_cmap, set_cbar, +from ribs.visualize._utils import (retrieve_cmap, set_cbar, validate_df, validate_heatmap_visual_args) @@ -13,6 +13,7 @@ def cvt_archive_3d_plot( archive, ax=None, *, + df=None, measure_order=None, cmap="magma", lw=0.5, @@ -150,6 +151,13 @@ def cvt_archive_3d_plot( archive (CVTArchive): A 3D :class:`~ribs.archives.CVTArchive`. ax (matplotlib.axes.Axes): Axes on which to plot the heatmap. If ``None``, we will create a new 3D axis. + df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from + this argument instead of the data currently in the archive. This + data can be obtained by, for instance, calling + :meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the + resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the + data must contain columns for index, objective, and measures. To + display a custom metric, replace the "objective" column. measure_order (array-like of int): Specifies the axes order for plotting the measures. By default, the first measure (measure 0) in the archive appears on the x-axis, the second (measure 1) on y-axis, and @@ -208,8 +216,8 @@ def cvt_archive_3d_plot( # Try getting the colormap early in case it fails. cmap = retrieve_cmap(cmap) - # Retrieve data from archive. - df = archive.as_pandas() + # Retrieve archive data. + df = archive.as_pandas() if df is None else validate_df(df) objective_batch = df.objective_batch() measures_batch = df.measures_batch() lower_bounds = archive.lower_bounds @@ -289,7 +297,7 @@ def cvt_archive_3d_plot( objs = [] # Also record objective for each ridge so we can color it. # Map from centroid index to objective. - pt_to_obj = {elite.index: elite.objective for elite in archive} + pt_to_obj = dict(zip(df.index_batch(), objective_batch)) # The points in the Voronoi diagram are indexed by their placement in the # input list. Above, when we called Voronoi, `centroids` were placed first, diff --git a/ribs/visualize/_cvt_archive_heatmap.py b/ribs/visualize/_cvt_archive_heatmap.py index 799ed9534..0648dbb90 100644 --- a/ribs/visualize/_cvt_archive_heatmap.py +++ b/ribs/visualize/_cvt_archive_heatmap.py @@ -7,7 +7,7 @@ from scipy.spatial import Voronoi # pylint: disable=no-name-in-module from ribs.visualize._utils import (archive_heatmap_1d, retrieve_cmap, set_cbar, - validate_heatmap_visual_args) + validate_df, validate_heatmap_visual_args) # Matplotlib functions tend to have a ton of args. # pylint: disable = too-many-arguments @@ -16,6 +16,7 @@ def cvt_archive_heatmap(archive, ax=None, *, + df=None, transpose_measures=False, cmap="magma", aspect=None, @@ -98,6 +99,13 @@ def cvt_archive_heatmap(archive, archive (CVTArchive): A 1D or 2D :class:`~ribs.archives.CVTArchive`. ax (matplotlib.axes.Axes): Axes on which to plot the heatmap. If ``None``, the current axis will be used. + df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from + this argument instead of the data currently in the archive. This + data can be obtained by, for instance, calling + :meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the + resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the + data must contain columns for index, objective, and measures. To + display a custom metric, replace the "objective" column. transpose_measures (bool): By default, the first measure in the archive will appear along the x-axis, and the second will be along the y-axis. To switch this behavior (i.e. to transpose the axes), set @@ -173,6 +181,9 @@ def cvt_archive_heatmap(archive, # Try getting the colormap early in case it fails. cmap = retrieve_cmap(cmap) + # Retrieve archive data. + df = archive.as_pandas() if df is None else validate_df(df) + if archive.measure_dim == 1: # Read in pcm kwargs -- the linewidth and edgecolor are overwritten by # our arguments. @@ -195,8 +206,6 @@ def cvt_archive_heatmap(archive, [archive.upper_bounds[0]], )) - df = archive.as_pandas() - # centroid_sort_idx tells us which index to place the centroid at such # that it is sorted, i.e., it maps from the indices in the centroid # array to the cell indices. This means that if you index with it, e.g., @@ -279,7 +288,7 @@ def cvt_archive_heatmap(archive, # the region index of each point. region_obj = [None] * len(vor.regions) min_obj, max_obj = np.inf, -np.inf - pt_to_obj = {elite.index: elite.objective for elite in archive} + pt_to_obj = dict(zip(df.index_batch(), df.objective_batch())) for pt_idx, region_idx in enumerate( vor.point_region[:-4]): # Exclude faraway_pts. if region_idx != -1 and pt_idx in pt_to_obj: diff --git a/ribs/visualize/_grid_archive_heatmap.py b/ribs/visualize/_grid_archive_heatmap.py index 918951e64..703f713bd 100644 --- a/ribs/visualize/_grid_archive_heatmap.py +++ b/ribs/visualize/_grid_archive_heatmap.py @@ -3,7 +3,7 @@ import numpy as np from ribs.visualize._utils import (archive_heatmap_1d, retrieve_cmap, set_cbar, - validate_heatmap_visual_args) + validate_df, validate_heatmap_visual_args) # Matplotlib functions tend to have a ton of args. # pylint: disable = too-many-arguments @@ -12,6 +12,7 @@ def grid_archive_heatmap(archive, ax=None, *, + df=None, transpose_measures=False, cmap="magma", aspect=None, @@ -86,6 +87,13 @@ def grid_archive_heatmap(archive, archive (GridArchive): A 1D or 2D :class:`~ribs.archives.GridArchive`. ax (matplotlib.axes.Axes): Axes on which to plot the heatmap. If ``None``, the current axis will be used. + df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from + this argument instead of the data currently in the archive. This + data can be obtained by, for instance, calling + :meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the + resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the + data must contain columns for index, objective, and measures. To + display a custom metric, replace the "objective" column. transpose_measures (bool): By default, the first measure in the archive will appear along the x-axis, and the second will be along the y-axis. To switch this behavior (i.e. to transpose the axes), set @@ -138,8 +146,10 @@ def grid_archive_heatmap(archive, # Try getting the colormap early in case it fails. cmap = retrieve_cmap(cmap) + # Retrieve archive data. + df = archive.as_pandas() if df is None else validate_df(df) + if archive.measure_dim == 1: - df = archive.as_pandas() cell_objectives = np.full(archive.cells, np.nan) cell_idx = archive.int_to_grid_index(df.index_batch()).squeeze() cell_objectives[cell_idx] = df.objective_batch() @@ -161,7 +171,6 @@ def grid_archive_heatmap(archive, elif archive.measure_dim == 2: # Retrieve data from archive. - df = archive.as_pandas() objective_batch = df.objective_batch() lower_bounds = archive.lower_bounds upper_bounds = archive.upper_bounds diff --git a/ribs/visualize/_parallel_axes_plot.py b/ribs/visualize/_parallel_axes_plot.py index 1afb7b687..28b57ca45 100644 --- a/ribs/visualize/_parallel_axes_plot.py +++ b/ribs/visualize/_parallel_axes_plot.py @@ -4,7 +4,7 @@ import numpy as np from matplotlib.cm import ScalarMappable -from ribs.visualize._utils import retrieve_cmap, set_cbar +from ribs.visualize._utils import retrieve_cmap, set_cbar, validate_df # Matplotlib functions tend to have a ton of args. # pylint: disable = too-many-arguments @@ -13,6 +13,7 @@ def parallel_axes_plot(archive, ax=None, *, + df=None, measure_order=None, cmap="magma", linewidth=1.5, @@ -78,6 +79,13 @@ def parallel_axes_plot(archive, archive (ArchiveBase): Any ribs archive. ax (matplotlib.axes.Axes): Axes on which to create the plot. If ``None``, the current axis will be used. + df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from + this argument instead of the data currently in the archive. This + data can be obtained by, for instance, calling + :meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the + resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the + data must contain columns for index, objective, and measures. To + display a custom metric, replace the "objective" column. measure_order (list of int or list of (int, str)): If this is a list of ints, it specifies the axes order for measures (e.g. ``[2, 0, 1]``). If this is a list of tuples, each tuple takes the form @@ -155,7 +163,7 @@ def parallel_axes_plot(archive, upper_bounds = archive.upper_bounds[cols] host_ax = plt.gca() if ax is None else ax # Try to get current axis. - df = archive.as_pandas(include_solutions=False) + df = archive.as_pandas() if df is None else validate_df(df) vmin = df["objective"].min() if vmin is None else vmin vmax = df["objective"].max() if vmax is None else vmax norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=True) diff --git a/ribs/visualize/_sliding_boundaries_archive_heatmap.py b/ribs/visualize/_sliding_boundaries_archive_heatmap.py index aa81d7946..207150cc1 100644 --- a/ribs/visualize/_sliding_boundaries_archive_heatmap.py +++ b/ribs/visualize/_sliding_boundaries_archive_heatmap.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt import numpy as np -from ribs.visualize._utils import (retrieve_cmap, set_cbar, +from ribs.visualize._utils import (retrieve_cmap, set_cbar, validate_df, validate_heatmap_visual_args) # Matplotlib functions tend to have a ton of args. @@ -12,6 +12,7 @@ def sliding_boundaries_archive_heatmap(archive, ax=None, *, + df=None, transpose_measures=False, cmap="magma", aspect="auto", @@ -64,6 +65,13 @@ def sliding_boundaries_archive_heatmap(archive, :class:`~ribs.archives.SlidingBoundariesArchive`. ax (matplotlib.axes.Axes): Axes on which to plot the heatmap. If ``None``, the current axis will be used. + df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from + this argument instead of the data currently in the archive. This + data can be obtained by, for instance, calling + :meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the + resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the + data must contain columns for index, objective, and measures. To + display a custom metric, replace the "objective" column. transpose_measures (bool): By default, the first measure in the archive will appear along the x-axis, and the second will be along the y-axis. To switch this behavior (i.e. to transpose the axes), set @@ -110,8 +118,8 @@ def sliding_boundaries_archive_heatmap(archive, # Try getting the colormap early in case it fails. cmap = retrieve_cmap(cmap) - # Retrieve data from archive. - df = archive.as_pandas() + # Retrieve archive data. + df = archive.as_pandas() if df is None else validate_df(df) measures_batch = df.measures_batch() x = measures_batch[:, 0] y = measures_batch[:, 1] diff --git a/ribs/visualize/_utils.py b/ribs/visualize/_utils.py index dfcbbf7aa..6b393aed8 100644 --- a/ribs/visualize/_utils.py +++ b/ribs/visualize/_utils.py @@ -3,6 +3,8 @@ import matplotlib.pyplot as plt import numpy as np +from ribs.archives import ArchiveDataFrame + def retrieve_cmap(cmap): """Retrieves colormap from Matplotlib.""" @@ -39,6 +41,17 @@ def validate_heatmap_visual_args(aspect, cbar, measure_dim, valid_dims, "or matplotlib.axes.Axes") +def validate_df(df): + """Helper to validate the df passed into visualization functions.""" + + # Cast to an ArchiveDataFrame in case someone passed in a regular DataFrame + # or other object. + if not isinstance(df, ArchiveDataFrame): + df = ArchiveDataFrame(df) + + return df + + def set_cbar(t, ax, cbar, cbar_kwargs): """Sets cbar on the Axes given cbar arg.""" cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs @@ -119,5 +132,4 @@ def archive_heatmap_1d( # Create color bar. set_cbar(t, ax, cbar, cbar_kwargs) - return ax diff --git a/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_metadata_with_df.png b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_metadata_with_df.png new file mode 100644 index 000000000..6fe001b00 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_3d_plot_test/plot_metadata_with_df.png differ diff --git a/tests/visualize/baseline_images/cvt_archive_heatmap_test/plot_metadata_with_df.png b/tests/visualize/baseline_images/cvt_archive_heatmap_test/plot_metadata_with_df.png new file mode 100644 index 000000000..1c8946919 Binary files /dev/null and b/tests/visualize/baseline_images/cvt_archive_heatmap_test/plot_metadata_with_df.png differ diff --git a/tests/visualize/baseline_images/grid_archive_heatmap_test/plot_metadata_with_df.png b/tests/visualize/baseline_images/grid_archive_heatmap_test/plot_metadata_with_df.png new file mode 100644 index 000000000..e353c8471 Binary files /dev/null and b/tests/visualize/baseline_images/grid_archive_heatmap_test/plot_metadata_with_df.png differ diff --git a/tests/visualize/baseline_images/parallel_axes_plot_test/plot_metadata_with_df.png b/tests/visualize/baseline_images/parallel_axes_plot_test/plot_metadata_with_df.png new file mode 100644 index 000000000..aeabf380f Binary files /dev/null and b/tests/visualize/baseline_images/parallel_axes_plot_test/plot_metadata_with_df.png differ diff --git a/tests/visualize/baseline_images/sliding_boundaries_archive_heatmap_test/plot_metadata_with_df.png b/tests/visualize/baseline_images/sliding_boundaries_archive_heatmap_test/plot_metadata_with_df.png new file mode 100644 index 000000000..6366796a9 Binary files /dev/null and b/tests/visualize/baseline_images/sliding_boundaries_archive_heatmap_test/plot_metadata_with_df.png differ diff --git a/tests/visualize/conftest.py b/tests/visualize/conftest.py index 91a7230f7..67910128e 100644 --- a/tests/visualize/conftest.py +++ b/tests/visualize/conftest.py @@ -40,6 +40,8 @@ def add_uniform_sphere_2d(archive, x_range, y_range): The solutions are the same as the measures (the (x,y) coordinates). x_range and y_range are tuples of (lower_bound, upper_bound). + + The metadata contains the positive sphere function for the points. """ xxs, yys = np.meshgrid( np.linspace(x_range[0], x_range[1], 100), @@ -47,10 +49,12 @@ def add_uniform_sphere_2d(archive, x_range, y_range): ) xxs, yys = xxs.ravel(), yys.ravel() coords = np.stack((xxs, yys), axis=1) + sphere = xxs**2 + yys**2 archive.add( solution_batch=coords, - objective_batch=-(xxs**2 + yys**2), # Negative sphere. + objective_batch=-sphere, # Negative sphere. measures_batch=coords, + metadata_batch=sphere, # Positive sphere. ) @@ -60,6 +64,8 @@ def add_uniform_sphere_3d(archive, x_range, y_range, z_range): The solutions are the same as the measures (the (x,y,z) coordinates). x_range, y_range, and z_range are tuples of (lower_bound, upper_bound). + + The metadata contains the positive sphere function for the points. """ xxs, yys, zzs = np.meshgrid( np.linspace(x_range[0], x_range[1], 40), @@ -68,8 +74,10 @@ def add_uniform_sphere_3d(archive, x_range, y_range, z_range): ) xxs, yys, zzs = xxs.ravel(), yys.ravel(), zzs.ravel() coords = np.stack((xxs, yys, zzs), axis=1) + sphere = xxs**2 + yys**2 + zzs**2 archive.add( solution_batch=coords, - objective_batch=-(xxs**2 + yys**2 + zzs**2), # Negative sphere. + objective_batch=-sphere, # Negative sphere. measures_batch=coords, + metadata_batch=sphere, # Positive sphere. ) diff --git a/tests/visualize/cvt_archive_3d_plot_test.py b/tests/visualize/cvt_archive_3d_plot_test.py index d91f8494d..e91f5a200 100644 --- a/tests/visualize/cvt_archive_3d_plot_test.py +++ b/tests/visualize/cvt_archive_3d_plot_test.py @@ -200,3 +200,14 @@ def test_transparent(cvt_archive_3d): def test_plot_elites(cvt_archive_3d): plt.figure(figsize=(8, 6)) cvt_archive_3d_plot(cvt_archive_3d, cell_alpha=0.0, plot_elites=True) + + +@image_comparison(baseline_images=["plot_metadata_with_df"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_plot_metadata_with_df(cvt_archive_3d): + plt.figure(figsize=(8, 6)) + df = cvt_archive_3d.as_pandas(include_metadata=True) + df["objective"] = df["metadata"] + cvt_archive_3d_plot(cvt_archive_3d, df=df) diff --git a/tests/visualize/cvt_archive_heatmap_test.py b/tests/visualize/cvt_archive_heatmap_test.py index a1ee505cd..d5a6827d6 100644 --- a/tests/visualize/cvt_archive_heatmap_test.py +++ b/tests/visualize/cvt_archive_heatmap_test.py @@ -219,6 +219,17 @@ def test_rasterized(cvt_archive_2d): cvt_archive_heatmap(cvt_archive_2d, rasterized=True) +@image_comparison(baseline_images=["plot_metadata_with_df"], + remove_text=False, + extensions=["png"], + tol=CVT_IMAGE_TOLERANCE) +def test_plot_metadata_with_df(cvt_archive_2d): + plt.figure(figsize=(8, 6)) + df = cvt_archive_2d.as_pandas(include_metadata=True) + df["objective"] = df["metadata"] + cvt_archive_heatmap(cvt_archive_2d, df=df) + + # # Tests for `clip` parameter # diff --git a/tests/visualize/grid_archive_heatmap_test.py b/tests/visualize/grid_archive_heatmap_test.py index 3ff574e5c..35e939aa6 100644 --- a/tests/visualize/grid_archive_heatmap_test.py +++ b/tests/visualize/grid_archive_heatmap_test.py @@ -194,6 +194,16 @@ def test_rasterized(grid_archive_2d): grid_archive_heatmap(grid_archive_2d, rasterized=True) +@image_comparison(baseline_images=["plot_metadata_with_df"], + remove_text=False, + extensions=["png"]) +def test_plot_metadata_with_df(grid_archive_2d): + plt.figure(figsize=(8, 6)) + df = grid_archive_2d.as_pandas(include_metadata=True) + df["objective"] = df["metadata"] + grid_archive_heatmap(grid_archive_2d, df=df) + + # # 1D tests # diff --git a/tests/visualize/parallel_axes_plot_test.py b/tests/visualize/parallel_axes_plot_test.py index e7fef5287..faaf4033d 100644 --- a/tests/visualize/parallel_axes_plot_test.py +++ b/tests/visualize/parallel_axes_plot_test.py @@ -86,3 +86,13 @@ def test_3d_sorted(grid_archive_3d): def test_3d_vertical_cbar(grid_archive_3d): plt.figure(figsize=(8, 6)) parallel_axes_plot(grid_archive_3d, cbar_kwargs={"orientation": "vertical"}) + + +@image_comparison(baseline_images=["plot_metadata_with_df"], + remove_text=False, + extensions=["png"]) +def test_plot_metadata_with_df(grid_archive_3d): + plt.figure(figsize=(8, 6)) + df = grid_archive_3d.as_pandas(include_metadata=True) + df["objective"] = df["metadata"] + parallel_axes_plot(grid_archive_3d, df=df) diff --git a/tests/visualize/sliding_boundaries_archive_heatmap_test.py b/tests/visualize/sliding_boundaries_archive_heatmap_test.py index 2efedf95c..d92e3042c 100644 --- a/tests/visualize/sliding_boundaries_archive_heatmap_test.py +++ b/tests/visualize/sliding_boundaries_archive_heatmap_test.py @@ -29,10 +29,12 @@ def add_random_sphere(archive, x_range, y_range): (x_range[1], y_range[1]), (1000, 2), ) + sphere = np.sum(np.square(solutions), axis=1) archive.add( solution_batch=solutions, - objective_batch=-np.sum(np.square(solutions), axis=1), + objective_batch=-sphere, measures_batch=solutions, + metadata_batch=sphere, ) @@ -158,3 +160,13 @@ def test_rasterized(sliding_archive_2d): sliding_boundaries_archive_heatmap(sliding_archive_2d, boundary_lw=1.0, rasterized=True) + + +@image_comparison(baseline_images=["plot_metadata_with_df"], + remove_text=False, + extensions=["png"]) +def test_plot_metadata_with_df(sliding_archive_2d): + plt.figure(figsize=(8, 6)) + df = sliding_archive_2d.as_pandas(include_metadata=True) + df["objective"] = df["metadata"] + sliding_boundaries_archive_heatmap(sliding_archive_2d, df=df)