Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] support qk head_dim different from vo head_dim #980

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

defei-coder
Copy link

@defei-coder defei-coder commented Jun 6, 2024

Support query/key head_dim different from value head_dim, fix issue-753 and issue-952.
Recently, DeepSeek-V2 proposed a new attention called MLA (Multi-head Latent Attention), which utilizes low-rank key-value union compression to eliminate the bottleneck of inference-time key-value cache, thus supporting efficient inference. MLA will use query/key head_dim=192 and value head_dim=128, but flashAttention not support the combination. Although this can be achieved by padding value head_dim from 128 to 192, but this way will increase global memory and hurt the performence.
In order to expand the versatility of flashAttention, I modify the code to support this ability. For compilation time considerations, only one combination is added, other combinations can be implemented by the user as needed.
Compared with padding value head_dim from 192 to 128, use query/key head_dim=192 and value head_dim=128 will save global memory and improve performence(forward will speedup about 15%, backward will speedup 5%).

@NiuMa-1234
Copy link

Hi, the latest flash-attn3 is released and it supports the head-dim of 64 and 128 only for now, do you plan to support more?

@bzantium
Copy link

supporting combinations of (256, 256), (128, 256) would be great.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Is it possible to relax V shape requirements to have different head dim than q/k?
3 participants