From 3752770b23e79bc17e2bf8f9024dbceae921465e Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Sat, 29 Jul 2023 01:02:43 -0700 Subject: [PATCH] Fix comment --- keras_nlp/samplers/beam_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 8d1e94ccfe..6a076f3582 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -176,7 +176,7 @@ def body(prompt, cache, index, log_probs): next_token = flatten_beams(indices % vocab_size) # We need `ensure_shape` as `top_k` will change the static shape. next_log_probs = flatten_beams(next_log_probs) - # Work around tensor shape when not using XLA. + # Work around for top_k output shape on tf backend. if isinstance(log_probs, tf.Tensor): log_probs = tf.ensure_shape(next_log_probs, log_probs.shape) else: