From f1da3187603aaa8dd98ef0da17b0e65ec19887e7 Mon Sep 17 00:00:00 2001 From: Bryon Tjanaka <38124174+btjanaka@users.noreply.github.com> Date: Fri, 10 Nov 2023 18:41:58 -0800 Subject: [PATCH] Return occupied booleans in retrieve (#414) ## Description Previously, we relied on sentinel values to indicate whether a given cell was occupied. Since it is entirely possible that users want to use these sentinel values in their fields, we now return a separate `occupied` array that indicates which cells are occupied. Considerations: - Chose not to support additional return types like tuple and pandas for now, as such flexibility is less essential in `retrieve`, and this feature can be added fairly easily later on - We still set the sentinel values depending on the field type since it may be confusing to see arbitrary values for a given field without seeing the occupied array. - the `threshold` field is now included in outputs from `retrieve()` ## TODO - [x] Implement new retrieve and retrieve_single methods - [x] Fix tests - [x] Fix usage in tutorials ## Questions ## Status - [x] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [x] I have formatted my code using `yapf` - [x] I have tested my code by running `pytest` - [x] I have linted my code with `pylint` - [x] I have added a one-line description of my change to the changelog in `HISTORY.md` - [x] This PR is ready to go --- HISTORY.md | 1 + ribs/archives/_archive_base.py | 85 ++++++++++++++--------------- tests/archives/archive_base_test.py | 16 ++++-- tutorials/arm_repertoire.ipynb | 5 +- tutorials/lunar_lander.ipynb | 16 +++--- 5 files changed, 63 insertions(+), 60 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 2918b7dd5..778366617 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -6,6 +6,7 @@ #### API +- **Backwards-incompatible:** Return occupied booleans in retrieve ({pr}`414`) - **Backwards-incompatible:** Deprecate `as_pandas` in favor of `data(return_type="pandas")` ({pr}`408`) - **Backwards-incompatible:** Replace ArchiveDataFrame batch methods with diff --git a/ribs/archives/_archive_base.py b/ribs/archives/_archive_base.py index a5a131c7e..2eab4906b 100644 --- a/ribs/archives/_archive_base.py +++ b/ribs/archives/_archive_base.py @@ -520,7 +520,7 @@ def retrieve(self, measures_batch): This method operates in batch, i.e., it takes in a batch of measures and outputs the batched data for the elites:: - elites = archive.retrieve(...) + occupied, elites = archive.retrieve(...) elites["solution"] # Shape: (batch_size, solution_dim) elites["objective"] elites["measures"] @@ -528,7 +528,8 @@ def retrieve(self, measures_batch): elites["metadata"] If the cell associated with ``elites["measures"][i]`` has an elite in - it, then ``elites["solution"][i]``, ``elites["objective"][i]``, + it, then ``occupied[i]`` will be True. Furthermore, + ``elites["solution"][i]``, ``elites["objective"][i]``, ``elites["measures"][i]``, ``elites["index"][i]``, and ``elites["metadata"][i]`` will be set to the properties of the elite. Note that ``elites["measures"][i]`` may not be equal to the @@ -536,14 +537,13 @@ def retrieve(self, measures_batch): need to be in the same archive cell. If the cell associated with ``measures_batch[i]`` *does not* have any - elite in it, then the corresponding outputs are set to empty values -- - namely: + elite in it, then ``occupied[i]`` will be set to False. Furthermore, the + corresponding outputs will be set to empty values -- namely: - * ``elites["solution"][i]`` will be an array of NaN - * ``elites["objective"][i]`` will be NaN - * ``elites["measures"][i]`` will be an array of NaN - * ``elites["index"][i]`` will be -1 - * ``elites["metadata"][i]`` will be None + * NaN for floating-point fields + * -1 for the "index" field + * 0 for integer fields + * None for object fields If you need to retrieve a *single* elite associated with some measures, consider using :meth:`retrieve_single`. @@ -552,7 +552,11 @@ def retrieve(self, measures_batch): measures_batch (array-like): (batch_size, :attr:`measure_dim`) array of coordinates in measure space. Returns: - dict: See above. + tuple: 2-element tuple of (occupied array, dict). The occupied array + indicates whether each of the cells indicated by the measures in + measures_batch has an elite, while the dict contains the data of + those elites. The dict maps from field name to the corresponding + array. Raises: ValueError: ``measures_batch`` is not of shape (batch_size, :attr:`measure_dim`). @@ -564,44 +568,38 @@ def retrieve(self, measures_batch): check_finite(measures_batch, "measures_batch") occupied, data = self._store.retrieve(self.index_of(measures_batch)) + unoccupied = ~occupied - return { - # For each occupied_batch[i], this np.where selects - # self._solution_arr[index_batch][i] if occupied_batch[i] is True. - # Otherwise, it uses the alternate value (a solution array - # consisting of np.nan). - "solution": - np.where(occupied[:, None], data["solution"], - np.full(self._solution_dim, np.nan)), - # Here the alternative is just a scalar np.nan. - "objective": - np.where(occupied, data["objective"], np.nan), - # And here it is a measures array of np.nan. - "measures": - np.where(occupied[:, None], data["measures"], - np.full(self._measure_dim, np.nan)), - # Indices must be integers, so np.nan would not work, so we use -1. - "index": - np.where(occupied, data["index"], -1), - "metadata": - np.where(occupied, data["metadata"], None), - } + for name, arr in data.items(): + if arr.dtype == object: + fill_val = None + elif name == "index": + fill_val = -1 + elif np.issubdtype(arr.dtype, np.integer): + fill_val = 0 + else: # Floating-point and other fields. + fill_val = np.nan + + arr[unoccupied] = fill_val + + return occupied, data def retrieve_single(self, measures): """Retrieves the elite with measures in the same cell as the measures specified. While :meth:`retrieve` takes in a *batch* of measures, this method takes - in the measures for only *one* solution and returns a dict with single - entries. + in the measures for only *one* solution and returns a single bool and a + dict with single entries. Args: measures (array-like): (:attr:`measure_dim`,) array of measures. Returns: - If there is an elite with measures in the same cell as the measures - specified, then this method returns dict where all the fields hold - the info of the elite. Otherwise, this method returns a dict filled - with the same "empty" values described in :meth:`retrieve`. + tuple: If there is an elite with measures in the same cell as the + measures specified, then this method returns a True value and a dict + where all the fields hold the info of the elite. Otherwise, this + method returns a False value and a dict filled with the same "empty" + values described in :meth:`retrieve`. Raises: ValueError: ``measures`` is not of shape (:attr:`measure_dim`,). ValueError: ``measures`` has non-finite values (inf or NaN). @@ -610,10 +608,9 @@ def retrieve_single(self, measures): check_1d_shape(measures, "measures", self.measure_dim, "measure_dim") check_finite(measures, "measures") - return { - field: arr[0] - for field, arr in self.retrieve(measures[None]).items() - } + occupied, data = self.retrieve(measures[None]) + + return occupied[0], {field: arr[0] for field, arr in data.items()} def sample_elites(self, n): """Randomly samples elites from the archive. @@ -834,10 +831,8 @@ def cqd_score(self, penalties = np.copy(penalties) # Copy since we return this. check_is_1d(penalties, "penalties") - objective_batch, measures_batch = self._store.data( - ["objective", "measures"], - return_type="tuple", - ) + objective_batch = self._store.data("objective") + measures_batch = self._store.data("measures") norm_objectives = objective_batch / (obj_max - obj_min) diff --git a/tests/archives/archive_base_test.py b/tests/archives/archive_base_test.py index 84a50ef13..58b236a8c 100644 --- a/tests/archives/archive_base_test.py +++ b/tests/archives/archive_base_test.py @@ -340,19 +340,23 @@ def test_basic_stats(data): def test_retrieve_gets_correct_elite(data): - elites = data.archive_with_elite.retrieve([data.measures]) + occupied, elites = data.archive_with_elite.retrieve([data.measures]) + assert occupied[0] assert np.all(elites["solution"][0] == data.solution) assert elites["objective"][0] == data.objective assert np.all(elites["measures"][0] == data.measures) + assert elites["threshold"][0] == data.objective # Avoid checking elites["index"] since the meaning varies by archive. assert elites["metadata"][0] == data.metadata def test_retrieve_empty_values(data): - elites = data.archive.retrieve([data.measures]) + occupied, elites = data.archive.retrieve([data.measures]) + assert not occupied[0] assert np.all(np.isnan(elites["solution"][0])) assert np.isnan(elites["objective"]) assert np.all(np.isnan(elites["measures"][0])) + assert np.isnan(elites["threshold"]) assert elites["index"][0] == -1 assert elites["metadata"][0] is None @@ -363,19 +367,23 @@ def test_retrieve_wrong_shape(data): def test_retrieve_single_gets_correct_elite(data): - elite = data.archive_with_elite.retrieve_single(data.measures) + occupied, elite = data.archive_with_elite.retrieve_single(data.measures) + assert occupied assert np.all(elite["solution"] == data.solution) assert elite["objective"] == data.objective assert np.all(elite["measures"] == data.measures) + assert elite["threshold"] == data.objective # Avoid checking elite["index"] since the meaning varies by archive. assert elite["metadata"] == data.metadata def test_retrieve_single_empty_values(data): - elite = data.archive.retrieve_single(data.measures) + occupied, elite = data.archive.retrieve_single(data.measures) + assert not occupied assert np.all(np.isnan(elite["solution"])) assert np.isnan(elite["objective"]) assert np.all(np.isnan(elite["measures"])) + assert np.isnan(elite["threshold"]) assert elite["index"] == -1 assert elite["metadata"] is None diff --git a/tutorials/arm_repertoire.ipynb b/tutorials/arm_repertoire.ipynb index 07f0de50b..ffa80dc39 100644 --- a/tutorials/arm_repertoire.ipynb +++ b/tutorials/arm_repertoire.ipynb @@ -449,9 +449,10 @@ } ], "source": [ - "elite = archive.retrieve_single([0, 0])\n", + "occupied, elite = archive.retrieve_single([0, 0])\n", "_, ax = plt.subplots()\n", - "if elite[\"solution\"] is not None: # This is None if there is no solution for [0,0].\n", + "# `occupied` indicates if there was an elite in the corresponding cell.\n", + "if occupied:\n", " visualize(elite[\"solution\"], link_lengths, elite[\"objective\"], ax)" ] }, diff --git a/tutorials/lunar_lander.ipynb b/tutorials/lunar_lander.ipynb index 55cc081cb..f58597440 100644 --- a/tutorials/lunar_lander.ipynb +++ b/tutorials/lunar_lander.ipynb @@ -746,7 +746,6 @@ "id": "t2QPnuqgFKhr" }, "source": [ - "\n", "We can retrieve policies with measures that are close to a query with the [`retrieve_single`](https://docs.pyribs.org/en/latest/api/ribs.archives.GridArchive.html#ribs.archives.GridArchive.retrieve_single) method. This method will look up the cell corresponding to the queried measures. Then, the method will check if there is an elite in that cell, and return the elite if it exists (the method does not check neighboring cells for elites). The returned elite may not have the exact measures requested because the elite only has to be in the same cell as the queried measures.\n", "\n", "Below, we first retrieve a policy that impacted the ground on the left (approximately -0.4) with low velocity (approximately -0.10) by querying for `[-0.4, -0.10]`." @@ -789,10 +788,9 @@ } ], "source": [ - "elite = archive.retrieve_single([-0.4, -0.10])\n", - "# NaN objective indicates the solution could not be retrieved because there was\n", - "# no elite in the corresponding cell.\n", - "if not np.isnan(elite[\"objective\"]):\n", + "occupied, elite = archive.retrieve_single([-0.4, -0.10])\n", + "# `occupied` indicates if there was an elite in the corresponding cell.\n", + "if occupied:\n", " print(f\"Objective: {elite['objective']}\")\n", " print(f\"Measures: (x-pos: {elite['measures'][0]}, y-vel: {elite['measures'][1]})\")\n", " display_video(elite[\"solution\"])" @@ -848,8 +846,8 @@ } ], "source": [ - "elite = archive.retrieve_single([0.6, -0.10])\n", - "if not np.isnan(elite[\"objective\"]):\n", + "occupied, elite = archive.retrieve_single([0.6, -0.10])\n", + "if occupied:\n", " print(f\"Objective: {elite['objective']}\")\n", " print(f\"Measures: (x-pos: {elite['measures'][0]}, y-vel: {elite['measures'][1]})\")\n", " display_video(elite[\"solution\"])" @@ -901,8 +899,8 @@ } ], "source": [ - "elite = archive.retrieve_single([0.0, -0.10])\n", - "if not np.isnan(elite[\"objective\"]):\n", + "occupied, elite = archive.retrieve_single([0.0, -0.10])\n", + "if occupied:\n", " print(f\"Objective: {elite['objective']}\")\n", " print(f\"Measures: (x-pos: {elite['measures'][0]}, y-vel: {elite['measures'][1]})\")\n", " display_video(elite[\"solution\"])"