From aa14aadd8a413ca4029305e128613dd6a1d96a99 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Fri, 25 Oct 2024 10:37:52 -0700 Subject: [PATCH] Update TrialAsTask transform to default pending points to most recent trial Summary: For MT mbm implementations, we use the TrialAsTask transform, however, until now this has only been used in scenarios where the trial is composed from a singular model. If we want to extend trial composition to be from multiple models, if the first model that is called is to create the trial generates a set of points, those become pending points. During fit and gen, the trialastask transform is called, and when we go to transform the pending points from that model within the current trial, it doesn't have an associated trial, so then we'll end up with a ```KeyError: 'TRIAL_PARAM' ``` error. Differential Revision: D64974517 --- .../tests/test_trial_as_task_transform.py | 40 +++++++++++++++++++ ax/modelbridge/transforms/trial_as_task.py | 7 ++++ 2 files changed, 47 insertions(+) diff --git a/ax/modelbridge/transforms/tests/test_trial_as_task_transform.py b/ax/modelbridge/transforms/tests/test_trial_as_task_transform.py index 4fdd3232a5a..b9a6c1fdc23 100644 --- a/ax/modelbridge/transforms/tests/test_trial_as_task_transform.py +++ b/ax/modelbridge/transforms/tests/test_trial_as_task_transform.py @@ -132,6 +132,46 @@ def test_TransformObservationFeatures(self) -> None: obs_ft4 = self.t3.untransform_observation_features(obs_ft4) self.assertEqual(obs_ft4, self.training_feats) + def test_TransformObservationFeaturesWithoutTrialIndex(self) -> None: + obs_ft_no_trial_index = deepcopy(self.training_feats) + obs_ft_no_trial_index.append( + ObservationFeatures( + {"x": 20}, + ) + ) + obs_ft_trans = [ + ObservationFeatures({"x": 1, "TRIAL_PARAM": "0"}), + ObservationFeatures({"x": 2, "TRIAL_PARAM": "0"}), + ObservationFeatures({"x": 3, "TRIAL_PARAM": "1"}), + ObservationFeatures({"x": 4, "TRIAL_PARAM": "2"}), + ObservationFeatures({"x": 20, "TRIAL_PARAM": "2"}), + ] + obs_ft_trans2 = [ + ObservationFeatures({"x": 1, "bp1": "v1", "bp2": "u1"}), + ObservationFeatures({"x": 2, "bp1": "v1", "bp2": "u1"}), + ObservationFeatures({"x": 3, "bp1": "v2", "bp2": "u1"}), + ObservationFeatures({"x": 4, "bp1": "v3", "bp2": "u2"}), + ObservationFeatures({"x": 20, "bp1": "v3", "bp2": "u2"}), + ] + + # test can transform and untransform with no config + obs_ft_no_trial_index_transformed = self.t.transform_observation_features( + obs_ft_no_trial_index + ) + self.assertEqual(obs_ft_no_trial_index_transformed, obs_ft_trans) + untransformed = self.t.untransform_observation_features( + obs_ft_no_trial_index_transformed + ) + # test can transform and untransform with config trial level map + self.assertEqual(untransformed, obs_ft_no_trial_index) + obs_ft_no_index_transformed_2 = self.t2.transform_observation_features( + obs_ft_no_trial_index + ) + self.assertEqual(obs_ft_no_index_transformed_2, obs_ft_trans2) + # can transform and untransform are equal with empty config + obs_ft4 = self.t3.untransform_observation_features(obs_ft_no_trial_index) + self.assertEqual(obs_ft4, obs_ft_no_trial_index) + def test_TransformSearchSpace(self) -> None: ss2 = deepcopy(self.search_space) ss2 = self.t.transform_search_space(ss2) diff --git a/ax/modelbridge/transforms/trial_as_task.py b/ax/modelbridge/transforms/trial_as_task.py index f7c6e357c07..8de266bca7c 100644 --- a/ax/modelbridge/transforms/trial_as_task.py +++ b/ax/modelbridge/transforms/trial_as_task.py @@ -124,6 +124,13 @@ def transform_observation_features( # typing.SupportsInt]` for 1st param but got `Optional[np.int64]`. obsf.parameters[p_name] = level_dict[int(obsf.trial_index)] obsf.trial_index = None + elif len(obsf.parameters) > 0: + # If the trial index is none, but the parameters are not empty + # perform the transform by assuming the observation is from the + # most recent trial. This is needed for generating trials composed + # of points from multiple models. + for p_name, level_dict in self.trial_level_map.items(): + obsf.parameters[p_name] = level_dict[max(level_dict)] return observation_features def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: