Skip to content

Commit

Permalink
Support custom data fields in scheduler and emitters (#429)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

This PR makes it possible to pass in custom data to schedulers and
emitters. For instance, one can call

```python
scheduler.tell(objective, measures, my_data=[1,2,3], my_other_data=[4,5,6])
```

The same applies for tell_dqd, and also for emitter methods.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Update scheduler signatures
- [x] Update emitter signatures
- [x] Miscellaneous edits to ArrayStore and archives
- [x] Write scheduler test

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka authored Nov 27, 2023
1 parent fede789 commit 74a2602
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 41 deletions.
3 changes: 2 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

#### API

- Support custom data fields in archive ({pr}`421`)
- Support custom data fields in archive, emitters, and scheduler ({pr}`421`,
{pr}`429`)
- **Backwards-incompatible:** Remove `_batch` from parameter names ({pr}`422`,
{pr}`424`, {pr}`425`, {pr}`426`, {pr}`428`)
- Add Gaussian, IsoLine Operators and Refactor GaussianEmitter/IsoLineEmitter
Expand Down
14 changes: 7 additions & 7 deletions ribs/archives/_archive_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def add(self, solution, objective, measures, **fields):
ValueError: ``objective`` or ``measures`` has non-finite values (inf
or NaN).
"""
new_data = validate_batch(
data = validate_batch(
self,
{
"solution": solution,
Expand All @@ -418,8 +418,8 @@ def add(self, solution, objective, measures, **fields):
)

add_info = self._store.add(
self.index_of(new_data["measures"]),
new_data,
self.index_of(data["measures"]),
data,
{
"dtype": self._dtype,
"learning_rate": self._learning_rate,
Expand Down Expand Up @@ -470,7 +470,7 @@ def add_single(self, solution, objective, measures, **fields):
ValueError: ``objective`` is non-finite (inf or NaN) or ``measures``
has non-finite values.
"""
new_data = validate_single(
data = validate_single(
self,
{
"solution": solution,
Expand All @@ -480,12 +480,12 @@ def add_single(self, solution, objective, measures, **fields):
},
)

for name, arr in new_data.items():
new_data[name] = np.expand_dims(arr, axis=0)
for name, arr in data.items():
data[name] = np.expand_dims(arr, axis=0)

add_info = self._store.add(
np.expand_dims(self.index_of_single(measures), axis=0),
new_data,
data,
{
"dtype": self._dtype,
"learning_rate": self._learning_rate,
Expand Down
5 changes: 4 additions & 1 deletion ribs/archives/_array_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,10 @@ def transform(indices, new_data, add_info, extra_args,
if new_data.keys() != self._fields.keys():
raise ValueError(
f"`new_data` had keys {new_data.keys()} but should have the "
f"same keys as this ArrayStore, i.e., {self._fields.keys()}")
f"same keys as this ArrayStore, i.e., {self._fields.keys()}. "
"You may be seeing this error if your archive has "
"extra_fields but the fields were not passed into "
"archive.add() or scheduler.tell().")

# Update occupancy data.
unique_indices = np.where(aggregate(indices, 1, func="len") != 0)[0]
Expand Down
11 changes: 9 additions & 2 deletions ribs/emitters/_emitter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def ask(self):
"""
return np.empty((0, self.solution_dim), dtype=self.archive.dtype)

def tell(self, solution, objective, measures, status_batch, value_batch):
def tell(self, solution, objective, measures, status_batch, value_batch,
**fields):
"""Gives the emitter results from evaluating solutions.
This base class implementation (in :class:`~ribs.emitters.EmitterBase`)
Expand All @@ -123,6 +124,9 @@ def tell(self, solution, objective, measures, status_batch, value_batch):
series of calls to archive's :meth:`add_single()` method or by a
single call to archive's :meth:`add()`. For what these floats
represent, refer to :meth:`ribs.archives.add()`.
fields (keyword arguments): Additional data for each solution. Each
argument should be an array with batch_size as the first
dimension.
"""

def ask_dqd(self):
Expand All @@ -135,7 +139,7 @@ def ask_dqd(self):
return np.empty((0, self.solution_dim), dtype=self.archive.dtype)

def tell_dqd(self, solution, objective, measures, jacobian, status_batch,
value_batch):
value_batch, **fields):
"""Gives the emitter results from evaluating the gradient of the
solutions, only used for DQD emitters.
Expand All @@ -158,4 +162,7 @@ def tell_dqd(self, solution, objective, measures, jacobian, status_batch,
value_batch (numpy.ndarray): 1d array of floats returned by a series
of calls to archive's :meth:`add()` method. for what these
floats represent, refer to :meth:`ribs.archives.add()`.
fields (keyword arguments): Additional data for each solution. Each
argument should be an array with batch_size as the first
dimension.
"""
7 changes: 6 additions & 1 deletion ribs/emitters/_evolution_strategy_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def _check_restart(self, num_parents):
return False
raise ValueError(f"Invalid restart_rule {self._restart_rule}")

def tell(self, solution, objective, measures, status_batch, value_batch):
def tell(self, solution, objective, measures, status_batch, value_batch,
**fields):
"""Gives the emitter results from evaluating solutions.
The solutions are ranked based on the `rank()` function defined by
Expand All @@ -206,13 +207,17 @@ def tell(self, solution, objective, measures, status_batch, value_batch):
value_batch (array-like): 1D array of floats returned by a series
of calls to archive's :meth:`add()` method. For what these
floats represent, refer to :meth:`ribs.archives.add()`.
fields (keyword arguments): Additional data for each solution. Each
argument should be an array with batch_size as the first
dimension.
"""
data, add_info = validate_batch(
self.archive,
{
"solution": solution,
"objective": objective,
"measures": measures,
**fields,
},
{
"status": status_batch,
Expand Down
13 changes: 11 additions & 2 deletions ribs/emitters/_gradient_arborescence_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _check_restart(self, num_parents):
raise ValueError(f"Invalid restart_rule {self._restart_rule}")

def tell_dqd(self, solution, objective, measures, jacobian, status_batch,
value_batch):
value_batch, **fields):
"""Gives the emitter results from evaluating the gradient of the
solutions.
Expand All @@ -334,13 +334,17 @@ def tell_dqd(self, solution, objective, measures, jacobian, status_batch,
value_batch (array-like): 1d array of floats returned by a series
of calls to archive's :meth:`add()` method. for what these
floats represent, refer to :meth:`ribs.archives.add()`.
fields (keyword arguments): Additional data for each solution. Each
argument should be an array with batch_size as the first
dimension.
"""
data, add_info, jacobian = validate_batch( # pylint: disable = unused-variable
self.archive,
{
"solution": solution,
"objective": objective,
"measures": measures,
**fields,
},
{
"status": status_batch,
Expand All @@ -355,7 +359,8 @@ def tell_dqd(self, solution, objective, measures, jacobian, status_batch,
jacobian /= norms
self._jacobian_batch = jacobian

def tell(self, solution, objective, measures, status_batch, value_batch):
def tell(self, solution, objective, measures, status_batch, value_batch,
**fields):
"""Gives the emitter results from evaluating solutions.
The solutions are ranked based on the `rank()` function defined by
Expand All @@ -374,6 +379,9 @@ def tell(self, solution, objective, measures, status_batch, value_batch):
value_batch (array-like): 1d array of floats returned by a series
of calls to archive's :meth:`add()` method. for what these
floats represent, refer to :meth:`ribs.archives.add()`.
fields (keyword arguments): Additional data for each solution. Each
argument should be an array with batch_size as the first
dimension.
Raises:
RuntimeError: This method was called without first passing gradients
with calls to ask_dqd() and tell_dqd().
Expand All @@ -384,6 +392,7 @@ def tell(self, solution, objective, measures, status_batch, value_batch):
"solution": solution,
"objective": objective,
"measures": measures,
**fields,
},
{
"status": status_batch,
Expand Down
6 changes: 5 additions & 1 deletion ribs/emitters/_gradient_operator_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def ask(self):
return sols

def tell_dqd(self, solution, objective, measures, jacobian, status_batch,
value_batch):
value_batch, **fields):
"""Gives the emitter results of evaluating solutions from ask_dqd().
Args:
Expand All @@ -310,13 +310,17 @@ def tell_dqd(self, solution, objective, measures, jacobian, status_batch,
value_batch (array-like): 1d array of floats returned by a series
of calls to archive's :meth:`add()` method. for what these
floats represent, refer to :meth:`ribs.archives.add()`.
fields (keyword arguments): Additional data for each solution. Each
argument should be an array with batch_size as the first
dimension.
"""
data, add_info, jacobian = validate_batch( # pylint: disable = unused-variable
self.archive,
{
"solution": solution,
"objective": objective,
"measures": measures,
**fields,
},
{
"status": status_batch,
Expand Down
24 changes: 16 additions & 8 deletions ribs/schedulers/_bandit_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,22 +305,25 @@ def tell_dqd(self, objective, measures, jacobian):
raise NotImplementedError("tell_dqd() is not supported by"
"BanditScheduler.")

def tell(self, objective, measures):
def tell(self, objective, measures, **fields):
"""Returns info for solutions from :meth:`ask`.
The emitters are the same with those used in the last call to
:meth:`ask`.
.. note:: The objective batch and measures batch must be in the same
order as the solutions created by :meth:`ask_dqd`; i.e.
``objective_batch[i]`` and ``measures_batch[i]`` should be the
objective and measures for ``solution_batch[i]``.
.. note:: The objective and measures arrays must be in the same order as
the solutions created by :meth:`ask_dqd`; i.e. ``objective[i]`` and
``measures[i]`` should be the objective and measures for
``solution[i]``.
Args:
objective_batch ((batch_size,) array): Each entry of this array
contains the objective function evaluation of a solution.
measures_batch ((batch_size, measures_dm) array): Each row of
this array contains a solution's coordinates in measure space.
fields (keyword arguments): Additional data for each solution. Each
argument should be an array with batch_size as the first
dimension.
Raises:
RuntimeError: This method is called without first calling
:meth:`ask`.
Expand All @@ -333,6 +336,7 @@ def tell(self, objective, measures):
data = self._validate_tell_data({
"objective": objective,
"measures": measures,
**fields,
})

archive_empty_before = self.archive.empty
Expand Down Expand Up @@ -382,7 +386,11 @@ def tell(self, objective, measures):
end = pos + n
self._selection[i] += n
self._success[i] += np.count_nonzero(status_batch[pos:end])
emitter.tell(self._cur_solutions[pos:end],
data["objective"][pos:end], data["measures"][pos:end],
status_batch[pos:end], value_batch[pos:end])
emitter.tell(
**{
name: arr[pos:end] for name, arr in data.items()
},
status_batch=status_batch[pos:end],
value_batch=value_batch[pos:end],
)
pos = end
50 changes: 33 additions & 17 deletions ribs/schedulers/_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,13 @@ def _add_to_archives(self, data):

return status_batch, value_batch

def tell_dqd(self, objective, measures, jacobian):
def tell_dqd(self, objective, measures, jacobian, **fields):
"""Returns info for solutions from :meth:`ask_dqd`.
.. note:: The objective batch, measures batch, and jacobian batch must
be in the same order as the solutions created by :meth:`ask_dqd`;
i.e. ``objective[i]``, ``measures[i]``, and ``jacobian[i]`` should
be the objective, measures, and jacobian for ``solution[i]``.
.. note:: The objective, measures, and jacobian arrays must be in the
same order as the solutions created by :meth:`ask_dqd`; i.e.
``objective[i]``, ``measures[i]``, and ``jacobian[i]`` should be the
objective, measures, and jacobian for ``solution[i]``.
Args:
objective ((batch_size,) array): Each entry of this array contains
Expand All @@ -282,6 +282,9 @@ def tell_dqd(self, objective, measures, jacobian):
solutions obtained from :meth:`ask_dqd`. Each matrix should
consist of the objective gradient of the solution followed by
the measure gradients.
fields (keyword arguments): Additional data for each solution. Each
argument should be an array with batch_size as the first
dimension.
Raises:
RuntimeError: This method is called without first calling
:meth:`ask`.
Expand All @@ -295,6 +298,7 @@ def tell_dqd(self, objective, measures, jacobian):
data = self._validate_tell_data({
"objective": objective,
"measures": measures,
**fields,
})

jacobian = np.asarray(jacobian)
Expand All @@ -306,25 +310,32 @@ def tell_dqd(self, objective, measures, jacobian):
pos = 0
for emitter, n in zip(self._emitters, self._num_emitted):
end = pos + n
emitter.tell_dqd(self._cur_solutions[pos:end],
data["objective"][pos:end],
data["measures"][pos:end], jacobian[pos:end],
status_batch[pos:end], value_batch[pos:end])
emitter.tell_dqd(
**{
name: arr[pos:end] for name, arr in data.items()
},
jacobian=jacobian[pos:end],
status_batch=status_batch[pos:end],
value_batch=value_batch[pos:end],
)
pos = end

def tell(self, objective, measures):
def tell(self, objective, measures, **fields):
"""Returns info for solutions from :meth:`ask`.
.. note:: The objective batch and measures batch must be in the same
order as the solutions created by :meth:`ask_dqd`; i.e.
``objective[i]`` and ``measures[i]`` should be the objective and
measures for ``solution[i]``.
.. note:: The objective and measures arrays must be in the same order as
the solutions created by :meth:`ask_dqd`; i.e. ``objective[i]`` and
``measures[i]`` should be the objective and measures for
``solution[i]``.
Args:
objective ((batch_size,) array): Each entry of this array contains
the objective function evaluation of a solution.
measures ((batch_size, measures_dm) array): Each row of this array
contains a solution's coordinates in measure space.
fields (keyword arguments): Additional data for each solution. Each
argument should be an array with batch_size as the first
dimension.
Raises:
RuntimeError: This method is called without first calling
:meth:`ask`.
Expand All @@ -337,6 +348,7 @@ def tell(self, objective, measures):
data = self._validate_tell_data({
"objective": objective,
"measures": measures,
**fields,
})

status_batch, value_batch = self._add_to_archives(data)
Expand All @@ -345,7 +357,11 @@ def tell(self, objective, measures):
pos = 0
for emitter, n in zip(self._emitters, self._num_emitted):
end = pos + n
emitter.tell(self._cur_solutions[pos:end],
data["objective"][pos:end], data["measures"][pos:end],
status_batch[pos:end], value_batch[pos:end])
emitter.tell(
**{
name: arr[pos:end] for name, arr in data.items()
},
status_batch=status_batch[pos:end],
value_batch=value_batch[pos:end],
)
pos = end
7 changes: 6 additions & 1 deletion tests/archives/grid_archive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def assert_archive_elites(
objective_batch=None,
measures_batch=None,
grid_indices_batch=None,
metadata_batch=None,
):
"""Asserts that the archive contains a batch of elites.
Expand Down Expand Up @@ -63,8 +64,12 @@ def assert_archive_elites(
index_match = (grid_indices_batch is None or
data["index"][j] == index_batch[i])

# Used for testing custom fields.
metadata_match = (metadata_batch is None or
data["metadata"][j] == metadata_batch[i])

if (solution_match and objective_match and measures_match and
index_match):
index_match and metadata_match):
archive_covered[j] = True

assert np.all(archive_covered)
Expand Down
Loading

0 comments on commit 74a2602

Please sign in to comment.