Skip to content

Commit

Permalink
Remove multi-surrogate support in MBM (facebook#2957)
Browse files Browse the repository at this point in the history
Summary:

MBM with multiple surrogates was never supported e2e and is being deprecated. This diff:
- Adds a new `surrogate_spec: Surrogate` input to `BoTorchModel`, to replace `surrogate_specs: dict[str, Surrogate]` input.
- Raises a `DeprecationWarning` (as an exception) if `surrogate_specs` is passed in with multiple elements.
- Cleans up the code that was necessary to support multiple surrogates.

Differential Revision: D64875988
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 24, 2024
1 parent a85c7ee commit 5613fb9
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 423 deletions.
17 changes: 5 additions & 12 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.utils.common.constants import Keys
from ax.utils.common.kwargs import get_function_argument_names
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -74,11 +73,9 @@ def test_botorch_modular(self) -> None:
self.assertEqual(gpei.model.botorch_acqf_class, qExpectedImprovement)
self.assertEqual(gpei.model.acquisition_class, Acquisition)
self.assertEqual(gpei.model.acquisition_options, {"best_f": 0.0})
self.assertIsInstance(gpei.model.surrogates[Keys.AUTOSET_SURROGATE], Surrogate)
self.assertIsInstance(gpei.model.surrogate, Surrogate)
# SingleTaskGP should be picked.
self.assertIsInstance(
gpei.model.surrogates[Keys.AUTOSET_SURROGATE].model, SingleTaskGP
)
self.assertIsInstance(gpei.model.surrogate.model, SingleTaskGP)

gr = gpei.gen(n=1)
self.assertIsNotNone(gr.best_arm_predictions)
Expand All @@ -96,14 +93,10 @@ def test_SAASBO(self) -> None:
self.assertIsInstance(saasbo, TorchModelBridge)
self.assertEqual(saasbo._model_key, "SAASBO")
self.assertIsInstance(saasbo.model, BoTorchModel)
surrogate_specs = saasbo.model.surrogate_specs
surrogate_spec = saasbo.model.surrogate_spec
self.assertEqual(
surrogate_specs,
{
"SAASBO_Surrogate": SurrogateSpec(
botorch_model_class=SaasFullyBayesianSingleTaskGP
)
},
surrogate_spec,
SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP),
)
self.assertEqual(
saasbo.model.surrogate.botorch_model_class, SaasFullyBayesianSingleTaskGP
Expand Down
1 change: 1 addition & 0 deletions ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def __init__(
"instead. If you run into a use case that is not supported by MBM, "
"please raise this with an issue at https://github.com/facebook/Ax",
DeprecationWarning,
stacklevel=2,
)
self.model_constructor = model_constructor
self.model_predictor = model_predictor
Expand Down
5 changes: 2 additions & 3 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ def optimize(
)
return candidates, acqf_values, arm_weights

# 2. Handle search spaces with discrete features.
# 2a. Handle the fully discrete search space.
# 2. Handle fully discrete search spaces.
if optimizer in (
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
Expand Down Expand Up @@ -384,7 +383,7 @@ def optimize(
)
return candidates, acqf_values, arm_weights

# 2b. Handle mixed search spaces that have discrete and continuous features.
# 3. Handle mixed search spaces that have discrete and continuous features.
# Only sequential optimization is supported for `optimize_acqf_mixed`.
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
Expand Down
Loading

0 comments on commit 5613fb9

Please sign in to comment.