Skip to content

Commit

Permalink
Update in other emitters
Browse files Browse the repository at this point in the history
  • Loading branch information
btjanaka committed Nov 27, 2023
1 parent eee3176 commit 91dfebd
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 237 deletions.
169 changes: 56 additions & 113 deletions ribs/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def check_solution_batch_dim(array,
f"shape {array.shape}.{extra_msg}")


def validate_batch(archive, data):
def validate_batch(archive, data, add_info=None, jacobian=None):
"""Preprocesses and validates batch arguments.
``data`` is a dict containing arrays with the data of each solution, e.g.,
Expand All @@ -107,10 +107,14 @@ def validate_batch(archive, data):
data["solution"] = np.asarray(data["solution"])
check_batch_shape(data["solution"], "solution", archive.solution_dim,
"solution_dim", "")

# Process and validate the other batch arguments.
batch_size = data["solution"].shape[0]

# Process and validate the other data.
for name, arr in data.items():
if name == "solution":
# Already checked above.
continue

if name == "objective":
arr = np.asarray(arr)
check_is_1d(arr, "objective", "")
Expand All @@ -132,32 +136,6 @@ def validate_batch(archive, data):
extra_msg="")
check_finite(arr, "measures")

elif name == "jacobian":
arr = np.asarray(arr)
check_batch_shape(arr, "jacobian",
(archive.measure_dim + 1, archive.solution_dim),
"measure_dim + 1, solution_dim")
check_finite(arr, "jacobian")

elif name == "status":
arr = np.asarray(arr)
check_is_1d(arr, "status", "")
check_solution_batch_dim(arr,
"status",
batch_size,
is_1d=True,
extra_msg="")
check_finite(arr, "status")

elif name == "value":
arr = np.asarray(arr)
check_is_1d(arr, "value", "")
check_solution_batch_dim(arr,
"value",
batch_size,
is_1d=True,
extra_msg="")

else:
arr = np.asarray(arr)
check_solution_batch_dim(arr,
Expand All @@ -168,97 +146,62 @@ def validate_batch(archive, data):

data[name] = arr

return data
extra_returns = []

# add_info is optional; check it if provided.
if add_info is not None:
for name, arr in add_info.items():
if name == "status":
arr = np.asarray(arr)
check_is_1d(arr, "status", "")
check_solution_batch_dim(arr,
"status",
batch_size,
is_1d=True,
extra_msg="")
check_finite(arr, "status")

elif name == "value":
arr = np.asarray(arr)
check_is_1d(arr, "value", "")
check_solution_batch_dim(arr,
"value",
batch_size,
is_1d=True,
extra_msg="")

else:
arr = np.asarray(arr)
check_solution_batch_dim(arr,
name,
batch_size,
is_1d=False,
extra_msg="")

add_info[name] = arr

extra_returns.append(add_info)

# jacobian is optional; check it if provided.
if jacobian is not None:
jacobian = np.asarray(jacobian)
check_batch_shape(jacobian, "jacobian",
(archive.measure_dim + 1, archive.solution_dim),
"measure_dim + 1, solution_dim")
check_finite(jacobian, "jacobian")
extra_returns.append(jacobian)

if extra_returns:
return data, *extra_returns
else:
return data


_BATCH_WARNING = (" Note that starting in pyribs 0.5.0, add() and tell() take"
" in a batch of solutions unlike in pyribs 0.4.0, where add()"
" and tell() only took in a single solution.")


def validate_batch_args(archive, solution_batch, **batch_kwargs):
"""Preprocesses and validates batch arguments.
The batch size of each argument in batch_kwargs is validated with respect to
solution_batch.
The arguments are assumed to come directly from users, so they may not be
arrays. Thus, we preprocess each argument by converting it into a numpy
array. We then perform checks on the array, including seeing if its batch
size matches the batch size of solution_batch. The arguments are then
returned in the same order that they were passed into the kwargs, with
solution_batch coming first.
Note that we can guarantee the order is the same as when passed in due to
PEP 468 (https://peps.python.org/pep-0468/), which guarantees that kwargs
will preserve the same order as they are listed.
See the for loop for the list of supported kwargs.
"""
# List of args to return.
returns = []

# Process and validate solution_batch.
solution_batch = np.asarray(solution_batch)
check_batch_shape(solution_batch, "solution_batch", archive.solution_dim,
"solution_dim", _BATCH_WARNING)
returns.append(solution_batch)

# Process and validate the other batch arguments.
batch_size = solution_batch.shape[0]
for name, arg in batch_kwargs.items():
if name == "objective_batch":
objective_batch = np.asarray(arg)
check_is_1d(objective_batch, "objective_batch", _BATCH_WARNING)
check_solution_batch_dim(objective_batch,
"objective_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
check_finite(objective_batch, "objective_batch")
returns.append(objective_batch)
elif name == "measures_batch":
measures_batch = np.asarray(arg)
check_batch_shape(measures_batch, "measures_batch",
archive.measure_dim, "measure_dim",
_BATCH_WARNING)
check_solution_batch_dim(measures_batch,
"measures_batch",
batch_size,
is_1d=False,
extra_msg=_BATCH_WARNING)
check_finite(measures_batch, "measures_batch")
returns.append(measures_batch)
elif name == "jacobian_batch":
jacobian_batch = np.asarray(arg)
check_batch_shape(jacobian_batch, "jacobian_batch",
(archive.measure_dim + 1, archive.solution_dim),
"measure_dim + 1, solution_dim")
check_finite(jacobian_batch, "jacobian_batch")
returns.append(jacobian_batch)
elif name == "status_batch":
status_batch = np.asarray(arg)
check_is_1d(status_batch, "status_batch", _BATCH_WARNING)
check_solution_batch_dim(status_batch,
"status_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
check_finite(status_batch, "status_batch")
returns.append(status_batch)
elif name == "value_batch":
value_batch = np.asarray(arg)
check_is_1d(value_batch, "value_batch", _BATCH_WARNING)
check_solution_batch_dim(value_batch,
"value_batch",
batch_size,
is_1d=True,
extra_msg=_BATCH_WARNING)
returns.append(value_batch)

return returns


def validate_single(archive, data):
"""Performs preprocessing and checks for arguments to add_single()."""
data["solution"] = np.asarray(data["solution"])
Expand Down
51 changes: 22 additions & 29 deletions ribs/emitters/_evolution_strategy_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from ribs._utils import check_shape, validate_batch_args
from ribs._utils import check_shape, validate_batch
from ribs.emitters._emitter_base import EmitterBase
from ribs.emitters.opt import _get_es
from ribs.emitters.rankers import _get_ranker
Expand Down Expand Up @@ -183,8 +183,7 @@ def _check_restart(self, num_parents):
return False
raise ValueError(f"Invalid restart_rule {self._restart_rule}")

def tell(self, solution_batch, objective_batch, measures_batch,
status_batch, value_batch):
def tell(self, solution, objective, measures, status_batch, value_batch):
"""Gives the emitter results from evaluating solutions.
The solutions are ranked based on the `rank()` function defined by
Expand All @@ -195,34 +194,30 @@ def tell(self, solution_batch, objective_batch, measures_batch,
when needed.
Args:
solution_batch (array-like): (batch_size, :attr:`solution_dim`)
array of solutions generated by this emitter's :meth:`ask()`
method.
objective_batch (array-like): 1D array containing the objective
function value of each solution.
measures_batch (array-like): (batch_size, measure space
dimension) array with the measure space coordinates of each
solution.
solution (array-like): (batch_size, :attr:`solution_dim`) array of
solutions generated by this emitter's :meth:`ask()` method.
objective (array-like): 1D array containing the objective function
value of each solution.
measures (array-like): (batch_size, measure space dimension) array
with the measure space coordinates of each solution.
status_batch (array-like): 1D array of
:class:`ribs.archive.AddStatus` returned by a series of calls
to archive's :meth:`add()` method.
value_batch (array-like): 1D array of floats returned by a series
of calls to archive's :meth:`add()` method. For what these
floats represent, refer to :meth:`ribs.archives.add()`.
"""
(
solution_batch,
objective_batch,
measures_batch,
status_batch,
value_batch,
) = validate_batch_args(
archive=self.archive,
solution_batch=solution_batch,
objective_batch=objective_batch,
measures_batch=measures_batch,
status_batch=status_batch,
value_batch=value_batch,
data, add_info = validate_batch(
self.archive,
{
"solution": solution,
"objective": objective,
"measures": measures,
},
{
"status": status_batch,
"value": value_batch,
},
)

# Increase iteration counter.
Expand All @@ -232,11 +227,9 @@ def tell(self, solution_batch, objective_batch, measures_batch,
new_sols = status_batch.astype(bool).sum()

# Sort the solutions using ranker.
indices, ranking_values = self._ranker.rank(self, self.archive,
self._rng, solution_batch,
objective_batch,
measures_batch,
status_batch, value_batch)
indices, ranking_values = self._ranker.rank(
self, self.archive, self._rng, data["solution"], data["objective"],
data["measures"], add_info["status"], add_info["value"])

# Select the number of parents.
num_parents = (new_sols if self._selection_rule == "filter" else
Expand Down
Loading

0 comments on commit 91dfebd

Please sign in to comment.