Skip to content

Commit

Permalink
Fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Jul 29, 2023
1 parent c4112d7 commit 3752770
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def body(prompt, cache, index, log_probs):
next_token = flatten_beams(indices % vocab_size)
# We need `ensure_shape` as `top_k` will change the static shape.
next_log_probs = flatten_beams(next_log_probs)
# Work around tensor shape when not using XLA.
# Work around for top_k output shape on tf backend.
if isinstance(log_probs, tf.Tensor):
log_probs = tf.ensure_shape(next_log_probs, log_probs.shape)
else:
Expand Down

0 comments on commit 3752770

Please sign in to comment.