Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port BeamSampler to core #1181

Merged
merged 7 commits into from
Jul 31, 2023
Merged

Conversation

shivance
Copy link
Collaborator

Closes #1156

@shivance
Copy link
Collaborator Author

Hey @mattdangerw !
Could you tell the alternative for these in core?

  1. tf.nest.map_structure
  2. tf.newaxis
  3. tf.while_loop
  4. tf.ensure_shape
  5. tf.reduce_any
  6. tf.reduce_all

@abheesht17
Copy link
Collaborator

  1. ops.expand_dims
  2. ops.while_loop
  3. You can probably just do x.shape == (…) or something?
  4. ops.any
  5. ops.all

@@ -186,16 +187,16 @@ def body(prompt, cache, index, log_probs):

def gather_beams(x):
x = unflatten_beams(x)
x = tf.gather(x, beam_indices, axis=1, batch_dims=1)
x = ops.take_along_axis(x, beam_indices, axis=1, batch_dims=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_dims is not a argument to take_along_axis

Copy link
Collaborator Author

@shivance shivance Jul 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line will throw dimension mismatch between beam_indicea and x
[Update] : CI reflects it for Torch Backend

@mattdangerw
Copy link
Member

For tf.ensure_shape, ideally we can ditch it.

@mattdangerw
Copy link
Member

For tf.nest.map_structure, just leave as is for now. On a separate PR, we can update to the dm_nest module.

@shivance
Copy link
Collaborator Author

For tf.ensure_shape, ideally we can ditch it.

There is this comment at line no. 180 in the file

# Ensure shape is `[None]`, otherwise it causes issues after
# converting to TFLite.

)
sorted_log_probs = tf.gather(
all_log_probs, sorted_indices, axis=-1, batch_dims=1
sorted_indices = ops.argsort(all_log_probs, axis=-1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have option to set direction in argsort. Because of this couple of tests fail in local with assertion error.

The possible solution could be to reverse the array sorted_indices along axis=-1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops.argsort(-all_log_probs, axis=-1) should do the trick

Copy link
Collaborator Author

@shivance shivance Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! This fixes the failing test for tf backend

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need the minus sign in front so that it is sorted in descending order

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flip does the job here.

Great! This fixes the failing test for tf backend

I was referring to ops.flip here 😂. Probably you commented at the same time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:o. Missed the ops.flip line

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think slipping the sign of the log probs would be more efficient, probably worth doing that and avoiding flip.

@mattdangerw
Copy link
Member

mattdangerw commented Jul 27, 2023

Ensure shape is [None], otherwise it causes issues after converting to TFLite.

Yeah, I would still ditch and see what happens. This isn't a fix we should need to carry around.

@mattdangerw
Copy link
Member

Worked through a lot of this while working through #1187, this ended up being quite tricky. Pushed some fixes.

@mattdangerw
Copy link
Member

/gcbrun

@mattdangerw
Copy link
Member

(this last failure is an OOM, unrelated I think)

@shivance
Copy link
Collaborator Author

shivance commented Jul 29, 2023

/gcbrun

@mattdangerw mattdangerw merged commit c4c8c36 into keras-team:master Jul 31, 2023
8 of 9 checks passed
@shivance shivance deleted the beam-sampler branch August 5, 2023 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Port BeamSampler to multi-backend keras
3 participants