-
Notifications
You must be signed in to change notification settings - Fork 319
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds subclass of MetadataToRange Transform that provides sensible def…
- Loading branch information
1 parent
842ec7d
commit 7b698f7
Showing
2 changed files
with
242 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from typing import Any, Optional, TYPE_CHECKING | ||
|
||
from ax.core.map_metric import MapMetric | ||
from ax.core.observation import Observation, ObservationFeatures | ||
from ax.core.search_space import SearchSpace | ||
from ax.modelbridge.transforms.metadata_to_range import MetadataToRange | ||
from ax.models.types import TConfig | ||
from pyre_extensions import assert_is_instance | ||
|
||
if TYPE_CHECKING: | ||
# import as module to make sphinx-autodoc-typehints happy | ||
from ax import modelbridge as modelbridge_module # noqa F401 | ||
|
||
|
||
class MapKeyToRange(MetadataToRange): | ||
DEFAULT_LOG_SCALE: bool = True | ||
DEFAULT_MAP_KEY: str = MapMetric.map_key_info.key | ||
|
||
def __init__( | ||
self, | ||
search_space: SearchSpace | None = None, | ||
observations: list[Observation] | None = None, | ||
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, | ||
config: TConfig | None = None, | ||
) -> None: | ||
config = config or {} | ||
self.parameters: dict[str, dict[str, Any]] = assert_is_instance( | ||
config.setdefault("parameters", {}), dict | ||
) | ||
# TODO[tiao]: raise warning if `DEFAULT_MAP_KEY` is already in keys(?) | ||
self.parameters.setdefault(self.DEFAULT_MAP_KEY, {}) | ||
super().__init__( | ||
search_space=search_space, | ||
observations=observations, | ||
modelbridge=modelbridge, | ||
config=config, | ||
) | ||
|
||
def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: | ||
if not obsf.parameters: | ||
for p in self._parameter_list: | ||
# TODO[tiao]: can we use be p.target_value? | ||
# (not its original intended use but could be advantageous) | ||
obsf.parameters[p.name] = p.upper | ||
return | ||
super()._transform_observation_feature(obsf) |
188 changes: 188 additions & 0 deletions
188
ax/modelbridge/transforms/tests/test_map_key_to_range_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from copy import deepcopy | ||
from typing import Iterator | ||
|
||
import numpy as np | ||
from ax.core.observation import Observation, ObservationData, ObservationFeatures | ||
from ax.core.parameter import ParameterType, RangeParameter | ||
from ax.core.search_space import SearchSpace | ||
from ax.exceptions.core import DataRequiredError | ||
from ax.modelbridge.transforms.map_key_to_range import MapKeyToRange | ||
from ax.utils.common.testutils import TestCase | ||
from pyre_extensions import assert_is_instance | ||
|
||
|
||
WIDTHS = [2.0, 4.0, 8.0] | ||
HEIGHTS = [4.0, 2.0, 8.0] | ||
STEPS_ENDS = [1, 5, 3] | ||
|
||
|
||
def _enumerate() -> Iterator[tuple[int, float, float, float]]: | ||
yield from ( | ||
(trial_index, width, height, float(i + 1)) | ||
for trial_index, (width, height, steps_end) in enumerate( | ||
zip(WIDTHS, HEIGHTS, STEPS_ENDS) | ||
) | ||
for i in range(steps_end) | ||
) | ||
|
||
|
||
class MapKeyToRangeTransformTest(TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
|
||
self.search_space = SearchSpace( | ||
parameters=[ | ||
RangeParameter( | ||
name="width", | ||
parameter_type=ParameterType.FLOAT, | ||
lower=1, | ||
upper=20, | ||
), | ||
RangeParameter( | ||
name="height", | ||
parameter_type=ParameterType.FLOAT, | ||
lower=1, | ||
upper=20, | ||
), | ||
] | ||
) | ||
|
||
self.observations = [] | ||
for trial_index, width, height, steps in _enumerate(): | ||
obs_feat = ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={"width": width, "height": height}, | ||
metadata={ | ||
"foo": 42, | ||
MapKeyToRange.DEFAULT_MAP_KEY: steps, | ||
}, | ||
) | ||
obs_data = ObservationData( | ||
metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) | ||
) | ||
self.observations.append(Observation(features=obs_feat, data=obs_data)) | ||
|
||
# does not require explicitly specifying `config` | ||
self.t = MapKeyToRange( | ||
observations=self.observations, | ||
) | ||
|
||
def test_Init(self) -> None: | ||
self.assertEqual(len(self.t._parameter_list), 1) | ||
|
||
p = self.t._parameter_list[0] | ||
|
||
self.assertEqual(p.name, MapKeyToRange.DEFAULT_MAP_KEY) | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 1.0) | ||
self.assertEqual(p.upper, 5.0) | ||
self.assertTrue(p.log_scale) | ||
self.assertFalse(p.logit_scale) | ||
self.assertIsNone(p.digits) | ||
self.assertFalse(p.is_fidelity) | ||
self.assertIsNone(p.target_value) | ||
|
||
with self.assertRaises(DataRequiredError): | ||
MetadataToRange(search_space=None, observations=None) | ||
with self.assertRaises(DataRequiredError): | ||
MetadataToRange(search_space=None, observations=[]) | ||
|
||
with self.subTest("infer parameter type"): | ||
observations = [] | ||
for trial_index, width, height, steps in _enumerate(): | ||
obs_feat = ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={"width": width, "height": height}, | ||
metadata={ | ||
"foo": 42, | ||
"bar": int(steps), | ||
}, | ||
) | ||
obs_data = ObservationData( | ||
metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) | ||
) | ||
observations.append(Observation(features=obs_feat, data=obs_data)) | ||
|
||
# test that one is able to override default config | ||
with self.subTest(msg="override default config"): | ||
t = MapKeyToRange( | ||
observations=self.observations, | ||
config={ | ||
"parameters": {MapKeyToRange.DEFAULT_MAP_KEY: {"log_scale": False}} | ||
}, | ||
) | ||
self.assertDictEqual(t.parameters, {"steps": {"log_scale": False}}) | ||
|
||
self.assertEqual(len(t._parameter_list), 1) | ||
|
||
p = t._parameter_list[0] | ||
|
||
self.assertEqual(p.name, MapKeyToRange.DEFAULT_MAP_KEY) | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 1.0) | ||
self.assertEqual(p.upper, 5.0) | ||
self.assertFalse(p.log_scale) | ||
|
||
def test_TransformSearchSpace(self) -> None: | ||
ss2 = deepcopy(self.search_space) | ||
ss2 = self.t.transform_search_space(ss2) | ||
|
||
self.assertSetEqual( | ||
set(ss2.parameters), | ||
{"height", "width", MapKeyToRange.DEFAULT_MAP_KEY}, | ||
) | ||
|
||
p = assert_is_instance( | ||
ss2.parameters[MapKeyToRange.DEFAULT_MAP_KEY], RangeParameter | ||
) | ||
|
||
self.assertEqual(p.name, MapKeyToRange.DEFAULT_MAP_KEY) | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 1.0) | ||
self.assertEqual(p.upper, 5.0) | ||
self.assertTrue(p.log_scale) | ||
self.assertFalse(p.logit_scale) | ||
self.assertIsNone(p.digits) | ||
self.assertFalse(p.is_fidelity) | ||
self.assertIsNone(p.target_value) | ||
|
||
def test_TransformObservationFeatures(self) -> None: | ||
observation_features = [obs.features for obs in self.observations] | ||
obs_ft2 = deepcopy(observation_features) | ||
obs_ft2 = self.t.transform_observation_features(obs_ft2) | ||
|
||
self.assertEqual( | ||
obs_ft2, | ||
[ | ||
ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={ | ||
"width": width, | ||
"height": height, | ||
MapKeyToRange.DEFAULT_MAP_KEY: steps, | ||
}, | ||
metadata={"foo": 42}, | ||
) | ||
for trial_index, width, height, steps in _enumerate() | ||
], | ||
) | ||
obs_ft2 = self.t.untransform_observation_features(obs_ft2) | ||
self.assertEqual(obs_ft2, observation_features) | ||
|
||
def test_TransformObservationFeaturesWithEmptyParameters(self) -> None: | ||
obsf = ObservationFeatures(parameters={}) | ||
self.t.transform_observation_features([obsf]) | ||
|
||
p = self.t._parameter_list[0] | ||
self.assertEqual( | ||
obsf, | ||
ObservationFeatures(parameters={MapKeyToRange.DEFAULT_MAP_KEY: p.upper}), | ||
) |