Skip to content

Commit

Permalink
[WIP] Using the robust solver for pyMBAR - avoiding convergence Failu… (
Browse files Browse the repository at this point in the history
#735)

* [WIP] Using the robust solver for pyMBAR - avoiding convergence Failures.

## Description
In this PR, we would like to propose switching the default solver, if pyMBAR > 4.0.0, such we have an improved convergence rate at the cost of minimal more time. -> less errors thrown. 

## Todos
- [ ] Implement feature / fix bug
- [ ] Add [tests](https://github.com/choderalab/openmmtools/tree/master/openmmtools/tests)
- [ ] Update [documentation](https://github.com/choderalab/openmmtools/tree/master/docs) as needed
- [ ] Update [changelog](https://github.com/choderalab/openmmtools/blob/master/docs/releasehistory.rst) to summarize changes in behavior, enhancements, and bugfixes implemented in this PR

## Status
- [ ] Ready to go

## Changelog message
```

```

* bump ci

* doing this a different way

* use Version to compare versions

* fix micromamba, see mamba-org/micromamba-releases#58

* fix version check

* forgot how we import pymbar in this package

* pymbar 3 stores version differently

* added test from pymbar issue 419

* re-run flaky tests

* check to see if we get an expected value

* Add some tol

* use pytest.approx to make @IAlibay happy ;)

* update the doc strings to explain where the file came from

---------

Co-authored-by: Mike Henry <[email protected]>
  • Loading branch information
RiesBen and mikemhenry authored Oct 3, 2024
1 parent c2a13c0 commit 9334bc9
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jobs:
- name: Setup micromamba
uses: mamba-org/setup-micromamba@v1
with:
micromamba-version: '2.0.0-0'
environment-file: devtools/conda-envs/test_env.yaml
environment-name: openmmtools-test
create-args: >-
Expand Down
7 changes: 7 additions & 0 deletions openmmtools/multistate/multistateanalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import copy
import inspect
import logging
from packaging.version import Version
import re
from typing import Optional, NamedTuple, Union

Expand All @@ -37,6 +38,7 @@
import simtk.unit as units
from scipy.special import logsumexp

from openmmtools.multistate import pymbar
from openmmtools import multistate, utils, forces
from openmmtools.multistate.pymbar import (
statistical_inefficiency_multiple,
Expand Down Expand Up @@ -567,6 +569,11 @@ def __init__(self, reporter, name=None, reference_states=(0, -1),
self.reference_states = reference_states
self._user_extra_analysis_kwargs = analysis_kwargs # Store the user-specified (higher priority) keywords

# If we are using pymbar 4, change the default behavior to use the robust solver protocol if the user
# didn't set a kwarg to control the solver protocol
if Version(pymbar.__version__) >= Version("4") and "solver_protocol" not in self._user_extra_analysis_kwargs:
self._user_extra_analysis_kwargs["solver_protocol"] = "robust"

# Initialize cached values that are read or derived from the Reporter.
self._cache = {} # This cache should be always set with _update_cache().
self.clear()
Expand Down
3 changes: 2 additions & 1 deletion openmmtools/multistate/pymbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
subsample_correlated_data,
statistical_inefficiency
)
from pymbar import MBAR
from pymbar import MBAR, __version__
from pymbar.utils import ParameterError
except ImportError:
# pymbar < 4
Expand All @@ -22,6 +22,7 @@
)
from pymbar import MBAR
from pymbar.utils import ParameterError
from pymbar.version import short_version as __version__


def _pymbar_bar(
Expand Down
63 changes: 62 additions & 1 deletion openmmtools/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import shutil
import sys
import tempfile
import time
from io import StringIO

import numpy as np
import yaml

import pytest
import requests

try:
import openmm
Expand Down Expand Up @@ -306,6 +308,7 @@ def run(self, include_unsampled_states=False):
# Clean up.
del simulation

@pytest.mark.flaky(reruns=3)
def test_with_unsampled_states(self):
"""Test multistate sampler on a harmonic oscillator with unsampled endstates"""
self.run(include_unsampled_states=True)
Expand Down Expand Up @@ -1861,7 +1864,7 @@ def test_analysis_opens_without_checkpoint(self):
del reporter
self.REPORTER(storage_path, checkpoint_storage=cp_file_mod, open_mode="r")

@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(sys.platform == "darwin", reason="seg faults on osx sometimes")
def test_storage_reporter_and_string(self):
"""Test that creating a MultiState by storage string and reporter is the same"""
thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(
Expand Down Expand Up @@ -2612,6 +2615,64 @@ def test_resume_velocities_from_legacy_storage(self):
state.velocities.value_in_unit_system(unit.md_unit_system) != 0
), "At least some velocity in sampler state from new checkpoint is expected to different from zero."

@pytest.fixture
def download_nc_file(tmpdir):
# See https://github.com/choderalab/pymbar/issues/419#issuecomment-1718386779
# and https://github.com/choderalab/openmmtools/pull/735#issuecomment-2378070388
# if this file ever starts to 404
FILE_URL = "https://github.com/user-attachments/files/17156868/ala-thr.zip"
MAX_RETRIES = 3
RETRY_DELAY = 2 # Delay between retries (in seconds)
file_name = os.path.join(tmpdir, "ala-thr.nc")
retries = 0
while retries < MAX_RETRIES:
try:
# Send GET request to download the file
response = requests.get(FILE_URL, timeout=20) # Timeout to avoid hanging
response.raise_for_status() # Raise HTTPError for bad responses (4xx/5xx)
with open(file_name, "wb") as f:
f.write(response.content)
# File downloaded successfully, break out of retry loop
break

except (requests.exceptions.RequestException, requests.exceptions.HTTPError) as e:
retries += 1
if retries >= MAX_RETRIES:
pytest.fail(f"Failed to download file after {MAX_RETRIES} retries: {e}")
else:
print(f"Retrying download... ({retries}/{MAX_RETRIES})")
time.sleep(RETRY_DELAY) # Wait before retrying
yield file_name


def test_pymbar_issue_419(download_nc_file):
"""
This test checks that a nc file from a ala-thr mutation simulation converges.
With pymbar 4 default (as of 2024-10-02) solver fails to converge.
With pymbar 3 defaults, the solver does converge.
With PR #735 (https://github.com/choderalab/openmmtools/pull/735) we updated
the MultiStateSamplerAnalyzer to use the "robust" sampler when using pymbar4.
See https://github.com/choderalab/pymbar/issues/419#issuecomment-1718386779 for more
information on how the file was generated.
"""


from openmmtools.multistate import MultiStateReporter, MultiStateSamplerAnalyzer

n_iterations = 1000
reporter_file = download_nc_file
reporter = MultiStateReporter(reporter_file)
analyzer = MultiStateSamplerAnalyzer(reporter, max_n_iterations=n_iterations)
f_ij, df_ij = analyzer.get_free_energy()
# free energy
assert f_ij[0, -1] == pytest.approx(-52.00083148433459)
# error
assert df_ij[0, -1] == pytest.approx(0.21365627649558516)


# ==============================================================================
# MAIN AND TESTS
Expand Down

0 comments on commit 9334bc9

Please sign in to comment.