Skip to content

Commit

Permalink
Support page kvcache in AMD ROCm (Dao-AILab#1198)
Browse files Browse the repository at this point in the history
* Integrate ck branch of ck_tile/fa_bwd_opt

* Assume dq and q share the same stride

* update ck

* Integrate more stride of dq_acc

* Revert fwd dropout

* Fix paremeter order

* Integrate ck with more stride

* update the limit of hdim of bwd

* Check argument

* Add test_flash_attn_causal

* Support unpad lse

* Add  test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic

* Fix stride and Kn0

* Fix CK sync issue

* Fix typo

* Update CK for changing of fmha_fwd_args

* Add kvcache tmp

* Add kvcache

* Fix comment

* Sync behavior with ck

* Update CK to develop

* remove large test case

* Add kvcache test

* Fix page_block_size in arg

* Minor fix

* Fix stride error

* Update seqlen of kvcache before splitkv

* Fix compile error

* Fix bug of hdim is not 8x

* Fit ck arg

* support adaptive num_splits

* add more tests

* Refine test tolerance

* update CK

* Move override_num_splits_if_necessary into cpp

* update ck

* Update ck

* Support different flag for different version of hip

* remove coerce-illegal, becasue this is not required in FA

* Update ck to fix xcratch memory

* Add coerce-illegal in some version

* Add compile flag for rtn rounding

* remove redundant init

* Using env var to switch rounding mode

* update ck
  • Loading branch information
rocking5566 authored Sep 16, 2024
1 parent cc1690d commit e2182cc
Show file tree
Hide file tree
Showing 11 changed files with 1,749 additions and 131 deletions.
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 386 files
41 changes: 32 additions & 9 deletions csrc/flash_attn_ck/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ mha_fwd(at::Tensor &q,
c10::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
Expand Down Expand Up @@ -89,11 +89,34 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);

std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
}
34 changes: 34 additions & 0 deletions csrc/flash_attn_ck/flash_common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#include "flash_common.hpp"

namespace flash {
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
return num_splits;

hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
return num_splits;

// TODO - tile size should match the TileFmhaShape, hardcode for now
const int kM0 = 128;
const int kN1 = hdim_v;

const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;

if(num_splits < 1 && p_drop == 0.0f)
return num_splits_heuristic_ck(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);

return num_splits;
}

} // namespace flash
40 changes: 39 additions & 1 deletion csrc/flash_attn_ck/flash_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace flash {
// Copy from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
Expand All @@ -35,4 +35,42 @@ static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
}
}

inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
// If we have enough to almost fill the SMs, then just use 1 split
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) {
efficiency.push_back(0.f);
} else {
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
}
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) { continue; }
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}

int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);

} // namespace flash
103 changes: 71 additions & 32 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool enable_alibi)
bool enable_alibi,
bool deterministic)
{
return fmha_bwd_traits{head_size,
head_size,
Expand All @@ -20,7 +21,9 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
has_dropout};
has_dropout,
false, // s_randval
deterministic};
}

fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
Expand All @@ -39,6 +42,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
at::Tensor dq_acc,
at::Tensor d,
at::Tensor dq,
at::Tensor dk,
Expand All @@ -49,41 +53,57 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
uint64_t drop_offset)
{
// q: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_q = q.stride(0);
ck_tile::index_t stride_q = q.stride(1);
ck_tile::index_t nhead_stride_q = q.stride(2);

// k: (batch_size, seqlen_k, nheads_k, hdim)
ck_tile::index_t batch_stride_k = k.stride(0);
ck_tile::index_t stride_k = k.stride(1);
ck_tile::index_t nhead_stride_k = k.stride(2);

// v: (batch_size, seqlen_k, nheads_k, hdim)
ck_tile::index_t batch_stride_v = v.stride(0);
ck_tile::index_t stride_v = v.stride(1);
ck_tile::index_t nhead_stride_v = v.stride(2);

// o: (batch_size, seqlen_q, nheads, hdim)
// dq: (batch_size, seqlen_q, nheads, hdim)
// dk_expanded: (batch_size, seqlen_k, nheads, hdim)
// dv_expanded: (batch_size, seqlen_k, nheads, hdim)
// do: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_o = out.stride(0);
ck_tile::index_t stride_o = out.stride(1);
ck_tile::index_t nhead_stride_o = out.stride(2);

// alibi_slopes:(batch_size, nheads) or (nhead)
// lse: (batch_size, nheads, seqlen_q)
// d: (batch_size, nheads, seqlen_q)
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);

