Skip to content

Commit

Permalink
Use Strides
Browse files Browse the repository at this point in the history
Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save
  • Loading branch information
micmelesse committed Oct 16, 2024
1 parent ce80a7e commit a168999
Show file tree
Hide file tree
Showing 9 changed files with 586 additions and 439 deletions.
262 changes: 186 additions & 76 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import triton.language as tl

from .bwd_ref import attention_backward_pytorch_ref_impl
from .utils import get_shape_from_layout, get_strides_from_layout

DEBUG = False

@triton.jit
def _bwd_preprocess_use_o(
def _bwd_preprocess_use_o_old(
Out,
DO,
Delta,
Expand Down Expand Up @@ -36,6 +37,51 @@ def _bwd_preprocess_use_o(
tl.store(Delta + off_m, delta)



@triton.jit
def _bwd_preprocess_use_o(
Out,
DO,
Delta,
stride_oz, stride_oh, stride_om, stride_ok,
stride_doz, stride_doh, stride_dom, stride_dok,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
N_CTX_Q: tl.constexpr,
Z: tl.constexpr,
H: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_bh = tl.program_id(1)

# Compute batch and head indices
batch_idx = pid_bh // H
head_idx = pid_bh % H

off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_d = tl.arange(0, BLOCK_DMODEL)

# create masks
mask_m = off_m < N_CTX_Q
mask_d = off_d < ACTUAL_BLOCK_DMODEL

# compute pointers using strides
out_ptrs = Out + batch_idx * stride_oz + head_idx * stride_oh + off_m[:, None] * stride_om + off_d[None, :] * stride_ok
do_ptrs = DO + batch_idx * stride_doz + head_idx * stride_doh + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok

# load
o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)

# compute delta
delta = tl.sum(o * do, axis=1)

# write-back delta
delta_ptrs = Delta + pid_bh * N_CTX_Q + off_m
tl.store(delta_ptrs, delta, mask=mask_m)


@triton.jit
def _bwd_preprocess_use_p(
Q, # Pointer to queries
Expand Down Expand Up @@ -385,7 +431,7 @@ def _bwd_kernel(
else:
dq_offset = DQ + off_z * stride_qz + off_h * stride_qh

# inner loop
# inner loop
if SEQUENCE_PARALLEL:
_bwd_kernel_one_col_block(
Q,
Expand Down Expand Up @@ -494,11 +540,31 @@ def _bwd_kernel(
USE_EXP2=USE_EXP2,
)

# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom.
def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, dk, dv, sm_scale, head_size, alibi_slopes, causal, layout, use_exp2, bwd_preprocessing_use_o, BLOCK_M=64, BLOCK_N=64):

DEBUG_INPUT=False

# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom.
def attention_prefill_backward_triton_new_impl(
do,
q,
k,
v,
o,
softmax_lse,
dq,
dk,
dv,
sm_scale: float,
alibi_slopes,
causal,
layout: str,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
bwd_preprocessing_use_o: bool,
BLOCK_M=64,
BLOCK_N=64,
):
if DEBUG:
print()
print("attention_prefill_backward_triton_new_impl")
Expand All @@ -512,43 +578,44 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq,
print("dk:", dk, dk.shape if dk is not None else None)
print("dv:", dv, dv.shape if dv is not None else None)
print("sm_scale:", sm_scale)
print("head_size:", head_size)
print("alibi_slopes:", alibi_slopes)
print("layout:", layout)
print("use_exp2:", use_exp2)
print("bwd_preprocessing_use_o:", bwd_preprocessing_use_o)
print("BLOCK_M:", BLOCK_M)
print("BLOCK_N:", BLOCK_N)

# the kernel wants bhsd
if layout == "bshd":
print("Changing layout to bhsd!")
do = do.transpose(1, 2).contiguous()
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
o = o.transpose(1, 2).contiguous()
# TODO: does L/M need to be transposed. possible to use strides
elif layout == "bhsd":
pass
# make contigious
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
softmax_lse = softmax_lse.contiguous()

# get strides and shape
if True:
batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
stride_qz, stride_qh, stride_qm, stride_qk = q_strides
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
stride_dq_all = q.numel()
batch_headsize = batch * nheads_q
else:
raise ValueError(f"Unknown layout {layout}")
batch_q, heads_q, seqlen_q, head_size_q = q.shape
batch_k, heads_k, seqlen_k, head_size_k = k.shape
batch_headsize = batch_q * heads_q
stride_dq_all = dq.numel()
stride_qz, stride_qh, stride_qm, stride_qk = q.stride(0), q.stride(1), q.stride(2), q.stride(3)
stride_kz, stride_kh, stride_kn, stride_kk = k.stride(0), k.stride(1), k.stride(2), k.stride(3)
stride_vz, stride_vh, stride_vn, stride_vk = v.stride(0), v.stride(1), v.stride(2), v.stride(3)

sequence_parallel = False
causal = False

batch_q, heads_q, N_CTX_Q, head_size_q = q.shape
batch_k, heads_k, N_CTX_K, head_size_k = k.shape

assert (batch_q == batch_k)
assert (heads_q == heads_k) # just for now
assert (head_size_q == head_size_q == head_size)

batch = batch_q

# divide up the problem
num_blocks_m = triton.cdiv(N_CTX_Q, BLOCK_M)
num_blocks_n = triton.cdiv(N_CTX_K, BLOCK_N)
num_blocks_m = triton.cdiv(seqlen_q, BLOCK_M)
num_blocks_n = triton.cdiv(seqlen_k, BLOCK_N)

# get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
Expand All @@ -563,9 +630,13 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq,
new_dq_shape = (replicas,) + q.shape
if dq is None:
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
else:
dq = dq.contiguous()
else:
if dq is None:
dq = torch.zeros_like(q, dtype=q.dtype)
else:
dq = dq.contiguous()

# NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
dq.zero_()
Expand All @@ -575,12 +646,16 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq,
dk = torch.zeros_like(k)
else:
dk = torch.empty_like(k)
else:
dk = dk.contiguous()

if dv is None:
if True:
dv = torch.zeros_like(v)
else:
dv = torch.empty_like(v)
else:
dv = dv.contiguous()

# assert contigious
assert do.is_contiguous()
Expand All @@ -593,32 +668,41 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq,
assert dk.is_contiguous()
assert dv.is_contiguous()

batch_headsize = batch * heads_q
stride_dq_all = dq.numel()
stride_qz, stride_qh, stride_qm, stride_qk = q.stride(0), q.stride(1), q.stride(2), q.stride(3)
stride_kz, stride_kh, stride_kn, stride_kk = k.stride(0), k.stride(1), k.stride(2), k.stride(3)
stride_vz, stride_vh, stride_vn, stride_vk = v.stride(0), v.stride(1), v.stride(2), v.stride(3)
num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful
num_stages = 1

if True:
delta = torch.zeros_like(softmax_lse)
else:
delta = torch.empty_like(softmax_lse)


if bwd_preprocessing_use_o:
_bwd_preprocess_use_o[(batch_headsize * num_blocks_m,)](
o,
do,
delta,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
N_CTX_Q=N_CTX_Q
)
if False:
_bwd_preprocess_use_o_old[(batch_headsize * num_blocks_m,)](
o,
do,
delta,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
N_CTX_Q=seqlen_q
)
else:
_bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
o,
do,
delta,
stride_oz, stride_oh, stride_om, stride_ok,
stride_oz, stride_oh, stride_om, stride_ok,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
N_CTX_Q=seqlen_q,
Z=batch,
H=nheads_q,
)
else:
_bwd_preprocess_use_p[(num_blocks_m, batch_headsize)](
_bwd_preprocess_use_p[(num_blocks_m, batch_headsize)](
q,
k,
v,
Expand All @@ -639,19 +723,18 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq,
stride_vh,
stride_vn,
stride_vk,
Z=batch_q,
H=heads_q,
N_CTX_Q=N_CTX_Q,
N_CTX_K=N_CTX_K,
Z=batch,
H=nheads_q,
N_CTX_Q=seqlen_q,
N_CTX_K=seqlen_k,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
USE_EXP2=use_exp2,
)


if False:
if DEBUG:
print("_bwd_kernel inputs")
print("do:", do, do.shape)
print("q:", q, q.shape)
Expand All @@ -667,12 +750,10 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq,
print("stride_qz, stride_qh, stride_qm, stride_qk:", stride_qz, stride_qh, stride_qm, stride_qk)
print("stride_kz, stride_kh, stride_kn, stride_kk:", stride_kz, stride_kh, stride_kn, stride_kk)
print("stride_vz, stride_vh, stride_vn, stride_vk:", stride_vz, stride_vh, stride_vn, stride_vk)
print("batch_q:", batch_q)
print("heads_q:",heads_q)
print("N_CTX_Q:",N_CTX_Q)
print("N_CTX_K:",N_CTX_K)
print("batch_q * head_size_q * N_CTX_Q:",batch_q * head_size_q * N_CTX_Q)
print("num_blocks_n * batch_q * head_size_q * N_CTX_Q:",num_blocks_n * batch_q * head_size_q * N_CTX_Q)
print("batch_q:", batch)
print("heads_q:",nheads_q)
print("seqlen_q:",seqlen_q)
print("seqlen_k:",seqlen_k)
print("BLOCK_M:",BLOCK_M)
print("BLOCK_N:",BLOCK_M)
print("BLOCK_DMODEL:",BLOCK_DMODEL)
Expand All @@ -699,10 +780,10 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
batch_q,
heads_q,
N_CTX_Q,
N_CTX_K,
batch,
nheads_q,
seqlen_q,
seqlen_k,
num_blocks_m,
num_blocks_n,
BLOCK_M=BLOCK_M,
Expand All @@ -719,23 +800,53 @@ def attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq,
if len(dq.shape) == 5:
dq = dq.sum(dim=0)

# go back to original layout
if layout == "bshd":
print("Changing back to bshd!")
dq = dq.transpose(1, 2)
dk = dk.transpose(1, 2)
dv = dv.transpose(1, 2)
elif layout == "bhsd":
pass
else:
raise ValueError(f"Unknown layout {layout}")

return dq, dk, dv, delta, None, None


def attention_prefill_backward_triton_impl(do, q, k, v, o, softmax_lse, dq, dk, dv, sm_scale, head_size, alibi_slopes, causal, layout, use_exp2, bwd_preprocessing_use_o, use_new):
def attention_prefill_backward_triton_impl(
do,
q,
k,
v,
o,
softmax_lse,
dq,
dk,
dv,
sm_scale: float,
alibi_slopes,
causal,
layout: str,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
bwd_preprocessing_use_o: bool,
use_new,
):
if use_new:
return attention_prefill_backward_triton_new_impl(do, q, k, v, o, softmax_lse, dq, dk, dv, sm_scale, head_size, alibi_slopes, causal, layout, use_exp2, bwd_preprocessing_use_o)
return attention_prefill_backward_triton_new_impl(
do,
q,
k,
v,
o,
softmax_lse,
dq,
dk,
dv,
sm_scale,
alibi_slopes,
causal,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2,
bwd_preprocessing_use_o,
)
else:
# test pytorch impl
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
Expand All @@ -757,4 +868,3 @@ def attention_prefill_backward_triton_impl(do, q, k, v, o, softmax_lse, dq, dk,
dv = dv_ref

return dq, dk, dv, delta_ref, None, None

Loading

0 comments on commit a168999

Please sign in to comment.