-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fused_linear_cross_entropy: Move float32 cast into kernel #238
base: main
Are you sure you want to change the base?
Conversation
thanks! this makes sense. i did try the similar thing before but seen divergence compared with casting from the torch side (not sure why, maybe i did it wrong). also, currently bfloat16 convergence test is not actually tested due to #176. after the fix is merged, we can try to run on convergence tests with bf16 to see if there is any gap. |
I added a |
cool! I will take a deeper look today or tomorrow. This is exciting! |
Can we merge this? in current form liger kernel broken. |
@hansonw can we resolve the conflict? ty |
We can merge this once the conflict is resolved. thanks!! |
Summary
Another small optimization :) The
logits_chunk.float()
allocation may be surprisingly large, e.g. Cohere models have 256K vocabs, so each logit chunk in float32 could be something like 1024 * 256K * 4 = 1GB VRAM (even more if the chunk size is larger.)I actually don't think any explicit casting is even required within the Triton kernel since the intermediate softmax calculation variables like
m
,d
, etc. are already float32 by default, so with type promotion the calculations should all be float32 regardless.However, I added explicit casts
.cast(tl.float32)
around all of the X_ptr loads to make this more obvious to the reader. In either case, the actualliger_cross_entropy_kernel
runs so quickly that I don't think there's any performance difference - this is purely to save the float32 allocation. (It might be more efficient without the explicit casts, but I was not able to measure anything - even with a 1K x 256K logit matrix the kernel kind of runs instantly lol.)Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence