Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MBAR bootstrap error #1077

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Use MBAR bootstrap error #1077

wants to merge 6 commits into from

Conversation

jthorton
Copy link
Collaborator

@jthorton jthorton commented Jan 14, 2025

Fixes #1012 by using the bootstrap error from pymbar3/4.

Would this be a good time to switch to only supporting pymbar4 so we only have to maintain a single interface for MBAR?

Note:

  • the full pymbar4 package brings in JAX
  • I found that 1000 iterations of bootstrapping only takes around 1 min for the default protocol (using jax)
  • For the extended charge changing protocol this can take up to 15 mins (using jax)
  • The variability in the dDGs between test runs was larger which meant I had to relax the relative tolerance on the tests

Checklist

  • Added a news entry

Developers certificate of origin

@jthorton jthorton requested review from IAlibay and atravitz January 14, 2025 17:21
Copy link

codecov bot commented Jan 14, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.65%. Comparing base (915d110) to head (9b8d3ad).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1077      +/-   ##
==========================================
- Coverage   94.46%   91.65%   -2.82%     
==========================================
  Files         135      135              
  Lines       10090    10083       -7     
==========================================
- Hits         9532     9242     -290     
- Misses        558      841     +283     
Flag Coverage Δ
fast-tests 91.65% <100.00%> (?)
slow-tests ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@IAlibay
Copy link
Member

IAlibay commented Jan 15, 2025

Would this be a good time to switch to only supporting pymbar4 so we only have to maintain a single interface for MBAR?

Yes I think that would be a good idea - if we think it's stable (we might have to benchmark a bit), we should make the jump.
If we go by spec0 rules pymbar 3 is > 2 years old.

PyMBAR 3 also has all kinds of stability issues we should try to avoid.

the full pymbar4 package brings in JAX

:/ how big of a dependency is JAX? It might be that we don't really have an option here. I know you can use pymbar 4 without JAX (that's how it gets deployed on PyPi). cc @atravitz

For the extended charge changing protocol this can take up to 15 mins (using jax)

Oof that's quite long. I guess as long as we're only doing that once in a multi-hour simulation it doesn't matter too much.

@jthorton
Copy link
Collaborator Author

JAX is around 60MB, but we can use pymbar-core which is the non-JAX version that should be a bit slower, how much slower, I am not sure but compared to a multi-hour simulation it should still be negligible!

  + jax                     0.4.35  pyhd8ed1ab_1         conda-forge/noarch        1MB
  + jaxlib                  0.4.35  cpu_py312hadfe8e1_0  conda-forge/osx-64       56MB

On the other hand, adding JAX is not too noticeable compared to the cudatoolkit?

Copy link
Member

@IAlibay IAlibay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of todos:

  • Could you add a news entry please?
  • Could you make the necessary changes to switch to pymbar 4 please?

np.array([0.07471 , 0.052914, 0.041508, 0.036613, 0.032827, 0.030489,
0.028154, 0.026529, 0.025284, 0.023968]),
rtol=1e-04,
np.array([0.077645, 0.054695, 0.044680, 0.03947, 0.034822,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating these - the new error values are expected to be different, so we should update things where we can.

rtol=1e-04,
np.array([0.077645, 0.054695, 0.044680, 0.03947, 0.034822,
0.033443, 0.030793, 0.028777, 0.026683, 0.026199]),
rtol=1e-01,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little bit loose as a tolerance, but I guess it's fine given the bootstraps are stochastic.

except AttributeError:
r = mbar.compute_free_energy_differences()
# pymbar 4
mbar = MBAR(u_ln, N_l, solver_protocol="robust", n_bootstraps=1000, bootstrap_solver_protocol="robust")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is most of the cost in the forward & reverse analysis?

Copy link
Collaborator Author

@jthorton jthorton Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah running the bootstrapping on repeat is expensive! One thought on the forward and backward estimates should we be subsampling using g_t calculated for this subset of data? In the industry benchmarking I calculated it 3 ways no subsampling, subsample based on the % of data and subsample using the g_t calculated for the full set of data. https://github.com/OpenFreeEnergy/IndustryBenchmarks2024/blob/fb60d7a971cb5d04787d796b6adcf257d905786a/industry_benchmarks/analysis/1_download_and_extract_data.py#L464-L552

@IAlibay
Copy link
Member

IAlibay commented Jan 15, 2025

On the other hand, adding JAX is not too noticeable compared to the cudatoolkit?

Yeah - I also suspect we're picking up a ton of dependencies elsewhere.

Long term maybe we should look into an openfe-base version that has the very minimal set of dependencies for everything.

I'll let @atravitz weigh in, but generally I'm ok / would very much like it if we pushed for pymbar4 w/ JAX.

@IAlibay
Copy link
Member

IAlibay commented Jan 15, 2025

Completely forgot to ask @jthorton - could you have a look through our docs and see if there's anywhere we can make it clear that this is now the bootstrap error? I know some folks got confused by it all.

Copy link

No API break detected ✅

@jthorton
Copy link
Collaborator Author

Currently blocked by perses=0.10.3 which pins to pymbar3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Switch to bootstrapping for MBAR errors.
2 participants