diff --git a/ax/modelbridge/best_model_selector.py b/ax/modelbridge/best_model_selector.py index 7829954c5c4..4b691b3b2b6 100644 --- a/ax/modelbridge/best_model_selector.py +++ b/ax/modelbridge/best_model_selector.py @@ -9,15 +9,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Callable -from enum import Enum -from functools import partial +from enum import unique from typing import Any, Union import numpy as np from ax.exceptions.core import UserInputError from ax.modelbridge.model_spec import ModelSpec from ax.utils.common.base import Base +from ax.utils.common.func_enum import FuncEnum from ax.utils.common.typeutils import not_none # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. @@ -36,27 +35,24 @@ def best_model(self, model_specs: list[ModelSpec]) -> ModelSpec: """ -class ReductionCriterion(Enum): +@unique +class ReductionCriterion(FuncEnum): """An enum for callables that are used for aggregating diagnostics over metrics and selecting the best diagnostic in ``SingleDiagnosticBestModelSelector``. + NOTE: The methods defined by this enum should all share identical signatures: + ``Callable[[ARRAYLIKE], np.ndarray]``, and reside in this file. + NOTE: This is used to ensure serializability of the callables. """ - # NOTE: Callables need to be wrapped in `partial` to be registered as members. - # pyre-fixme[35]: Target cannot be annotated. - # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. - MEAN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.mean) - # pyre-fixme[35]: Target cannot be annotated. - # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. - MIN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.min) - # pyre-fixme[35]: Target cannot be annotated. - # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. - MAX: Callable[[ARRAYLIKE], np.ndarray] = partial(np.max) + MEAN = "mean_reduction_criterion" + MIN = "min_reduction_criterion" + MAX = "max_reduction_criterion" # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. def __call__(self, array_like: ARRAYLIKE) -> np.ndarray: - return self.value(array_like) + return super().__call__(array_like=array_like) class SingleDiagnosticBestModelSelector(BestModelSelector): @@ -132,3 +128,23 @@ def best_model(self, model_specs: list[ModelSpec]) -> ModelSpec: best_diagnostic = self.criterion(aggregated_diagnostic_values).item() best_index = aggregated_diagnostic_values.index(best_diagnostic) return model_specs[best_index] + + +# ------------------------- Reduction criteria ------------------------- # + + +# Wrap the numpy functions, to be able to access them directly from this +# module in `ReductionCriterion(FuncEnum)` and to have typechecking +def mean_reduction_criterion(array_like: ARRAYLIKE) -> np.ndarray: + """Compute the mean of an array-like object.""" + return np.mean(array_like) + + +def min_reduction_criterion(array_like: ARRAYLIKE) -> np.ndarray: + """Compute the min of an array-like object.""" + return np.min(array_like) + + +def max_reduction_criterion(array_like: ARRAYLIKE) -> np.ndarray: + """Compute the max of an array-like object.""" + return np.max(array_like) diff --git a/ax/modelbridge/tests/test_best_model_selector.py b/ax/modelbridge/tests/test_best_model_selector.py index 6359996f18b..a2a244c9808 100644 --- a/ax/modelbridge/tests/test_best_model_selector.py +++ b/ax/modelbridge/tests/test_best_model_selector.py @@ -6,8 +6,11 @@ # pyre-strict +import inspect from unittest.mock import Mock, patch +import numpy as np + from ax.exceptions.core import UserInputError from ax.modelbridge.best_model_selector import ( ReductionCriterion, @@ -36,6 +39,32 @@ def setUp(self) -> None: ms._last_cv_kwargs = {} self.model_specs.append(ms) + def test_member_typing(self) -> None: + for reduction_criterion in ReductionCriterion: + signature = inspect.signature(reduction_criterion._get_function_for_value()) + self.assertEqual(signature.return_annotation, "np.ndarray") + + # pyre-fixme [56]: Pyre was not able to infer the type of argument + # `numpy` to decorator factory `unittest.mock.patch` + @patch(f"{ReductionCriterion.__module__}.np", wraps=np) + def test_ReductionCriterion(self, mock_np: Mock) -> None: + untested_reduction_criteria = set(ReductionCriterion) + # Check MEAN (should just fall through to `np.mean`) + array = np.array([1, 2, 3]) # and then use this var all the way down + self.assertEqual(ReductionCriterion.MEAN(array), np.mean(array)) + mock_np.mean.assert_called_once() + untested_reduction_criteria.remove(ReductionCriterion.MEAN) + # Check MIN (should just fall through to `np.min`) + self.assertEqual(ReductionCriterion.MIN(np.array([1, 2, 3])), 1.0) + mock_np.min.assert_called_once() + untested_reduction_criteria.remove(ReductionCriterion.MIN) + # Check MAX (should just fall through to `np.max`) + self.assertEqual(ReductionCriterion.MAX(np.array([1, 2, 3])), 3.0) + mock_np.max.assert_called_once() + untested_reduction_criteria.remove(ReductionCriterion.MAX) + # There should be no untested reduction criteria left + self.assertEqual(len(untested_reduction_criteria), 0) + def test_user_input_error(self) -> None: with self.assertRaisesRegex(UserInputError, "ReductionCriterion"): SingleDiagnosticBestModelSelector( diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index ec3d85b229d..d0c2d8f2e71 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -77,6 +77,9 @@ class TestGenerationStrategyWithoutModelBridgeMocks(TestCase): test class that makes use of mocking rather sparingly. """ + def _setUp(self) -> None: + super().setUp() + @fast_botorch_optimize @patch( "ax.modelbridge.generation_node._extract_model_state_after_gen",