Skip to content

Commit

Permalink
Fix beam sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Jul 29, 2023
1 parent 657cc4f commit c4112d7
Showing 1 changed file with 37 additions and 36 deletions.
73 changes: 37 additions & 36 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
"""Beam Sampler."""

import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring
from keras_nlp.utils.tensor_utils import tensor_to_list


@format_docstring(call_args=call_args_docstring)
Expand Down Expand Up @@ -53,7 +51,7 @@ class BeamSampler(Sampler):
batch_size, length, vocab_size = 1, 12, len(int_lookup)
def next(prompt, cache, index):
prompt_batch_size = ops.shape(prompt)[0]
prompt_batch_size = tf.shape(prompt)[0]
hidden_states = np.ones((prompt_batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((prompt_batch_size, vocab_size))
Expand All @@ -76,7 +74,7 @@ def next(prompt, cache, index):
batch_size, length, vocab_size = 1, 8, len(int_lookup)
def next(prompt, cache, index):
prompt_batch_size = ops.shape(prompt)[0]
prompt_batch_size = tf.shape(prompt)[0]
hidden_states = np.ones((prompt_batch_size, 10))
# A uniform distribution over our alphabet.
logits = np.ones((batch_size, vocab_size))
Expand Down Expand Up @@ -118,39 +116,37 @@ def __call__(
hidden_states=None,
):
batch_size, max_length = ops.shape(prompt)[0], ops.shape(prompt)[1]
# Make sure max length and start index are the same dtype.
# index = ops.cast(index, max_length.dtype)
index = ops.cast(index, "int32")

def create_beams(x):
"""Add initial beam state."""
return ops.repeat(x, self.num_beams, axis=0)

def flatten_beams(x):
"""Combine the beam dim and batch dim."""
flat_shape = [batch_size * self.num_beams] + tensor_to_list(
x.shape
)[2:]
return ops.reshape(x, new_shape=flat_shape)
flat_shape = (batch_size * self.num_beams,) + tuple(x.shape)[2:]
return ops.reshape(x, flat_shape)

def unflatten_beams(x):
"""Separate the beam dim and batch dim."""
unflat_shape = [batch_size, self.num_beams] + tensor_to_list(
x.shape
)[1:]
return ops.reshape(x, new_shape=unflat_shape)
unflat_shape = (batch_size, self.num_beams) + tuple(x.shape)[1:]
return ops.reshape(x, unflat_shape)

if mask is None:
mask = ops.zeros_like(prompt, dtype="bool")
else:
mask = ops.cast(mask, dtype="bool")
# `ops.while_loop` will not accept `None` as a value for `loop_vars`.
cache = () if cache is None else cache
has_cache = cache is not None
cache = cache if has_cache else ()
# Add extra sequences for each beam.
prompt, mask = create_beams(prompt), create_beams(mask)
cache = tf.nest.map_structure(create_beams, cache)
# Setup the initial beam log-likelihoods.
# On the first loop, make sure only the original beam is considered.
log_probs = [[0.0] + [-1e9] * (self.num_beams - 1)]
log_probs = ops.array(
[[0.0] + [-1e9] * (self.num_beams - 1)], dtype="float32"
)
log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0))

def cond(prompt, cache, index, log_probs):
Expand All @@ -159,7 +155,7 @@ def cond(prompt, cache, index, log_probs):
# Stop if all sequences have produced a *new* end_token_id.
end_tokens = (prompt == end_token_id) & (~mask)
prompt_done = ops.any(end_tokens, axis=-1)
return not ops.all(prompt_done)
return ops.logical_not(ops.all(prompt_done))

def body(prompt, cache, index, log_probs):
# Compute the softmax distribution for the next token.
Expand All @@ -168,45 +164,47 @@ def body(prompt, cache, index, log_probs):
probs = keras.activations.softmax(logits / self.temperature)

# Compute the running log-likelihood of each new candidate.
next_log_probs = ops.log(probs) + log_probs[..., tf.newaxis]
next_log_probs = ops.log(probs) + log_probs[..., None]
# Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`.
next_log_probs = ops.reshape(
next_log_probs, new_shape=[batch_size, -1]
)
next_log_probs = ops.reshape(next_log_probs, [batch_size, -1])

# Compute the top beam indices and next tokens.
next_log_probs, indices = ops.top_k(
next_log_probs, k=self.num_beams, sorted=False
)
beam_indices = indices // vocab_size
next_token = flatten_beams(indices % vocab_size)
# Ensure shape is `[None]`, otherwise it causes issues after
# converting to TFLite.
# next_token = tf.ensure_shape(next_token, [None])
# We need `ensure_shape` as `top_k` will change the static shape.
next_log_probs = flatten_beams(next_log_probs)
# log_probs = tf.ensure_shape(next_log_probs, log_probs.shape)
# Work around tensor shape when not using XLA.
if isinstance(log_probs, tf.Tensor):
log_probs = tf.ensure_shape(next_log_probs, log_probs.shape)
else:
log_probs = next_log_probs

def gather_beams(x):
x = unflatten_beams(x)
print(x.shape, beam_indices.shape)
x = ops.take_along_axis(x, beam_indices, axis=1)
indices = beam_indices
for axis in range(2, len(x.shape)):
indices = ops.expand_dims(indices, axis=axis)
x = ops.take_along_axis(x, indices, axis=1)
return flatten_beams(x)

prompt = gather_beams(prompt)
cache = tf.nest.map_structure(gather_beams, cache)
if has_cache:
cache = tf.nest.map_structure(gather_beams, cache)

# Update each beam with the next token.
next_token = ops.cast(next_token, prompt.dtype)
# Don't overwrite anywhere mask is True.
next_token = ops.where(mask[:, index], prompt[:, index], next_token)
# Update the prompt with the next token.
next_token = next_token[:, tf.newaxis]
prompt = dynamic_update_slice(prompt, next_token, [0, index])
next_token = next_token[:, None]
prompt = ops.slice_update(prompt, [0, index], next_token)
# Return the iteration of the loop state.
return (prompt, cache, index + 1, log_probs)

prompt, _, _, log_probs = ops.while_loop(
prompt, _, _, log_probs = self.run_loop(
cond=cond,
body=body,
loop_vars=(prompt, cache, index, log_probs),
Expand All @@ -217,18 +215,21 @@ def gather_beams(x):
all_log_probs = unflatten_beams(log_probs)

if self.return_all_beams:
sorted_indices = ops.argsort(all_log_probs, axis=-1)
sorted_indices = ops.flip(sorted_indices, axis=-1)
sorted_indices = ops.argsort(-all_log_probs, axis=-1)
sorted_log_probs = ops.take_along_axis(
all_log_probs, sorted_indices, axis=-1
all_log_probs,
sorted_indices,
axis=1,
)
sorted_prompts = ops.take_along_axis(
all_prompts, sorted_indices, axis=1
all_prompts,
ops.expand_dims(sorted_indices, -1),
axis=1,
)
return sorted_prompts, sorted_log_probs
else:
# Gather the top beam at each batch index.
top_beams = ops.argmax(all_log_probs, axis=-1)[:, tf.newaxis]
top_beams = ops.argmax(all_log_probs, axis=-1)[:, None, None]
prompt = ops.take_along_axis(all_prompts, top_beams, axis=1)
return ops.squeeze(prompt, axis=1)

Expand Down

0 comments on commit c4112d7

Please sign in to comment.