-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharding the computation between GPUs #83
Comments
Before I start working on this, I would like to have a discussion:
I'd be interested in your comments @Matematija. |
Sorry for not replying earlier. I would say that #1 is definitely the way to go. Backing up a little bit, my intuition is that breaking up the integration grid into parts that get committed to different devices is a good starting point. Many large tensors in DFT inherit their "largeness" from the grid. |
I think I have to disagree with @Matematija on this one. While parallelizing single DFT executions is nice, a lot of performance can be extracted by simply sending all data to a single (good) GPU on an HPC cluster. On Perlmutter, for example, I can perform differentiable SCF calculations in a matter of seconds for reasonable solid materials with reasonable basis sets. Given that the use case for Grad DFT is most of the time going to be learning from small systems (as this is the domain in which accurate wavefunction calculations are possible for the training data), I think it is best to put some effort into parallelizing batched loss function computation with sharding. This, in principle, should be pretty easy as our data is stored in |
If the target use case is learning smaller systems then I agree that the problem I was describing doesn't exist by definition. Simple batch-level parallelism should do the trick. I was imagining large molecules with tens of millions of grid points at inference time. |
I guess that in the (possibly far) future, we may wish to insert such parallelizations. I just think for now batch parallelism is a good target. I'm experimenting with this now... |
Add sharding following https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
The text was updated successfully, but these errors were encountered: