diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index 1128d961ff5..1562dc19b53 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -11,9 +11,11 @@ import torch from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import ParamBasedTestProblem from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.observation import ObservationFeatures from ax.core.search_space import SearchSpaceDigest +from ax.core.types import TParamValue from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker @@ -22,6 +24,96 @@ from torch import Tensor +@dataclass(kw_only=True) +class SurrogateTestFunction(ParamBasedTestProblem): + """ + Data-generating function for surrogate benchmark problems. + + Args: + num_objectives: The number of objectives. Defaults to 1. + name: The name of the runner. + outcome_names: Names of outcomes to return in `evaluate_true`, if the + surrogate produces more outcomes than are needed. If None, all + outcomes are returned. + _surrogate: Either `None`, or a `TorchModelBridge` surrogate to use + for generating observations. If `None`, `get_surrogate_and_datasets` + must not be None and will be used to generate the surrogate when it + is needed. + _datasets: Either `None`, or the `SupervisedDataset`s used to fit + the surrogate model. If `None`, `get_surrogate_and_datasets` must + not be None and will be used to generate the datasets when they are + needed. + get_surrogate_and_datasets: Function that returns the surrogate and + datasets, to allow for lazy construction. If + `get_surrogate_and_datasets` is not provided, `surrogate` and + `datasets` must be provided, and vice versa. + """ + + name: str + num_objectives: int = 1 + outcome_names: list[str] | None = None + _surrogate: TorchModelBridge | None = None + _datasets: list[SupervisedDataset] | None = None + get_surrogate_and_datasets: ( + None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]] + ) = None + + def __post_init__(self) -> None: + if self.get_surrogate_and_datasets is None and ( + self._surrogate is None or self._datasets is None + ): + raise ValueError( + "If `get_surrogate_and_datasets` is None, `_surrogate` " + "and `_datasets` must not be None, and vice versa." + ) + if ( + self.outcome_names is not None + and len(self.outcome_names) != self.num_objectives + ): + raise ValueError( + f"Number of outcome names ({len(self.outcome_names)}) must match " + f"number of objectives ({self.num_objectives})." + ) + + def set_surrogate_and_datasets(self) -> None: + self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)() + + @property + def surrogate(self) -> TorchModelBridge: + if self._surrogate is None: + self.set_surrogate_and_datasets() + return none_throws(self._surrogate) + + @property + def datasets(self) -> list[SupervisedDataset]: + if self._datasets is None: + self.set_surrogate_and_datasets() + return none_throws(self._datasets) + + def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor: + # We're ignoring the uncertainty predictions of the surrogate model here and + # use the mean predictions as the outcomes (before potentially adding noise) + means, _ = self.surrogate.predict( + # pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict + observation_features=[ObservationFeatures(params)] + ) + names = list(means.keys()) if self.outcome_names is None else self.outcome_names + means = [means[name][0] for name in names] + return torch.tensor( + means, + device=self.surrogate.device, + dtype=self.surrogate.dtype, + ) + + @equality_typechecker + def __eq__(self, other: Base) -> bool: + if type(other) is not type(self): + return False + + # Don't check surrogate, datasets, or callable + return self.name == other.name + + @dataclass class SurrogateRunner(BenchmarkRunner): """Runner for surrogate benchmark problems. diff --git a/ax/benchmark/tests/runners/test_botorch_test_problem.py b/ax/benchmark/tests/runners/test_botorch_test_problem.py index fa2d9555c29..ac7e593d4e8 100644 --- a/ax/benchmark/tests/runners/test_botorch_test_problem.py +++ b/ax/benchmark/tests/runners/test_botorch_test_problem.py @@ -7,9 +7,10 @@ # pyre-strict +from contextlib import nullcontext from dataclasses import replace from itertools import product -from unittest.mock import Mock +from unittest.mock import Mock, patch import numpy as np @@ -19,13 +20,17 @@ BoTorchTestProblem, ParamBasedTestProblemRunner, ) +from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.core.arm import Arm from ax.core.base_trial import TrialStatus from ax.core.trial import Trial from ax.exceptions.core import UnsupportedError from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast -from ax.utils.testing.benchmark_stubs import TestParamBasedTestProblem +from ax.utils.testing.benchmark_stubs import ( + get_soo_surrogate_test_function, + TestParamBasedTestProblem, +) from botorch.test_functions.synthetic import Ackley, ConstrainedHartmann, Hartmann from botorch.utils.transforms import normalize @@ -87,7 +92,13 @@ def test_synthetic_runner(self) -> None: ) for num_objectives, noise_std in product((1, 2), (None, 0.0, 1.0)) ] - for test_problem, noise_std in botorch_cases + param_based_cases: + surrogate_cases = [ + (get_soo_surrogate_test_function(lazy=False), noise_std) + for noise_std in (None, 0.0, 1.0) + ] + for test_problem, noise_std in ( + botorch_cases + param_based_cases + surrogate_cases + ): num_objectives = test_problem.num_objectives outcome_names = [f"objective_{i}" for i in range(num_objectives)] @@ -148,7 +159,20 @@ def test_synthetic_runner(self) -> None: ) params = dict(zip(param_names, (x.item() for x in X.unbind(-1)))) - Y = runner.get_Y_true(params=params) + with ( + nullcontext() + if not isinstance(test_problem, SurrogateTestFunction) + else patch.object( + # pyre-fixme: ParamBasedTestProblem` has no attribute + # `_surrogate`. + runner.test_problem._surrogate, + "predict", + return_value=({"objective_0": [4.2]}, None), + ) + ): + Y = runner.get_Y_true(params=params) + oracle = runner.evaluate_oracle(parameters=params) + if ( isinstance(test_problem, BoTorchTestProblem) and test_problem.modified_bounds is not None @@ -176,12 +200,13 @@ def test_synthetic_runner(self) -> None: ) else: expected_Y = obj + elif isinstance(test_problem, SurrogateTestFunction): + expected_Y = torch.tensor([4.2], dtype=torch.double) else: expected_Y = torch.full( torch.Size([2]), X.pow(2).sum().item(), dtype=torch.double ) self.assertTrue(torch.allclose(Y, expected_Y)) - oracle = runner.evaluate_oracle(parameters=params) self.assertTrue(np.equal(Y.numpy(), oracle).all()) with self.subTest(f"test `run()`, {test_description}"): @@ -192,11 +217,28 @@ def test_synthetic_runner(self) -> None: trial.arms = [arm] trial.arm = arm trial.index = 0 - res = runner.run(trial=trial) + + with ( + nullcontext() + if not isinstance(test_problem, SurrogateTestFunction) + else patch.object( + # pyre-fixme: ParamBasedTestProblem` has no attribute + # `_surrogate`. + runner.test_problem._surrogate, + "predict", + return_value=({"objective_0": [4.2]}, None), + ) + ): + res = runner.run(trial=trial) self.assertEqual({"Ys", "Ystds", "outcome_names"}, res.keys()) self.assertEqual({"0_0"}, res["Ys"].keys()) - if noise_std is not None: + if isinstance(noise_std, float): self.assertEqual(res["Ystds"]["0_0"], [noise_std] * len(Y)) + elif isinstance(noise_std, dict): + self.assertEqual( + res["Ystds"]["0_0"], + [noise_std[k] for k in runner.outcome_names], + ) else: self.assertEqual(res["Ys"]["0_0"], Y.tolist()) self.assertEqual(res["Ystds"]["0_0"], [0.0] * len(Y)) diff --git a/ax/benchmark/tests/runners/test_surrogate_runner.py b/ax/benchmark/tests/runners/test_surrogate_runner.py index e0046c32433..b59c842e2c3 100644 --- a/ax/benchmark/tests/runners/test_surrogate_runner.py +++ b/ax/benchmark/tests/runners/test_surrogate_runner.py @@ -8,12 +8,83 @@ from unittest.mock import MagicMock, patch import torch -from ax.benchmark.runners.surrogate import SurrogateRunner +from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.testutils import TestCase -from ax.utils.testing.benchmark_stubs import get_soo_surrogate +from ax.utils.testing.benchmark_stubs import ( + get_soo_surrogate_legacy, + get_soo_surrogate_test_function, +) + + +class TestSurrogateTestFunction(TestCase): + def test_surrogate_test_function(self) -> None: + # Construct a search space with log-scale parameters. + for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}): + with self.subTest(noise_std=noise_std): + surrogate = MagicMock() + mock_mean = torch.tensor([[0.1234]], dtype=torch.double) + surrogate.predict = MagicMock(return_value=(mock_mean, 0)) + surrogate.device = torch.device("cpu") + surrogate.dtype = torch.double + test_function = SurrogateTestFunction( + name="test test function", + num_objectives=1, + _surrogate=surrogate, + _datasets=[], + ) + self.assertEqual(test_function.name, "test test function") + self.assertIs(test_function.surrogate, surrogate) + self.assertEqual(test_function.num_objectives, 1) + + def test_lazy_instantiation(self) -> None: + test_function = get_soo_surrogate_test_function() + + self.assertIsNone(test_function._surrogate) + self.assertIsNone(test_function._datasets) + + # Accessing `surrogate` sets datasets and surrogate + self.assertIsInstance(test_function.surrogate, TorchModelBridge) + self.assertIsInstance(test_function._surrogate, TorchModelBridge) + self.assertIsInstance(test_function._datasets, list) + + # Accessing `datasets` also sets datasets and surrogate + test_function = get_soo_surrogate_test_function() + self.assertIsInstance(test_function.datasets, list) + self.assertIsInstance(test_function._surrogate, TorchModelBridge) + self.assertIsInstance(test_function._datasets, list) + + with patch.object( + test_function, + "get_surrogate_and_datasets", + wraps=test_function.get_surrogate_and_datasets, + ) as mock_get_surrogate_and_datasets: + test_function.surrogate + mock_get_surrogate_and_datasets.assert_not_called() + + def test_instantiation_raises_with_missing_args(self) -> None: + with self.assertRaisesRegex( + ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and " + ): + SurrogateTestFunction(name="test runner", num_objectives=1) + + def test_equality(self) -> None: + def _construct_test_function(name: str) -> SurrogateTestFunction: + return SurrogateTestFunction( + name=name, + _surrogate=MagicMock(), + _datasets=[], + num_objectives=1, + ) + + runner_1 = _construct_test_function("test 1") + runner_2 = _construct_test_function("test 2") + runner_1a = _construct_test_function("test 1") + self.assertEqual(runner_1, runner_1a) + self.assertNotEqual(runner_1, runner_2) + self.assertNotEqual(runner_1, 1) class TestSurrogateRunner(TestCase): @@ -49,7 +120,7 @@ def test_surrogate_runner(self) -> None: self.assertEqual(runner.noise_stds, noise_std) def test_lazy_instantiation(self) -> None: - runner = get_soo_surrogate().runner + runner = get_soo_surrogate_legacy().runner self.assertIsNone(runner._surrogate) self.assertIsNone(runner._datasets) @@ -60,7 +131,7 @@ def test_lazy_instantiation(self) -> None: self.assertIsInstance(runner._datasets, list) # Accessing `datasets` also sets datasets and surrogate - runner = get_soo_surrogate().runner + runner = get_soo_surrogate_legacy().runner self.assertIsInstance(runner.datasets, list) self.assertIsInstance(runner._surrogate, TorchModelBridge) self.assertIsInstance(runner._datasets, list) diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 0bf7d49c4dd..6888a1efbaa 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -16,8 +16,11 @@ from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult from ax.benchmark.problems.surrogate import SurrogateBenchmarkProblem -from ax.benchmark.runners.botorch_test import ParamBasedTestProblem -from ax.benchmark.runners.surrogate import SurrogateRunner +from ax.benchmark.runners.botorch_test import ( + ParamBasedTestProblem, + ParamBasedTestProblemRunner, +) +from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction from ax.core.experiment import Experiment from ax.core.objective import MultiObjective, Objective from ax.core.optimization_config import ( @@ -75,7 +78,58 @@ def get_multi_objective_benchmark_problem( ) -def get_soo_surrogate() -> SurrogateBenchmarkProblem: +def get_soo_surrogate_test_function(lazy: bool = True) -> SurrogateTestFunction: + experiment = get_branin_experiment(with_completed_trial=True) + surrogate = TorchModelBridge( + experiment=experiment, + search_space=experiment.search_space, + model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), + data=experiment.lookup_data(), + transforms=[], + ) + if lazy: + test_function = SurrogateTestFunction( + num_objectives=1, + name="test", + get_surrogate_and_datasets=lambda: (surrogate, []), + ) + else: + test_function = SurrogateTestFunction( + num_objectives=1, + name="test", + _surrogate=surrogate, + _datasets=[], + ) + return test_function + + +def get_soo_surrogate() -> BenchmarkProblem: + experiment = get_branin_experiment(with_completed_trial=True) + test_function = get_soo_surrogate_test_function() + runner = ParamBasedTestProblemRunner( + test_problem=test_function, outcome_names=["branin"] + ) + + observe_noise_sd = True + objective = Objective( + metric=BenchmarkMetric( + name="branin", lower_is_better=True, observe_noise_sd=observe_noise_sd + ), + ) + optimization_config = OptimizationConfig(objective=objective) + + return BenchmarkProblem( + name="test", + search_space=experiment.search_space, + optimization_config=optimization_config, + num_trials=6, + observe_noise_stds=observe_noise_sd, + optimal_value=0.0, + runner=runner, + ) + + +def get_soo_surrogate_legacy() -> SurrogateBenchmarkProblem: experiment = get_branin_experiment(with_completed_trial=True) surrogate = TorchModelBridge( experiment=experiment,