-
Notifications
You must be signed in to change notification settings - Fork 243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix return_attention_scores
bug
#1977
Fix return_attention_scores
bug
#1977
Conversation
query=x, | ||
value=x, | ||
attention_mask=self_attention_mask, | ||
return_attention_scores=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking to see how return_attention_scores
is used in MHA and looking at this, I was thinking if setting return_attention_scores=True
here could cause any problems. If we set it True
here, the user doesn't have a way to disable it if they want to use flash attention, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good spot. I'll keep the if...else, then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can still get rid of the if-else
block by directly passing the flag to MHA (setting it like this return_attention_scores=return_attention_scores
). We don't have to pass the flag only when it's True
. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, but in the case of True, it returns two elements, whereas in the case of False, it returns just one element.
If we want to do it that way, it'd have to be something like this:
...
attention_layer_output = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
return_attention_scores=return_attention_scores,
training=training,
)
if return_attention_scores:
x, attention_scores = attention_layer_output
else:
x = attention_layer_output
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right! Let's keep it the way it is then!
@abheesht17 I was wondering if you've checked why one of the tests is failing. |
I think this might be the root cause: https://github.com/keras-team/keras/pull/20448/files#r1837037902 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Abheesht, for catching and fixing this bug!
I'll merge as the test failure is not related to this PR. |
I'll take a look later today! Sorry for the delay |
All good! The test failure was unrelated to this PR and it should be fixed now (keras-team/keras#20482). |
Awesome! |
No description provided.