Skip to content

Commit

Permalink
Fix cvt archive heatmap again
Browse files Browse the repository at this point in the history
  • Loading branch information
btjanaka committed Sep 12, 2023
1 parent 2668139 commit 86afa05
Showing 1 changed file with 41 additions and 5 deletions.
46 changes: 41 additions & 5 deletions ribs/visualize/_cvt_archive_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,32 @@ def cvt_archive_heatmap(archive,
>>> # Plot a heatmap of the archive.
>>> plt.figure(figsize=(8, 6))
>>> cvt_archive_heatmap(archive)
>>> plt.title("Negative sphere function")
>>> plt.title("Negative sphere function with 2D measures")
>>> plt.xlabel("x coords")
>>> plt.ylabel("y coords")
>>> plt.show()
.. plot::
:context: close-figs
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from ribs.archives import CVTArchive
>>> from ribs.visualize import cvt_archive_heatmap
>>> # Populate the archive with the negative sphere function.
>>> archive = CVTArchive(solution_dim=2,
... cells=20, ranges=[(-1, 1)])
>>> x = np.linspace(-1, 1, 100)
>>> archive.add(solution_batch=np.stack((x, x), axis=1),
... objective_batch=-x**2,
... measures_batch=x[:, None])
>>> # Plot a heatmap of the archive.
>>> plt.figure(figsize=(8, 6))
>>> cvt_archive_heatmap(archive)
>>> plt.title("Negative sphere function with 1D measures")
>>> plt.xlabel("x coords")
>>> plt.show()
Args:
archive (CVTArchive): A 1D or 2D :class:`~ribs.archives.CVTArchive`.
ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
Expand Down Expand Up @@ -172,10 +193,25 @@ def cvt_archive_heatmap(archive,
))

df = archive.as_pandas()
objective_batch = df.objective_batch()
cell_objectives = np.full((1, archive.cells), np.nan)
cell_idx = centroid_sort_idx[df.index_batch()]
cell_objectives[0, cell_idx] = objective_batch

# centroid_sort_idx tells us which index to place the centroid at such
# that it is sorted, i.e., it maps from the indices in the centroid
# array to the cell indices. This means that if you index with it, e.g.,
# arr[centroid_sort_idx], you get a sorted array.
#
# For computing cell_objectives, we need to know the inverse mapping,
# i.e., the mapping from cell indices to centroid indices. This way,
# when we index with it, we get the original order of centroids. This
# original order then matches with the objectives in objective_batch.
inv_idx = np.zeros_like(centroid_sort_idx, dtype=np.int32)
for i, x in enumerate(centroid_sort_idx):
inv_idx[x] = i

# We only want inverse indexes that are actually used in the archive.
selected_inv_idx = inv_idx[df.index_batch()]

cell_objectives = np.full(archive.cells, np.nan)
cell_objectives[selected_inv_idx] = df.objective_batch()

ax = archive_heatmap_1d(archive, cell_boundaries, cell_objectives, ax,
cmap, aspect, vmin, vmax, cbar, cbar_kwargs,
Expand Down

0 comments on commit 86afa05

Please sign in to comment.