Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 16, 2024
1 parent 0bd8120 commit 8d1c2f7
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,28 +677,30 @@ def attention_prefill_backward_triton_new_impl(
delta = torch.empty_like(softmax_lse)

if bwd_preprocessing_use_o:
_bwd_preprocess_use_o_old[(batch_headsize * num_blocks_m,)](
o,
do,
delta,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
N_CTX_Q=seqlen_q
)
# _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
# o,
# do,
# delta,
# stride_oz, stride_oh, stride_om, stride_ok,
# stride_oz, stride_oh, stride_om, stride_ok,
# BLOCK_M=BLOCK_M,
# BLOCK_DMODEL=BLOCK_DMODEL,
# ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
# N_CTX_Q=seqlen_q,
# Z=batch,
# H=nheads_q,
# )
if False:
_bwd_preprocess_use_o_old[(batch_headsize * num_blocks_m,)](
o,
do,
delta,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
N_CTX_Q=seqlen_q
)
else:
_bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
o,
do,
delta,
stride_oz, stride_oh, stride_om, stride_ok,
stride_oz, stride_oh, stride_om, stride_ok,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
N_CTX_Q=seqlen_q,
Z=batch,
H=nheads_q,
)
else:
_bwd_preprocess_use_p[(num_blocks_m, batch_headsize)](
q,
Expand Down

0 comments on commit 8d1c2f7

Please sign in to comment.