From fefc406a71919d545302b44e7616a5d9469c1bc8 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Tue, 29 Sep 2020 11:29:50 -0700 Subject: [PATCH] `ListSurrogate` Summary: `Surrogate` for `ModelListGP`: constructs a `ModelListGP` with specified sub-model classes under the hood Usage example: ``` m = BoTorchModel( surrogate=ListSurrogate( # Dictionary of which model class should be used to model which outcome. botorch_model_class_per_outcome={"m": SingleTaskGP, "n": SingleTaskGP} ) ) ``` D23606005 then introduces default usage of this for cases where there are multiple X-s in Xs and they are not all the same (so we are not using batched multi-output). Reviewed By: Balandat Differential Revision: D23382535 fbshipit-source-id: ff524c658e09a8837f71646f309289cde4da8c33 --- ax/models/tests/test_botorch_model.py | 50 ++--- .../torch/botorch_modular/acquisition.py | 106 +++++++++-- .../torch/botorch_modular/list_surrogate.py | 142 ++++++++++++++ ax/models/torch/botorch_modular/surrogate.py | 34 ++-- ax/models/torch/tests/test_acquisition.py | 20 +- ax/models/torch/tests/test_list_surrogate.py | 177 ++++++++++++++++++ ax/models/torch/tests/test_model.py | 23 ++- ax/models/torch/tests/test_surrogate.py | 30 ++- ax/utils/common/constants.py | 2 + ax/utils/testing/torch_stubs.py | 29 +++ 10 files changed, 519 insertions(+), 94 deletions(-) create mode 100644 ax/models/torch/botorch_modular/list_surrogate.py create mode 100644 ax/models/torch/tests/test_list_surrogate.py create mode 100644 ax/utils/testing/torch_stubs.py diff --git a/ax/models/tests/test_botorch_model.py b/ax/models/tests/test_botorch_model.py index 7fe9f37de39..532bd675dd4 100644 --- a/ax/models/tests/test_botorch_model.py +++ b/ax/models/tests/test_botorch_model.py @@ -5,16 +5,18 @@ # LICENSE file in the root directory of this source tree. from itertools import chain -from typing import Dict from unittest import mock import torch from ax.models.torch.botorch import BotorchModel, get_rounding_func from ax.models.torch.botorch_defaults import ( get_and_fit_model, + get_chebyshev_scalarization, recommend_best_out_of_sample_point, ) +from ax.models.torch.utils import sample_simplex from ax.utils.common.testutils import TestCase +from ax.utils.testing.torch_stubs import get_torch_test_data from botorch.acquisition.utils import get_infeasible_cost from botorch.models import FixedNoiseGP, ModelListGP from botorch.utils import get_objective_weights_transform @@ -23,11 +25,11 @@ from gpytorch.priors.lkj_prior import LKJCovariancePrior -FIT_MODEL_MO_PATH = "ax.models.torch.botorch_defaults.fit_gpytorch_model" -SAMPLE_SIMPLEX_UTIL_PATH = "ax.models.torch.utils.sample_simplex" -SAMPLE_HYPERSPHERE_UTIL_PATH = "ax.models.torch.utils.sample_hypersphere" +FIT_MODEL_MO_PATH = f"{get_and_fit_model.__module__}.fit_gpytorch_model" +SAMPLE_SIMPLEX_UTIL_PATH = f"{sample_simplex.__module__}.sample_simplex" +SAMPLE_HYPERSPHERE_UTIL_PATH = f"{sample_simplex.__module__}.sample_hypersphere" CHEBYSHEV_SCALARIZATION_PATH = ( - "ax.models.torch.botorch_defaults.get_chebyshev_scalarization" + f"{get_chebyshev_scalarization.__module__}.get_chebyshev_scalarization" ) @@ -35,32 +37,12 @@ def dummy_func(X: torch.Tensor) -> torch.Tensor: return X -def _get_optimizer_kwargs() -> Dict[str, int]: - return {"num_restarts": 2, "raw_samples": 2, "maxiter": 2, "batch_limit": 1} - - -def _get_torch_test_data( - dtype=torch.float, cuda=False, constant_noise=True, task_features=None -): - device = torch.device("cuda") if cuda else torch.device("cpu") - Xs = [torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], dtype=dtype, device=device)] - Ys = [torch.tensor([[3.0], [4.0]], dtype=dtype, device=device)] - Yvars = [torch.tensor([[0.0], [2.0]], dtype=dtype, device=device)] - if constant_noise: - Yvars[0].fill_(1.0) - bounds = [(0.0, 1.0), (1.0, 4.0), (2.0, 5.0)] - feature_names = ["x1", "x2", "x3"] - task_features = [] if task_features is None else task_features - metric_names = ["y", "r"] - return Xs, Ys, Yvars, bounds, task_features, feature_names, metric_names - - class BotorchModelTest(TestCase): def test_fixed_rank_BotorchModel(self, dtype=torch.float, cuda=False): - Xs1, Ys1, Yvars1, bounds, _, fns, __package__ = _get_torch_test_data( + Xs1, Ys1, Yvars1, bounds, _, fns, __package__ = get_torch_test_data( dtype=dtype, cuda=cuda, constant_noise=True ) - Xs2, Ys2, Yvars2, _, _, _, _ = _get_torch_test_data( + Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data( dtype=dtype, cuda=cuda, constant_noise=True ) model = BotorchModel(multitask_gp_ranks={"y": 2, "w": 1}) @@ -84,10 +66,10 @@ def test_fixed_rank_BotorchModel(self, dtype=torch.float, cuda=False): self.assertEqual(model_list[1]._rank, 1) def test_fixed_prior_BotorchModel(self, dtype=torch.float, cuda=False): - Xs1, Ys1, Yvars1, bounds, _, fns, __package__ = _get_torch_test_data( + Xs1, Ys1, Yvars1, bounds, _, fns, __package__ = get_torch_test_data( dtype=dtype, cuda=cuda, constant_noise=True ) - Xs2, Ys2, Yvars2, _, _, _, _ = _get_torch_test_data( + Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data( dtype=dtype, cuda=cuda, constant_noise=True ) kwargs = { @@ -131,10 +113,10 @@ def test_fixed_prior_BotorchModel(self, dtype=torch.float, cuda=False): ) def test_BotorchModel(self, dtype=torch.float, cuda=False): - Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = _get_torch_test_data( + Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = get_torch_test_data( dtype=dtype, cuda=cuda, constant_noise=True ) - Xs2, Ys2, Yvars2, _, _, _, _ = _get_torch_test_data( + Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data( dtype=dtype, cuda=cuda, constant_noise=True ) model = BotorchModel() @@ -457,7 +439,7 @@ def test_BotorchModel_double_cuda(self): self.test_BotorchModel(dtype=torch.double, cuda=True) def test_BotorchModelOneOutcome(self): - Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = _get_torch_test_data( + Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = get_torch_test_data( dtype=torch.float, cuda=False, constant_noise=True ) model = BotorchModel() @@ -479,10 +461,10 @@ def test_BotorchModelOneOutcome(self): self.assertTrue(f_cov.shape == torch.Size([2, 1, 1])) def test_BotorchModelConstraints(self): - Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = _get_torch_test_data( + Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = get_torch_test_data( dtype=torch.float, cuda=False, constant_noise=True ) - Xs2, Ys2, Yvars2, _, _, _, _ = _get_torch_test_data( + Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data( dtype=torch.float, cuda=False, constant_noise=True ) # make infeasible diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index dadad678ef9..4090a19f3e3 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -6,9 +6,10 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from ax.core.types import TConfig +from ax.models.torch.botorch_modular.list_surrogate import ListSurrogate from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.torch.utils import ( _get_X_pending_and_observed, @@ -18,9 +19,11 @@ from ax.utils.common.constants import Keys from ax.utils.common.docutils import copy_doc from ax.utils.common.equality import Base -from ax.utils.common.typeutils import not_none +from ax.utils.common.typeutils import checked_cast, not_none from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.analytic import AnalyticAcquisitionFunction +from botorch.acquisition.objective import AcquisitionObjective +from botorch.models.model import Model from botorch.optim.optimize import optimize_acqf from botorch.utils.containers import TrainingData from torch import Tensor @@ -75,6 +78,9 @@ class Acquisition(Base): # class by default. `None` for the base `Acquisition` class, but can be # specified in subclasses. default_botorch_acqf_class: Optional[Type[AcquisitionFunction]] = None + # BoTorch `AcquisitionFunction` class associated with this `Acquisition` + # instance. Determined during `__init__`, do not set manually. + _botorch_acqf_class: Type[AcquisitionFunction] def __init__( self, @@ -95,13 +101,15 @@ def __init__( "BoTorch `AcquisitionFunction`, so `botorch_acqf_class` " "argument must be specified." ) - botorch_acqf_class = not_none( + self._botorch_acqf_class = not_none( botorch_acqf_class or self.default_botorch_acqf_class ) self.surrogate = surrogate self.options = options or {} + trd = self._extract_training_data(surrogate=surrogate) + Xs = [trd.X] if isinstance(trd, TrainingData) else [i.X for i in trd.values()] X_pending, X_observed = _get_X_pending_and_observed( - Xs=[self.surrogate.training_data.X], + Xs=Xs, pending_observations=pending_observations, objective_weights=objective_weights, outcome_constraints=outcome_constraints, @@ -120,17 +128,12 @@ def __init__( else: model = self.surrogate.model - objective = get_botorch_objective( + objective = self._get_botorch_objective( model=model, objective_weights=objective_weights, outcome_constraints=outcome_constraints, X_observed=X_observed, - use_scalarized_objective=issubclass( - botorch_acqf_class, AnalyticAcquisitionFunction - ), ) - # NOTE: Computing model dependencies might be handled entirely on - # BoTorch side. model_deps = self.compute_model_dependencies( surrogate=surrogate, bounds=bounds, @@ -142,19 +145,16 @@ def __init__( target_fidelities=target_fidelities, options=self.options, ) - data_deps = self.compute_data_dependencies( - training_data=self.surrogate.training_data - ) # pyre-ignore[28]: Some kwargs are not expected in base `Model` # but are expected in its subclasses. - self.acqf = botorch_acqf_class( + self.acqf = self._botorch_acqf_class( model=model, objective=objective, X_pending=X_pending, X_baseline=X_observed, **self.options, **model_deps, - **data_deps, + **self.compute_data_dependencies(training_data=trd), ) def optimize( @@ -171,7 +171,7 @@ def optimize( candidates and their associated acquisition function values. """ optimizer_options = optimizer_options or {} - # TODO: make use of `optimizer_class` when its added to BoTorch. + # NOTE: Could make use of `optimizer_class` when it's added to BoTorch. return optimize_acqf( self.acqf, bounds=bounds, @@ -236,17 +236,85 @@ def compute_model_dependencies( """Computes inputs to acquisition function class based on the given surrogate model. - NOTE: May not be needed if model dependencies are handled entirely on - the BoTorch side. + NOTE: When subclassing `Acquisition` from a superclass where this + method returns a non-empty dictionary of kwargs to `AcquisitionFunction`, + call `super().compute_model_dependencies` and then update that + dictionary of options with the options for the subclass you are creating + (unless the superclass' model dependencies should not be propagated to + the subclass). See `MultiFidelityKnowledgeGradient.compute_model_dependencies` + for an example. + + Args: + surrogate: The surrogate object containing the BoTorch `Model`, + with which this `Acquisition` is to be used. + bounds: A list of (lower, upper) tuples for each column of X in + the training data of the surrogate model. + objective_weights: The objective is to maximize a weighted sum of + the columns of f(x). These are the weights. + pending_observations: A list of tensors, each of which contains + points whose evaluation is pending (i.e. that have been + submitted for evaluation) for a given outcome. A list + of m (k_i x d) feature tensors X for m outcomes and k_i, + pending observations for outcome i. + outcome_constraints: A tuple of (A, b). For k outcome constraints + and m outputs at f(x), A is (k x m) and b is (k x 1) such that + A f(x) <= b. (Not used by single task models) + linear_constraints: A tuple of (A, b). For k linear constraints on + d-dimensional x, A is (k x d) and b is (k x 1) such that + A x <= b. (Not used by single task models) + fixed_features: A map {feature_index: value} for features that + should be fixed to a particular value during generation. + target_fidelities: Optional mapping from parameter name to its + target fidelity, applicable to fidelity parameters only. + options: The `options` kwarg dict, passed on initialization of + the `Acquisition` object. + + Returns: A dictionary of surrogate model-dependent options, to be passed + as kwargs to BoTorch`AcquisitionFunction` constructor. """ return {} @classmethod - def compute_data_dependencies(cls, training_data: TrainingData) -> Dict[str, Any]: + def compute_data_dependencies( + cls, training_data: Union[TrainingData, Dict[str, TrainingData]] + ) -> Dict[str, Any]: """Computes inputs to acquisition function class based on the given data in model's training data. NOTE: May not be needed if model dependencies are handled entirely on the BoTorch side. + + Args: + training_data: Either a `TrainingData` for 1 outcome, or a mapping of + outcome name to respective `TrainingData` (if `ListSurrogate` is used). + + Returns: A dictionary of training data-dependent options, to be passed + as kwargs to BoTorch`AcquisitionFunction` constructor. """ return {} + + def _get_botorch_objective( + self, + model: Model, + objective_weights: Tensor, + outcome_constraints: Optional[Tuple[Tensor, Tensor]] = None, + X_observed: Optional[Tensor] = None, + ) -> AcquisitionObjective: + return get_botorch_objective( + model=model, + objective_weights=objective_weights, + use_scalarized_objective=issubclass( + self._botorch_acqf_class, AnalyticAcquisitionFunction + ), + outcome_constraints=outcome_constraints, + X_observed=X_observed, + ) + + @classmethod + def _extract_training_data( + cls, surrogate: Surrogate + ) -> Union[TrainingData, Dict[str, TrainingData]]: + if isinstance(surrogate, ListSurrogate): + return checked_cast(dict, surrogate.training_data_per_outcome) + else: + return checked_cast(TrainingData, surrogate.training_data) diff --git a/ax/models/torch/botorch_modular/list_surrogate.py b/ax/models/torch/botorch_modular/list_surrogate.py new file mode 100644 index 00000000000..990ad120030 --- /dev/null +++ b/ax/models/torch/botorch_modular/list_surrogate.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from ax.core.types import TCandidateMetadata +from ax.models.torch.botorch_modular.surrogate import NOT_YET_FIT_MSG, Surrogate +from ax.utils.common.constants import Keys +from ax.utils.common.typeutils import not_none +from botorch.models.model import Model, TrainingData +from botorch.models.model_list_gp_regression import ModelListGP +from gpytorch.kernels import Kernel +from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood + + +class ListSurrogate(Surrogate): + """Special type of `Surrogate` that wraps a set of submodels into + `ModelListGP` under the hood for multi-outcome or multi-task + models. + + Args: + botorch_model_class_per_outcome: Mapping from metric name to + BoTorch model class that should be used as surrogate model for + that metric. + submodel_options_per_outcome: Optional mapping from metric name to + dictionary of kwargs for the submodel for that outcome. + mll_class: `MarginalLogLikelihood` class to use for model-fitting. + """ + + botorch_model_class_per_outcome: Dict[str, Type[Model]] + mll_class: Type[MarginalLogLikelihood] + kernel_class: Optional[Type[Kernel]] = None + _training_data_per_outcome: Optional[Dict[str, TrainingData]] = None + _model: Optional[Model] = None + # Special setting for surrogates instantiated via `Surrogate.from_BoTorch`, + # to avoid re-constructing the underlying BoTorch model on `Surrogate.fit` + # when set to `False`. + _should_reconstruct: bool = True + + def __init__( + self, + botorch_model_class_per_outcome: Dict[str, Type[Model]], + submodel_options_per_outcome: Optional[Dict[str, Dict[str, Any]]] = None, + mll_class: Type[MarginalLogLikelihood] = SumMarginalLogLikelihood, + ) -> None: + self.botorch_model_class_per_outcome = botorch_model_class_per_outcome + self.submodel_options_per_outcome = submodel_options_per_outcome + super().__init__(botorch_model_class=ModelListGP, mll_class=mll_class) + + @property + def training_data(self) -> TrainingData: + raise NotImplementedError( + f"{self.__class__.__name__} doesn't implement `training_data`, " + "use `training_data_per_outcome`." + ) + + @property + def training_data_per_outcome(self) -> Dict[str, TrainingData]: + if self._training_data_per_outcome is None: + raise ValueError(NOT_YET_FIT_MSG) + return not_none(self._training_data_per_outcome) + + # pyre-ignore[14]: `construct` takes in list of training data in list surrogate, + # whereas it takes just a single training data in base surrogate. + def construct(self, training_data: List[TrainingData], **kwargs: Any) -> None: + """Constructs the underlying BoTorch `Model` using the training data. + + Args: + training_data: List of Training data for the submodels of `ModelListGP`. + Each training data is for one outcome, and the order of outcomes + should match the order of metrics in `metric_names` argument. + **kwargs: Keyword arguments, expects all of: + - `fidelity_features`: Indices of columns in X that represent + fidelity. + - `task_features`: Indices of columns in X that represent tasks. + - `metric_names`: Names of metrics, in the same order as training + data (so it training data is `[tr_A, tr_B]`, the metrics would be + `["A" and "B"]`). These are used to match training data with correct + submodels of `ModelListGP`. + """ + metric_names = kwargs.get(Keys.METRIC_NAMES) + fidelity_features = kwargs.get(Keys.FIDELITY_FEATURES) + task_features = kwargs.get(Keys.TASK_FEATURES) + if metric_names is None: + raise ValueError("Metric names are required.") + + self._training_data_per_outcome = { + metric_name: tr for metric_name, tr in zip(metric_names, training_data) + } + submodel_options = self.submodel_options_per_outcome or {} + submodels = [] + + for metric_name, model_cls in self.botorch_model_class_per_outcome.items(): + if metric_name not in self.training_data_per_outcome: + continue # pragma: no cover + tr = self.training_data_per_outcome[metric_name] + formatted_model_inputs = model_cls.construct_inputs( + training_data=tr, + fidelity_features=fidelity_features or [], + task_features=task_features or [], + ) + kwargs = submodel_options.get(metric_name, {}) + # pyre-ignore[45]: Py raises informative msg if `model_cls` abstract. + submodels.append(model_cls(**formatted_model_inputs, **kwargs)) + self._model = ModelListGP(*submodels) + + # pyre-ignore[14]: `fit` takes in list of training data in list surrogate, + # whereas it takes just a single training data in base surrogate. + def fit( + self, + training_data: List[TrainingData], + bounds: List[Tuple[float, float]], + task_features: List[int], + feature_names: List[str], + metric_names: List[str], + fidelity_features: List[int], + target_fidelities: Optional[Dict[int, float]] = None, + candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, + state_dict: Optional[Dict[str, torch.Tensor]] = None, + refit: bool = True, + ) -> None: + super().fit( + # pyre-ignore[6]: `Surrogate.fit` expects single training data + # and in `ListSurrogate` we use a list of training data. + training_data=training_data, + bounds=bounds, + task_features=task_features, + feature_names=feature_names, + metric_names=metric_names, + fidelity_features=fidelity_features, + target_fidelities=target_fidelities, + candidate_metadata=candidate_metadata, + state_dict=state_dict, + refit=refit, + ) diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 029632e5142..850bb13ba28 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -6,7 +6,6 @@ from __future__ import annotations -from inspect import isabstract from typing import Any, Dict, List, Optional, Tuple, Type import torch @@ -30,6 +29,12 @@ from torch import Tensor +NOT_YET_FIT_MSG = ( + "Underlying BoTorch `Model` has not yet received its training_data." + "Please fit the model first." +) + + class Surrogate(Base): """ **All classes in 'botorch_modular' directory are under @@ -73,13 +78,6 @@ def __init__( ) -> None: self.botorch_model_class = botorch_model_class self.mll_class = mll_class - # NOTE: assuming here that plugging in kernels will be made easier on the - # BoTorch side (for v0 can just always raise `NotImplementedError` if - # `kernel` kwarg is not None). - if kernel_class: - self.kernel_class = kernel_class - # NOTE: `validate_kernel_class` to be implemented on BoTorch `Model`. - # self.botorch_model_class.validate_kernel_class(kernel_class) # Temporary validation while we develop these customizations. if likelihood is not None: @@ -96,11 +94,15 @@ def model(self) -> Model: @property def training_data(self) -> TrainingData: if self._training_data is None: - raise ValueError( - "Underlying BoTorch `Model` has not yet received its training_data." - ) + raise ValueError(NOT_YET_FIT_MSG) return not_none(self._training_data) + @property + def training_data_per_outcome(self) -> Dict[str, TrainingData]: + raise NotImplementedError( # pragma: no cover + "`training_data_per_outcome` is only used in `ListSurrogate`." + ) + @property def dtype(self) -> torch.dtype: return self.training_data.X.dtype @@ -137,8 +139,6 @@ def construct(self, training_data: TrainingData, **kwargs: Any) -> None: - "fidelity_features": Indices of columns in X that represent fidelity. """ - if isabstract(self.botorch_model_class): - raise TypeError("Cannot construct an abstract model.") if not isinstance(training_data, TrainingData): raise ValueError( # pragma: no cover "Base `Surrogate` expects training data for single outcome." @@ -150,7 +150,7 @@ def construct(self, training_data: TrainingData, **kwargs: Any) -> None: training_data=self.training_data, fidelity_features=kwargs.get(Keys.FIDELITY_FEATURES), ) - # pyre-ignore[45]: Model isn't abstract per the check above. + # pyre-ignore[45]: Py raises informative msg if `model_cls` abstract. self._model = self.botorch_model_class(**formatted_model_inputs) def fit( @@ -168,7 +168,11 @@ def fit( ) -> None: if self._model is None or self._should_reconstruct: self.construct( - training_data=training_data, fidelity_features=fidelity_features + training_data=training_data, + fidelity_features=fidelity_features, + # Kwargs below are unused in base `Surrogate`, but used in subclasses. + metric_names=metric_names, + task_features=task_features, ) if state_dict is not None: self.model.load_state_dict(state_dict) diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index 63b13260c71..d7f1223bd6f 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -9,6 +9,7 @@ import torch from ax.models.torch.botorch_modular.acquisition import Acquisition +from ax.models.torch.botorch_modular.list_surrogate import ListSurrogate from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase @@ -17,9 +18,9 @@ from botorch.utils.containers import TrainingData -ACQUISITION_PATH = f"{Acquisition.__module__}" -CURRENT_PATH = f"{__name__}" -SURROGATE_PATH = f"{Surrogate.__module__}" +ACQUISITION_PATH = Acquisition.__module__ +CURRENT_PATH = __name__ +SURROGATE_PATH = Surrogate.__module__ # Used to avoid going through BoTorch `Acquisition.__init__` which @@ -233,3 +234,16 @@ def test_best_point(self, mock_best_point): def test_evaluate(self, mock_call): self.acquisition.evaluate(X=self.X) mock_call.assert_called_with(X=self.X) + + def test_extract_training_data(self): + self.assertEqual( # Base `Surrogate` case. + self.acquisition._extract_training_data(surrogate=self.surrogate), + self.training_data, + ) + # `ListSurrogate` case. + list_surrogate = ListSurrogate(botorch_model_class_per_outcome={}) + list_surrogate._training_data_per_outcome = {"a": self.training_data} + self.assertEqual( + self.acquisition._extract_training_data(surrogate=list_surrogate), + list_surrogate._training_data_per_outcome, + ) diff --git a/ax/models/torch/tests/test_list_surrogate.py b/ax/models/torch/tests/test_list_surrogate.py new file mode 100644 index 00000000000..3a1a728c86f --- /dev/null +++ b/ax/models/torch/tests/test_list_surrogate.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import patch + +import torch +from ax.models.torch.botorch_modular.acquisition import Acquisition +from ax.models.torch.botorch_modular.list_surrogate import ( + NOT_YET_FIT_MSG, + ListSurrogate, +) +from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.utils import choose_model_class +from ax.utils.common.testutils import TestCase +from ax.utils.testing.torch_stubs import get_torch_test_data +from botorch.models.model import TrainingData +from botorch.models.model_list_gp_regression import ModelListGP +from botorch.models.multitask import FixedNoiseMultiTaskGP +from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood + + +SURROGATE_PATH = f"{Surrogate.__module__}" +CURRENT_PATH = f"{__name__}" +ACQUISITION_PATH = f"{Acquisition.__module__}" +RANK = "rank" + + +class ListSurrogateTest(TestCase): + def setUp(self): + self.outcomes = ["outcome_1", "outcome_2"] + self.mll_class = SumMarginalLogLikelihood + self.device = torch.device("cpu") + self.dtype = torch.float + self.task_features = [0] + Xs1, Ys1, Yvars1, bounds, _, _, _ = get_torch_test_data( + dtype=self.dtype, task_features=self.task_features + ) + Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data( + dtype=self.dtype, task_features=self.task_features + ) + self.botorch_model_class_per_outcome = { + self.outcomes[0]: choose_model_class( + Yvars=Yvars1, task_features=self.task_features, fidelity_features=[] + ), + self.outcomes[1]: choose_model_class( + Yvars=Yvars2, task_features=self.task_features, fidelity_features=[] + ), + } + self.expected_submodel_type = FixedNoiseMultiTaskGP + for submodel_cls in self.botorch_model_class_per_outcome.values(): + self.assertEqual(submodel_cls, FixedNoiseMultiTaskGP) + self.Xs = Xs1 + Xs2 + self.Ys = Ys1 + Ys2 + self.Yvars = Yvars1 + Yvars2 + self.training_data = [ + TrainingData(X=X, Y=Y, Yvar=Yvar) + for X, Y, Yvar in zip(self.Xs, self.Ys, self.Yvars) + ] + self.submodel_options = { + self.outcomes[0]: {RANK: 1}, + self.outcomes[1]: {RANK: 2}, + } + self.surrogate = ListSurrogate( + botorch_model_class_per_outcome=self.botorch_model_class_per_outcome, + mll_class=self.mll_class, + submodel_options_per_outcome=self.submodel_options, + ) + self.bounds = [(0.0, 1.0), (1.0, 4.0)] + self.feature_names = ["x1", "x2"] + + def check_ranks(self, c: ListSurrogate) -> type(None): + self.assertIsInstance(c, ListSurrogate) + self.assertIsInstance(c.model, ModelListGP) + for idx, submodel in enumerate(c.model.models): + self.assertIsInstance(submodel, self.expected_submodel_type) + self.assertEqual( + submodel._rank, self.submodel_options[self.outcomes[idx]][RANK] + ) + + def test_init(self): + self.assertEqual( + self.surrogate.botorch_model_class_per_outcome, + self.botorch_model_class_per_outcome, + ) + self.assertEqual(self.surrogate.mll_class, self.mll_class) + with self.assertRaises(NotImplementedError): + self.surrogate.training_data + with self.assertRaisesRegex(ValueError, NOT_YET_FIT_MSG): + self.surrogate.training_data_per_outcome + with self.assertRaisesRegex( + ValueError, "BoTorch `Model` has not yet been constructed" + ): + self.surrogate.model + + @patch( + f"{CURRENT_PATH}.FixedNoiseMultiTaskGP.construct_inputs", + # Mock to register calls, but still execute the function. + side_effect=FixedNoiseMultiTaskGP.construct_inputs, + ) + def test_construct(self, mock_MTGP_construct_inputs): + with self.assertRaisesRegex(ValueError, ".* are required"): + self.surrogate.construct(training_data=self.training_data) + self.surrogate.construct( + training_data=self.training_data, + fidelity_features=[], + task_features=self.task_features, + metric_names=self.outcomes, + ) + self.check_ranks(self.surrogate) + # Should construct inputs for MTGP twice. + self.assertEqual(len(mock_MTGP_construct_inputs.call_args_list), 2) + # First construct inputs should be called for MTGP with training data #0. + self.assertEqual( + # `call_args` is a tuple of (args, kwargs), and we are interested in kwargs. + mock_MTGP_construct_inputs.call_args_list[0][1], + { + "fidelity_features": [], + "task_features": self.task_features, + "training_data": self.training_data[0], + }, + ) + # Then, with training data #1. + self.assertEqual( + # `call_args` is a tuple of (args, kwargs), and we are interested in kwargs. + mock_MTGP_construct_inputs.call_args_list[1][1], + { + "fidelity_features": [], + "task_features": self.task_features, + "training_data": self.training_data[1], + }, + ) + + @patch(f"{CURRENT_PATH}.ModelListGP.load_state_dict", return_value=None) + @patch(f"{CURRENT_PATH}.SumMarginalLogLikelihood") + @patch(f"{SURROGATE_PATH}.fit_gpytorch_model") + def test_fit(self, mock_fit_gpytorch, mock_MLL, mock_state_dict): + surrogate = ListSurrogate( + botorch_model_class_per_outcome=self.botorch_model_class_per_outcome, + mll_class=SumMarginalLogLikelihood, + ) + # Checking that model is None before `fit` (and `construct`) calls. + self.assertIsNone(surrogate._model) + # Should instantiate mll and `fit_gpytorch_model` when `state_dict` + # is `None`. + surrogate.fit( + training_data=self.training_data, + bounds=self.bounds, + task_features=self.task_features, + feature_names=self.feature_names, + metric_names=self.outcomes, + fidelity_features=[], + ) + mock_state_dict.assert_not_called() + mock_MLL.assert_called_once() + mock_fit_gpytorch.assert_called_once() + mock_state_dict.reset_mock() + mock_MLL.reset_mock() + mock_fit_gpytorch.reset_mock() + # Should `load_state_dict` when `state_dict` is not `None` + # and `refit` is `False`. + state_dict = {} + surrogate.fit( + training_data=self.training_data, + bounds=self.bounds, + task_features=self.task_features, + feature_names=self.feature_names, + metric_names=self.outcomes, + fidelity_features=[], + refit=False, + state_dict=state_dict, + ) + mock_state_dict.assert_called_once() + mock_MLL.assert_not_called() + mock_fit_gpytorch.assert_not_called() diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index 50b62773dfd..c214e674faf 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -14,6 +14,7 @@ from ax.models.torch.botorch_modular.utils import choose_model_class from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase +from ax.utils.testing.torch_stubs import get_torch_test_data from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.monte_carlo import qExpectedImprovement from botorch.models.gp_regression import SingleTaskGP @@ -21,10 +22,11 @@ from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -CURRENT_PATH = f"{__name__}" -MODEL_PATH = f"{BoTorchModel.__module__}" -SURROGATE_PATH = f"{Surrogate.__module__}" -UTILS_PATH = f"{choose_model_class.__module__}" +CURRENT_PATH = __name__ +MODEL_PATH = BoTorchModel.__module__ +SURROGATE_PATH = Surrogate.__module__ +UTILS_PATH = choose_model_class.__module__ +ACQUISITION_PATH = Acquisition.__module__ class BoTorchModelTest(TestCase): @@ -46,14 +48,21 @@ def setUp(self): surrogate_fit_options=self.surrogate_fit_options, ) - self.X = torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]]) - self.Y = torch.tensor([[3.0], [4.0]]) - self.Yvar = torch.tensor([[0.0], [2.0]]) + self.dtype = torch.float + Xs1, Ys1, Yvars1, self.bounds, _, _, _ = get_torch_test_data(dtype=self.dtype) + Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data(dtype=self.dtype) + self.Xs = Xs1 + Xs2 + self.Ys = Ys1 + Ys2 + self.Yvars = Yvars1 + Yvars2 + self.X = self.Xs[0] + self.Y = self.Ys[0] + self.Yvar = self.Yvars[0] self.training_data = TrainingData(X=self.X, Y=self.Y, Yvar=self.Yvar) self.bounds = [(0.0, 10.0), (0.0, 10.0), (0.0, 10.0)] self.task_features = [] self.feature_names = ["x1", "x2", "x3"] self.metric_names = ["y"] + self.metric_names_for_list_surrogate = ["y1", "y2"] self.fidelity_features = [2] self.target_fidelities = {1: 1.0} self.candidate_metadata = [] diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 16bc8df62cb..a90d54d645a 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -11,6 +11,7 @@ from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase +from ax.utils.testing.torch_stubs import get_torch_test_data from botorch.acquisition.monte_carlo import qSimpleRegret from botorch.models.gp_regression import SingleTaskGP from botorch.models.model import Model @@ -32,23 +33,20 @@ def setUp(self): self.mll_class = ExactMarginalLogLikelihood self.device = torch.device("cpu") self.dtype = torch.float - self.X = torch.tensor( - [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], dtype=self.dtype, device=self.device + self.Xs, self.Ys, self.Yvars, self.bounds, _, _, _ = get_torch_test_data( + dtype=self.dtype + ) + self.training_data = TrainingData( + X=self.Xs[0], Y=self.Ys[0], Yvar=self.Yvars[0] ) - - self.Y = torch.tensor([[3.0], [4.0]], dtype=self.dtype, device=self.device) - self.Yvar = torch.tensor([[0.0], [2.0]], dtype=self.dtype, device=self.device) - - self.training_data = TrainingData(X=self.X, Y=self.Y, Yvar=self.Yvar) self.surrogate_kwargs = self.botorch_model_class.construct_inputs( self.training_data ) self.surrogate = Surrogate( botorch_model_class=self.botorch_model_class, mll_class=self.mll_class ) - self.bounds = [(0.0, 1.0), (1.0, 4.0), (2.0, 5.0)] self.task_features = [] - self.feature_names = ["x1", "x2", "x3"] + self.feature_names = ["x1", "x2"] self.metric_names = ["y"] self.fidelity_features = [] self.target_fidelities = {1: 1.0} @@ -112,16 +110,16 @@ def test_from_BoTorch(self): @patch(f"{CURRENT_PATH}.SingleTaskGP.__init__", return_value=None) def test_construct(self, mock_GP): - base_surrogate = Surrogate(botorch_model_class=Model) - with self.assertRaisesRegex(TypeError, "Cannot construct an abstract model."): - base_surrogate.construct( + with self.assertRaises(NotImplementedError): + # Base `Model` does not implement `construct_inputs`. + Surrogate(botorch_model_class=Model).construct( training_data=self.training_data, fidelity_features=self.fidelity_features, ) self.surrogate.construct( training_data=self.training_data, fidelity_features=self.fidelity_features ) - mock_GP.assert_called_with(train_X=self.X, train_Y=self.Y) + mock_GP.assert_called_with(train_X=self.Xs[0], train_Y=self.Ys[0]) @patch(f"{CURRENT_PATH}.SingleTaskGP.load_state_dict", return_value=None) @patch(f"{CURRENT_PATH}.ExactMarginalLogLikelihood") @@ -174,8 +172,8 @@ def test_predict(self, mock_predict): self.surrogate.construct( training_data=self.training_data, fidelity_features=self.fidelity_features ) - self.surrogate.predict(X=self.X) - mock_predict.assert_called_with(model=self.surrogate.model, X=self.X) + self.surrogate.predict(X=self.Xs[0]) + mock_predict.assert_called_with(model=self.surrogate.model, X=self.Xs[0]) def test_best_in_sample_point(self): self.surrogate.construct( @@ -190,7 +188,7 @@ def test_best_in_sample_point(self): bounds=self.bounds, objective_weights=None ) with patch( - f"{SURROGATE_PATH}.best_in_sample_point", return_value=(self.X, 0.0) + f"{SURROGATE_PATH}.best_in_sample_point", return_value=(self.Xs[0], 0.0) ) as mock_best_in_sample: best_point, observed_value = self.surrogate.best_in_sample_point( bounds=self.bounds, diff --git a/ax/utils/common/constants.py b/ax/utils/common/constants.py index 24431920480..65aacd48961 100644 --- a/ax/utils/common/constants.py +++ b/ax/utils/common/constants.py @@ -43,6 +43,7 @@ class Keys(str, Enum): IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF = "immutable_search_space_and_opt_config" MAXIMIZE = "maximize" METADATA = "metadata" + METRIC_NAMES = "metric_names" NUM_FANTASIES = "num_fantasies" NUM_INNER_RESTARTS = "num_inner_restarts" NUM_RESTARTS = "num_restarts" @@ -60,4 +61,5 @@ class Keys(str, Enum): STATE_DICT = "state_dict" SUBCLASS = "subclass" SUBSET_MODEL = "subset_model" + TASK_FEATURES = "task_features" WARM_START_REFITTING = "warm_start_refitting" diff --git a/ax/utils/testing/torch_stubs.py b/ax/utils/testing/torch_stubs.py new file mode 100644 index 00000000000..5c7fbb41373 --- /dev/null +++ b/ax/utils/testing/torch_stubs.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch + + +def get_optimizer_kwargs() -> Dict[str, int]: + return {"num_restarts": 2, "raw_samples": 2, "maxiter": 2, "batch_limit": 1} + + +def get_torch_test_data( + dtype=torch.float, cuda=False, constant_noise=True, task_features=None +): + device = torch.device("cuda") if cuda else torch.device("cpu") + Xs = [torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], dtype=dtype, device=device)] + Ys = [torch.tensor([[3.0], [4.0]], dtype=dtype, device=device)] + Yvars = [torch.tensor([[0.0], [2.0]], dtype=dtype, device=device)] + if constant_noise: + Yvars[0].fill_(1.0) + bounds = [(0.0, 1.0), (1.0, 4.0), (2.0, 5.0)] + feature_names = ["x1", "x2", "x3"] + task_features = [] if task_features is None else task_features + metric_names = ["y", "r"] + return Xs, Ys, Yvars, bounds, task_features, feature_names, metric_names