Skip to content

Commit

Permalink
Paged Attention support for FA3
Browse files Browse the repository at this point in the history
  • Loading branch information
kadeng committed Oct 11, 2024
1 parent bedf877 commit a5bac6b
Show file tree
Hide file tree
Showing 10 changed files with 1,414 additions and 122 deletions.
391 changes: 391 additions & 0 deletions hopper/copy_paged_sm90_tma.hpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int page_num_blocks;

// The dropout probability (probability of keeping an activation).
float p_dropout;
Expand Down
45 changes: 39 additions & 6 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void set_params_fprop(Flash_fwd_params &params,

params.is_bf16 = q.dtype() == torch::kBFloat16;
params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;

params.page_num_blocks = 0;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
Expand Down Expand Up @@ -212,6 +212,7 @@ void set_params_dgrad(Flash_bwd_params &params,
params.dq_ptr = dq.data_ptr();
params.dk_ptr = dk.data_ptr();
params.dv_ptr = dv.data_ptr();
params.page_num_blocks = 0;
params.dq_row_stride = dq.stride(-3);
params.dk_row_stride = dk.stride(-3);
params.dv_row_stride = dv.stride(-3);
Expand Down Expand Up @@ -443,6 +444,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
int max_seqlen_q,
const int max_seqlen_k,
const float softmax_scale,
Expand Down Expand Up @@ -472,25 +474,46 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);

at::Tensor block_table;
const bool paged_KV = block_table_.has_value();
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}

const auto sizes = q.sizes();

const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int num_heads_k = k.size(1);
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);

void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();

const int total_q = q.sizes()[0];

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? -1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");

TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

CHECK_SHAPE(q, total_q, num_heads, head_size_og);
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);

if (!paged_KV) {
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}

CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
if (seqused_q.has_value()){
Expand Down Expand Up @@ -571,6 +594,17 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
params.total_q = total_q;
params.total_k = total_k;

if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0);
params.page_num_blocks = k.size(0);
}
params.page_block_size = page_block_size;
params.page_num_blocks = num_blocks;

//printf("mha_varlen_fwd: params.seqlen_k=%d, max_seqlen_k=%d, params.page_num_blocks=%d\n", (int)params.seqlen_k, (int)max_seqlen_k, (int)params.page_num_blocks);
if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
Expand All @@ -585,7 +619,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
}

return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
}

Expand Down
67 changes: 57 additions & 10 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,31 @@

from typing import Optional, Union

import torch
import torch.nn as nn

# isort: off
# We need to import the CUDA kernels after importing torch
import flashattn_hopper_cuda

import torch
import torch.nn as nn

# isort: on


def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None):

def _flash_attn_forward(
q,
k,
v,
softmax_scale,
causal,
window_size,
descale_q=None,
descale_k=None,
descale_v=None,
):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
q,
Expand Down Expand Up @@ -45,7 +57,7 @@ def _flash_attn_backward(
softmax_scale,
causal,
window_size,
deterministic=False
deterministic=False,
):
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand All @@ -67,6 +79,7 @@ def _flash_attn_backward(
)
return dq, dk, dv, softmax_d


def _flash_attn_varlen_forward(
q,
k,
Expand All @@ -77,10 +90,14 @@ def _flash_attn_varlen_forward(
max_seqlen_k,
softmax_scale,
causal,
block_table,
window_size=(-1, -1),
seqused_q=None,
seqused_k=None,
):
assert (
block_table is None or k.dtype != torch.float8_e4m3fn
), "Paged Attention / block_table is not supported for fp8 just yet"
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
Expand All @@ -92,6 +109,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_k,
seqused_q,
seqused_k,
block_table,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
Expand Down Expand Up @@ -218,7 +236,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand All @@ -238,6 +256,7 @@ def forward(
deterministic=False,
seqused_q=None,
seqused_k=None,
block_table=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -254,10 +273,18 @@ def forward(
window_size=window_size,
seqused_q=seqused_q,
seqused_k=seqused_k,
block_table=block_table,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
seqused_q, seqused_k
q,
k,
v,
out_padded,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
)
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
Expand All @@ -269,7 +296,9 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = (
ctx.saved_tensors
)
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward(
dout,
Expand All @@ -295,7 +324,22 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
return (
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)


def flash_attn_func(
Expand Down Expand Up @@ -389,6 +433,7 @@ def flash_attn_varlen_func(
deterministic=False,
seqused_q=None,
seqused_k=None,
block_table=None,
):
"""
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -424,6 +469,7 @@ def flash_attn_varlen_func(
query and output tokens in each sequence.
seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
key and value tokens in each sequence.
block_table: Optional block_table of dtype int32 and shape [batch_size, num_blocks_per_seq] used for paged attention.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
Expand All @@ -444,4 +490,5 @@ def flash_attn_varlen_func(
deterministic,
seqused_q,
seqused_k,
block_table,
)
1 change: 0 additions & 1 deletion hopper/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 24 : 32>();
// cutlass::arch::warpgroup_reg_dealloc<56>();

int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
Expand Down
Loading

0 comments on commit a5bac6b

Please sign in to comment.