Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use of tf.where in model breaks (de)serialization #20409

Open
matemijolovic opened this issue Oct 25, 2024 · 3 comments
Open

Use of tf.where in model breaks (de)serialization #20409

matemijolovic opened this issue Oct 25, 2024 · 3 comments
Assignees
Labels

Comments

@matemijolovic
Copy link

matemijolovic commented Oct 25, 2024

Description

When the Keras model contains tf.where somewhere in the graph, it breaks model (de)serialization. The problem doesn't happen if keras.ops.where is used instead. There is no mention of this in Keras 2 Keras 3 incompatibilities list.

The bug is a bit tricky because only a warning message is logged. In our case it happened that the stateful optimizer state didn't get restored properly on the deserialization step, so the optimizer settings were reset and the training failed to resume correctly.

If for some reason tf.where is not supported anymore, I think it should fail earlier in the serialization step.

Environment

MacOS 15.0.1
Python 3.12.7
Tensorflow 2.17.0
Keras 3.6.0

Minimum reproducible example

I tried to eliminate the possibility of eager/graph mode issue, so the check runs for both modes.

import keras
import tensorflow as tf

@keras.saving.register_keras_serializable(name='SimpleModel')
class SimpleModel(keras.Model):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.flatten = keras.layers.Flatten()
        self.dense = keras.layers.Dense(2)

    def call(self, input):
        masked_image = tf.where(input["mask"], input["image"], 0)
        # masked_image = keras.ops.where(input["mask"], input["image"], 0)  # replace the line above with this one to avoid the issue
        flat = self.flatten(masked_image)
        return {"label": self.dense(flat)}

def create_dataset():
    label_index = tf.constant([0] * 32)
    image_wh = 32
    features = {
        "image": tf.random.uniform([32, image_wh, image_wh, 3], 0, 255),
        "mask": tf.random.uniform([32, image_wh, image_wh, 3], 0, 1) < 0.5,
    }
    labels = {
        "label": label_index,
    }
    return features, labels

x, y = create_dataset()

for run_eagerly in [False, True]:
    print("run_eagerly:", run_eagerly)
    simple_model = SimpleModel()
    simple_model.compile(
        optimizer=keras.optimizers.Adam(),
        loss='sparse_categorical_crossentropy',
        run_eagerly=run_eagerly,
    )
    simple_model.fit(x, y, epochs=1, steps_per_epoch=8)
    optimizer_iterations_1 = int(simple_model.optimizer.iterations)
    simple_model.save(f'simple_model-eager-{str(run_eagerly)}.keras')
    ss_restored_model = keras.models.load_model(f'simple_model-eager-{str(run_eagerly)}.keras')
    current_iterations = int(ss_restored_model.optimizer.iterations)
    print(
        f"actual iterations count before serialization: {optimizer_iterations_1}, "
          f"iterations after deserialization: {current_iterations}"
    )

Two warnings are logged:

UserWarning: Model 'simple_model' had a build config, but the model cannot be built automatically in `build_from_config(config)`. You should implement `def build_from_config(self, config)`, and you might also want to implement the method  that generates the config at saving time, `def get_build_config(self)`. The method `build_from_config()` is meant to create the state of the model (i.e. its variables) upon deserialization.
  instance.build_from_config(build_config)
UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 6 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))

It's not obvious how implementing build_from_config and get_build_config could help to circumvent the issue.

I found some other relevant StackOverflow thread about tf.where issues in TF 2.17.0: link.

@james77777778
Copy link
Contributor

The root cause is that when building from the config, Keras doesn’t know the dtype of the inputs and defaults to keras.config.floatx(), which causes tf.where to fail.

You should either modify build_from_config or add tf.cast to input["mask"].

@keras.saving.register_keras_serializable(name="SimpleModel")
class SimpleModel(keras.Model):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.flatten = keras.layers.Flatten()
        self.dense = keras.layers.Dense(2)

    def call(self, input):
        masked_image = tf.where(input["mask"], input["image"], 0)
        flat = self.flatten(masked_image)
        return {"label": self.dense(flat)}

    def build_from_config(self, config):
        image_shape = config["input_shape"]["image"]
        self.flatten.build(image_shape)
        output_shape = self.flatten.compute_output_shape(image_shape)
        self.dense.build(output_shape)
@keras.saving.register_keras_serializable(name="SimpleModel")
class SimpleModel(keras.Model):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.flatten = keras.layers.Flatten()
        self.dense = keras.layers.Dense(2)

    def call(self, input):
        masked_image = tf.where(tf.cast(input["mask"], "bool"), input["image"], 0)
        flat = self.flatten(masked_image)
        return {"label": self.dense(flat)}

@matemijolovic
Copy link
Author

matemijolovic commented Oct 28, 2024

@james77777778 thanks for the suggestions.
Do you maybe know if tf.cast incurs any runtime penalty when used this way, or is it closer to just being a specification (something like keras.layers.Input for functional models)?

Regarding build_from_config and the warning messages in general, I personally think the problem is just too hard to debug for anyone who is not familiar with the low-level details of Keras serialization. A clearer error/warning message might make it easier, though.

@james77777778
Copy link
Contributor

@matemijolovic
There might be a slight runtime penalty since your masks will be converted to float32 and then back to bool.
I also agree with your comments about the warning message.

In your case, you could try build the model using the functional API which is widely adopted by KerasHub. This way, the dtype of the inputs can be specified, avoiding any overhead.

Here's an example:

import tensorflow as tf

import keras


class CustomMask(keras.layers.Layer):
    def call(self, image, mask):
        return keras.ops.where(mask, image, 0)


def create_simple_model():
    image_input = keras.layers.Input(shape=(32, 32, 3), name="image")
    mask_input = keras.layers.Input(
        shape=(32, 32, 3), name="mask", dtype="bool"
    )
    custom_mask = CustomMask()
    flatten = keras.layers.Flatten()
    dense = keras.layers.Dense(2)
    x = custom_mask(image_input, mask_input)
    x = flatten(x)
    label = dense(x)
    return keras.Model(
        inputs={"image": image_input, "mask": mask_input},
        outputs={"label": label},
    )


def create_dataset():
    label_index = tf.constant([0] * 32)
    image_wh = 32
    features = {
        "image": tf.random.uniform([32, image_wh, image_wh, 3], 0, 255),
        "mask": tf.random.uniform([32, image_wh, image_wh, 3], 0, 1) < 0.5,
    }
    labels = {
        "label": label_index,
    }
    return features, labels


x, y = create_dataset()

for run_eagerly in [False, True]:
    print("run_eagerly:", run_eagerly)
    simple_model = create_simple_model()
    simple_model.compile(
        optimizer=keras.optimizers.Adam(),
        loss="sparse_categorical_crossentropy",
        run_eagerly=run_eagerly,
    )
    simple_model.fit(x, y, epochs=1, steps_per_epoch=8)
    optimizer_iterations_1 = int(simple_model.optimizer.iterations)
    simple_model.save(f"simple_model-eager-{str(run_eagerly)}.keras")
    ss_restored_model = keras.models.load_model(
        f"simple_model-eager-{str(run_eagerly)}.keras"
    )
    current_iterations = int(ss_restored_model.optimizer.iterations)
    print(
        f"actual iterations count before serialization: {optimizer_iterations_1}, "
        f"iterations after deserialization: {current_iterations}"
    )

Outputs:

run_eagerly: False
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.8529 
actual iterations count before serialization: 8, iterations after deserialization: 8
run_eagerly: True
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 5.5751  
actual iterations count before serialization: 8, iterations after deserialization: 8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants