diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py index 662f049cd..dd1616c7f 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/openfe/protocols/openmm_afe/base.py @@ -746,9 +746,13 @@ def _run_simulation( # Get the relevant simulation steps mc_steps = settings['integrator_settings'].n_steps.m - equil_steps, prod_steps = settings_validation.get_simsteps( - equil_length=settings['simulation_settings'].equilibration_length, - prod_length=settings['simulation_settings'].production_length, + equil_steps = settings_validation.get_simsteps( + sim_length=settings['simulation_settings'].equilibration_length, + timestep=settings['integrator_settings'].timestep, + mc_steps=mc_steps, + ) + prod_steps = settings_validation.get_simsteps( + sim_length=settings['simulation_settings'].production_length, timestep=settings['integrator_settings'].timestep, mc_steps=mc_steps, ) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 8086aaa72..97035386e 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -637,10 +637,13 @@ def run(self, *, dry=False, verbose=True, settings_validation.validate_timestep( forcefield_settings.hydrogen_mass, timestep ) - equil_steps, prod_steps = settings_validation.get_simsteps( - equil_length=sim_settings.equilibration_length, - prod_length=sim_settings.production_length, - timestep=timestep, mc_steps=mc_steps + equil_steps = settings_validation.get_simsteps( + sim_length=sim_settings.equilibration_length, + timestep=timestep, mc_steps=mc_steps, + ) + prod_steps = settings_validation.get_simsteps( + sim_length=sim_settings.production_length, + timestep=timestep, mc_steps=mc_steps, ) solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) diff --git a/openfe/protocols/openmm_utils/settings_validation.py b/openfe/protocols/openmm_utils/settings_validation.py index 5745a178d..976bc4963 100644 --- a/openfe/protocols/openmm_utils/settings_validation.py +++ b/openfe/protocols/openmm_utils/settings_validation.py @@ -37,17 +37,15 @@ def validate_timestep(hmass: float, timestep: unit.Quantity): raise ValueError(errmsg) -def get_simsteps(equil_length: unit.Quantity, prod_length: unit.Quantity, - timestep: unit.Quantity, mc_steps: int) -> Tuple[int, int]: +def get_simsteps(sim_length: unit.Quantity, + timestep: unit.Quantity, mc_steps: int) -> int: """ - Gets and validates the number of equilibration and production steps. + Gets and validates the number of simulation steps. Parameters ---------- - equil_length : unit.Quantity - Simulation equilibration length. - prod_length : unit.Quantity - Simulation production length. + sim_length : unit.Quantity + Simulation length. timestep : unit.Quantity Integration timestep. mc_steps : int @@ -55,29 +53,21 @@ def get_simsteps(equil_length: unit.Quantity, prod_length: unit.Quantity, Returns ------- - equil_steps : int - The number of equilibration timesteps. - prod_steps : int - The number of production timesteps. + sim_steps : int + The number of simulation timesteps. """ - equil_time = round(equil_length.to('attosecond').m) - prod_time = round(prod_length.to('attosecond').m) + sim_time = round(sim_length.to('attosecond').m) ts = round(timestep.to('attosecond').m) - equil_steps, mod = divmod(equil_time, ts) + sim_steps, mod = divmod(sim_time, ts) if mod != 0: - raise ValueError("Equilibration time not divisible by timestep") - prod_steps, mod = divmod(prod_time, ts) - if mod != 0: - raise ValueError("Production time not divisible by timestep") + raise ValueError("Simulation time not divisible by timestep") - for var in [("Equilibration", equil_steps, equil_time), - ("Production", prod_steps, prod_time)]: - if (var[1] % mc_steps) != 0: - errmsg = (f"{var[0]} time {var[2]/1000000} ps should contain a " - "number of steps divisible by the number of integrator " - f"timesteps between MC moves {mc_steps}") - raise ValueError(errmsg) + if (sim_steps % mc_steps) != 0: + errmsg = (f"Simulation time {sim_time/1000000} ps should contain a " + "number of steps divisible by the number of integrator " + f"timesteps between MC moves {mc_steps}") + raise ValueError(errmsg) - return equil_steps, prod_steps + return sim_steps diff --git a/openfe/tests/protocols/test_openmmutils.py b/openfe/tests/protocols/test_openmmutils.py index 7147f82d4..fb80e90fc 100644 --- a/openfe/tests/protocols/test_openmmutils.py +++ b/openfe/tests/protocols/test_openmmutils.py @@ -27,43 +27,30 @@ def test_validate_timestep(): settings_validation.validate_timestep(2.0, 4.0 * unit.femtoseconds) -@pytest.mark.parametrize('e,p,ts,mc,es,ps', [ - [1 * unit.nanoseconds, 5 * unit.nanoseconds, 4 * unit.femtoseconds, - 250, 250000, 1250000], - [1 * unit.picoseconds, 1 * unit.picoseconds, 2 * unit.femtoseconds, - 250, 500, 500], +@pytest.mark.parametrize('s,ts,mc,es', [ + [5 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 1250000], + [1 * unit.nanoseconds, 4 * unit.femtoseconds, 250, 250000], + [1 * unit.picoseconds, 2 * unit.femtoseconds, 250, 500], ]) -def test_get_simsteps(e, p, ts, mc, es, ps): - equil_steps, prod_steps = settings_validation.get_simsteps(e, p, ts, mc) +def test_get_simsteps(s, ts, mc, es): + sim_steps = settings_validation.get_simsteps(s, ts, mc) - assert equil_steps == es - assert prod_steps == ps + assert sim_steps == es -@pytest.mark.parametrize('nametype, timelengths', [ - ['Equilibration', [1.003 * unit.picoseconds, 1 * unit.picoseconds]], - ['Production', [1 * unit.picoseconds, 1.003 * unit.picoseconds]], -]) -def test_get_simsteps_indivisible_simtime(nametype, timelengths): - errmsg = f"{nametype} time not divisible by timestep" +def test_get_simsteps_indivisible_simtime(): + errmsg = "Simulation time not divisible by timestep" + timelength = 1.003 * unit.picosecond with pytest.raises(ValueError, match=errmsg): - settings_validation.get_simsteps( - timelengths[0], - timelengths[1], - 2 * unit.femtoseconds, - 100) + settings_validation.get_simsteps(timelength, 2 * unit.femtoseconds, 100) -@pytest.mark.parametrize('nametype, timelengths', [ - ['Equilibration', [1 * unit.picoseconds, 10 * unit.picoseconds]], - ['Production', [10 * unit.picoseconds, 1 * unit.picoseconds]], -]) -def test_mc_indivisible(nametype, timelengths): - errmsg = f"{nametype} time 1.0 ps should contain" +def test_mc_indivisible(): + errmsg = "Simulation time 1.0 ps should contain" + timelength = 1 * unit.picoseconds with pytest.raises(ValueError, match=errmsg): settings_validation.get_simsteps( - timelengths[0], timelengths[1], - 2 * unit.femtoseconds, 1000) + timelength, 2 * unit.femtoseconds, 1000) def test_get_alchemical_components(benzene_modifications, @@ -90,7 +77,7 @@ def test_get_alchemical_components(benzene_modifications, def test_duplicate_chemical_components(benzene_modifications): stateA = openfe.ChemicalSystem({'A': benzene_modifications['toluene'], - 'B': benzene_modifications['toluene'],}) + 'B': benzene_modifications['toluene'], }) stateB = openfe.ChemicalSystem({'A': benzene_modifications['toluene']}) errmsg = "state A components B:" @@ -139,7 +126,7 @@ def test_multiple_proteins(T4_protein_component): def test_get_components_gas(benzene_modifications): state = openfe.ChemicalSystem({'A': benzene_modifications['benzene'], - 'B': benzene_modifications['toluene'],}) + 'B': benzene_modifications['toluene'], }) s, p, mols = system_validation.get_components(state) @@ -152,7 +139,7 @@ def test_components_solvent(benzene_modifications): state = openfe.ChemicalSystem({'S': openfe.SolventComponent(), 'A': benzene_modifications['benzene'], - 'B': benzene_modifications['toluene'],}) + 'B': benzene_modifications['toluene'], }) s, p, mols = system_validation.get_components(state)