Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masked LSTM support #2030

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,5 +817,164 @@ def func(x):
return tf.identity(y[0], name="output")
self.run_test_case(func, {"input:0": x_val}, [], ["output:0"], rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_embedding_unidirectional(self):
for go_backwards in [True, False]:
timesteps = 4
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val = np.array([
[1, 2, 3, 4],
[5, 6, 0, 0],
[0, 0, 0, 0]
], dtype=np.int32)

model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
x_embedding = tf.keras.layers.Embedding(
input_dim=10,
output_dim=5,
mask_zero=True,
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
)(model_in)

# RNN layer inherits the mask propagated from above embedding layer
model_out = tf.keras.layers.LSTM(
units=5,
go_backwards=go_backwards,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)(x_embedding)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
return(tf.identity(y[1], name="output_yh"),
tf.identity(y[2], name="output_yc"))

output_list = ["output_yh:0", "output_yc:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_embedding_bidirectional(self):
timesteps = 4
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val = np.array([
[1, 2, 3, 4],
[5, 6, 0, 0],
[0, 0, 0, 0]
], dtype=np.int32)

model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
x_embedding = tf.keras.layers.Embedding(
input_dim=10,
output_dim=5,
mask_zero=True,
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
)(model_in)

# RNN layer inherits the mask propagated from above embedding layer
lstm_layer = tf.keras.layers.LSTM(
units=5,
go_backwards=False,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_embedding)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
return(tf.identity(y[1], name="output_yh_f"),
tf.identity(y[2], name="output_yc_f"),
tf.identity(y[3], name="output_yh_r"),
tf.identity(y[4], name="output_yc_r"))

output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
require_lstm_count=2)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_unidirectional(self):
for go_backwards in [True, False]:
batch_size, timesteps, feat = 3, 4, 5
in_shape = (timesteps, feat)
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val[1, 2:, :] = 0.
x_val[2, :, :] = 0.

model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)

# RNN layer inherits the mask propagated from above mask layer
model_out = tf.keras.layers.LSTM(
units=5,
go_backwards=go_backwards,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)(x_masked)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
return(tf.identity(y[1], name="output_yh"),
tf.identity(y[2], name="output_yc"))

output_list = ["output_yh:0", "output_yc:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_bidirectional(self):
batch_size, timesteps, feat = 3, 4, 5
in_shape = (timesteps, feat)
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val[1, 2:, :] = 0.
x_val[2, :, :] = 0.

model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)

# RNN layer inherits the mask propagated from above mask layer
lstm_layer = tf.keras.layers.LSTM(
units=5,
go_backwards=False,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_masked)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
return (tf.identity(y[1], name="output_yh_f"),
tf.identity(y[2], name="output_yc_f"),
tf.identity(y[3], name="output_yh_r"),
tf.identity(y[4], name="output_yc_r"))

output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
require_lstm_count=2)


if __name__ == '__main__':
unittest_main()
27 changes: 18 additions & 9 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,15 +2300,24 @@ def version_10(cls, ctx, node, **kwargs):
const_axis_name = utils.make_name(f'const_{axis}')
const_axis = ctx.make_const(name=const_axis_name, np_val=np.array([axis], dtype=np.int64))

# Add a Constant node (seq_len) for ReverseSequence.
# Index 1 for the shape should not return 0, since rank(input) >=2
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
op_name_scope=rv2_node_name)
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
op_name_scope=rv2_node_name)
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])
# Add sequence_lens as ReverseSequence input
has_sequence_lens = node.get_attr_value("has_sequence_lens", False)
if not has_sequence_lens:
# Add a Constant node (seq_len) for ReverseSequence.
# Index 1 for the shape should not return 0, since rank(input) >=2
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
op_name_scope=rv2_node_name)
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
op_name_scope=rv2_node_name)
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])
else:
# masked backward LSTM:
# sequence_lens is appended to ReverseV2's input by lstm_tf2_rewriter
# to keep tensor post-padded after reverse
seq_lens_casted = ctx.make_node("Cast", [node.input[-1]], attr={'to': TensorProto.INT64}).output[0]
inputs.append(seq_lens_casted)

# Add a ReverseSequence node.

Expand Down
99 changes: 83 additions & 16 deletions tf2onnx/rewriter/lstm_tf2_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"""
tf2onnx.rewriter.lstm_tf2_rewriter - Rewrites LSTM pattern used by tf2.
"""

import logging
import numpy as np
from onnx import onnx_pb

from tf2onnx.graph_matcher import GraphMatcher
from tf2onnx.rewriter.rnn_utils import make_lstm_pattern
from tf2onnx.tf_loader import find_function
Expand Down Expand Up @@ -79,21 +81,35 @@ def rewriter_lstm_tf2(g, ops):
# extract output h_t
ht_mul = match_result.get_op("ht")
final_consumers = g.find_output_consumers(ht_mul.output[0])
select_ops = [n for n in final_consumers if n.type == "Select"]
select_ops = [n for n in final_consumers if n.type == "Select" or n.type == "SelectV2"]
def has_tensor_list_consumer(n):
return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0]))
select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]

# extract sequence length
seq_len_idx, mask_idx = None, None
if len(select_ops) == 1:
greater_eq = select_ops[0].inputs[0]
if greater_eq.type != "GreaterEqual":
continue
seq_len = greater_eq.inputs[1]
if not seq_len.is_graph_input():
select_op_condition = select_ops[0].inputs[0]
while select_op_condition.type == "Identity":
select_op_condition = select_op_condition.inputs[0]

