-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support custom cell/RNN layers with extension types #20485
Comments
Hi @Johansmm - Thanks for reporting the issue. Here in your code there are multiple corrections required.
Attached gist for your reference. |
Hi @mehtamansi29 thanks for your response. However, I don't think I have been clear in explaining what the purpose of the code is: I want to know how I can write RNN models with custom cell that works with Extension types tensors. For this, I have written the example to show my problem based on :
I hope that my model can make inferences with import tensorflow as tf
mt = MaskedTensor(tf.random.uniform((2,10,5)), tf.ones((2,10,5)))
model(mt) Error:
I hope that the purpose of my question is clearer. |
Hi @Johansmm - Here is the reference where you can create custom RNN layer using subclassing. And for creating RNN models with custom cell that works with Extension types tensors.
Attached gist for the reference. |
Hi @mehtamansi29, I do not thing to understand your example, because you are creating a With your example it is not possible to make an inference with input_spec: y = model(input_spec)
# Raise the following error:
# ValueError: Inputs to a layer should be tensors. Got 'MaskedTensor .... |
Issue type
Feature Request
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
2.15
Custom code
Yes
OS platform and distribution
Windows 11
Mobile device
No response
Python version
3.11
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
I want to write a Keras-like model with keras.layers.RNN that supports Extension types, both for inputs and states.
Standalone code to reproduce the issue
Relevant log output
The text was updated successfully, but these errors were encountered: