Skip to content

Commit

Permalink
update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jotix16 committed Jun 29, 2021
1 parent 1cb1ae4 commit 1292e80
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,8 @@ def __init__(self, start, size, min_size=None, **kwargs):

if seq_lens is not None:
mask = tf.greater_equal(
tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,Tn)
mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,1,1,..,Tn,1)
tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,T1,..,Tn)
mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,T1,..,Tn,1,..)
slices = where_bc(mask, tf.zeros_like(slices), slices)

self.output.size_placeholder = x.size_placeholder.copy()
Expand Down

0 comments on commit 1292e80

Please sign in to comment.