Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Vanilla Bwd and Refactor #86

Draft
wants to merge 52 commits into
base: main_perf
Choose a base branch
from
Draft

Conversation

micmelesse
Copy link

@micmelesse micmelesse commented Oct 14, 2024

This PR

  • enables vanilla bwd
  • refactors our code into several files.
  • fixes a bug with softmax lse that was returned with forward
  • adds an alternate mode that uses exp instead of exp2. This was useful to debug issues with both forward and backward
  • creates interfaces for fa and pytorch that use implementation functions with explicit paramaters.
  • adds a pytorch ref implementations that mimics our triton kernels for testing.

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp
@micmelesse
Copy link
Author

micmelesse commented Oct 15, 2024

The kernel tests pass on MI300 but seems the ci MI200 have issues.
image

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant