Skip to content

Commit

Permalink
DimensionTag get_for_batch size identity fix control flow ctx
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Sep 2, 2021
1 parent b738ab8 commit bbe57d1
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ def get_for_batch(self, batch):
# when there are different beams with same beam size!
# This breaks the current logic in get_tag_from_size_tensor.
# As a workaround, we make an explicit new tensor here.
from .basic import get_valid_scope_name_from_str
dyn_size_ext.placeholder = tf.identity(
dyn_size_ext.placeholder,
name=get_valid_scope_name_from_str("%s_size_beam_%s" % (dyn_size_ext.name, batch.beam.name)))
from .basic import get_valid_scope_name_from_str, same_control_flow_ctx
with same_control_flow_ctx(dyn_size_ext.placeholder):
dyn_size_ext.placeholder = tf.identity(
dyn_size_ext.placeholder,
name=get_valid_scope_name_from_str("%s_identity_for_beam_%s" % (dyn_size_ext.name, batch.beam.name)))
dyn_size_ext.placeholder._RETURNN_dyn_size_beam = batch.beam
dyn_size_ext.placeholder._RETURNN_beam_expanded_base_data = beam_expanded_base_data
else:
Expand Down

0 comments on commit bbe57d1

Please sign in to comment.