Skip to content

Commit

Permalink
Update TrialAsTask transform to default pending points to most recent…
Browse files Browse the repository at this point in the history
… trial

Summary: For MT mbm implementations, we use the TrialAsTask transform, however, until now this has only been used in scenarios where the trial is composed from a singular model. If we want to extend trial composition to be from multiple models, if the first model that is called is to create the trial generates a set of points, those become pending points. During fit and gen, the trialastask transform is called, and when we go to transform the pending points from that model within the current trial, it doesn't have an associated trial, so then we'll end up with a ```KeyError: 'TRIAL_PARAM' ``` error.

Differential Revision: D64974517
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 25, 2024
1 parent 9803b25 commit aa14aad
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
40 changes: 40 additions & 0 deletions ax/modelbridge/transforms/tests/test_trial_as_task_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,46 @@ def test_TransformObservationFeatures(self) -> None:
obs_ft4 = self.t3.untransform_observation_features(obs_ft4)
self.assertEqual(obs_ft4, self.training_feats)

def test_TransformObservationFeaturesWithoutTrialIndex(self) -> None:
obs_ft_no_trial_index = deepcopy(self.training_feats)
obs_ft_no_trial_index.append(
ObservationFeatures(
{"x": 20},
)
)
obs_ft_trans = [
ObservationFeatures({"x": 1, "TRIAL_PARAM": "0"}),
ObservationFeatures({"x": 2, "TRIAL_PARAM": "0"}),
ObservationFeatures({"x": 3, "TRIAL_PARAM": "1"}),
ObservationFeatures({"x": 4, "TRIAL_PARAM": "2"}),
ObservationFeatures({"x": 20, "TRIAL_PARAM": "2"}),
]
obs_ft_trans2 = [
ObservationFeatures({"x": 1, "bp1": "v1", "bp2": "u1"}),
ObservationFeatures({"x": 2, "bp1": "v1", "bp2": "u1"}),
ObservationFeatures({"x": 3, "bp1": "v2", "bp2": "u1"}),
ObservationFeatures({"x": 4, "bp1": "v3", "bp2": "u2"}),
ObservationFeatures({"x": 20, "bp1": "v3", "bp2": "u2"}),
]

# test can transform and untransform with no config
obs_ft_no_trial_index_transformed = self.t.transform_observation_features(
obs_ft_no_trial_index
)
self.assertEqual(obs_ft_no_trial_index_transformed, obs_ft_trans)
untransformed = self.t.untransform_observation_features(
obs_ft_no_trial_index_transformed
)
# test can transform and untransform with config trial level map
self.assertEqual(untransformed, obs_ft_no_trial_index)
obs_ft_no_index_transformed_2 = self.t2.transform_observation_features(
obs_ft_no_trial_index
)
self.assertEqual(obs_ft_no_index_transformed_2, obs_ft_trans2)
# can transform and untransform are equal with empty config
obs_ft4 = self.t3.untransform_observation_features(obs_ft_no_trial_index)
self.assertEqual(obs_ft4, obs_ft_no_trial_index)

def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
ss2 = self.t.transform_search_space(ss2)
Expand Down
7 changes: 7 additions & 0 deletions ax/modelbridge/transforms/trial_as_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ def transform_observation_features(
# typing.SupportsInt]` for 1st param but got `Optional[np.int64]`.
obsf.parameters[p_name] = level_dict[int(obsf.trial_index)]
obsf.trial_index = None
elif len(obsf.parameters) > 0:
# If the trial index is none, but the parameters are not empty
# perform the transform by assuming the observation is from the
# most recent trial. This is needed for generating trials composed
# of points from multiple models.
for p_name, level_dict in self.trial_level_map.items():
obsf.parameters[p_name] = level_dict[max(level_dict)]
return observation_features

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
Expand Down

0 comments on commit aa14aad

Please sign in to comment.