-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using cumsum instead of a for loop #18
Comments
Thank you, and so sorry for the late reply! I’ve been a bit busy recently, but let me figure out the best way to incorporate these ideas in a bit. Thank you! |
I have test the original mamba implementation. It's so fast! I consider the length=3136, bs=128, and channel=192 for the input x, meanwhile, d_state=16 for B, C. The original impl achieves an inference speed up of ~48x than the cumsum impl. |
Are you testing the original impl in training mode or inference mode? The inference (recurrent or online) mode is not comparable to the forward pass for training, because the former is a recurrent step and the latter takes in the full sequence. Either way, neither |
I also would like to point out that
so that with a for loop you lose the dynamic input of sequence length. The insane speed is tied up to the hardware-aware optimization the author made on the official mamba model, but the use of Triton and the close GPU optimization is preventing me from converting the original model to ONNX with the official PyTorch exporter. Just leaving it here for someone who needs ONNX model conversion in the future, also thank you guys for |
I tested out this cumsum approach and found that it doesn't actually produce the same outputs as the standard one in the for loop. Everything else equal, while the current function is slow it ultimately produces a model with sensible output. Using @PeaBrane 's cumsum version is multiple orders of magnitude faster, but the model ends up producing mostly nonsensical output. |
By "nonsensical" do you mean encountering nan or inf, or semantically the outputs are non-sensical. Note the sentence generation script used is stochastic, so everytime the generated outputs is going to be different. That being said, I did encounter some stablity issues when running the |
There is a problem with this code: |
There is a way to perform the selective scan with two cumulative sums or
torch.cumsum
, which is effectively like a parallel scan but supported by pytorch natively.I made a minimal commit in my fork here PeaBrane@2908f50. The correctness and functionality are tested, and I could observe an inference speed up of ~14x on an A30. But not sure how close it is to the original impl with parallel scan still. More details are here.
If intersted, it would be nice if someone could review this change, and discuss whether this could be merged here, albiet the explicitness of the code may suffer (as I understand the repo is meant to be pedagogical).
The text was updated successfully, but these errors were encountered: