From 2f77c7a24d8333299dd8c38f8bb2dbfcbdee6347 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Wed, 8 Feb 2023 18:49:52 -0800 Subject: [PATCH] Preparation for switch in Keras serialization format (#834) --- keras_tuner/engine/conditions.py | 5 ++--- .../engine/hyperparameters/__init__.py | 5 +++-- .../hyperparameters/hp_types/__init__.py | 5 +++-- keras_tuner/engine/metrics_tracking.py | 21 +++++++++++++++++-- keras_tuner/utils.py | 18 ++++++++++++++++ 5 files changed, 45 insertions(+), 9 deletions(-) diff --git a/keras_tuner/engine/conditions.py b/keras_tuner/engine/conditions.py index e95600355..aa57c857f 100644 --- a/keras_tuner/engine/conditions.py +++ b/keras_tuner/engine/conditions.py @@ -17,7 +17,6 @@ import abc import six -from tensorflow import keras from keras_tuner import utils from keras_tuner.protos import keras_tuner_pb2 @@ -145,8 +144,8 @@ def to_proto(self): def deserialize(config): - return keras.utils.deserialize_keras_object(config, module_objects=ALL_CLASSES) + return utils.deserialize_keras_object(config, module_objects=ALL_CLASSES) def serialize(obj): - return keras.utils.serialize_keras_object(obj) + return utils.serialize_keras_object(obj) diff --git a/keras_tuner/engine/hyperparameters/__init__.py b/keras_tuner/engine/hyperparameters/__init__.py index 8c4cb3041..d3f249491 100644 --- a/keras_tuner/engine/hyperparameters/__init__.py +++ b/keras_tuner/engine/hyperparameters/__init__.py @@ -14,6 +14,7 @@ from tensorflow import keras +from keras_tuner import utils from keras_tuner.engine.hyperparameters import hp_types from keras_tuner.engine.hyperparameters.hp_types import Boolean from keras_tuner.engine.hyperparameters.hp_types import Choice @@ -32,8 +33,8 @@ def deserialize(config): - return keras.utils.deserialize_keras_object(config, module_objects=ALL_CLASSES) + return utils.deserialize_keras_object(config, module_objects=ALL_CLASSES) def serialize(obj): - return keras.utils.serialize_keras_object(obj) + return utils.serialize_keras_object(obj) diff --git a/keras_tuner/engine/hyperparameters/hp_types/__init__.py b/keras_tuner/engine/hyperparameters/hp_types/__init__.py index 9914880c9..d20c7e16a 100644 --- a/keras_tuner/engine/hyperparameters/hp_types/__init__.py +++ b/keras_tuner/engine/hyperparameters/hp_types/__init__.py @@ -14,6 +14,7 @@ from tensorflow import keras +from keras_tuner import utils from keras_tuner.engine.hyperparameters.hp_types.boolean_hp import Boolean from keras_tuner.engine.hyperparameters.hp_types.choice_hp import Choice from keras_tuner.engine.hyperparameters.hp_types.fixed_hp import Fixed @@ -32,8 +33,8 @@ def deserialize(config): - return keras.utils.deserialize_keras_object(config, module_objects=ALL_CLASSES) + return utils.deserialize_keras_object(config, module_objects=ALL_CLASSES) def serialize(obj): - return keras.utils.serialize_keras_object(obj) + return utils.serialize_keras_object(obj) diff --git a/keras_tuner/engine/metrics_tracking.py b/keras_tuner/engine/metrics_tracking.py index d8cd723f9..703c2a24d 100644 --- a/keras_tuner/engine/metrics_tracking.py +++ b/keras_tuner/engine/metrics_tracking.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import numpy as np import six @@ -332,10 +333,26 @@ def infer_metric_direction(metric): return "max" try: - metric = keras.metrics.get(metric_name) + if ( + "use_legacy_format" + in inspect.getargspec(keras.metrics.deserialize).args + ): + metric = keras.metrics.deserialize( + metric_name, use_legacy_format=True + ) + else: + metric = keras.metrics.deserialize(metric_name) except ValueError: try: - metric = keras.losses.get(metric_name) + if ( + "use_legacy_format" + in inspect.getargspec(keras.losses.deserialize).args + ): + metric = keras.losses.deserialize( + metric_name, use_legacy_format=True + ) + else: + metric = keras.losses.deserialize(metric_name) except Exception: # Direction can't be inferred. return None diff --git a/keras_tuner/utils.py b/keras_tuner/utils.py index bb71655f6..be62421e3 100644 --- a/keras_tuner/utils.py +++ b/keras_tuner/utils.py @@ -64,6 +64,24 @@ def check_tf_version(): ) +def serialize_keras_object(obj): + if hasattr(tf.keras.utils, "legacy"): + return tf.keras.utils.legacy.serialize_keras_object(obj) + else: + return tf.keras.utils.serialize_keras_object(obj) + + +def deserialize_keras_object(config, module_objects=None, custom_objects=None): + if hasattr(tf.keras.utils, "legacy"): + return tf.keras.utils.legacy.deserialize_keras_object( + config, custom_objects, module_objects + ) + else: + return tf.keras.utils.deserialize_keras_object( + config, custom_objects, module_objects + ) + + def to_list(values): if isinstance(values, list): return values