# skip timestpes based on speicific sequence length
if select_op_condition.type == "GreaterEqual":
seq_len = select_op_condition.inputs[1]
if not seq_len.is_graph_input():
continue
seq_len_idx = g.input_names.index(seq_len.output[0])

# masked LSTM: skip timesteps based on dynamically-computed boolean mask tensor
elif select_op_condition.type == "TensorListGetItem":
mask = select_op_condition.inputs[0]
if not mask.is_graph_input():
continue
mask_idx = g.input_names.index(mask.output[0])
else:
continue
seq_len_idx = g.input_names.index(seq_len.output[0])

final_consumers = g.find_output_consumers(select_ops[0].output[0])
else:
seq_len_idx = None

tensor_set_items = [n for n in final_consumers if n.type == "TensorListSetItem"]
if len(tensor_set_items) != 1:
Expand Down Expand Up @@ -209,6 +225,7 @@ def has_tensor_list_consumer(n):
# Keras
"w_idx": gk_idx,
"r_idx": hk_idx,
"mask_idx": mask_idx,
}

for op in ops:
Expand Down Expand Up @@ -276,15 +293,63 @@ def has_tensor_list_consumer(n):
tensor_array_inp = op.inputs[body_context["x_idx"]]
if not tensor_array_inp.type == "TensorListFromTensor":
continue
context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0]

final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]])
output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"]
# parse sequence length
seq_len_idx = body_context["seq_len_idx"]
mask_idx = body_context["mask_idx"]
if seq_len_idx:
context.onnx_input_ids[0]["sequence_lens"] = op.input[seq_len_idx]
elif mask_idx:
logging.warning(
"Found mask-enabled LSTM. Converted ONNX model will only support post-padded LSTM input. "
"If input is pre- or randomly-padded, masked timesteps will not be correctly skipped.")

# parse sequence length
tensor_array_mask = op.inputs[body_context["mask_idx"]]
if not tensor_array_mask.type == "TensorListFromTensor":
continue
mask_mat = tensor_array_mask.input[0]
mask_mat_node = g.get_node_by_output(mask_mat)
is_mask_reverse = mask_mat_node.type == "ReverseV2"
# no need to reverse the mask sequence
# the positions of skipped timesteps per batch is irrelevant assuming post-padded input
if is_mask_reverse:
mask_mat = mask_mat_node.input[0]

# reduce mask tensor to sequence_lens assuming post-padded input
# tranpose (1,0,2) -> boolean mask tensor (N, timesteps, 1)
# squeeze on dim(-1) -> boolean mask matrix (N, timesteps)
# reduceSum on dim(-1) -> sequence_lens (N)
mask_transpose_node = g.make_node(op_type="Transpose", inputs=[mask_mat], attr={"perm": [1, 0, 2]})
mask_squeeze = GraphBuilder(g).make_squeeze({"data": mask_transpose_node.output[0], "axes": [-1]})
mask_cast_node = g.make_node(op_type="Cast", inputs=[mask_squeeze],
attr={"to": onnx_pb.TensorProto.INT32})
sequence_lens = GraphBuilder(g).make_reduce_sum({"data": mask_cast_node.output[0],
"axes": [-1], "keepdims": 0})
context.onnx_input_ids[0]["sequence_lens"] = sequence_lens

# handle backward LSTM
tensor_array_inp_producer = tensor_array_inp.inputs[0]
is_input_reverse = tensor_array_inp_producer.type == "ReverseV2"
# backward LSTM is identified by the reverses of both input and mask tensors pre-LSTM
if is_mask_reverse != is_input_reverse:
continue
if is_input_reverse:
# TF uses simple "ReverseV2" to reverse input tensor with no assumption on padding position
# because reversed mask with shape (batch_size, timesteps) is explicit per-timestep.
# ONNX requires "ReverseSequence" to keep the reversed input tensor post-padded because mask
# is implied by sequence_lens. This requires passing sequence_lens to such "ReverseSequence" op.

# Note: tensor op conversions run after rewriters. Appending sequence_lens as a "ReverseV2" input
# signalizes alternative behavior in "ReverseV2" conversion in onnx_opset/tensor.py.
tensor_array_inp_producer.set_attr("has_sequence_lens", True)
inp_reverse_inputs = tensor_array_inp_producer.input
inp_reverse_inputs.append(sequence_lens)

context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0]
if body_context["seq_len_idx"] is None:
context.onnx_input_ids[0]["sequence_lens"] = ""
else:
context.onnx_input_ids[0]["sequence_lens"] = op.input[body_context["seq_len_idx"]]
context.onnx_input_ids[0]["sequence_lens"] = ""

context.onnx_input_ids[0]["initial_c"] = initial_c
context.onnx_input_ids[0]["initial_h"] = initial_h

Expand All @@ -295,6 +360,8 @@ def has_tensor_list_consumer(n):
lstm_node = lstm_rewriter.create_rnn_node(context)[0]

squeeze_output = GraphBuilder(g).make_squeeze({"data": lstm_node.output[0], "axes": [1]})
final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]])
output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"]
for output in output_ys:
g.replace_all_inputs(output, squeeze_output)

Expand Down
Loading