Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using cumsum instead of a for loop #18

Open
PeaBrane opened this issue Feb 9, 2024 · 7 comments
Open

Using cumsum instead of a for loop #18

PeaBrane opened this issue Feb 9, 2024 · 7 comments

Comments

@PeaBrane
Copy link

PeaBrane commented Feb 9, 2024

There is a way to perform the selective scan with two cumulative sums or torch.cumsum, which is effectively like a parallel scan but supported by pytorch natively.

I made a minimal commit in my fork here PeaBrane@2908f50. The correctness and functionality are tested, and I could observe an inference speed up of ~14x on an A30. But not sure how close it is to the original impl with parallel scan still. More details are here.

If intersted, it would be nice if someone could review this change, and discuss whether this could be merged here, albiet the explicitness of the code may suffer (as I understand the repo is meant to be pedagogical).

@johnma2006
Copy link
Owner

Thank you, and so sorry for the late reply! I’ve been a bit busy recently, but let me figure out the best way to incorporate these ideas in a bit. Thank you!

@huiserwang
Copy link

huiserwang commented Feb 24, 2024

I have test the original mamba implementation. It's so fast! I consider the length=3136, bs=128, and channel=192 for the input x, meanwhile, d_state=16 for B, C. The original impl achieves an inference speed up of ~48x than the cumsum impl.

@PeaBrane
Copy link
Author

I have test the original mamba implementation. It's so fast! I consider the length=3136, bs=128, and channel=192 for the input x, meanwhile, d_state=16 for B, C. The original impl achieves an inference speed up of ~48x than the cumsum impl.

Are you testing the original impl in training mode or inference mode? The inference (recurrent or online) mode is not comparable to the forward pass for training, because the former is a recurrent step and the latter takes in the full sequence. Either way, neither mamba-minimal nor mamba-tiny is optimized for training or inference, and they are purely pedagogical

@wredan
Copy link

wredan commented Feb 27, 2024

I also would like to point out that cumsum implementation is a better way to go if you need to convert mamba-minimal or mamba-tiny to ONNX. The static PyTorch converter says:

It does not record any control-flow, like if-statements or loops

so that with a for loop you lose the dynamic input of sequence length.

The insane speed is tied up to the hardware-aware optimization the author made on the official mamba model, but the use of Triton and the close GPU optimization is preventing me from converting the original model to ONNX with the official PyTorch exporter.

Just leaving it here for someone who needs ONNX model conversion in the future, also thank you guys for mamba-minimal and mamba-tiny, they are so great to understand how mamba works.

@DustinEwan
Copy link

I tested out this cumsum approach and found that it doesn't actually produce the same outputs as the standard one in the for loop.

Everything else equal, while the current function is slow it ultimately produces a model with sensible output.

Using @PeaBrane 's cumsum version is multiple orders of magnitude faster, but the model ends up producing mostly nonsensical output.

@PeaBrane
Copy link
Author

PeaBrane commented Mar 8, 2024

I tested out this cumsum approach and found that it doesn't actually produce the same outputs as the standard one in the for loop.

Everything else equal, while the current function is slow it ultimately produces a model with sensible output.

Using @PeaBrane 's cumsum version is multiple orders of magnitude faster, but the model ends up producing mostly nonsensical output.

By "nonsensical" do you mean encountering nan or inf, or semantically the outputs are non-sensical. Note the sentence generation script used is stochastic, so everytime the generated outputs is going to be different. That being said, I did encounter some stablity issues when running the logcumsumexp scan on the gpu where it would lead to nan or inf values (but no problem on the cpu)

@dftidft
Copy link

dftidft commented Jul 9, 2024

There is a problem with this code:
dA_cumsum = F.pad(dA[:, 1:], (0, 0, 0, 0, 0, 1)).flip(1).cumsum(1).exp().flip(1)
dA[:, 1:] uses the t-th value in the sequence as input when predicting the t-th value.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants