Skip to content

Commit

Permalink
checkpointing callback handle objective not in metrics for fit (#674)
Browse files Browse the repository at this point in the history
Co-authored-by: Haifeng Jin <[email protected]>
  • Loading branch information
haifeng-jin and haifeng-jin authored Mar 25, 2022
1 parent 0c79129 commit bd21a5d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
16 changes: 16 additions & 0 deletions keras_tuner/engine/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ def __init__(self, name, direction):
self.name = name
self.direction = direction

def has_value(self, logs):
"""Check if objective value exists in logs.
Args:
logs: A dictionary with the metric names as the keys and the metric
values as the values, which is in the same format as the `logs`
argument for `Callback.on_epoch_end()`.
Returns:
Boolean, whether we can compute objective value from the logs.
"""
return self.name in logs

def get_value(self, logs):
"""Get the objective value from the metrics logs.
Expand Down Expand Up @@ -81,6 +94,9 @@ def __init__(self, objectives):
objective.name: objective.direction for objective in self.objectives
}

def has_value(self, logs):
return all([key in logs for key in self.name_to_direction])

def get_value(self, logs):
obj_value = 0
for metric_name, metric_value in logs.items():
Expand Down
12 changes: 12 additions & 0 deletions keras_tuner/engine/objective_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def test_objective_better_than_min():
assert not obj.better_than(0, 0)


def test_objective_has_value():
obj = objective.create_objective("loss")
assert obj.has_value({"loss": 3.0})
assert not obj.has_value({"accuracy": 3.0})


def test_objective_get_value():
obj = objective.create_objective("loss")
assert obj.get_value({"accuracy": 3.0, "loss": 2.0}) == 2.0
Expand Down Expand Up @@ -97,3 +103,9 @@ def test_multi_objective_not_equal():
obj1 = objective.create_objective(["loss", "loss"])
obj2 = objective.create_objective(["loss", "accuracy"])
assert obj1 != obj2


def test_multi_objective_has_value():
obj = objective.create_objective(["loss", "accuracy"])
assert obj.has_value({"loss": 1.0, "accuracy": 1.0, "mse": 2.0})
assert not obj.has_value({"accuracy": 1.0, "mse": 2.0})
6 changes: 4 additions & 2 deletions keras_tuner/engine/tuner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ def __init__(self, objective, filepath):
self.best_value = float("inf")

def on_epoch_end(self, epoch, logs=None):
if isinstance(self.objective, obj_module.DefaultObjective):
# Save on every epoch if no objective is specified.
if not self.objective.has_value(logs):
# Save on every epoch if metric value is not in the logs. Either no
# objective is specified, or objective is computed and returned
# after `fit()`.
self.model.save_weights(self.filepath)
return
current_value = self.objective.get_value(logs)
Expand Down

0 comments on commit bd21a5d

Please sign in to comment.