-
Notifications
You must be signed in to change notification settings - Fork 243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port BeamSampler
to core
#1181
Port BeamSampler
to core
#1181
Conversation
Hey @mattdangerw !
|
|
keras_nlp/samplers/beam_sampler.py
Outdated
@@ -186,16 +187,16 @@ def body(prompt, cache, index, log_probs): | |||
|
|||
def gather_beams(x): | |||
x = unflatten_beams(x) | |||
x = tf.gather(x, beam_indices, axis=1, batch_dims=1) | |||
x = ops.take_along_axis(x, beam_indices, axis=1, batch_dims=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_dims
is not a argument to take_along_axis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line will throw dimension mismatch between beam_indicea and x
[Update] : CI reflects it for Torch Backend
For |
For |
There is this comment at line no. 180 in the file
|
keras_nlp/samplers/beam_sampler.py
Outdated
) | ||
sorted_log_probs = tf.gather( | ||
all_log_probs, sorted_indices, axis=-1, batch_dims=1 | ||
sorted_indices = ops.argsort(all_log_probs, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have option to set direction in argsort. Because of this couple of tests fail in local with assertion error.
The possible solution could be to reverse the array sorted_indices
along axis=-1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ops.argsort(-all_log_probs, axis=-1) should do the trick
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! This fixes the failing test for tf backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need the minus sign in front so that it is sorted in descending order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flip does the job here.
Great! This fixes the failing test for tf backend
I was referring to ops.flip here 😂. Probably you commented at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:o. Missed the ops.flip
line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think slipping the sign of the log probs would be more efficient, probably worth doing that and avoiding flip
.
Yeah, I would still ditch and see what happens. This isn't a fix we should need to carry around. |
Worked through a lot of this while working through #1187, this ended up being quite tricky. Pushed some fixes. |
/gcbrun |
(this last failure is an OOM, unrelated I think) |
/gcbrun |
Closes #1156