Skip to content

Commit

Permalink
Enable plotting custom data in visualizations (#374)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

This PR seeks to solve two issues:

1. Plotting data from old archives. For instance, if we previously
called `as_pandas()` on an archive and stored an old dataframe, we may
want to plot that dataframe afterwards, e.g., during our data analysis.
2. Plotting custom data that is not necessarily the objective. For
instance, #195 requests plotting metrics that are not necessarily the
objective. Another reasonable case is that we have stored some metric in
our metadata and wish to plot it instead of the objective.

I propose that these two issues are really the same issue. In
particular, both of these are asking to _visualize custom data that are
not currently in the archive_. Issue 1 wants to plot old data, and Issue
2 wants to plot data with a different objective.

Thus, this PR adds a single parameter, `df`, that can be used to change
the data that is plotted. Essentially, when this parameter is provided,
the archive only provides configurations like the upper/lower bounds of
the measure space and the cell boundaries, while `df` provides the
content that is plotted. This `df` may be retrieved from an earlier call
to `as_pandas` on the archive, thus resolving Issue 1. Furthermore,
users can replace `df["objective"]` on their own, thus resolving Issue
2. This feature also allows a user to plot data after performing
operations on the dataframe; for instance, one could filter the
dataframe and plot the X highest performing solutions.

### Caveats

- It is a bit cumbersome to ask users to ask users to manually set
`df["objective"]`. However, the only alternative I could think of was to
pass in a callable that takes in the dataframe or an EliteBatch and then
returns new values to plot. However, this is equally cumbersome as users
need to understand the callable format in addition to the dataframe
structure.
- It is possible that users pass in a "corrupted" dataframe, e.g., one
that has out-of-bounds indices. However, for now, I believe it is
reasonable to assume users will only get dataframes from `as_pandas()`
and will not perform operations that introduce such entries. In the
future, we can add validation checks if this becomes an issue.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Add `df` param to all visualization functions
- [x] Add util for preprocessing data param
- [x] Add tests for all visualization functions

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka authored Sep 15, 2023
1 parent 4f8637c commit b6de0f9
Show file tree
Hide file tree
Showing 18 changed files with 138 additions and 21 deletions.
3 changes: 2 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
- Add `rasterized` arg for heatmaps (#359)
- Support 1D cvt_archive_heatmap ({pr}`362`)
- Add 3D plots for CVTArchive ({pr}`371`)
- Add visualization of 3D QDax repertoires ({pr}`372`)
- Add visualization of 3D QDax repertoires ({pr}`373`)
- Enable plotting custom data in visualizations ({pr}`374`)

#### Documentation

Expand Down
16 changes: 12 additions & 4 deletions ribs/visualize/_cvt_archive_3d_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.spatial import Voronoi # pylint: disable=no-name-in-module

from ribs.visualize._utils import (retrieve_cmap, set_cbar,
from ribs.visualize._utils import (retrieve_cmap, set_cbar, validate_df,
validate_heatmap_visual_args)


def cvt_archive_3d_plot(
archive,
ax=None,
*,
df=None,
measure_order=None,
cmap="magma",
lw=0.5,
Expand Down Expand Up @@ -150,6 +151,13 @@ def cvt_archive_3d_plot(
archive (CVTArchive): A 3D :class:`~ribs.archives.CVTArchive`.
ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
If ``None``, we will create a new 3D axis.
df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from
this argument instead of the data currently in the archive. This
data can be obtained by, for instance, calling
:meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the
resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the
data must contain columns for index, objective, and measures. To
display a custom metric, replace the "objective" column.
measure_order (array-like of int): Specifies the axes order for plotting
the measures. By default, the first measure (measure 0) in the
archive appears on the x-axis, the second (measure 1) on y-axis, and
Expand Down Expand Up @@ -208,8 +216,8 @@ def cvt_archive_3d_plot(
# Try getting the colormap early in case it fails.
cmap = retrieve_cmap(cmap)

# Retrieve data from archive.
df = archive.as_pandas()
# Retrieve archive data.
df = archive.as_pandas() if df is None else validate_df(df)
objective_batch = df.objective_batch()
measures_batch = df.measures_batch()
lower_bounds = archive.lower_bounds
Expand Down Expand Up @@ -289,7 +297,7 @@ def cvt_archive_3d_plot(
objs = [] # Also record objective for each ridge so we can color it.

# Map from centroid index to objective.
pt_to_obj = {elite.index: elite.objective for elite in archive}
pt_to_obj = dict(zip(df.index_batch(), objective_batch))

# The points in the Voronoi diagram are indexed by their placement in the
# input list. Above, when we called Voronoi, `centroids` were placed first,
Expand Down
17 changes: 13 additions & 4 deletions ribs/visualize/_cvt_archive_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy.spatial import Voronoi # pylint: disable=no-name-in-module

from ribs.visualize._utils import (archive_heatmap_1d, retrieve_cmap, set_cbar,
validate_heatmap_visual_args)
validate_df, validate_heatmap_visual_args)

# Matplotlib functions tend to have a ton of args.
# pylint: disable = too-many-arguments
Expand All @@ -16,6 +16,7 @@
def cvt_archive_heatmap(archive,
ax=None,
*,
df=None,
transpose_measures=False,
cmap="magma",
aspect=None,
Expand Down Expand Up @@ -98,6 +99,13 @@ def cvt_archive_heatmap(archive,
archive (CVTArchive): A 1D or 2D :class:`~ribs.archives.CVTArchive`.
ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
If ``None``, the current axis will be used.
df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from
this argument instead of the data currently in the archive. This
data can be obtained by, for instance, calling
:meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the
resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the
data must contain columns for index, objective, and measures. To
display a custom metric, replace the "objective" column.
transpose_measures (bool): By default, the first measure in the archive
will appear along the x-axis, and the second will be along the
y-axis. To switch this behavior (i.e. to transpose the axes), set
Expand Down Expand Up @@ -173,6 +181,9 @@ def cvt_archive_heatmap(archive,
# Try getting the colormap early in case it fails.
cmap = retrieve_cmap(cmap)

# Retrieve archive data.
df = archive.as_pandas() if df is None else validate_df(df)

if archive.measure_dim == 1:
# Read in pcm kwargs -- the linewidth and edgecolor are overwritten by
# our arguments.
Expand All @@ -195,8 +206,6 @@ def cvt_archive_heatmap(archive,
[archive.upper_bounds[0]],
))

df = archive.as_pandas()

# centroid_sort_idx tells us which index to place the centroid at such
# that it is sorted, i.e., it maps from the indices in the centroid
# array to the cell indices. This means that if you index with it, e.g.,
Expand Down Expand Up @@ -279,7 +288,7 @@ def cvt_archive_heatmap(archive,
# the region index of each point.
region_obj = [None] * len(vor.regions)
min_obj, max_obj = np.inf, -np.inf
pt_to_obj = {elite.index: elite.objective for elite in archive}
pt_to_obj = dict(zip(df.index_batch(), df.objective_batch()))
for pt_idx, region_idx in enumerate(
vor.point_region[:-4]): # Exclude faraway_pts.
if region_idx != -1 and pt_idx in pt_to_obj:
Expand Down
15 changes: 12 additions & 3 deletions ribs/visualize/_grid_archive_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from ribs.visualize._utils import (archive_heatmap_1d, retrieve_cmap, set_cbar,
validate_heatmap_visual_args)
validate_df, validate_heatmap_visual_args)

# Matplotlib functions tend to have a ton of args.
# pylint: disable = too-many-arguments
Expand All @@ -12,6 +12,7 @@
def grid_archive_heatmap(archive,
ax=None,
*,
df=None,
transpose_measures=False,
cmap="magma",
aspect=None,
Expand Down Expand Up @@ -86,6 +87,13 @@ def grid_archive_heatmap(archive,
archive (GridArchive): A 1D or 2D :class:`~ribs.archives.GridArchive`.
ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
If ``None``, the current axis will be used.
df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from
this argument instead of the data currently in the archive. This
data can be obtained by, for instance, calling
:meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the
resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the
data must contain columns for index, objective, and measures. To
display a custom metric, replace the "objective" column.
transpose_measures (bool): By default, the first measure in the archive
will appear along the x-axis, and the second will be along the
y-axis. To switch this behavior (i.e. to transpose the axes), set
Expand Down Expand Up @@ -138,8 +146,10 @@ def grid_archive_heatmap(archive,
# Try getting the colormap early in case it fails.
cmap = retrieve_cmap(cmap)

# Retrieve archive data.
df = archive.as_pandas() if df is None else validate_df(df)

if archive.measure_dim == 1:
df = archive.as_pandas()
cell_objectives = np.full(archive.cells, np.nan)
cell_idx = archive.int_to_grid_index(df.index_batch()).squeeze()
cell_objectives[cell_idx] = df.objective_batch()
Expand All @@ -161,7 +171,6 @@ def grid_archive_heatmap(archive,

elif archive.measure_dim == 2:
# Retrieve data from archive.
df = archive.as_pandas()
objective_batch = df.objective_batch()
lower_bounds = archive.lower_bounds
upper_bounds = archive.upper_bounds
Expand Down
12 changes: 10 additions & 2 deletions ribs/visualize/_parallel_axes_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from matplotlib.cm import ScalarMappable

from ribs.visualize._utils import retrieve_cmap, set_cbar
from ribs.visualize._utils import retrieve_cmap, set_cbar, validate_df

# Matplotlib functions tend to have a ton of args.
# pylint: disable = too-many-arguments
Expand All @@ -13,6 +13,7 @@
def parallel_axes_plot(archive,
ax=None,
*,
df=None,
measure_order=None,
cmap="magma",
linewidth=1.5,
Expand Down Expand Up @@ -78,6 +79,13 @@ def parallel_axes_plot(archive,
archive (ArchiveBase): Any ribs archive.
ax (matplotlib.axes.Axes): Axes on which to create the plot.
If ``None``, the current axis will be used.
df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from
this argument instead of the data currently in the archive. This
data can be obtained by, for instance, calling
:meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the
resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the
data must contain columns for index, objective, and measures. To
display a custom metric, replace the "objective" column.
measure_order (list of int or list of (int, str)): If this is a list
of ints, it specifies the axes order for measures (e.g. ``[2, 0,
1]``). If this is a list of tuples, each tuple takes the form
Expand Down Expand Up @@ -155,7 +163,7 @@ def parallel_axes_plot(archive,
upper_bounds = archive.upper_bounds[cols]

host_ax = plt.gca() if ax is None else ax # Try to get current axis.
df = archive.as_pandas(include_solutions=False)
df = archive.as_pandas() if df is None else validate_df(df)
vmin = df["objective"].min() if vmin is None else vmin
vmax = df["objective"].max() if vmax is None else vmax
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
Expand Down
14 changes: 11 additions & 3 deletions ribs/visualize/_sliding_boundaries_archive_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import matplotlib.pyplot as plt
import numpy as np

from ribs.visualize._utils import (retrieve_cmap, set_cbar,
from ribs.visualize._utils import (retrieve_cmap, set_cbar, validate_df,
validate_heatmap_visual_args)

# Matplotlib functions tend to have a ton of args.
Expand All @@ -12,6 +12,7 @@
def sliding_boundaries_archive_heatmap(archive,
ax=None,
*,
df=None,
transpose_measures=False,
cmap="magma",
aspect="auto",
Expand Down Expand Up @@ -64,6 +65,13 @@ def sliding_boundaries_archive_heatmap(archive,
:class:`~ribs.archives.SlidingBoundariesArchive`.
ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
If ``None``, the current axis will be used.
df (ribs.archives.ArchiveDataFrame): If provided, we will plot data from
this argument instead of the data currently in the archive. This
data can be obtained by, for instance, calling
:meth:`ribs.archives.ArchiveBase.as_pandas()` and modifying the
resulting :class:`ArchiveDataFrame`. Note that, at a minimum, the
data must contain columns for index, objective, and measures. To
display a custom metric, replace the "objective" column.
transpose_measures (bool): By default, the first measure in the archive
will appear along the x-axis, and the second will be along the
y-axis. To switch this behavior (i.e. to transpose the axes), set
Expand Down Expand Up @@ -110,8 +118,8 @@ def sliding_boundaries_archive_heatmap(archive,
# Try getting the colormap early in case it fails.
cmap = retrieve_cmap(cmap)

# Retrieve data from archive.
df = archive.as_pandas()
# Retrieve archive data.
df = archive.as_pandas() if df is None else validate_df(df)
measures_batch = df.measures_batch()
x = measures_batch[:, 0]
y = measures_batch[:, 1]
Expand Down
14 changes: 13 additions & 1 deletion ribs/visualize/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import matplotlib.pyplot as plt
import numpy as np

from ribs.archives import ArchiveDataFrame


def retrieve_cmap(cmap):
"""Retrieves colormap from Matplotlib."""
Expand Down Expand Up @@ -39,6 +41,17 @@ def validate_heatmap_visual_args(aspect, cbar, measure_dim, valid_dims,
"or matplotlib.axes.Axes")


def validate_df(df):
"""Helper to validate the df passed into visualization functions."""

# Cast to an ArchiveDataFrame in case someone passed in a regular DataFrame
# or other object.
if not isinstance(df, ArchiveDataFrame):
df = ArchiveDataFrame(df)

return df


def set_cbar(t, ax, cbar, cbar_kwargs):
"""Sets cbar on the Axes given cbar arg."""
cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
Expand Down Expand Up @@ -119,5 +132,4 @@ def archive_heatmap_1d(

# Create color bar.
set_cbar(t, ax, cbar, cbar_kwargs)

return ax
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 10 additions & 2 deletions tests/visualize/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,21 @@ def add_uniform_sphere_2d(archive, x_range, y_range):
The solutions are the same as the measures (the (x,y) coordinates).
x_range and y_range are tuples of (lower_bound, upper_bound).
The metadata contains the positive sphere function for the points.
"""
xxs, yys = np.meshgrid(
np.linspace(x_range[0], x_range[1], 100),
np.linspace(y_range[0], y_range[1], 100),
)
xxs, yys = xxs.ravel(), yys.ravel()
coords = np.stack((xxs, yys), axis=1)
sphere = xxs**2 + yys**2
archive.add(
solution_batch=coords,
objective_batch=-(xxs**2 + yys**2), # Negative sphere.
objective_batch=-sphere, # Negative sphere.
measures_batch=coords,
metadata_batch=sphere, # Positive sphere.
)


Expand All @@ -60,6 +64,8 @@ def add_uniform_sphere_3d(archive, x_range, y_range, z_range):
The solutions are the same as the measures (the (x,y,z) coordinates).
x_range, y_range, and z_range are tuples of (lower_bound, upper_bound).
The metadata contains the positive sphere function for the points.
"""
xxs, yys, zzs = np.meshgrid(
np.linspace(x_range[0], x_range[1], 40),
Expand All @@ -68,8 +74,10 @@ def add_uniform_sphere_3d(archive, x_range, y_range, z_range):
)
xxs, yys, zzs = xxs.ravel(), yys.ravel(), zzs.ravel()
coords = np.stack((xxs, yys, zzs), axis=1)
sphere = xxs**2 + yys**2 + zzs**2
archive.add(
solution_batch=coords,
objective_batch=-(xxs**2 + yys**2 + zzs**2), # Negative sphere.
objective_batch=-sphere, # Negative sphere.
measures_batch=coords,
metadata_batch=sphere, # Positive sphere.
)
11 changes: 11 additions & 0 deletions tests/visualize/cvt_archive_3d_plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,14 @@ def test_transparent(cvt_archive_3d):
def test_plot_elites(cvt_archive_3d):
plt.figure(figsize=(8, 6))
cvt_archive_3d_plot(cvt_archive_3d, cell_alpha=0.0, plot_elites=True)


@image_comparison(baseline_images=["plot_metadata_with_df"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_plot_metadata_with_df(cvt_archive_3d):
plt.figure(figsize=(8, 6))
df = cvt_archive_3d.as_pandas(include_metadata=True)
df["objective"] = df["metadata"]
cvt_archive_3d_plot(cvt_archive_3d, df=df)
11 changes: 11 additions & 0 deletions tests/visualize/cvt_archive_heatmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,17 @@ def test_rasterized(cvt_archive_2d):
cvt_archive_heatmap(cvt_archive_2d, rasterized=True)


@image_comparison(baseline_images=["plot_metadata_with_df"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_plot_metadata_with_df(cvt_archive_2d):
plt.figure(figsize=(8, 6))
df = cvt_archive_2d.as_pandas(include_metadata=True)
df["objective"] = df["metadata"]
cvt_archive_heatmap(cvt_archive_2d, df=df)


#
# Tests for `clip` parameter
#
Expand Down
10 changes: 10 additions & 0 deletions tests/visualize/grid_archive_heatmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ def test_rasterized(grid_archive_2d):
grid_archive_heatmap(grid_archive_2d, rasterized=True)


@image_comparison(baseline_images=["plot_metadata_with_df"],
remove_text=False,
extensions=["png"])
def test_plot_metadata_with_df(grid_archive_2d):
plt.figure(figsize=(8, 6))
df = grid_archive_2d.as_pandas(include_metadata=True)
df["objective"] = df["metadata"]
grid_archive_heatmap(grid_archive_2d, df=df)


#
# 1D tests
#
Expand Down
10 changes: 10 additions & 0 deletions tests/visualize/parallel_axes_plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,13 @@ def test_3d_sorted(grid_archive_3d):
def test_3d_vertical_cbar(grid_archive_3d):
plt.figure(figsize=(8, 6))
parallel_axes_plot(grid_archive_3d, cbar_kwargs={"orientation": "vertical"})


@image_comparison(baseline_images=["plot_metadata_with_df"],
remove_text=False,
extensions=["png"])
def test_plot_metadata_with_df(grid_archive_3d):
plt.figure(figsize=(8, 6))
df = grid_archive_3d.as_pandas(include_metadata=True)
df["objective"] = df["metadata"]
parallel_axes_plot(grid_archive_3d, df=df)
Loading

0 comments on commit b6de0f9

Please sign in to comment.