Skip to content

Commit

Permalink
Adds subclass of MetadataToRange Transform that provides sensible def…
Browse files Browse the repository at this point in the history
…aults for MapData (#3155)

Summary: Pull Request resolved: #3155

Differential Revision: D66945078
  • Loading branch information
Louis Tiao authored and facebook-github-bot committed Dec 12, 2024
1 parent 842ec7d commit 7b698f7
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 0 deletions.
54 changes: 54 additions & 0 deletions ax/modelbridge/transforms/map_key_to_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Optional, TYPE_CHECKING

from ax.core.map_metric import MapMetric
from ax.core.observation import Observation, ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.metadata_to_range import MetadataToRange
from ax.models.types import TConfig
from pyre_extensions import assert_is_instance

if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401


class MapKeyToRange(MetadataToRange):
DEFAULT_LOG_SCALE: bool = True
DEFAULT_MAP_KEY: str = MapMetric.map_key_info.key

def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: TConfig | None = None,
) -> None:
config = config or {}
self.parameters: dict[str, dict[str, Any]] = assert_is_instance(
config.setdefault("parameters", {}), dict
)
# TODO[tiao]: raise warning if `DEFAULT_MAP_KEY` is already in keys(?)
self.parameters.setdefault(self.DEFAULT_MAP_KEY, {})
super().__init__(
search_space=search_space,
observations=observations,
modelbridge=modelbridge,
config=config,
)

def _transform_observation_feature(self, obsf: ObservationFeatures) -> None:
if not obsf.parameters:
for p in self._parameter_list:
# TODO[tiao]: can we use be p.target_value?
# (not its original intended use but could be advantageous)
obsf.parameters[p.name] = p.upper
return
super()._transform_observation_feature(obsf)
188 changes: 188 additions & 0 deletions ax/modelbridge/transforms/tests/test_map_key_to_range_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from copy import deepcopy
from typing import Iterator

import numpy as np
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.map_key_to_range import MapKeyToRange
from ax.utils.common.testutils import TestCase
from pyre_extensions import assert_is_instance


WIDTHS = [2.0, 4.0, 8.0]
HEIGHTS = [4.0, 2.0, 8.0]
STEPS_ENDS = [1, 5, 3]


def _enumerate() -> Iterator[tuple[int, float, float, float]]:
yield from (
(trial_index, width, height, float(i + 1))
for trial_index, (width, height, steps_end) in enumerate(
zip(WIDTHS, HEIGHTS, STEPS_ENDS)
)
for i in range(steps_end)
)


class MapKeyToRangeTransformTest(TestCase):
def setUp(self) -> None:
super().setUp()

self.search_space = SearchSpace(
parameters=[
RangeParameter(
name="width",
parameter_type=ParameterType.FLOAT,
lower=1,
upper=20,
),
RangeParameter(
name="height",
parameter_type=ParameterType.FLOAT,
lower=1,
upper=20,
),
]
)

self.observations = []
for trial_index, width, height, steps in _enumerate():
obs_feat = ObservationFeatures(
trial_index=trial_index,
parameters={"width": width, "height": height},
metadata={
"foo": 42,
MapKeyToRange.DEFAULT_MAP_KEY: steps,
},
)
obs_data = ObservationData(
metric_names=[], means=np.array([]), covariance=np.empty((0, 0))
)
self.observations.append(Observation(features=obs_feat, data=obs_data))

# does not require explicitly specifying `config`
self.t = MapKeyToRange(
observations=self.observations,
)

def test_Init(self) -> None:
self.assertEqual(len(self.t._parameter_list), 1)

p = self.t._parameter_list[0]

self.assertEqual(p.name, MapKeyToRange.DEFAULT_MAP_KEY)
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 1.0)
self.assertEqual(p.upper, 5.0)
self.assertTrue(p.log_scale)
self.assertFalse(p.logit_scale)
self.assertIsNone(p.digits)
self.assertFalse(p.is_fidelity)
self.assertIsNone(p.target_value)

with self.assertRaises(DataRequiredError):
MetadataToRange(search_space=None, observations=None)
with self.assertRaises(DataRequiredError):
MetadataToRange(search_space=None, observations=[])

with self.subTest("infer parameter type"):
observations = []
for trial_index, width, height, steps in _enumerate():
obs_feat = ObservationFeatures(
trial_index=trial_index,
parameters={"width": width, "height": height},
metadata={
"foo": 42,
"bar": int(steps),
},
)
obs_data = ObservationData(
metric_names=[], means=np.array([]), covariance=np.empty((0, 0))
)
observations.append(Observation(features=obs_feat, data=obs_data))

# test that one is able to override default config
with self.subTest(msg="override default config"):
t = MapKeyToRange(
observations=self.observations,
config={
"parameters": {MapKeyToRange.DEFAULT_MAP_KEY: {"log_scale": False}}
},
)
self.assertDictEqual(t.parameters, {"steps": {"log_scale": False}})

self.assertEqual(len(t._parameter_list), 1)

p = t._parameter_list[0]

self.assertEqual(p.name, MapKeyToRange.DEFAULT_MAP_KEY)
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 1.0)
self.assertEqual(p.upper, 5.0)
self.assertFalse(p.log_scale)

def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
ss2 = self.t.transform_search_space(ss2)

self.assertSetEqual(
set(ss2.parameters),
{"height", "width", MapKeyToRange.DEFAULT_MAP_KEY},
)

p = assert_is_instance(
ss2.parameters[MapKeyToRange.DEFAULT_MAP_KEY], RangeParameter
)

self.assertEqual(p.name, MapKeyToRange.DEFAULT_MAP_KEY)
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 1.0)
self.assertEqual(p.upper, 5.0)
self.assertTrue(p.log_scale)
self.assertFalse(p.logit_scale)
self.assertIsNone(p.digits)
self.assertFalse(p.is_fidelity)
self.assertIsNone(p.target_value)

def test_TransformObservationFeatures(self) -> None:
observation_features = [obs.features for obs in self.observations]
obs_ft2 = deepcopy(observation_features)
obs_ft2 = self.t.transform_observation_features(obs_ft2)

self.assertEqual(
obs_ft2,
[
ObservationFeatures(
trial_index=trial_index,
parameters={
"width": width,
"height": height,
MapKeyToRange.DEFAULT_MAP_KEY: steps,
},
metadata={"foo": 42},
)
for trial_index, width, height, steps in _enumerate()
],
)
obs_ft2 = self.t.untransform_observation_features(obs_ft2)
self.assertEqual(obs_ft2, observation_features)

def test_TransformObservationFeaturesWithEmptyParameters(self) -> None:
obsf = ObservationFeatures(parameters={})
self.t.transform_observation_features([obsf])

p = self.t._parameter_list[0]
self.assertEqual(
obsf,
ObservationFeatures(parameters={MapKeyToRange.DEFAULT_MAP_KEY: p.upper}),
)

0 comments on commit 7b698f7

Please sign in to comment.