Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Storage manager update #679

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
settings, ChemicalSystem, LigandAtomMapping, Component, ComponentMapping,
SmallMoleculeComponent, ProteinComponent, SolventComponent,
)
from gufe.storage import stagingregistry

from .equil_rfe_settings import (
RelativeHybridTopologyProtocolSettings, SystemSettings,
Expand Down Expand Up @@ -578,8 +579,10 @@ def __init__(self, *,
)

def run(self, *, dry=False, verbose=True,
scratch_basepath=None,
shared_basepath=None) -> dict[str, Any]:
scratch_basepath: pathlib.Path,
shared_basepath: stagingregistry.StagingPath,
permanent_basepath: stagingregistry.StagingPath,
) -> dict[str, Any]:
"""Run the relative free energy calculation.

Parameters
Expand All @@ -591,10 +594,12 @@ def run(self, *, dry=False, verbose=True,
verbose : bool
Verbose output of the simulation progress. Output is provided via
INFO level logging.
scratch_basepath: Pathlike, optional
Where to store temporary files, defaults to current working directory
shared_basepath : Pathlike, optional
Where to run the calculation, defaults to current working directory
scratch_basepath: pathlib.Path
Where to store temporary files
shared_basepath : StagingPath
Where to run the calculation
permanent_basepath : StagingPath
Where to store files that must persist beyond the DAG

Returns
-------
Expand All @@ -609,11 +614,6 @@ def run(self, *, dry=False, verbose=True,
"""
if verbose:
self.logger.info("Preparing the hybrid topology simulation")
if scratch_basepath is None:
scratch_basepath = pathlib.Path('.')
if shared_basepath is None:
# use cwd
shared_basepath = pathlib.Path('.')

# 0. General setup and settings dependency resolution step

Expand Down Expand Up @@ -664,11 +664,13 @@ def run(self, *, dry=False, verbose=True,
else:
ffcache = None

ffcache.register()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Problem 1 (easy): this line and the ffcache.as_path() below will fail if ffcache is None.

Problem 2 (more complicated): How is ffcache supposed to work? Isn't that supposed to be something that we can pull down from some previous unit? shared_basepath is supposed to be for a specific execution of a given unit. That directory might not exist until we're about to start executing this unit. How does a file get in there? How are we currently using that? [EDIT: I'm having a call with @IAlibay tomorrow to clarify what is going on with this]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reading through, I think the only thing that would be cached would be the LJ parameters for the small molecules. We can probably remove this?


system_generator = system_creation.get_system_generator(
forcefield_settings=forcefield_settings,
thermo_settings=thermo_settings,
system_settings=system_settings,
cache=ffcache,
cache=ffcache.as_path(),
has_solvent=solvent_comp is not None,
)

Expand Down Expand Up @@ -812,10 +814,18 @@ def run(self, *, dry=False, verbose=True,
)

# a. Create the multistate reporter
nc = shared_basepath / sim_settings.output_filename
# TODO: Logic about keeping/not .nc files goes here
nc = (shared_basepath / sim_settings.output_filename)
checkpoint = (shared_basepath / sim_settings.checkpoint_storage)
real_time_analysis = (shared_basepath / "real_time_analysis.yaml")
# have to flag these files as being created so that they get brought back
nc.register()
checkpoint.register()
real_time_analysis.register()

chk = sim_settings.checkpoint_storage
reporter = multistate.MultiStateReporter(
storage=nc,
storage=str(nc.as_path()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're using this str(file.as_path()) pattern a lot. When you use .as_path(), you're explicitly bypassing the automatic registration mechanism, meaning that you must explicitly register the object.

I'm not sure if this is implemented, but you should be able to instead just use str(file), which should also register the file. Equivalently, str(pathlib.Path(file)) will definitely work, which means you don't need to manually use .register().

You should only need to use .register() on paths that are created outside your control (like the checkpoint and RTA files, which are paths generated inside OpenMMTools and not exposed to us).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah once it's MultiStateReporter wants a string (not path) and the other is going into a subprocess call, so again it's bypassing file-like objects and is a string.

analysis_particle_indices=selection_indices,
checkpoint_interval=sim_settings.checkpoint_interval.m,
checkpoint_storage=chk,
Expand Down Expand Up @@ -947,13 +957,12 @@ def run(self, *, dry=False, verbose=True,
sampling_method=sampler_settings.sampler_method.lower(),
result_units=unit.kilocalorie_per_mole,
)
analyzer.plot(filepath=shared_basepath, filename_prefix="")
analyzer.plot(filepath=permanent_basepath, filename_prefix="")
analyzer.close()

else:
# clean up the reporter file
fns = [shared_basepath / sim_settings.output_filename,
shared_basepath / sim_settings.checkpoint_storage]
fns = [nc.as_path(), checkpoint.as_path()]
for fn in fns:
os.remove(fn)
finally:
Expand Down Expand Up @@ -981,35 +990,38 @@ def run(self, *, dry=False, verbose=True,
if not dry: # pragma: no-cover
return {
'nc': nc,
'last_checkpoint': chk,
'last_checkpoint': checkpoint,
**analyzer.unit_results_dict
}
else:
return {'debug': {'sampler': sampler}}

@staticmethod
def analyse(where) -> dict:
def analyse(where: stagingregistry.StagingPath) -> dict:
# don't put energy analysis in here, it uses the open file reporter
# whereas structural stuff requires that the file handle is closed
ret = subprocess.run(['openfe_analysis', str(where)],
output = (where / 'results.json')
ret = subprocess.run(['openfe_analysis', 'RFE_analysis',
str(where.as_path()),
str(output.as_path())],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if ret.returncode:
return {'structural_analysis_error': ret.stderr}

data = json.loads(ret.stdout)
with open(output, 'r') as f:
data = json.load(f)

savedir = pathlib.Path(where)
if d := data['protein_2D_RMSD']:
fig = plotting.plot_2D_rmsd(d)
fig.savefig(savedir / "protein_2D_RMSD.png")
fig.savefig(where / "protein_2D_RMSD.png")
plt.close(fig)
f2 = plotting.plot_ligand_COM_drift(data['time(ps)'], data['ligand_wander'])
f2.savefig(savedir / "ligand_COM_drift.png")
f2.savefig(where / "ligand_COM_drift.png")
plt.close(f2)

f3 = plotting.plot_ligand_RMSD(data['time(ps)'], data['ligand_RMSD'])
f3.savefig(savedir / "ligand_RMSD.png")
f3.savefig(where / "ligand_RMSD.png")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not really this line, but this is the closest changed line)

You may want to include the figure filenames in the return value, too. Otherwise it will be harder to download them from the cloud.

plt.close(f3)

return {'structural_analysis': data}
Expand All @@ -1020,7 +1032,8 @@ def _execute(
log_system_probe(logging.INFO, paths=[ctx.scratch])
with without_oechem_backend():
outputs = self.run(scratch_basepath=ctx.scratch,
shared_basepath=ctx.shared)
shared_basepath=ctx.shared,
permanent_basepath=ctx.permanent)

analysis_outputs = self.analyse(ctx.shared)

Expand Down