Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix return_attention_scores bug #1977

Merged

Conversation

abheesht17
Copy link
Collaborator

No description provided.

query=x,
value=x,
attention_mask=self_attention_mask,
return_attention_scores=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking to see how return_attention_scores is used in MHA and looking at this, I was thinking if setting return_attention_scores=True here could cause any problems. If we set it True here, the user doesn't have a way to disable it if they want to use flash attention, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good spot. I'll keep the if...else, then

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can still get rid of the if-else block by directly passing the flag to MHA (setting it like this return_attention_scores=return_attention_scores). We don't have to pass the flag only when it's True. wdyt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, but in the case of True, it returns two elements, whereas in the case of False, it returns just one element.

If we want to do it that way, it'd have to be something like this:

            ...
            attention_layer_output = self._self_attention_layer(
                query=x,
                value=x,
                attention_mask=self_attention_mask,
                return_attention_scores=return_attention_scores,
                training=training,
            )

            if return_attention_scores:
                x, attention_scores = attention_layer_output
            else:
                x = attention_layer_output
           ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! Let's keep it the way it is then!

@SamanehSaadat
Copy link
Member

@abheesht17 I was wondering if you've checked why one of the tests is failing.

@SamanehSaadat
Copy link
Member

@abheesht17 I was wondering if you've checked why one of the tests is failing.

I think this might be the root cause: https://github.com/keras-team/keras/pull/20448/files#r1837037902

Copy link
Member

@SamanehSaadat SamanehSaadat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Abheesht, for catching and fixing this bug!

@SamanehSaadat
Copy link
Member

I'll merge as the test failure is not related to this PR.

@SamanehSaadat SamanehSaadat merged commit d97db05 into keras-team:master Nov 11, 2024
6 of 7 checks passed
@abheesht17
Copy link
Collaborator Author

@abheesht17 I was wondering if you've checked why one of the tests is failing.

I think this might be the root cause: https://github.com/keras-team/keras/pull/20448/files#r1837037902

I'll take a look later today! Sorry for the delay

@SamanehSaadat
Copy link
Member

@abheesht17 I was wondering if you've checked why one of the tests is failing.

I think this might be the root cause: https://github.com/keras-team/keras/pull/20448/files#r1837037902

I'll take a look later today! Sorry for the delay

All good! The test failure was unrelated to this PR and it should be fixed now (keras-team/keras#20482).

@abheesht17
Copy link
Collaborator Author

@abheesht17 I was wondering if you've checked why one of the tests is failing.

I think this might be the root cause: https://github.com/keras-team/keras/pull/20448/files#r1837037902

I'll take a look later today! Sorry for the delay

All good! The test failure was unrelated to this PR and it should be fixed now (keras-team/keras#20482).

Awesome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants