Skip to content

Commit

Permalink
Preparation for switch in Keras serialization format (#834)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkovela1 authored Feb 9, 2023
1 parent 8a0f9a7 commit 2f77c7a
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 9 deletions.
5 changes: 2 additions & 3 deletions keras_tuner/engine/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import abc

import six
from tensorflow import keras

from keras_tuner import utils
from keras_tuner.protos import keras_tuner_pb2
Expand Down Expand Up @@ -145,8 +144,8 @@ def to_proto(self):


def deserialize(config):
return keras.utils.deserialize_keras_object(config, module_objects=ALL_CLASSES)
return utils.deserialize_keras_object(config, module_objects=ALL_CLASSES)


def serialize(obj):
return keras.utils.serialize_keras_object(obj)
return utils.serialize_keras_object(obj)
5 changes: 3 additions & 2 deletions keras_tuner/engine/hyperparameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from tensorflow import keras

from keras_tuner import utils
from keras_tuner.engine.hyperparameters import hp_types
from keras_tuner.engine.hyperparameters.hp_types import Boolean
from keras_tuner.engine.hyperparameters.hp_types import Choice
Expand All @@ -32,8 +33,8 @@


def deserialize(config):
return keras.utils.deserialize_keras_object(config, module_objects=ALL_CLASSES)
return utils.deserialize_keras_object(config, module_objects=ALL_CLASSES)


def serialize(obj):
return keras.utils.serialize_keras_object(obj)
return utils.serialize_keras_object(obj)
5 changes: 3 additions & 2 deletions keras_tuner/engine/hyperparameters/hp_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from tensorflow import keras

from keras_tuner import utils
from keras_tuner.engine.hyperparameters.hp_types.boolean_hp import Boolean
from keras_tuner.engine.hyperparameters.hp_types.choice_hp import Choice
from keras_tuner.engine.hyperparameters.hp_types.fixed_hp import Fixed
Expand All @@ -32,8 +33,8 @@


def deserialize(config):
return keras.utils.deserialize_keras_object(config, module_objects=ALL_CLASSES)
return utils.deserialize_keras_object(config, module_objects=ALL_CLASSES)


def serialize(obj):
return keras.utils.serialize_keras_object(obj)
return utils.serialize_keras_object(obj)
21 changes: 19 additions & 2 deletions keras_tuner/engine/metrics_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import numpy as np
import six
Expand Down Expand Up @@ -332,10 +333,26 @@ def infer_metric_direction(metric):
return "max"

try:
metric = keras.metrics.get(metric_name)
if (
"use_legacy_format"
in inspect.getargspec(keras.metrics.deserialize).args
):
metric = keras.metrics.deserialize(
metric_name, use_legacy_format=True
)
else:
metric = keras.metrics.deserialize(metric_name)
except ValueError:
try:
metric = keras.losses.get(metric_name)
if (
"use_legacy_format"
in inspect.getargspec(keras.losses.deserialize).args
):
metric = keras.losses.deserialize(
metric_name, use_legacy_format=True
)
else:
metric = keras.losses.deserialize(metric_name)
except Exception:
# Direction can't be inferred.
return None
Expand Down
18 changes: 18 additions & 0 deletions keras_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ def check_tf_version():
)


def serialize_keras_object(obj):
if hasattr(tf.keras.utils, "legacy"):
return tf.keras.utils.legacy.serialize_keras_object(obj)
else:
return tf.keras.utils.serialize_keras_object(obj)


def deserialize_keras_object(config, module_objects=None, custom_objects=None):
if hasattr(tf.keras.utils, "legacy"):
return tf.keras.utils.legacy.deserialize_keras_object(
config, custom_objects, module_objects
)
else:
return tf.keras.utils.deserialize_keras_object(
config, custom_objects, module_objects
)


def to_list(values):
if isinstance(values, list):
return values
Expand Down

0 comments on commit 2f77c7a

Please sign in to comment.