ck_tile::index_t stride_q = q.stride(1);
ck_tile::index_t stride_k = k.stride(1);
ck_tile::index_t stride_v = v.stride(1);
ck_tile::index_t stride_o = out.stride(1);
// do: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_do = dout.stride(0);
ck_tile::index_t stride_do = dout.stride(1);
ck_tile::index_t stride_dk = dk.stride(1);
ck_tile::index_t stride_dv = dv.stride(1);

ck_tile::index_t nhead_stride_q = q.stride(2);
ck_tile::index_t nhead_stride_k = k.stride(2);
ck_tile::index_t nhead_stride_v = v.stride(2);
ck_tile::index_t nhead_stride_o = out.stride(2);
ck_tile::index_t nhead_stride_do = dout.stride(2);
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);

ck_tile::index_t batch_stride_q = q.stride(0);
ck_tile::index_t batch_stride_k = k.stride(0);
ck_tile::index_t batch_stride_v = v.stride(0);
ck_tile::index_t batch_stride_o = out.stride(0);
ck_tile::index_t batch_stride_do = dout.stride(0);
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
// d: (batch_size, nheads, seqlen_q)
// CK assume d share the same stride with lse

// dq: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_dq = dq.stride(0);
ck_tile::index_t stride_dq = dq.stride(1);
ck_tile::index_t nhead_stride_dq = dq.stride(2);

// dk_expanded: (batch_size, seqlen_k, nheads, hdim)
ck_tile::index_t batch_stride_dk = dk.stride(0);
ck_tile::index_t stride_dk = dk.stride(1);
ck_tile::index_t nhead_stride_dk = dk.stride(2);

// dv_expanded: (batch_size, seqlen_k, nheads, hdim)
ck_tile::index_t batch_stride_dv = dv.stride(0);
ck_tile::index_t stride_dv = dv.stride(1);
ck_tile::index_t nhead_stride_dv = dv.stride(2);

// dq_acc: (split, batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);

float p_undrop = 1.0 - p_dropout;

Expand All @@ -96,6 +116,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
// alibi_slopes:(batch_size, nheads) or (nhead)
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}

Expand All @@ -112,6 +133,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
dq_acc.data_ptr(), // dq_acc
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr, // seqlen_k_ptr
Expand All @@ -132,6 +154,8 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
stride_o,
0, // stride_randval
stride_do,
stride_dq_acc,
stride_dq,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
Expand All @@ -143,6 +167,10 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
nhead_stride_dq_acc,
nhead_stride_dq,
nhead_stride_dk,
nhead_stride_dv,
0, // nhead_stride_dbias, FA without dbias
batch_stride_q,
batch_stride_k,
Expand All @@ -152,15 +180,17 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
batch_stride_dq_acc,
batch_stride_dq,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
false, // s_randval
{drop_seed, drop_offset}};
}

Expand Down Expand Up @@ -224,7 +254,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
TORCH_CHECK(head_size_8x <= 256, "CK FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
Expand Down Expand Up @@ -296,7 +326,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num

auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
// TODO - CK does not support dq_accum
at::Tensor dq_accum;

if (!deterministic) {
dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
}

at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
Expand Down Expand Up @@ -326,10 +364,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num

if (seqlen_q > 0) {
ck_tile::stream_config stream_config{stream};
dq.zero_(); // ck use atomic operation on dq

auto traits =
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);

auto args =
get_ck_fmha_bwd_args(
Expand All @@ -347,6 +384,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
out,
softmax_lse,
dout_padded,
dq_accum,
softmax_d,
dq,
dk_expanded,
Expand All @@ -356,7 +394,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
drop_seed,
drop_offset);

fmha_bwd(traits, args, stream_config);
float t = fmha_bwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
Expand Down
Loading

0 comments on commit e2182cc

Please sign in to comment.