-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use of tf.where
in model breaks (de)serialization
#20409
Comments
The root cause is that when building from the config, Keras doesn’t know the dtype of the inputs and defaults to You should either modify @keras.saving.register_keras_serializable(name="SimpleModel")
class SimpleModel(keras.Model):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.flatten = keras.layers.Flatten()
self.dense = keras.layers.Dense(2)
def call(self, input):
masked_image = tf.where(input["mask"], input["image"], 0)
flat = self.flatten(masked_image)
return {"label": self.dense(flat)}
def build_from_config(self, config):
image_shape = config["input_shape"]["image"]
self.flatten.build(image_shape)
output_shape = self.flatten.compute_output_shape(image_shape)
self.dense.build(output_shape) @keras.saving.register_keras_serializable(name="SimpleModel")
class SimpleModel(keras.Model):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.flatten = keras.layers.Flatten()
self.dense = keras.layers.Dense(2)
def call(self, input):
masked_image = tf.where(tf.cast(input["mask"], "bool"), input["image"], 0)
flat = self.flatten(masked_image)
return {"label": self.dense(flat)} |
@james77777778 thanks for the suggestions. Regarding |
@matemijolovic In your case, you could try build the model using the functional API which is widely adopted by KerasHub. This way, the dtype of the inputs can be specified, avoiding any overhead. Here's an example: import tensorflow as tf
import keras
class CustomMask(keras.layers.Layer):
def call(self, image, mask):
return keras.ops.where(mask, image, 0)
def create_simple_model():
image_input = keras.layers.Input(shape=(32, 32, 3), name="image")
mask_input = keras.layers.Input(
shape=(32, 32, 3), name="mask", dtype="bool"
)
custom_mask = CustomMask()
flatten = keras.layers.Flatten()
dense = keras.layers.Dense(2)
x = custom_mask(image_input, mask_input)
x = flatten(x)
label = dense(x)
return keras.Model(
inputs={"image": image_input, "mask": mask_input},
outputs={"label": label},
)
def create_dataset():
label_index = tf.constant([0] * 32)
image_wh = 32
features = {
"image": tf.random.uniform([32, image_wh, image_wh, 3], 0, 255),
"mask": tf.random.uniform([32, image_wh, image_wh, 3], 0, 1) < 0.5,
}
labels = {
"label": label_index,
}
return features, labels
x, y = create_dataset()
for run_eagerly in [False, True]:
print("run_eagerly:", run_eagerly)
simple_model = create_simple_model()
simple_model.compile(
optimizer=keras.optimizers.Adam(),
loss="sparse_categorical_crossentropy",
run_eagerly=run_eagerly,
)
simple_model.fit(x, y, epochs=1, steps_per_epoch=8)
optimizer_iterations_1 = int(simple_model.optimizer.iterations)
simple_model.save(f"simple_model-eager-{str(run_eagerly)}.keras")
ss_restored_model = keras.models.load_model(
f"simple_model-eager-{str(run_eagerly)}.keras"
)
current_iterations = int(ss_restored_model.optimizer.iterations)
print(
f"actual iterations count before serialization: {optimizer_iterations_1}, "
f"iterations after deserialization: {current_iterations}"
) Outputs: run_eagerly: False
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.8529
actual iterations count before serialization: 8, iterations after deserialization: 8
run_eagerly: True
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 5.5751
actual iterations count before serialization: 8, iterations after deserialization: 8 |
Description
When the Keras model contains
tf.where
somewhere in the graph, it breaks model (de)serialization. The problem doesn't happen ifkeras.ops.where
is used instead. There is no mention of this in Keras 2 Keras 3 incompatibilities list.The bug is a bit tricky because only a warning message is logged. In our case it happened that the stateful optimizer state didn't get restored properly on the deserialization step, so the optimizer settings were reset and the training failed to resume correctly.
If for some reason
tf.where
is not supported anymore, I think it should fail earlier in the serialization step.Environment
MacOS 15.0.1
Python 3.12.7
Tensorflow 2.17.0
Keras 3.6.0
Minimum reproducible example
I tried to eliminate the possibility of eager/graph mode issue, so the check runs for both modes.
Two warnings are logged:
It's not obvious how implementing
build_from_config
andget_build_config
could help to circumvent the issue.I found some other relevant StackOverflow thread about
tf.where
issues in TF 2.17.0: link.The text was updated successfully, but these errors were encountered: