Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive associative scan algorithm for factorization #28

Open
dfm opened this issue Jan 28, 2021 · 5 comments
Open

Derive associative scan algorithm for factorization #28

dfm opened this issue Jan 28, 2021 · 5 comments
Labels
enhancement New feature or request

Comments

@dfm
Copy link
Member

dfm commented Jan 28, 2021

I've derived the algorithms for matrix multiplication and solves, but I haven't been able to work out the factorization algorithm yet. There don't seem to be numerical issues for the ops that I've derived so far, but I haven't extensively tested it. This would be interesting because it would allow parallel implementation on a GPU.

@dfm dfm added the enhancement New feature or request label Jan 28, 2021
@bmorris3
Copy link

bmorris3 commented Jul 9, 2021

Hi @dfm, sorry to be plaguing you 😅. I'm working on a JAX project with GPU acceleration and I'd like to use celerite2. If I use it out of the box, I get a warning that says:

NotImplementedError: XLA translation rule for celerite2_factor on platform 'gpu' not found

which brought me here. Is this still on the to-do list?

@dfm
Copy link
Member Author

dfm commented Jul 9, 2021

There is no GPU support planned for celerite2. It's possible to parallelize some of the algorithms but it's slower than the CPU version for all the tests I've done and scales badly with J (J^3 instead of J^2).

@bmorris3
Copy link

bmorris3 commented Jul 9, 2021

Thanks for the quick response! If you have any pointers on alternatives I'd be grateful.

@dfm
Copy link
Member Author

dfm commented Jul 9, 2021

I don't know of any good JAX libraries for GPs, but it's not too hard to implement the math yourself to try it out. If the GP is your bottleneck, I think it's unlikely that you'll get any benefit from using a GPU, but if your computation is dominated by other parts of the model that are improved by GPU acceleration and not too many data points then it might be worth it. Here's an example implementation of naive GP computations using JAX + GPU acceleration that could get you started: https://github.com/dfm/tinygp/blob/main/src/tinygp/gp.py

@bmorris3
Copy link

bmorris3 commented Jul 9, 2021

Thanks so much, as always!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants