Skip to content

Commit

Permalink
Finalize fix of not needing to store transition criterion on steps (#…
Browse files Browse the repository at this point in the history
…1985)

Summary:
Pull Request resolved: #1985

This diff:
Finalizes the fix to the storage by removing the need to unset transiton criteria and doesn't store transition criterion anymore. It can do so because in the decoder we reconstruct the generation step, which automatically fills in the relevant node fields during its init method.

Reviewed By: lena-kashtelyan

Differential Revision: D50752054

fbshipit-source-id: 28a6529cc20d0239c521a37fd58dd379b1dc9610
  • Loading branch information
mgarrard authored and facebook-github-bot committed Nov 15, 2023
1 parent de81262 commit b3ee6c9
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 17 deletions.
2 changes: 0 additions & 2 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,6 @@ def _unset_non_persistent_state_fields(self) -> None:
self._model = None
for s in self._steps:
s._model_spec_to_gen_from = None
# TODO: @mgarrard remove once re-enabled criterion storage
s._transition_criteria = []

def __repr__(self) -> str:
"""String representation of this generation strategy."""
Expand Down
1 change: 0 additions & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,6 @@ def test_save_and_load_generation_strategy(self) -> None:
)
second_client = AxClient(db_settings=db_settings)
second_client.load_experiment_from_database("unique_test_experiment")
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(second_client.generation_strategy, generation_strategy)

@patch(
Expand Down
1 change: 0 additions & 1 deletion ax/service/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,6 @@ def test_sqa_storage(self) -> None:
# Check that experiment and GS were saved.
exp, gs = scheduler._load_experiment_and_generation_strategy(experiment.name)
self.assertEqual(exp, experiment)
self.two_sobol_steps_GS._unset_non_persistent_state_fields()
self.assertEqual(gs, self.two_sobol_steps_GS)
scheduler.run_all_trials()
# Check that experiment and GS were saved and test reloading with reduced state.
Expand Down
9 changes: 1 addition & 8 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,14 +728,7 @@ def generation_step_from_json(
if gen_kwargs
else None,
index=generation_step_json.pop("index", -1),
should_deduplicate=generation_step_json.pop("should_deduplicate")
if "should_deduplicate" in generation_step_json
else False,
)
generation_step._transition_criteria = transition_criteria_from_json(
generation_step_json.pop("transition_criteria")
if "transition_criteria" in generation_step_json.keys()
else None
should_deduplicate=generation_step_json.pop("should_deduplicate", False),
)
return generation_step

Expand Down
12 changes: 8 additions & 4 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
import torch
from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.core.metric import Metric
from ax.core.runner import Runner
from ax.exceptions.core import AxStorageWarning
Expand Down Expand Up @@ -328,8 +327,14 @@ def test_EncodeDecode(self) -> None:
converted_object = converted_object.state_dict()
if isinstance(original_object, GenerationStrategy):
original_object._unset_non_persistent_state_fields()
if isinstance(original_object, BenchmarkMethod):
original_object.generation_strategy._unset_non_persistent_state_fields()
# for the test, completion criterion are set post init
# and therefore do not become transition critirion, unset
# for this specific test only
if "with_completion_criteria" in fake_func.keywords:
for step in original_object._steps:
step._transition_criteria = None
for step in converted_object._steps:
step._transition_criteria = None
try:
self.assertEqual(
original_object,
Expand Down Expand Up @@ -402,7 +407,6 @@ def test_DecodeGenerationStrategy(self) -> None:
decoder_registry=CORE_DECODER_REGISTRY,
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
)
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertGreater(len(new_generation_strategy._steps), 0)
self.assertIsInstance(new_generation_strategy._steps[0].model, Models)
Expand Down
1 change: 0 additions & 1 deletion ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,6 @@ def test_EncodeDecodeGenerationStrategy(self) -> None:
# pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`.
gs_id=generation_strategy._db_id
)
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertIsNone(generation_strategy._experiment)

Expand Down

0 comments on commit b3ee6c9

Please sign in to comment.