Skip to content

Commit

Permalink
Fix oracle resume bug (#837)
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin authored Feb 9, 2023
1 parent 2f77c7a commit 16b25c3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Release v1.2.1

## Bug fixes
* The resume feature (`overwrite=False`) would crash in 1.2.0. This is now fixed.

# Release v1.2.0

## Breaking changes
Expand Down
4 changes: 2 additions & 2 deletions keras_tuner/engine/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,8 @@ def reload(self):
) from e

# Empty the ongoing_trials and send them for retry.
for _, trial_id in self.ongoing_trials.items():
self._retry_queue.append(trial_id)
for _, trial in self.ongoing_trials.items():
self._retry_queue.append(trial.trial_id)
self.ongoing_trials = {}

def _get_oracle_fname(self):
Expand Down
26 changes: 26 additions & 0 deletions keras_tuner/engine/oracle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,29 @@ def test_get_best_trial_with_nans(tmp_path):

assert len(oracle.get_best_trials()) > 0
assert oracle.get_best_trials()[0].trial_id == best_trial.trial_id


def test_overwrite_false_resume(tmp_path):
oracle = OracleStub(
directory=tmp_path, objective="val_loss", max_retries_per_trial=1
)
for i in range(10):
trial = oracle.create_trial(tuner_id="a")
oracle.update_trial(trial.trial_id, {"val_loss": np.random.rand()})
trial.status = trial_module.TrialStatus.COMPLETED
oracle.end_trial(trial)

trial = oracle.create_trial(tuner_id="a")
trial_id = trial.trial_id
oracle = OracleStub(
directory=tmp_path, objective="val_loss", max_retries_per_trial=1
)
oracle.reload()

trial = oracle.create_trial(tuner_id="a")
oracle.update_trial(trial.trial_id, {"val_loss": np.random.rand()})
trial.status = trial_module.TrialStatus.COMPLETED
oracle.end_trial(trial)

assert trial.trial_id == trial_id
assert oracle.get_trial(trial_id).status == trial_module.TrialStatus.COMPLETED

0 comments on commit 16b25c3

Please sign in to comment.