Skip to content

Commit

Permalink
Return occupied booleans in retrieve (#414)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

Previously, we relied on sentinel values to indicate whether a given
cell was occupied. Since it is entirely possible that users want to use
these sentinel values in their fields, we now return a separate
`occupied` array that indicates which cells are occupied.

Considerations:
- Chose not to support additional return types like tuple and pandas for
now, as such flexibility is less essential in `retrieve`, and this
feature can be added fairly easily later on
- We still set the sentinel values depending on the field type since it
may be confusing to see arbitrary values for a given field without
seeing the occupied array.
- the `threshold` field is now included in outputs from `retrieve()`

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Implement new retrieve and retrieve_single methods
- [x] Fix tests
- [x] Fix usage in tutorials

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka authored Nov 11, 2023
1 parent 05e4910 commit f1da318
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 60 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#### API

- **Backwards-incompatible:** Return occupied booleans in retrieve ({pr}`414`)
- **Backwards-incompatible:** Deprecate `as_pandas` in favor of
`data(return_type="pandas")` ({pr}`408`)
- **Backwards-incompatible:** Replace ArchiveDataFrame batch methods with
Expand Down
85 changes: 40 additions & 45 deletions ribs/archives/_archive_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,30 +520,30 @@ def retrieve(self, measures_batch):
This method operates in batch, i.e., it takes in a batch of measures and
outputs the batched data for the elites::
elites = archive.retrieve(...)
occupied, elites = archive.retrieve(...)
elites["solution"] # Shape: (batch_size, solution_dim)
elites["objective"]
elites["measures"]
elites["index"]
elites["metadata"]
If the cell associated with ``elites["measures"][i]`` has an elite in
it, then ``elites["solution"][i]``, ``elites["objective"][i]``,
it, then ``occupied[i]`` will be True. Furthermore,
``elites["solution"][i]``, ``elites["objective"][i]``,
``elites["measures"][i]``, ``elites["index"][i]``, and
``elites["metadata"][i]`` will be set to the properties of the elite.
Note that ``elites["measures"][i]`` may not be equal to the
``measures_batch[i]`` passed as an argument, since the measures only
need to be in the same archive cell.
If the cell associated with ``measures_batch[i]`` *does not* have any
elite in it, then the corresponding outputs are set to empty values --
namely:
elite in it, then ``occupied[i]`` will be set to False. Furthermore, the
corresponding outputs will be set to empty values -- namely:
* ``elites["solution"][i]`` will be an array of NaN
* ``elites["objective"][i]`` will be NaN
* ``elites["measures"][i]`` will be an array of NaN
* ``elites["index"][i]`` will be -1
* ``elites["metadata"][i]`` will be None
* NaN for floating-point fields
* -1 for the "index" field
* 0 for integer fields
* None for object fields
If you need to retrieve a *single* elite associated with some measures,
consider using :meth:`retrieve_single`.
Expand All @@ -552,7 +552,11 @@ def retrieve(self, measures_batch):
measures_batch (array-like): (batch_size, :attr:`measure_dim`)
array of coordinates in measure space.
Returns:
dict: See above.
tuple: 2-element tuple of (occupied array, dict). The occupied array
indicates whether each of the cells indicated by the measures in
measures_batch has an elite, while the dict contains the data of
those elites. The dict maps from field name to the corresponding
array.
Raises:
ValueError: ``measures_batch`` is not of shape (batch_size,
:attr:`measure_dim`).
Expand All @@ -564,44 +568,38 @@ def retrieve(self, measures_batch):
check_finite(measures_batch, "measures_batch")

occupied, data = self._store.retrieve(self.index_of(measures_batch))
unoccupied = ~occupied

return {
# For each occupied_batch[i], this np.where selects
# self._solution_arr[index_batch][i] if occupied_batch[i] is True.
# Otherwise, it uses the alternate value (a solution array
# consisting of np.nan).
"solution":
np.where(occupied[:, None], data["solution"],
np.full(self._solution_dim, np.nan)),
# Here the alternative is just a scalar np.nan.
"objective":
np.where(occupied, data["objective"], np.nan),
# And here it is a measures array of np.nan.
"measures":
np.where(occupied[:, None], data["measures"],
np.full(self._measure_dim, np.nan)),
# Indices must be integers, so np.nan would not work, so we use -1.
"index":
np.where(occupied, data["index"], -1),
"metadata":
np.where(occupied, data["metadata"], None),
}
for name, arr in data.items():
if arr.dtype == object:
fill_val = None
elif name == "index":
fill_val = -1
elif np.issubdtype(arr.dtype, np.integer):
fill_val = 0
else: # Floating-point and other fields.
fill_val = np.nan

arr[unoccupied] = fill_val

return occupied, data

def retrieve_single(self, measures):
"""Retrieves the elite with measures in the same cell as the measures
specified.
While :meth:`retrieve` takes in a *batch* of measures, this method takes
in the measures for only *one* solution and returns a dict with single
entries.
in the measures for only *one* solution and returns a single bool and a
dict with single entries.
Args:
measures (array-like): (:attr:`measure_dim`,) array of measures.
Returns:
If there is an elite with measures in the same cell as the measures
specified, then this method returns dict where all the fields hold
the info of the elite. Otherwise, this method returns a dict filled
with the same "empty" values described in :meth:`retrieve`.
tuple: If there is an elite with measures in the same cell as the
measures specified, then this method returns a True value and a dict
where all the fields hold the info of the elite. Otherwise, this
method returns a False value and a dict filled with the same "empty"
values described in :meth:`retrieve`.
Raises:
ValueError: ``measures`` is not of shape (:attr:`measure_dim`,).
ValueError: ``measures`` has non-finite values (inf or NaN).
Expand All @@ -610,10 +608,9 @@ def retrieve_single(self, measures):
check_1d_shape(measures, "measures", self.measure_dim, "measure_dim")
check_finite(measures, "measures")

return {
field: arr[0]
for field, arr in self.retrieve(measures[None]).items()
}
occupied, data = self.retrieve(measures[None])

return occupied[0], {field: arr[0] for field, arr in data.items()}

def sample_elites(self, n):
"""Randomly samples elites from the archive.
Expand Down Expand Up @@ -834,10 +831,8 @@ def cqd_score(self,
penalties = np.copy(penalties) # Copy since we return this.
check_is_1d(penalties, "penalties")

objective_batch, measures_batch = self._store.data(
["objective", "measures"],
return_type="tuple",
)
objective_batch = self._store.data("objective")
measures_batch = self._store.data("measures")

norm_objectives = objective_batch / (obj_max - obj_min)

Expand Down
16 changes: 12 additions & 4 deletions tests/archives/archive_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,19 +340,23 @@ def test_basic_stats(data):


def test_retrieve_gets_correct_elite(data):
elites = data.archive_with_elite.retrieve([data.measures])
occupied, elites = data.archive_with_elite.retrieve([data.measures])
assert occupied[0]
assert np.all(elites["solution"][0] == data.solution)
assert elites["objective"][0] == data.objective
assert np.all(elites["measures"][0] == data.measures)
assert elites["threshold"][0] == data.objective
# Avoid checking elites["index"] since the meaning varies by archive.
assert elites["metadata"][0] == data.metadata


def test_retrieve_empty_values(data):
elites = data.archive.retrieve([data.measures])
occupied, elites = data.archive.retrieve([data.measures])
assert not occupied[0]
assert np.all(np.isnan(elites["solution"][0]))
assert np.isnan(elites["objective"])
assert np.all(np.isnan(elites["measures"][0]))
assert np.isnan(elites["threshold"])
assert elites["index"][0] == -1
assert elites["metadata"][0] is None

Expand All @@ -363,19 +367,23 @@ def test_retrieve_wrong_shape(data):


def test_retrieve_single_gets_correct_elite(data):
elite = data.archive_with_elite.retrieve_single(data.measures)
occupied, elite = data.archive_with_elite.retrieve_single(data.measures)
assert occupied
assert np.all(elite["solution"] == data.solution)
assert elite["objective"] == data.objective
assert np.all(elite["measures"] == data.measures)
assert elite["threshold"] == data.objective
# Avoid checking elite["index"] since the meaning varies by archive.
assert elite["metadata"] == data.metadata


def test_retrieve_single_empty_values(data):
elite = data.archive.retrieve_single(data.measures)
occupied, elite = data.archive.retrieve_single(data.measures)
assert not occupied
assert np.all(np.isnan(elite["solution"]))
assert np.isnan(elite["objective"])
assert np.all(np.isnan(elite["measures"]))
assert np.isnan(elite["threshold"])
assert elite["index"] == -1
assert elite["metadata"] is None

Expand Down
5 changes: 3 additions & 2 deletions tutorials/arm_repertoire.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,10 @@
}
],
"source": [
"elite = archive.retrieve_single([0, 0])\n",
"occupied, elite = archive.retrieve_single([0, 0])\n",
"_, ax = plt.subplots()\n",
"if elite[\"solution\"] is not None: # This is None if there is no solution for [0,0].\n",
"# `occupied` indicates if there was an elite in the corresponding cell.\n",
"if occupied:\n",
" visualize(elite[\"solution\"], link_lengths, elite[\"objective\"], ax)"
]
},
Expand Down
16 changes: 7 additions & 9 deletions tutorials/lunar_lander.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,6 @@
"id": "t2QPnuqgFKhr"
},
"source": [
"\n",
"We can retrieve policies with measures that are close to a query with the [`retrieve_single`](https://docs.pyribs.org/en/latest/api/ribs.archives.GridArchive.html#ribs.archives.GridArchive.retrieve_single) method. This method will look up the cell corresponding to the queried measures. Then, the method will check if there is an elite in that cell, and return the elite if it exists (the method does not check neighboring cells for elites). The returned elite may not have the exact measures requested because the elite only has to be in the same cell as the queried measures.\n",
"\n",
"Below, we first retrieve a policy that impacted the ground on the left (approximately -0.4) with low velocity (approximately -0.10) by querying for `[-0.4, -0.10]`."
Expand Down Expand Up @@ -789,10 +788,9 @@
}
],
"source": [
"elite = archive.retrieve_single([-0.4, -0.10])\n",
"# NaN objective indicates the solution could not be retrieved because there was\n",
"# no elite in the corresponding cell.\n",
"if not np.isnan(elite[\"objective\"]):\n",
"occupied, elite = archive.retrieve_single([-0.4, -0.10])\n",
"# `occupied` indicates if there was an elite in the corresponding cell.\n",
"if occupied:\n",
" print(f\"Objective: {elite['objective']}\")\n",
" print(f\"Measures: (x-pos: {elite['measures'][0]}, y-vel: {elite['measures'][1]})\")\n",
" display_video(elite[\"solution\"])"
Expand Down Expand Up @@ -848,8 +846,8 @@
}
],
"source": [
"elite = archive.retrieve_single([0.6, -0.10])\n",
"if not np.isnan(elite[\"objective\"]):\n",
"occupied, elite = archive.retrieve_single([0.6, -0.10])\n",
"if occupied:\n",
" print(f\"Objective: {elite['objective']}\")\n",
" print(f\"Measures: (x-pos: {elite['measures'][0]}, y-vel: {elite['measures'][1]})\")\n",
" display_video(elite[\"solution\"])"
Expand Down Expand Up @@ -901,8 +899,8 @@
}
],
"source": [
"elite = archive.retrieve_single([0.0, -0.10])\n",
"if not np.isnan(elite[\"objective\"]):\n",
"occupied, elite = archive.retrieve_single([0.0, -0.10])\n",
"if occupied:\n",
" print(f\"Objective: {elite['objective']}\")\n",
" print(f\"Measures: (x-pos: {elite['measures'][0]}, y-vel: {elite['measures'][1]})\")\n",
" display_video(elite[\"solution\"])"
Expand Down

0 comments on commit f1da318

Please sign in to comment.