Skip to content

Commit

Permalink
Support single fields in ArrayStore.retrieve (#411)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

This adds a shortcut whereby a single array can be retrieved from an
ArrayStore. For instance, `occupied, objective =
store.retrieve("objective")`. This call also extends to the data method,
e.g., `objective = store.retrieve("objective")`.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka authored Nov 9, 2023
1 parent e08c1cc commit 8e13081
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
2 changes: 1 addition & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
- **Backwards-incompatible:** Rename `measure_*` columns to `measures_*` in
`as_pandas` ({pr}`396`)
- Add ArrayStore data structure ({pr}`395`, {pr}`398`, {pr}`400`, {pr}`402`,
{pr}`403`, {pr}`404`, {pr}`406`, {pr}`407`)
{pr}`403`, {pr}`404`, {pr}`406`, {pr}`407`, {pr}`411`)
- Add GradientOperatorEmitter to support OMG-MEGA and OG-MAP-Elites ({pr}`348`)

#### Improvements
Expand Down
33 changes: 22 additions & 11 deletions ribs/archives/_array_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,12 @@ def retrieve(self, indices, fields=None, return_type="dict"):
Args:
indices (array-like): List of indices at which to collect data.
fields (array-like of str): List of fields to include. By default,
all fields will be included, with an additional "index" as the
last field ("index" can also be placed anywhere in this list).
fields (str or array-like of str): List of fields to include. By
default, all fields will be included, with an additional "index"
as the last field ("index" can also be placed anywhere in this
list). This can also be a single str indicating a field name.
return_type (str): Type of data to return. See the ``data`` returned
below.
below. Ignored if ``fields`` is a str.
Returns:
tuple: 2-element tuple consisting of:
Expand All @@ -235,8 +236,10 @@ def retrieve(self, indices, fields=None, return_type="dict"):
not occupied, then the 6.0 returned in the ``dict`` example below
should be ignored.
- **data**: The data at the given indices. This can take the
following forms, depending on the ``return_type`` argument:
- **data**: The data at the given indices. If ``fields`` was a
single str, this will just be an array holding data for the given
field. Otherwise, this data can take the following forms,
depending on the ``return_type`` argument:
- ``return_type="dict"``: Dict mapping from the field name to the
field data at the given indices. For instance, if we have an
Expand Down Expand Up @@ -296,18 +299,24 @@ def retrieve(self, indices, fields=None, return_type="dict"):
ValueError: Invalid field name provided.
ValueError: Invalid return_type provided.
"""
single_field = isinstance(fields, str)
indices = np.asarray(indices, dtype=np.int32)
occupied = self._props["occupied"][indices] # Induces copy.

if return_type in ("dict", "pandas"):
if single_field:
data = None
elif return_type in ("dict", "pandas"):
data = {}
elif return_type == "tuple":
data = []
else:
raise ValueError(f"Invalid return_type {return_type}.")

fields = (itertools.chain(self._fields, ["index"])
if fields is None else fields)
if single_field:
fields = [fields]
elif fields is None:
fields = itertools.chain(self._fields, ["index"])

for name in fields:
# Collect array data.
#
Expand All @@ -321,7 +330,9 @@ def retrieve(self, indices, fields=None, return_type="dict"):
raise ValueError(f"`{name}` is not a field in this ArrayStore.")

# Accumulate data into the return type.
if return_type == "dict":
if single_field:
data = arr
elif return_type == "dict":
data[name] = arr
elif return_type == "tuple":
data.append(arr)
Expand Down Expand Up @@ -351,7 +362,7 @@ def data(self, fields=None, return_type="dict"):
Equivalent to calling :meth:`retrieve` with :attr:`occupied_list`.
Args:
fields (array-like of str): See :meth:`retrieve`.
fields (str or array-like of str): See :meth:`retrieve`.
return_type (str): See :meth:`retrieve`.
Returns:
See ``data`` in :meth:`retrieve`. ``occupied`` is not returned since
Expand Down
18 changes: 18 additions & 0 deletions tests/archives/array_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,24 @@ def test_retrieve_custom_fields(store, return_type):
assert np.all(df["objective"] == [2.0, 1.0])


def test_retrieve_single_field(store):
store.add(
[3, 5],
{
"objective": [1.0, 2.0],
"measures": [[1.0, 2.0], [3.0, 4.0]],
"solution": [np.zeros(10), np.ones(10)],
},
{}, # Empty extra_args.
[], # Empty transforms.
)

occupied, data = store.retrieve([5, 3], fields="objective")

assert np.all(occupied == [True, True])
assert np.all(data == [2.0, 1.0])


def test_add_simple_transform(store):

def obj_meas(indices, new_data, add_info, extra_args, occupied, cur_data):
Expand Down

0 comments on commit 8e13081

Please sign in to comment.