From 3e848a6e1f7bd7484c337c545df04f1d0d98690a Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 4 Dec 2024 13:24:29 +0000 Subject: [PATCH 01/33] first pass at abstract classes --- .../protocols/openmm_utils/omm_restraints.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 openfe/protocols/openmm_utils/omm_restraints.py diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py new file mode 100644 index 000000000..2a62d30ae --- /dev/null +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -0,0 +1,64 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Classes for applying restraints to OpenMM Systems. + +Acknowledgements +---------------- +Many of the classes here are at least in part inspired, if not taken from +`Yank `_ and +`OpenMMTools `_. + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from typing import Optional, Union + +from openmmtools.states import GlobalParameterState + + +class RestraintParameterState(GlobalParameterState): + """ + Composable state to control `lambda_restraints` OpenMM Force parameters. + + See :class:`openmmtools.states.GlobalParameterState` for more details. + + Parameters + ---------- + parameters_name_suffix : Optional[str] + If specified, the state will control a modified version of the parameter + ``lambda_restraints_{parameters_name_suffix}` instead of just ``lambda_restraints``. + lambda_restraints : Optional[float] + The strength of the restraint. If defined, must be between 0 and 1. + + Acknowledgement + --------------- + Partially reproduced from Yank. + """ + + lambda_restraints = GlobalParameterState.GlobalParameter('lambda_restraints', standard_value=1.0) + + @lambda_restraints.validator + def lambda_restraints(self, instance, new_value): + if new_value is not None and not (0.0 <= new_value <= 1.0): + errmsg = ("lambda_restraints must be between 0.0 and 1.0, " + f"got {new_value}") + raise ValueError(errmsg) + # Not crashing out on None to match upstream behaviour + return new_value + + +class BaseHostGuestRestraints(abc.ABC): + """ + An abstract base class for defining objects that apply a restraint between + two entities (referred to as a Host and a Guest). + + + TODO + ---- + Add some examples here. + """ + def __init__(self, host_atoms: list[int], guest_atoms: list[int], restraint_settings: SettingBaseModel, restraint_geometry: BaseRestraintGeometry): + From ef050e5ab47fc228d98d6fba00476b331c46ab3d Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 5 Dec 2024 15:31:32 +0000 Subject: [PATCH 02/33] A start at restraints and forces --- openfe/protocols/openmm_utils/omm_forces.py | 134 +++++++++++++ .../protocols/openmm_utils/omm_restraints.py | 187 +++++++++++++++++- 2 files changed, 315 insertions(+), 6 deletions(-) create mode 100644 openfe/protocols/openmm_utils/omm_forces.py diff --git a/openfe/protocols/openmm_utils/omm_forces.py b/openfe/protocols/openmm_utils/omm_forces.py new file mode 100644 index 000000000..e9246b694 --- /dev/null +++ b/openfe/protocols/openmm_utils/omm_forces.py @@ -0,0 +1,134 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Custom OpenMM Forces + +TODO +---- +* Add relevant duecredit entries. +""" +import numpy as np +import openmm + + +def get_boresch_energy_function( + control_parameter: str, + K_r: float, r_aA0: float, + K_thetaA: float, theta_A0: float, + K_thetaB: float, theta_B0: float, + K_phiA: float, phi_A0: float, + K_phiB: float, phi_B0: float, + K_phiC: float, phi_C0: float +) -> str: + energy_function = ( + f"{control_parameter} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " + "dphi_A = dA - floor(dA/(2*pi)+0.5)*(2*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2*pi)+0.5)*(2*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2*pi)+0.5)*(2*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + f"K_r = {K_r}; " + f"r_aA0 = {r_aA0}; " + f"K_thetaA = {K_thetaA}; " + f"theta_A0 = {theta_A0}; " + f"K_thetaB = {K_thetaB}; " + f"theta_B0 = {theta_B0}; " + f"K_phiA = {K_phiA}; " + f"phi_A0 = {phi_A0}; " + f"K_phiB = {K_phiB}; " + f"phi_B0 = {phi_B0}; " + f"K_phiC = {K_phiC}; " + f"phi_C0 = {phi_C0}; " + ) + return energy_function + + +def get_periodic_boresch_energy_function( + control_parameter: str, + K_r: float, r_aA0: float, + K_thetaA: float, theta_A0: float, + K_thetaB: float, theta_B0: float, + K_phiA: float, phi_A0: float, + K_phiB: float, phi_B0: float, + K_phiC: float, phi_C0: float +) -> str: + energy_function = ( + f"{control_parameter} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*uphi_A + (K_phiB/2)*uphi_B + (K_phiC/2)*uphi_C; " + "uphi_A = (1-cos(dA)); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + f"K_r = {K_r}; " + f"r_aA0 = {r_aA0}; " + f"K_thetaA = {K_thetaA}; " + f"theta_A0 = {theta_A0}; " + f"K_thetaB = {K_thetaB}; " + f"theta_B0 = {theta_B0}; " + f"K_phiA = {K_phiA}; " + f"phi_A0 = {phi_A0}; " + f"K_phiB = {K_phiB}; " + f"phi_B0 = {phi_B0}; " + f"K_phiC = {K_phiC}; " + f"phi_C0 = {phi_C0}; " + ) + return energy_function + + +def get_custom_compound_bond_force( + n_particles: int = 6, energy_function: str = BORESCH_ENERGY_FUNCTION +): + """ + Return an OpenMM CustomCompoundForce + + TODO + ---- + Change this to a direct subclass like openmmtools.force. + + Acknowledgements + ---------------- + Boresch-like energy functions are reproduced from `Yank `_ + """ + return openmm.CustomCompoundBondForce(n_particles, energy_function) + + +def add_force_in_separate_group( + system: openmm.System, + force: openmm.Force, +): + """ + Add force to a System in a separate force group. + + Parameters + ---------- + system : openmm.System + System to add the Force to. + force : openmm.Force + The Force to add to the System. + + Raises + ------ + ValueError + If all 32 force groups are occupied. + + + TODO + ---- + Unlike the original Yank implementation, we assume that + all 32 force groups will not be filled. Should this be an issue + we can consider just separating it from NonbondedForce. + + Acknowledgements + ---------------- + Mostly reproduced from `Yank `_. + """ + available_force_groups = set(range(32)) + for force in system.getForces(): + available_force_groups.discard(force.getForceGroup()) + + force.setForceGroup(min(available_force_groups)) + system.addForce(force) diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py index 2a62d30ae..f678428c8 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -14,9 +14,28 @@ * Add relevant duecredit entries. """ import abc -from typing import Optional, Union +from typing import Optional, Union, Callable -from openmmtools.states import GlobalParameterState +import openmm +from openmmtools.forces import ( + HarmonicRestraintForce, + HarmonicRestraintBondForce, + FlatBottomRestraintForce, + FlatBottomRestraintBondForce, +) +from openmmtools.states import GlobalParameterState, ThermodynamicState + +from gufe.settings.models import SettingsBaseModel +from openfe.protocols.openmm_utils.omm_forces import ( + get_custom_compound_bond_force, + add_force_in_separate_group, + get_boresch_energy_function, + get_periodic_boresch_energy_function, +) + + +class BaseRestraintGeometry: + pass class RestraintParameterState(GlobalParameterState): @@ -38,13 +57,16 @@ class RestraintParameterState(GlobalParameterState): Partially reproduced from Yank. """ - lambda_restraints = GlobalParameterState.GlobalParameter('lambda_restraints', standard_value=1.0) + lambda_restraints = GlobalParameterState.GlobalParameter( + "lambda_restraints", standard_value=1.0 + ) @lambda_restraints.validator def lambda_restraints(self, instance, new_value): if new_value is not None and not (0.0 <= new_value <= 1.0): - errmsg = ("lambda_restraints must be between 0.0 and 1.0, " - f"got {new_value}") + errmsg = ( + "lambda_restraints must be between 0.0 and 1.0, " f"got {new_value}" + ) raise ValueError(errmsg) # Not crashing out on None to match upstream behaviour return new_value @@ -60,5 +82,158 @@ class BaseHostGuestRestraints(abc.ABC): ---- Add some examples here. """ - def __init__(self, host_atoms: list[int], guest_atoms: list[int], restraint_settings: SettingBaseModel, restraint_geometry: BaseRestraintGeometry): + def __init__( + self, + host_atoms: list[int], + guest_atoms: list[int], + restraint_settings: SettingsBaseModel, + restraint_geometry: BaseRestraintGeometry, + controlling_parameter_name: str = "lambda_restraints", + ): + self.host_atoms = host_atoms + self.guest_atoms = guest_atoms + self.settings = restraint_settings + self.geometry = restraint_geometry + self._verify_input() + + @abc.abstractmethod + def _verify_inputs(self): + pass + + @abc.abstractmethod + def add_force(self, thermodynamic_state: ThermodynamicState): + pass + + @abc.abstractmethod + def get_standard_state_correction(self, thermodynamic_state: ThermodynamicState): + pass + + @abc.abstractmethod + def _get_force(self): + pass + + +class SingleBondMixin: + def _verify_input(self): + if len(self.host_atoms) != 1 or len(self.guest_atoms) != 1: + errmsg = ( + "host_atoms and guest_atoms must only include a single index " + f"each, got {len(host_atoms)} and " + f"{len(guest_atoms)} respectively." + ) + raise ValueError(errmsg) + super()._verify_inputs() + + +class BaseRadialllySymmetricRestraintForce(BaseHostGuestRestraints): + def _verify_inputs(self) -> None: + if not isinstance(self.settings, BaseDistanceRestraintSettings): + errmsg = f"Incorrect settings type {self.settings} passed through" + raise ValueError(errmsg) + if not isinstance(self.geometry, DistanceRestraintGeometry): + errmsg = f"Incorrect geometry type {self.geometry} passed through" + raise ValueError(errmsg) + + def add_force(self, thermodynamic_state: ThermodynamicState) -> None: + force = self._get_force() + force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) + # Note .system is a call to get_system() so it's returning a copy + system = thermodynamic_state.system + add_force_in_separate_group(system, force) + thermodynamic_state.system = system + + def get_standard_state_correction( + self, thermodynamic_state: ThermodynamicState + ) -> float: + force = self._get_force() + return force.compute_standard_state_correction( + thermodynamic_state, volume="system" + ) + + def _get_force(self): + raise NotImplementedError("only implemented in child classes") + + +class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): + def _get_force(self) -> openmm.Force: + return HarmonicRestraintBondForce( + spring_constant=self.settings.spring_constant, + restrained_atom_index1=self.host_atoms[0], + restrained_atom_index2=self.guest_atoms[0], + controlling_parameter_name=self.controlling_parameter_name, + ) + + +class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): + def _get_force(self) -> openmm.Force: + return FlatBottomRestraintBondForce( + spring_constant=self.settings.spring_constant, + well_radius=self.settings.well_radius, + restrained_atom_index1=self.host_atoms[0], + restrained_atom_index2=self.guest_atoms[0], + controlling_parameter_name=self.controlling_parameter_name, + ) + + +class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): + def _get_force(self) -> openmm.Force: + return HarmonicRestraintForce( + spring_constant=self.settings.spring_constant, + restrained_atom_index1=self.host_atoms, + restrained_atom_index2=self.guest_atoms, + controlling_parameter_name=self.controlling_parameter_name, + ) + + +class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): + def _get_force(self): + return FlatBottomRestraintBondForce( + spring_constant=self.settings.spring_constant, + well_radius=self.settings.well_radius, + restrained_atom_index1=self.host_atoms, + restrained_atom_index2=self.guest_atoms, + controlling_parameter_name=self.controlling_parameter_name, + ) + + +class BoreschRestraint(BaseHostGuestRestraints): + _EFUNC_METHOD: Callable = get_boresch_energy_function + def _verify_inputs(self) -> None: + if not isinstance(self.settings, BoreschRestraintSettings): + errmsg = f"Incorrect settings type {self.settings} passed through" + raise ValueError(errmsg) + if not isinstance(self.geometry, BoreschRestraintGeometry): + errmsg = f"Incorrect geometry type {self.geometry} passed through" + raise ValueError(errmsg) + + def add_force(self, thermodynamic_state: ThermodynamicState) -> None: + force = self._get_force() + force.addGlobalParameter(self.controlling_parameter_name, 1.0) + force.addBond(self.host_atoms + self.guest_atoms, []) + force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) + # Note .system is a call to get_system() so it's returning a copy + system = thermodynamic_state.system + add_force_in_separate_group(system, force) + thermodynamic_state.system = system + + def _get_force(self) -> openmm.Force: + efunc = _EFUNC_METHOD( + self.controlling_parameter_name, + self.settings.K_r, + self.geometry.r_aA0, + self.settings.K_thetaA, + self.geometry.theta_A0, + self.settings.K_thetaB, + self.geometry.theta_B0, + self.settings.K_phiA, + self.geometry.phi_A0, + self.settings.K_phiB, + self.geometry.phi_B0, + self.settings.K_phiC, + self.geometry.phi_C0, + ) + + return get_custom_compound_bond_force( + n_particles=6, energy_function=efunc + ) From 420f3e56d6321a2df49f06a865fefc628e771d87 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 6 Dec 2024 16:24:15 +0000 Subject: [PATCH 03/33] Add boresch restraint class --- openfe/protocols/openmm_utils/omm_forces.py | 42 +------- .../protocols/openmm_utils/omm_restraints.py | 95 ++++++++++++++----- 2 files changed, 76 insertions(+), 61 deletions(-) diff --git a/openfe/protocols/openmm_utils/omm_forces.py b/openfe/protocols/openmm_utils/omm_forces.py index e9246b694..3ad9d0aa6 100644 --- a/openfe/protocols/openmm_utils/omm_forces.py +++ b/openfe/protocols/openmm_utils/omm_forces.py @@ -13,46 +13,22 @@ def get_boresch_energy_function( control_parameter: str, - K_r: float, r_aA0: float, - K_thetaA: float, theta_A0: float, - K_thetaB: float, theta_B0: float, - K_phiA: float, phi_A0: float, - K_phiB: float, phi_B0: float, - K_phiC: float, phi_C0: float ) -> str: energy_function = ( f"{control_parameter} * E; " "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " - "dphi_A = dA - floor(dA/(2*pi)+0.5)*(2*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " - "dphi_B = dB - floor(dB/(2*pi)+0.5)*(2*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " - "dphi_C = dC - floor(dC/(2*pi)+0.5)*(2*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + "dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " f"pi = {np.pi}; " - f"K_r = {K_r}; " - f"r_aA0 = {r_aA0}; " - f"K_thetaA = {K_thetaA}; " - f"theta_A0 = {theta_A0}; " - f"K_thetaB = {K_thetaB}; " - f"theta_B0 = {theta_B0}; " - f"K_phiA = {K_phiA}; " - f"phi_A0 = {phi_A0}; " - f"K_phiB = {K_phiB}; " - f"phi_B0 = {phi_B0}; " - f"K_phiC = {K_phiC}; " - f"phi_C0 = {phi_C0}; " ) return energy_function def get_periodic_boresch_energy_function( control_parameter: str, - K_r: float, r_aA0: float, - K_thetaA: float, theta_A0: float, - K_thetaB: float, theta_B0: float, - K_phiA: float, phi_A0: float, - K_phiB: float, phi_B0: float, - K_phiC: float, phi_C0: float ) -> str: energy_function = ( f"{control_parameter} * E; " @@ -63,18 +39,6 @@ def get_periodic_boresch_energy_function( "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " f"pi = {np.pi}; " - f"K_r = {K_r}; " - f"r_aA0 = {r_aA0}; " - f"K_thetaA = {K_thetaA}; " - f"theta_A0 = {theta_A0}; " - f"K_thetaB = {K_thetaB}; " - f"theta_B0 = {theta_B0}; " - f"K_phiA = {K_phiA}; " - f"phi_A0 = {phi_A0}; " - f"K_phiB = {K_phiB}; " - f"phi_B0 = {phi_B0}; " - f"K_phiC = {K_phiC}; " - f"phi_C0 = {phi_C0}; " ) return energy_function diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py index f678428c8..d03b1f195 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -16,7 +16,9 @@ import abc from typing import Optional, Union, Callable +import numpy as np import openmm +from openmm import unit as omm_unit from openmmtools.forces import ( HarmonicRestraintForce, HarmonicRestraintBondForce, @@ -24,6 +26,7 @@ FlatBottomRestraintBondForce, ) from openmmtools.states import GlobalParameterState, ThermodynamicState +from openff.units.openmm import to_openmm from gufe.settings.models import SettingsBaseModel from openfe.protocols.openmm_utils.omm_forces import ( @@ -143,7 +146,7 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: add_force_in_separate_group(system, force) thermodynamic_state.system = system - def get_standard_state_correction( + def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState ) -> float: force = self._get_force() @@ -157,8 +160,9 @@ def _get_force(self): class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: + spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( - spring_constant=self.settings.spring_constant, + spring_constant=spring_constant, restrained_atom_index1=self.host_atoms[0], restrained_atom_index2=self.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, @@ -167,9 +171,11 @@ def _get_force(self) -> openmm.Force: class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: + spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( - spring_constant=self.settings.spring_constant, - well_radius=self.settings.well_radius, + spring_constant=spring_constant, + well_radius=well_radius, restrained_atom_index1=self.host_atoms[0], restrained_atom_index2=self.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, @@ -178,8 +184,9 @@ def _get_force(self) -> openmm.Force: class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self) -> openmm.Force: + spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( - spring_constant=self.settings.spring_constant, + spring_constant=spring_constant, restrained_atom_index1=self.host_atoms, restrained_atom_index2=self.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, @@ -188,9 +195,11 @@ def _get_force(self) -> openmm.Force: class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self): + spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( - spring_constant=self.settings.spring_constant, - well_radius=self.settings.well_radius, + spring_constant=spring_constant, + well_radius=well_radius, restrained_atom_index1=self.host_atoms, restrained_atom_index2=self.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, @@ -209,8 +218,6 @@ def _verify_inputs(self) -> None: def add_force(self, thermodynamic_state: ThermodynamicState) -> None: force = self._get_force() - force.addGlobalParameter(self.controlling_parameter_name, 1.0) - force.addBond(self.host_atoms + self.guest_atoms, []) force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system @@ -220,20 +227,64 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: def _get_force(self) -> openmm.Force: efunc = _EFUNC_METHOD( self.controlling_parameter_name, - self.settings.K_r, - self.geometry.r_aA0, - self.settings.K_thetaA, - self.geometry.theta_A0, - self.settings.K_thetaB, - self.geometry.theta_B0, - self.settings.K_phiA, - self.geometry.phi_A0, - self.settings.K_phiB, - self.geometry.phi_B0, - self.settings.K_phiC, - self.geometry.phi_C0, ) - return get_custom_compound_bond_force( + force = get_custom_compound_bond_force( n_particles=6, energy_function=efunc ) + + param_values = [] + + parameter_dict = { + 'K_r': self.settings.K_r, + 'r_aA0': self.geometry.r_aA0, + 'K_thetaA': self.settings.K_thetaA, + 'theta_A0': self.geometry.theta_A0, + 'K_thetaB': self.settings.K_thetaB, + 'theta_B0': self.geometry.theta_B0, + 'K_phiA': self.settings.K_phiA, + 'phi_A0': self.geometry.phi_A0, + 'K_phiB': self.settings.K_phiB, + 'phi_B0': self.geometry.phi_B0, + 'K_phiC': self.settings.K_phiC, + 'phi_C0': self.geometry.phi_C0, + } + for key, val in parameter_dict.items(): + param_values.append(to_openmm(val).value_in_unit_system(omm_unit.md_unit_system)) + force.addPerBondParameter(key) + + force.addGlobalParameter(self.controlling_parameter_name, 1.0) + force.addBond(self.host_atoms + self.guest_atoms, param_values) + return force + + def get_standard_state_correction( + self, thermodynamic_state: ThermodynamicState + ) -> float: + + StandardV = 1660.53928 * unit.angstroms**3 + kt = from_openmm(thermodynamic_state.kT) + + # distances + r_aA0 = self.geometry.r_aA0.to('nm') + sin_thetaA0 = np.sin(self.geometry.theta_A0.to('radians')) + sin_thetaB0 = np.sin(self.geometry.theta_B0.to('radians')) + + # restraint energies + K_r = self.settings.K_r.to('kilojoule_per_mole') + K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole') + k_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole') + K_phiA = self.settings.K_phiA.to('kilojoule_per_mole') + K_phiB = self.settings.K_phiB.to('kilojoule_per_mole') + K_phiC = self.settings.K_phiC.to('kilojoule_per_mole') + + numerator1 = 8.0 * (np.pi**2) * StandardV + denum1 = (r_aA0**2) * sin_thetaA0 * sin_thetaB0 + numerator2 = np.sqrt(K_r * K_thetaA * K_thetaB * K_phiA * K_phiB * K_phiC) + denum2 = (2.0 * np.pi * kt)**3 + + dG = -kt * np.log((numerator1/denum1) * (numerator2/denum2)) + + return dG + + +# TODO - implement periodic torsion Boresch restraint From 2e2b52e38a44c636fb431c0544c12503de59e0cf Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 6 Dec 2024 16:35:00 +0000 Subject: [PATCH 04/33] Fix units --- openfe/protocols/openmm_utils/omm_restraints.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py index d03b1f195..104b07a4a 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -261,7 +261,7 @@ def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState ) -> float: - StandardV = 1660.53928 * unit.angstroms**3 + StandardV = 1.66053928 * unit.nanometer**3 kt = from_openmm(thermodynamic_state.kT) # distances @@ -270,12 +270,12 @@ def get_standard_state_correction( sin_thetaB0 = np.sin(self.geometry.theta_B0.to('radians')) # restraint energies - K_r = self.settings.K_r.to('kilojoule_per_mole') - K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole') - k_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole') - K_phiA = self.settings.K_phiA.to('kilojoule_per_mole') - K_phiB = self.settings.K_phiB.to('kilojoule_per_mole') - K_phiC = self.settings.K_phiC.to('kilojoule_per_mole') + K_r = self.settings.K_r.to('kilojoule_per_mole / nm ** 2') + K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole / radians ** 2') + k_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole / radians ** 2') + K_phiA = self.settings.K_phiA.to('kilojoule_per_mole / radians ** 2') + K_phiB = self.settings.K_phiB.to('kilojoule_per_mole / radians ** 2') + K_phiC = self.settings.K_phiC.to('kilojoule_per_mole / radians ** 2') numerator1 = 8.0 * (np.pi**2) * StandardV denum1 = (r_aA0**2) * sin_thetaA0 * sin_thetaB0 From 76c5fcfe845250b527ab815cda89e204daedc206 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 9 Dec 2024 12:25:46 +0000 Subject: [PATCH 05/33] Fix correction return in kj/mole --- openfe/protocols/openmm_utils/omm_restraints.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py index 104b07a4a..915d51d3c 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -12,6 +12,7 @@ TODO ---- * Add relevant duecredit entries. +* Add Periodic Torsion Boresch class """ import abc from typing import Optional, Union, Callable @@ -26,7 +27,8 @@ FlatBottomRestraintBondForce, ) from openmmtools.states import GlobalParameterState, ThermodynamicState -from openff.units.openmm import to_openmm +from openff.units.openmm import to_openmm, from_openmm +from openff.units import unit from gufe.settings.models import SettingsBaseModel from openfe.protocols.openmm_utils.omm_forces import ( @@ -148,11 +150,13 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState - ) -> float: + ) -> unit.Quantity: force = self._get_force() - return force.compute_standard_state_correction( + corr = force.compute_standard_state_correction( thermodynamic_state, volume="system" ) + dg = corr * thermodynamic_state.kT + return from_openmm(dg).to('kilojoule_per_mole') def _get_force(self): raise NotImplementedError("only implemented in child classes") @@ -194,7 +198,7 @@ def _get_force(self) -> openmm.Force: class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): - def _get_force(self): + def _get_force(self) -> openmm.Force: spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( @@ -259,7 +263,7 @@ def _get_force(self) -> openmm.Force: def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState - ) -> float: + ) -> unit.Quantity: StandardV = 1.66053928 * unit.nanometer**3 kt = from_openmm(thermodynamic_state.kT) @@ -285,6 +289,3 @@ def get_standard_state_correction( dG = -kt * np.log((numerator1/denum1) * (numerator2/denum2)) return dG - - -# TODO - implement periodic torsion Boresch restraint From e9cd918c60a9d3053f2459acb61b37cead9f2579 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 9 Dec 2024 12:57:04 +0000 Subject: [PATCH 06/33] Add more restraint API bits --- .../openmm_utils/restraints/__init__.py | 0 .../openmm_utils/restraints/geometry.py | 56 +++++++++++++++++++ .../{ => restraints}/omm_forces.py | 0 .../{ => restraints}/omm_restraints.py | 34 +++++------ 4 files changed, 71 insertions(+), 19 deletions(-) create mode 100644 openfe/protocols/openmm_utils/restraints/__init__.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry.py rename openfe/protocols/openmm_utils/{ => restraints}/omm_forces.py (100%) rename openfe/protocols/openmm_utils/{ => restraints}/omm_restraints.py (90%) diff --git a/openfe/protocols/openmm_utils/restraints/__init__.py b/openfe/protocols/openmm_utils/restraints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/openmm_utils/restraints/geometry.py b/openfe/protocols/openmm_utils/restraints/geometry.py new file mode 100644 index 000000000..3d1f37a10 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry.py @@ -0,0 +1,56 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + + +class BaseRestraintGeometry(BaseModel, abc.ABC): + class Config: + arbitrary_types_allowed = True + + +class HostGuestRestraintGeometry(BaseRestraintGeometry): + """ + An ordered list of guest atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + guest_atoms: list[int] + """ + An ordered list of host atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + host_atoms: list[int] + + @validator("guest_atoms", "host_atoms") + def positive_idxs(cls, v): + if any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v + + +class BondDistanceRestraintGeoemtry(HostGuestRestraintGeometry): + @validator("host_atoms", "guest_atoms") + def single_atoms(cls, v): + if len(v) != 1: + errmsg = ( + "Host and guest atom lists must only include a single atom, " + f"got {len(v)} atoms." + ) + raise ValueError(errmsg) + return v diff --git a/openfe/protocols/openmm_utils/omm_forces.py b/openfe/protocols/openmm_utils/restraints/omm_forces.py similarity index 100% rename from openfe/protocols/openmm_utils/omm_forces.py rename to openfe/protocols/openmm_utils/restraints/omm_forces.py diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/omm_restraints.py similarity index 90% rename from openfe/protocols/openmm_utils/omm_restraints.py rename to openfe/protocols/openmm_utils/restraints/omm_restraints.py index 915d51d3c..0bfb6eb80 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/omm_restraints.py @@ -90,14 +90,10 @@ class BaseHostGuestRestraints(abc.ABC): def __init__( self, - host_atoms: list[int], - guest_atoms: list[int], restraint_settings: SettingsBaseModel, restraint_geometry: BaseRestraintGeometry, controlling_parameter_name: str = "lambda_restraints", ): - self.host_atoms = host_atoms - self.guest_atoms = guest_atoms self.settings = restraint_settings self.geometry = restraint_geometry self._verify_input() @@ -121,7 +117,7 @@ def _get_force(self): class SingleBondMixin: def _verify_input(self): - if len(self.host_atoms) != 1 or len(self.guest_atoms) != 1: + if len(self.geometry.host_atoms) != 1 or len(self.geometry.guest_atoms) != 1: errmsg = ( "host_atoms and guest_atoms must only include a single index " f"each, got {len(host_atoms)} and " @@ -148,7 +144,7 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: add_force_in_separate_group(system, force) thermodynamic_state.system = system - def get_standard_state_correction( + def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState ) -> unit.Quantity: force = self._get_force() @@ -164,48 +160,48 @@ def _get_force(self): class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: - spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( spring_constant=spring_constant, - restrained_atom_index1=self.host_atoms[0], - restrained_atom_index2=self.guest_atoms[0], + restrained_atom_index1=self.geometry.host_atoms[0], + restrained_atom_index2=self.geometry.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, ) class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: - spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, - restrained_atom_index1=self.host_atoms[0], - restrained_atom_index2=self.guest_atoms[0], + restrained_atom_index1=self.geometry.host_atoms[0], + restrained_atom_index2=self.geometry.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, ) class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self) -> openmm.Force: - spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( spring_constant=spring_constant, - restrained_atom_index1=self.host_atoms, - restrained_atom_index2=self.guest_atoms, + restrained_atom_index1=self.geometry.host_atoms, + restrained_atom_index2=self.geometry.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, ) class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self) -> openmm.Force: - spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, - restrained_atom_index1=self.host_atoms, - restrained_atom_index2=self.guest_atoms, + restrained_atom_index1=self.geometry.host_atoms, + restrained_atom_index2=self.geometry.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, ) @@ -258,7 +254,7 @@ def _get_force(self) -> openmm.Force: force.addPerBondParameter(key) force.addGlobalParameter(self.controlling_parameter_name, 1.0) - force.addBond(self.host_atoms + self.guest_atoms, param_values) + force.addBond(self.geometry.host_atoms + self.geometry.guest_atoms, param_values) return force def get_standard_state_correction( From f1bbd8a440fd254f2b1eb0dfa17f1326e605a11e Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 9 Dec 2024 15:50:07 +0000 Subject: [PATCH 07/33] move some things around --- openfe/protocols/openmm_utils/restraints/omm_restraints.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/omm_restraints.py index 0bfb6eb80..599230ccc 100644 --- a/openfe/protocols/openmm_utils/restraints/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/omm_restraints.py @@ -5,7 +5,7 @@ Acknowledgements ---------------- -Many of the classes here are at least in part inspired, if not taken from +Many of the classes here are at least in part inspired from `Yank `_ and `OpenMMTools `_. @@ -39,10 +39,6 @@ ) -class BaseRestraintGeometry: - pass - - class RestraintParameterState(GlobalParameterState): """ Composable state to control `lambda_restraints` OpenMM Force parameters. From ac452e9df96f2c8ae23536be702e6939b1e98769 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 11 Dec 2024 11:52:04 +0000 Subject: [PATCH 08/33] Some changes --- .../openmm_utils/restraints/geometry.py | 54 +++++++++++++++---- .../openmm_utils/restraints/omm_restraints.py | 4 +- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry.py b/openfe/protocols/openmm_utils/restraints/geometry.py index 3d1f37a10..14e1cd289 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry.py +++ b/openfe/protocols/openmm_utils/restraints/geometry.py @@ -10,6 +10,10 @@ import abc from pydantic.v1 import BaseModel, validator +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds + class BaseRestraintGeometry(BaseModel, abc.ABC): class Config: @@ -25,6 +29,7 @@ class HostGuestRestraintGeometry(BaseRestraintGeometry): The order matters! It will be used to define the underlying force. """ + guest_atoms: list[int] """ An ordered list of host atoms to restrain. @@ -44,13 +49,42 @@ def positive_idxs(cls, v): return v -class BondDistanceRestraintGeoemtry(HostGuestRestraintGeometry): - @validator("host_atoms", "guest_atoms") - def single_atoms(cls, v): - if len(v) != 1: - errmsg = ( - "Host and guest atom lists must only include a single atom, " - f"got {len(v)} atoms." - ) - raise ValueError(errmsg) - return v +class CentroidDistanceMixin: + def get_distance(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + ag1 = u.atoms[self.host_atoms] + ag2 = u.atoms[self.guest_atoms] + bond = calc_bonds( + ag1.center_of_mass(), ag2.center_of_mass(), u.atoms.dimensions + ) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + +def _check_single_atoms(value): + if len(value) != 1: + errmsg = ( + "Host and guest atom lists must only include a single atom, " + f"got {len(value)} atoms." + ) + raise ValueError(errmsg) + return value + + +class BondDistanceMixin: + def get_distance(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + at1 = u.atoms[self.host_atoms[0]] + at2 = u.atoms[self.guest_atoms[0]] + bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) + # convert to float so we avoid having a np.float64 value + return float(bond) * unit.angstrom + + +class CentroidDistanceRestraintGeometry(HostGuestRestraintGeometry, CentroidDistanceMixin): + pass + + +class BondDistanceRestraintGeoemtry(HostGuestRestraintGeometry, BondDistanceMixin): + _check_host_atoms: classmethod = validator("host_atoms", allow_reuse=True)(_check_single_atoms) + _check_guest_atoms: classmethod = validator("guest_atoms", allow_reuse=True)(_check_single_atoms) diff --git a/openfe/protocols/openmm_utils/restraints/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/omm_restraints.py index 599230ccc..ab5f4e821 100644 --- a/openfe/protocols/openmm_utils/restraints/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/omm_restraints.py @@ -168,7 +168,7 @@ def _get_force(self) -> openmm.Force: class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(self.geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, @@ -192,7 +192,7 @@ def _get_force(self) -> openmm.Force: class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(self.geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, From a19c86cae5100520cef62a87de7a6dcd44cb5050 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 12:03:52 +0000 Subject: [PATCH 09/33] refactor restraints --- .../openmm_utils/restraints/geometry.py | 90 ----- .../restraints/geometry/__init__.py | 0 .../openmm_utils/restraints/geometry/base.py | 50 +++ .../restraints/geometry/boresch.py | 66 ++++ .../restraints/geometry/flatbottom.py | 90 +++++ .../restraints/geometry/harmonic.py | 94 +++++ .../openmm_utils/restraints/geometry/utils.py | 360 ++++++++++++++++++ .../restraints/openmm/__init__.py | 0 .../restraints/{ => openmm}/omm_forces.py | 0 .../restraints/{ => openmm}/omm_restraints.py | 0 .../openmm_utils/restraints/search.py | 360 ++++++++++++++++++ 11 files changed, 1020 insertions(+), 90 deletions(-) delete mode 100644 openfe/protocols/openmm_utils/restraints/geometry.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/__init__.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/base.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/boresch.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/harmonic.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/utils.py create mode 100644 openfe/protocols/openmm_utils/restraints/openmm/__init__.py rename openfe/protocols/openmm_utils/restraints/{ => openmm}/omm_forces.py (100%) rename openfe/protocols/openmm_utils/restraints/{ => openmm}/omm_restraints.py (100%) create mode 100644 openfe/protocols/openmm_utils/restraints/search.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry.py b/openfe/protocols/openmm_utils/restraints/geometry.py deleted file mode 100644 index 14e1cd289..000000000 --- a/openfe/protocols/openmm_utils/restraints/geometry.py +++ /dev/null @@ -1,90 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -""" -Restraint Geometry classes - -TODO ----- -* Add relevant duecredit entries. -""" -import abc -from pydantic.v1 import BaseModel, validator - -from openff.units import unit -import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds - - -class BaseRestraintGeometry(BaseModel, abc.ABC): - class Config: - arbitrary_types_allowed = True - - -class HostGuestRestraintGeometry(BaseRestraintGeometry): - """ - An ordered list of guest atoms to restrain. - - Note - ---- - The order matters! It will be used to define the underlying - force. - """ - - guest_atoms: list[int] - """ - An ordered list of host atoms to restrain. - - Note - ---- - The order matters! It will be used to define the underlying - force. - """ - host_atoms: list[int] - - @validator("guest_atoms", "host_atoms") - def positive_idxs(cls, v): - if any([i < 0 for i in v]): - errmsg = "negative indices passed" - raise ValueError(errmsg) - return v - - -class CentroidDistanceMixin: - def get_distance(self, topology, coordinates) -> unit.Quantity: - u = mda.Universe(topology, coordinates) - ag1 = u.atoms[self.host_atoms] - ag2 = u.atoms[self.guest_atoms] - bond = calc_bonds( - ag1.center_of_mass(), ag2.center_of_mass(), u.atoms.dimensions - ) - # convert to float so we avoid having a np.float64 - return float(bond) * unit.angstrom - - -def _check_single_atoms(value): - if len(value) != 1: - errmsg = ( - "Host and guest atom lists must only include a single atom, " - f"got {len(value)} atoms." - ) - raise ValueError(errmsg) - return value - - -class BondDistanceMixin: - def get_distance(self, topology, coordinates) -> unit.Quantity: - u = mda.Universe(topology, coordinates) - at1 = u.atoms[self.host_atoms[0]] - at2 = u.atoms[self.guest_atoms[0]] - bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) - # convert to float so we avoid having a np.float64 value - return float(bond) * unit.angstrom - - -class CentroidDistanceRestraintGeometry(HostGuestRestraintGeometry, CentroidDistanceMixin): - pass - - -class BondDistanceRestraintGeoemtry(HostGuestRestraintGeometry, BondDistanceMixin): - _check_host_atoms: classmethod = validator("host_atoms", allow_reuse=True)(_check_single_atoms) - _check_guest_atoms: classmethod = validator("guest_atoms", allow_reuse=True)(_check_single_atoms) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/__init__.py b/openfe/protocols/openmm_utils/restraints/geometry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/openmm_utils/restraints/geometry/base.py b/openfe/protocols/openmm_utils/restraints/geometry/base.py new file mode 100644 index 000000000..21a714cde --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/base.py @@ -0,0 +1,50 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds, calc_angles + + +class BaseRestraintGeometry(BaseModel, abc.ABC): + class Config: + arbitrary_types_allowed = True + + +class HostGuestRestraintGeometry(BaseRestraintGeometry): + """ + An ordered list of guest atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + + guest_atoms: list[int] + """ + An ordered list of host atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + host_atoms: list[int] + + @validator("guest_atoms", "host_atoms") + def positive_idxs(cls, v): + if any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v + diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py new file mode 100644 index 000000000..822382b9c --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -0,0 +1,66 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds, calc_angles + +from .base import HostGuestRestraintGeometry + + +class BoreschRestraintGeometry(HostGuestRestraintGeometry): + """ + A class that defines the restraint geometry for a Boresch restraint. + + The restraint is defined by the following: + + H0 G2 + - - + - - + H1 - - H2 -- G0 - - G1 + + Where HX represents the X index of ``host_atoms`` and GX + the X index of ``guest_atoms``. + """ + def get_bond_distance(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + at1 = u.atoms[host_atoms[2]] + at2 = u.atoms[guest_atoms[0]] + bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + def get_angles(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + at1 = u.atoms[host_atoms[1]] + at2 = u.atoms[host_atoms[2]] + at3 = u.atoms[guest_atoms[0]] + at4 = u.atoms[guest_atoms[1]] + + angleA = calc_angles(at1.position, at2.position, at3.position, u.atoms.dimensions) + angleB = calc_angles(at2.position, at3.position, at4.position, u.atoms.dimensions) + return angleA, angleB + + def get_dihedrals(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + at1 = u.atoms[host_atoms[0]] + at2 = u.atoms[host_atoms[1]] + at3 = u.atoms[host_atoms[2]] + at4 = u.atoms[guest_atoms[0]] + at5 = u.atoms[guest_atoms[1]] + at6 = u.atoms[guest_atoms[2]] + + dihA = calc_dihedrals(at1.position, at2.position, at3.position, at4.position, u.atoms.dimensions) + dihB = calc_dihedrals(at2.position, at3.position, at4.position, at5.position, u.atoms.dimensions) + dihC = calc_dihedrals(at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions) + + return dihA, dihB, dihC diff --git a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py b/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py new file mode 100644 index 000000000..c7e987736 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py @@ -0,0 +1,90 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +import numpy as np +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles + +from .harmonic import ( + DistanceRestraintGeometry, + _get_selection, +) + + +class FlatBottomDistanceGeometry(DistanceRestraintGeometry): + """ + A geometry class for a flat bottom distance restraint between two groups + of atoms. + """ + + well_radius: FloatQuantity["nanometer"] + + +class COMDistanceAnalysis(AnalysisBase): + """ + Get a timeseries of COM distances between two AtomGroups + + Parameters + ---------- + group1 : MDAnalysis.AtomGroup + Atoms defining the first centroid. + group2 : MDANalysis.AtomGroup + Atoms defining the second centroid. + """ + + _analysis_algorithm_is_parallelizable = False + + def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): + super().__init__(host_atoms.universe.trajectory, **kwargs) + + self.ag1 = group1 + self.ag2 = group2 + + def _prepare(self): + self.results.distances = np.zeros(self.n_frames) + + def _single_frame(self): + com_dist = calc_bonds( + self.ag1.center_of_mass(), + self.ag2.center_of_mass(), + box=self.ag1.universe.dimensions, + ) + self.results.distances[self._frame_index] = com_dist + + def _conclude(self): + pass + + +def get_flatbottom_distance_restraint( + topology: Union[str, openmm.app.Topology], + trajectory: pathlib.Path, + topology_format: Optional[str] = None, + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, + padding: unit.Quantity = 0.5 * unit.nanometer, +) -> FlatBottomDistanceGeometry: + u = mda.Universe(topology, trajectory, topology_format=topology_format) + + guest_ag = _get_selection(u, guest_atoms, guest_selection) + host_ag = _get_selection(u, host_atoms, host_selection) + + com_dists = COMDistanceAnalysis(guest_ag, host_ag) + com_dists.run() + + well_radius = com_dists.results.distances.max() * unit.angstrom + padding + return FlatBottomDistanceGeometry( + guest_atoms=guest_atoms, host_atoms=host_atoms, well_radius=well_radius + ) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py new file mode 100644 index 000000000..36e7a61a7 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py @@ -0,0 +1,94 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds, calc_angles +from rdkit import Chem + +from .base import HostGuestRestraintGeometry +from .utils import _get_central_atom_idx + + +class DistanceRestraintGeometry(HostGuestRestraintGeometry): + """ + A geometry class for a distance restraint between two groups of atoms. + """ + + def get_distance(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + ag1 = u.atoms[self.host_atoms] + ag2 = u.atoms[self.guest_atoms] + bond = calc_bonds( + ag1.center_of_mass(), ag2.center_of_mass(), box=u.atoms.dimensions + ) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + +def _get_selection(universe, atom_list, selection): + if atom_list is None: + if selection is None: + raise ValueError( + "one of either the atom lists or selections must be defined" + ) + + ag = universe.select_atoms(selection) + else: + ag = universe.atoms[atom_list] + + return ag + + +def get_distance_restraint( + topology: Union[str, openmm.app.Topology], + trajectory: pathlib.Path, + topology_format: Optional[str] = None, + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, +) -> DistanceRestraintGeometry: + u = mda.Universe(topology, trajectory, topology_format=topology_format) + + guest_ag = _get_selection(u, guest_atoms, guest_selection) + host_ag = _get_selection(u, host_atoms, host_selection) + + return DistanceRestraintGeometry(guest_atoms=guest_atoms, host_atoms=host_atoms) + + +def get_molecule_centers_restraint( + topology: Union[str, openmm.app.Topology], + trajectory: pathlib.Path, + molA_rdmol: Chem.Mol, + molB_rdmol: Chem.Mol, + molA_idxs: list[int], + molB_idxs: list[int], + topology_format: Optional[str] = None, +): + # We assume that the mol idxs are ordered + centerA = molA_idxs[_get_central_atom_idx(molA_rdmol)] + centerB = molB_idxs[_get_central_atom_idx(molB_rdmol)] + + u = mda.Universe(topology, trajectory, topology_format=topology_format) + guest_ag = _get_selection( + u, + [centerA], + None, + ) + guest_ag = _get_selection( + u, + [centerB], + None, + ) + + return DistsanceRestraintGeometry(guest_atoms=guest_atoms, host_atoms=host_atoms) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py new file mode 100644 index 000000000..6b3d94eb7 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -0,0 +1,360 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Search methods for generating Geometry objects + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.toolkit import Molecule as OFFMol +from openff.units import unit +import networkx as nx +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles + + +def _get_aromatic_atom_idxs(rdmol) -> list[int]: + """ + Helper method to get aromatic atoms idxs + in a RDKit Molecule + + Parameters + ---------- + rdmol : ??? + RDKit Molecule + + Returns + ------- + list[int] + A list of the aromatic atom idxs + """ + idxs = [ + at.GetIdx() for at in rdmol.GetAtoms() + if at.GetIsAromatic() + ] + return idxs + + +def _get_heavy_atom_idxs(rdmol) -> list[int]: + """ + Get idxs of heavy atoms in an RDKit Molecule + + Parameters + ---------- + rmdol : ??? + + Returns + ------- + list[int] + A list of heavy atom idxs + """ + idxs = [ + at.GetIdx() for at in rdmol.GetAtoms() + if at.GetAtomicNum() > 1 + ] + return idxs + + +def _get_central_atom_idx(rdmol) -> int: + offmol = OFFMol(rdmol, allow_undefined_stereo=True) + # We take the zero-th entry if there are multiple center + # atoms (e.g. equal likelihood centers) + center = nx.center(offmol.to_networkx())[0] + return center + + +def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: + """ + Sort a list of atoms by their distance from a target atom. + + Parameters + ---------- + target_idx : int + The idx of the target atom. + atom_idxs : list[int] + The idx values of the atoms to sort. + rdmol : ??? + RDKit Molecule the atoms belong to + + Returns + ------- + list[int] + The input atom idxs sorted by their distance from the target atom. + """ + distances = [] + + conformer = rdmol.GetConformer() + # Get the target atom position + target_pos = conformer.GetAtomPosition(target_idx) + + for idx in atom_idxs: + pos = conformer.GetAtomPosition(idx) + distances.append(((target_pos - pos).Length(), idx)) + + return [i[1] for i in sorted(distances)] + + +def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): + angles = [] + + # Get the base atom and its neighbors + at1 = rdmol.GetAtomWithIdx(atom_idx) + at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] + + # We loop at2 and at3 through the sorted atom_pool in order to get + # a list of angles in the branch that are sorted by how close the atoms + # are from the central atom + for at2 in atom_pool: + if at2 in at1_neighbors: + at2_neighbors = [ + at.GetIdx() + for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + ] + for at3 in atom_pool: + if at3 != atom_idx and at3 in at2_neighbors: + angles.append((atom_idx, at2, at3)) + return angles + + +def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: + """ + Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). + + Parameters + ---------- + rdmol : ??? + Molecule object for the ligand to apply a restraint to. + + Returns + ------- + angles : list[tuple[int, int, int]] + A list of ligand atom triples denoting the possible l1, l2, and l3 + restraint atoms. Ordered by likelihood of restraint-ability. + """ + # Find the central atom + center = _get_central_atom_idx(rdmol) + + # Get a pool of potential anchor atoms looking for aromatic atoms + anchor_pool = _get_aromatic_atoms(rdmol) + + # If there are not enough aromatic atoms, then default to heavy atoms + if len(anchor_pool) < 3: + anchor_pool = _get_heavy_atoms(rdmol) + + # Raise an error if we have less than 3 anchors + if len(anchor_pool) < 3: + errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" + raise ValueError(errmsg) + + # Sort the pool of anchor atoms by their distance from the central atom + sorted_anchor_pool = _sort_by_distance_from_target(rdmol, center, anchor_pool) + + # Get a list of ligand anchor angle atoms + angles = [] + for atom in sorted_anchor_pool: + angles.extend( + _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) + ) + + +def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str): + """ + Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. + + Parameters + ---------- + positions : openmm.unit.Quantity + Positions of the input system + topology : openmm.app.Topology + OpenMM Topology for input system + exclude_resids : list[int] + List of residue numbers to exclude from host selection + lig_anchor_idx : int + The index of the l1 ligand anchor. + selection : str + Selection string for the host atoms. + """ + # Create an mdtraj trajectory to manipulate + # First fetch the box vectors and pass them as lengths and angles + vectors = from_openmm(topology.getPeriodicBoxVectors()) + a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles(vectors[0].m, vectors[1].m, vectors[2].m) + + traj = mdt.Trajectory( + positions[np.newaxis, ...], + mdt.Topology.from_openmm(topology) + ) + + # Get all the potential protein atoms matching the selection + host_sel = traj.topology.select(selection) + + # Get residues to exclude from the selection + exclude_sel = np.array([ + at.index for at in + chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) + ]) + + # Remove exclusion + anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] + + # Compute distanecs from ligand l1 anchor atom + pairs = np.vstack((anchors, np.array([lig_anchor_idx for _ in range(len(anchors))]))).T + + distances = mdt.compute_distances(traj, pairs, periodic=True) + + return np.array([pairs[i][0] for i in np.argsort(distances[0])]) + + +def is_collinear(positions, atoms, threshold=0.9): + """ + Check whether any sequential vectors in a sequence of atoms are collinear. + + Parameters + ---------- + positions : openmm.unit.Quantity + System positions. + atoms : list[int] + The indices of the atoms to test. + threshold : float + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. + + Returns + ------- + result : bool + Returns True if any sequential pair of vectors is collinear; False otherwise. + + Notes + ----- + Originally from Yank, with modifications from Separated Topologies + """ + results = False + for i in range(len(atoms) - 2): + v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] + v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] + normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + result = result or (np.abs(normalized_inner_product) > threshold) + return result + + +def check_angle(angle, force_constant=83.68): + """ + Check whether the chosen angle is less than 10 kT from 0 or 180 + + Parameters + ---------- + angle : float + The angle to check in degrees. + force_constant : float + Force constant of the angle. + + Note + ---- + We assume the temperature to be 298.15 Kelvin. + """ + # TODO: convert this to unit.Quantity so we don't end up with + # conversion errors + RT = 8.31445985 * 0.001 * 298.15 + # check if angle is <10kT from 0 or 180 + check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) + check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) + ang_check_1 = check1 / RT + ang_check_2 = check2 / RT + if ang_check_1 < 10.0 or ang_check_2 < 10.0: + return False + return True + + + + +class FindHostAtoms(AnalysisBase): + """ + Class filter host atoms based on their distance + from a set of guest atoms. + + Parameters + ---------- + host_atoms : MDAnalysis.AtomGroup + Initial selection of host atoms to filter from. + guest_atoms : MDANalysis.AtomGroup + Selection of guest atoms to search around. + search_distance: unit.Quantity + Distance to filter atoms within. + """ + _analysis_algorithm_is_parallelizable = False + + def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): + super().__init__(host_atoms.universe.trajectory, **kwargs) + + self.host_ag = host_atoms + self.guest_ag = guest_atoms + self.cutoff = search_distance.to('angstrom').m + + def _prepare(self): + self.results.host_idxs = set() + + def _single_frame(self): + pairs = capped_distance( + reference=self.host_ag.positions, + configuration=self.guest_ag.positions, + max_cutoff=self.cutoff, + min_cutoff=None + box=self.guest_ag.universe.dimensions, + return_distances=False) + + host_idxs = [self.guest_ag.atoms[p].index for p in pairs[:, 1]] + self.results.host_idxs.update(set(host_idxs)) + + def _conclude(self): + pass + + +def find_host_atoms(topology, trajectory, host_selection, guest_selection, cutoff) -> mda.AtomGroup: + """ + Get an AtomGroup of the host atoms based on their distances from the guest atoms. + """ + u = mda.Universe(topology, trajectory) + + def _get_selection(selection): + """ + If it's a str, call select_atoms, if not a list of atom idxs + """ + if isinstance(selection, str): + ag = u.select_atoms(host_selection) + else: + ag = u.atoms[host_ag] + return ag + + host_ag = _get_selection(host_selection) + guest_ag = _get_selection(guest_selection) + + finder = FindHostAtoms(host_ag, guest_ag, cutoff) + finder.run() + + return u.atoms[list(finder.results.host_idxs)] + +def get_molecule_center_idx(atomgroup): + offmol = Molecule(atomgroup.convert_to("RDKIT"), allow_undefined_stereo=True) + # Check if the molecule is whole, otherwise throw an error. + nx = offmol.to_networkx() + + +def get_distance_restraint(topology, trajectory, host_atoms, guest_atoms, host_selection, guest_selection): + u = mda.Universe(topology, trajectory) + + if guest_atoms is None: + if guest_selection is None: + raise ValueError("one of guest_atoms or guest_selections must be defined") + guest_ag = u.select_atoms(guest_selection) + else: + + + if host_atoms is None: + if host_selection is None: + raise ValueError("one of host_atoms or host_selection must be defined") + + host_ag = u.select_atoms(host_selection) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/__init__.py b/openfe/protocols/openmm_utils/restraints/openmm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/openmm_utils/restraints/omm_forces.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/omm_forces.py rename to openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py diff --git a/openfe/protocols/openmm_utils/restraints/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/omm_restraints.py rename to openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py diff --git a/openfe/protocols/openmm_utils/restraints/search.py b/openfe/protocols/openmm_utils/restraints/search.py new file mode 100644 index 000000000..6b3d94eb7 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/search.py @@ -0,0 +1,360 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Search methods for generating Geometry objects + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.toolkit import Molecule as OFFMol +from openff.units import unit +import networkx as nx +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles + + +def _get_aromatic_atom_idxs(rdmol) -> list[int]: + """ + Helper method to get aromatic atoms idxs + in a RDKit Molecule + + Parameters + ---------- + rdmol : ??? + RDKit Molecule + + Returns + ------- + list[int] + A list of the aromatic atom idxs + """ + idxs = [ + at.GetIdx() for at in rdmol.GetAtoms() + if at.GetIsAromatic() + ] + return idxs + + +def _get_heavy_atom_idxs(rdmol) -> list[int]: + """ + Get idxs of heavy atoms in an RDKit Molecule + + Parameters + ---------- + rmdol : ??? + + Returns + ------- + list[int] + A list of heavy atom idxs + """ + idxs = [ + at.GetIdx() for at in rdmol.GetAtoms() + if at.GetAtomicNum() > 1 + ] + return idxs + + +def _get_central_atom_idx(rdmol) -> int: + offmol = OFFMol(rdmol, allow_undefined_stereo=True) + # We take the zero-th entry if there are multiple center + # atoms (e.g. equal likelihood centers) + center = nx.center(offmol.to_networkx())[0] + return center + + +def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: + """ + Sort a list of atoms by their distance from a target atom. + + Parameters + ---------- + target_idx : int + The idx of the target atom. + atom_idxs : list[int] + The idx values of the atoms to sort. + rdmol : ??? + RDKit Molecule the atoms belong to + + Returns + ------- + list[int] + The input atom idxs sorted by their distance from the target atom. + """ + distances = [] + + conformer = rdmol.GetConformer() + # Get the target atom position + target_pos = conformer.GetAtomPosition(target_idx) + + for idx in atom_idxs: + pos = conformer.GetAtomPosition(idx) + distances.append(((target_pos - pos).Length(), idx)) + + return [i[1] for i in sorted(distances)] + + +def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): + angles = [] + + # Get the base atom and its neighbors + at1 = rdmol.GetAtomWithIdx(atom_idx) + at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] + + # We loop at2 and at3 through the sorted atom_pool in order to get + # a list of angles in the branch that are sorted by how close the atoms + # are from the central atom + for at2 in atom_pool: + if at2 in at1_neighbors: + at2_neighbors = [ + at.GetIdx() + for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + ] + for at3 in atom_pool: + if at3 != atom_idx and at3 in at2_neighbors: + angles.append((atom_idx, at2, at3)) + return angles + + +def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: + """ + Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). + + Parameters + ---------- + rdmol : ??? + Molecule object for the ligand to apply a restraint to. + + Returns + ------- + angles : list[tuple[int, int, int]] + A list of ligand atom triples denoting the possible l1, l2, and l3 + restraint atoms. Ordered by likelihood of restraint-ability. + """ + # Find the central atom + center = _get_central_atom_idx(rdmol) + + # Get a pool of potential anchor atoms looking for aromatic atoms + anchor_pool = _get_aromatic_atoms(rdmol) + + # If there are not enough aromatic atoms, then default to heavy atoms + if len(anchor_pool) < 3: + anchor_pool = _get_heavy_atoms(rdmol) + + # Raise an error if we have less than 3 anchors + if len(anchor_pool) < 3: + errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" + raise ValueError(errmsg) + + # Sort the pool of anchor atoms by their distance from the central atom + sorted_anchor_pool = _sort_by_distance_from_target(rdmol, center, anchor_pool) + + # Get a list of ligand anchor angle atoms + angles = [] + for atom in sorted_anchor_pool: + angles.extend( + _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) + ) + + +def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str): + """ + Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. + + Parameters + ---------- + positions : openmm.unit.Quantity + Positions of the input system + topology : openmm.app.Topology + OpenMM Topology for input system + exclude_resids : list[int] + List of residue numbers to exclude from host selection + lig_anchor_idx : int + The index of the l1 ligand anchor. + selection : str + Selection string for the host atoms. + """ + # Create an mdtraj trajectory to manipulate + # First fetch the box vectors and pass them as lengths and angles + vectors = from_openmm(topology.getPeriodicBoxVectors()) + a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles(vectors[0].m, vectors[1].m, vectors[2].m) + + traj = mdt.Trajectory( + positions[np.newaxis, ...], + mdt.Topology.from_openmm(topology) + ) + + # Get all the potential protein atoms matching the selection + host_sel = traj.topology.select(selection) + + # Get residues to exclude from the selection + exclude_sel = np.array([ + at.index for at in + chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) + ]) + + # Remove exclusion + anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] + + # Compute distanecs from ligand l1 anchor atom + pairs = np.vstack((anchors, np.array([lig_anchor_idx for _ in range(len(anchors))]))).T + + distances = mdt.compute_distances(traj, pairs, periodic=True) + + return np.array([pairs[i][0] for i in np.argsort(distances[0])]) + + +def is_collinear(positions, atoms, threshold=0.9): + """ + Check whether any sequential vectors in a sequence of atoms are collinear. + + Parameters + ---------- + positions : openmm.unit.Quantity + System positions. + atoms : list[int] + The indices of the atoms to test. + threshold : float + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. + + Returns + ------- + result : bool + Returns True if any sequential pair of vectors is collinear; False otherwise. + + Notes + ----- + Originally from Yank, with modifications from Separated Topologies + """ + results = False + for i in range(len(atoms) - 2): + v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] + v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] + normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + result = result or (np.abs(normalized_inner_product) > threshold) + return result + + +def check_angle(angle, force_constant=83.68): + """ + Check whether the chosen angle is less than 10 kT from 0 or 180 + + Parameters + ---------- + angle : float + The angle to check in degrees. + force_constant : float + Force constant of the angle. + + Note + ---- + We assume the temperature to be 298.15 Kelvin. + """ + # TODO: convert this to unit.Quantity so we don't end up with + # conversion errors + RT = 8.31445985 * 0.001 * 298.15 + # check if angle is <10kT from 0 or 180 + check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) + check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) + ang_check_1 = check1 / RT + ang_check_2 = check2 / RT + if ang_check_1 < 10.0 or ang_check_2 < 10.0: + return False + return True + + + + +class FindHostAtoms(AnalysisBase): + """ + Class filter host atoms based on their distance + from a set of guest atoms. + + Parameters + ---------- + host_atoms : MDAnalysis.AtomGroup + Initial selection of host atoms to filter from. + guest_atoms : MDANalysis.AtomGroup + Selection of guest atoms to search around. + search_distance: unit.Quantity + Distance to filter atoms within. + """ + _analysis_algorithm_is_parallelizable = False + + def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): + super().__init__(host_atoms.universe.trajectory, **kwargs) + + self.host_ag = host_atoms + self.guest_ag = guest_atoms + self.cutoff = search_distance.to('angstrom').m + + def _prepare(self): + self.results.host_idxs = set() + + def _single_frame(self): + pairs = capped_distance( + reference=self.host_ag.positions, + configuration=self.guest_ag.positions, + max_cutoff=self.cutoff, + min_cutoff=None + box=self.guest_ag.universe.dimensions, + return_distances=False) + + host_idxs = [self.guest_ag.atoms[p].index for p in pairs[:, 1]] + self.results.host_idxs.update(set(host_idxs)) + + def _conclude(self): + pass + + +def find_host_atoms(topology, trajectory, host_selection, guest_selection, cutoff) -> mda.AtomGroup: + """ + Get an AtomGroup of the host atoms based on their distances from the guest atoms. + """ + u = mda.Universe(topology, trajectory) + + def _get_selection(selection): + """ + If it's a str, call select_atoms, if not a list of atom idxs + """ + if isinstance(selection, str): + ag = u.select_atoms(host_selection) + else: + ag = u.atoms[host_ag] + return ag + + host_ag = _get_selection(host_selection) + guest_ag = _get_selection(guest_selection) + + finder = FindHostAtoms(host_ag, guest_ag, cutoff) + finder.run() + + return u.atoms[list(finder.results.host_idxs)] + +def get_molecule_center_idx(atomgroup): + offmol = Molecule(atomgroup.convert_to("RDKIT"), allow_undefined_stereo=True) + # Check if the molecule is whole, otherwise throw an error. + nx = offmol.to_networkx() + + +def get_distance_restraint(topology, trajectory, host_atoms, guest_atoms, host_selection, guest_selection): + u = mda.Universe(topology, trajectory) + + if guest_atoms is None: + if guest_selection is None: + raise ValueError("one of guest_atoms or guest_selections must be defined") + guest_ag = u.select_atoms(guest_selection) + else: + + + if host_atoms is None: + if host_selection is None: + raise ValueError("one of host_atoms or host_selection must be defined") + + host_ag = u.select_atoms(host_selection) From 20dd1dcce9f6608d0fe5b3d37fcd257709f4a109 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 14:02:05 +0000 Subject: [PATCH 10/33] add some angle checks --- .../openmm_utils/restraints/geometry/utils.py | 132 +++++++++++++++++- 1 file changed, 126 insertions(+), 6 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 6b3d94eb7..80b7c3372 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -12,20 +12,22 @@ from openff.toolkit import Molecule as OFFMol from openff.units import unit +from openff.units.types import FloatQuantity import networkx as nx +from rdkit import Chem import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.lib.distances import calc_bonds, calc_angles -def _get_aromatic_atom_idxs(rdmol) -> list[int]: +def get_aromatic_atom_idxs(rdmol: Chem.Mol) -> list[int]: """ Helper method to get aromatic atoms idxs in a RDKit Molecule Parameters ---------- - rdmol : ??? + rdmol : Chem.Mol RDKit Molecule Returns @@ -40,13 +42,13 @@ def _get_aromatic_atom_idxs(rdmol) -> list[int]: return idxs -def _get_heavy_atom_idxs(rdmol) -> list[int]: +def get_heavy_atom_idxs(rdmol: Chem.Mol) -> list[int]: """ Get idxs of heavy atoms in an RDKit Molecule Parameters ---------- - rmdol : ??? + rmdol : Chem.Mol Returns ------- @@ -60,14 +62,132 @@ def _get_heavy_atom_idxs(rdmol) -> list[int]: return idxs -def _get_central_atom_idx(rdmol) -> int: +def get_central_atom_idx(rdmol: Chem.Mol) -> int: + """ + Get the central atom in an rdkit Molecule. + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molcule to query + + Returns + ------- + center : int + Index of central atom in Molecule + + Note + ---- + If there are equal likelihood centers, will return + the first entry. + """ + # TODO: switch to a manual conversion to avoid an OpenFF dependency offmol = OFFMol(rdmol, allow_undefined_stereo=True) + nx_mol = offmol.to_networkx() + if not nx.is_weakly_connected(nx_mol): + errmsg = "A disconnected molecule was passed, cannot find the center" + raise ValueError(errmsg) + # We take the zero-th entry if there are multiple center # atoms (e.g. equal likelihood centers) - center = nx.center(offmol.to_networkx())[0] + center = nx.center(nx_mol)[0] return center +def is_collinear(positions, atoms, threshold=0.9): + """ + Check whether any sequential vectors in a sequence of atoms are collinear. + + Parameters + ---------- + positions : openmm.unit.Quantity + System positions. + atoms : list[int] + The indices of the atoms to test. + threshold : float + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. + + Returns + ------- + result : bool + Returns True if any sequential pair of vectors is collinear; False otherwise. + + Notes + ----- + Originally from Yank, with modifications from Separated Topologies + """ + results = False + for i in range(len(atoms) - 2): + v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] + v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] + normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + result = result or (np.abs(normalized_inner_product) > threshold) + return result + + +def check_angle_energy( + angle: FloatQuantity['radians'], + force_constant: FloatQuantity['unit.kilojoule_per_mole / unit.radians**2'] = 83.68 * unit.kilojoule_per_mole / unit.radians**2, + temperature: FloatQuantity['kelvin'] = 298.15 * unit.kelvin +) -> bool: + """ + Check whether the chosen angle is less than 10 kT from 0 or 180 + + Parameters + ---------- + angle : unit.Quantity + The angle to check in units compatible with radians. + force_constant : unit.Quantity + Force constant of the angle in units compatible with kilojoule_per_mole / radians ** 2. + temperature: unit.Quantity + The system temperature in units compatible with Kelvin. + + Note + ---- + We assume the temperature to be 298.15 Kelvin. + """ + # Convert things + angle_rads = angle.to('radians') + frc_const = force_constant.to('unit.kilojoule_per_mole / unit.radians**2') + temp_kelvin = temperature.to('kelvin') + RT = 8.31445985 * 0.001 * temp_kelvin + + # check if angle is <10kT from 0 or 180 + check1 = 0.5 * frc_const * np.power((angle - 0.0), 2) + check2 = 0.5 * frc_const * np.power((angle - np.pi), 2) + ang_check_1 = check1 / RT + ang_check_2 = check2 / RT + if ang_check_1 < 10.0 or ang_check_2 < 10.0: + return False + return True + + +def check_dihedral_bounds( + dihedral: FloatQuantity['radians'] + lower_cutoff: FloatQuantity['radians'] = 2.618 * unit.radians, + upper_cutoff: FloatQuantity['radians'] = -2.6.18 * unit.radians, +): + """ + Check that a dihedral does not exceed the bounds set by + lower_cutoff and upper_cutoff. + + Parameters + ---------- + dihedral : unit.Quantity + Dihedral in units compatible with radians. + lower_cutoff : unit.Quantity + Dihedral lower cutoff in units compatible with radians. + upper_cutoff : unit.Quantity + Dihedral upper cutoff in units compatible with radians. + """ + if (dihedral < lower_cutoff) or (dihedral > upper_cutoff): + return False + return True + + + + def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: """ Sort a list of atoms by their distance from a target atom. From 9ab74a8225d7efc511fac79975abd5a2f795b5de Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 14:19:39 +0000 Subject: [PATCH 11/33] only construct with settings --- .../restraints/openmm/omm_restraints.py | 111 ++++++++++-------- 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py index ab5f4e821..e53e828d5 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py @@ -87,40 +87,42 @@ class BaseHostGuestRestraints(abc.ABC): def __init__( self, restraint_settings: SettingsBaseModel, - restraint_geometry: BaseRestraintGeometry, controlling_parameter_name: str = "lambda_restraints", ): self.settings = restraint_settings - self.geometry = restraint_geometry - self._verify_input() + self._verify_settings() @abc.abstractmethod - def _verify_inputs(self): + def _verify_settings(self): pass @abc.abstractmethod - def add_force(self, thermodynamic_state: ThermodynamicState): + def _verify_geometry(self, geometry): pass @abc.abstractmethod - def get_standard_state_correction(self, thermodynamic_state: ThermodynamicState): + def add_force(self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry): pass @abc.abstractmethod - def _get_force(self): + def get_standard_state_correction(self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry): + pass + + @abc.abstractmethod + def _get_force(self, geometry: BaseRestraintGeometry): pass class SingleBondMixin: - def _verify_input(self): - if len(self.geometry.host_atoms) != 1 or len(self.geometry.guest_atoms) != 1: + def _verify_geometry(self, geometry: BaseRestraintGeometry): + if len(geometry.host_atoms) != 1 or len(geometry.guest_atoms) != 1: errmsg = ( "host_atoms and guest_atoms must only include a single index " f"each, got {len(host_atoms)} and " f"{len(guest_atoms)} respectively." ) raise ValueError(errmsg) - super()._verify_inputs() + super()._verify_geometry(geometry) class BaseRadialllySymmetricRestraintForce(BaseHostGuestRestraints): @@ -128,12 +130,15 @@ def _verify_inputs(self) -> None: if not isinstance(self.settings, BaseDistanceRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" raise ValueError(errmsg) - if not isinstance(self.geometry, DistanceRestraintGeometry): - errmsg = f"Incorrect geometry type {self.geometry} passed through" + + def _verify_geometry(self, geometry: DistanceRestraintGeometry) + if not isinstance(geometry, DistanceRestraintGeometry): + errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) - def add_force(self, thermodynamic_state: ThermodynamicState) -> None: - force = self._get_force() + def add_force(self, thermodynamic_state: ThermodynamicState, geometry: DistanceRestraintGeometry) -> None: + self._verify_geometry(geometry) + force = self._get_force(geometry) force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system @@ -141,87 +146,92 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: thermodynamic_state.system = system def get_standard_state_correction( - self, thermodynamic_state: ThermodynamicState + self, + thermodynamic_state: ThermodynamicState, + geometry: DistanceRestraintGeometry, ) -> unit.Quantity: - force = self._get_force() + self._verify_geometry(geometry) + force = self._get_force(geometry) corr = force.compute_standard_state_correction( thermodynamic_state, volume="system" ) dg = corr * thermodynamic_state.kT return from_openmm(dg).to('kilojoule_per_mole') - def _get_force(self): + def _get_force(self, geometry: DistanceRestraintGeometry): raise NotImplementedError("only implemented in child classes") class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): - def _get_force(self) -> openmm.Force: + def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( spring_constant=spring_constant, - restrained_atom_index1=self.geometry.host_atoms[0], - restrained_atom_index2=self.geometry.guest_atoms[0], + restrained_atom_index1=geometry.host_atoms[0], + restrained_atom_index2=geometry.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, ) class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): - def _get_force(self) -> openmm.Force: + def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(self.geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, - restrained_atom_index1=self.geometry.host_atoms[0], - restrained_atom_index2=self.geometry.guest_atoms[0], + restrained_atom_index1=geometry.host_atoms[0], + restrained_atom_index2=geometry.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, ) class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): - def _get_force(self) -> openmm.Force: + def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( spring_constant=spring_constant, - restrained_atom_index1=self.geometry.host_atoms, - restrained_atom_index2=self.geometry.guest_atoms, + restrained_atom_index1=geometry.host_atoms, + restrained_atom_index2=geometry.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, ) class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): - def _get_force(self) -> openmm.Force: + def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(self.geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, - restrained_atom_index1=self.geometry.host_atoms, - restrained_atom_index2=self.geometry.guest_atoms, + restrained_atom_index1=geometry.host_atoms, + restrained_atom_index2=geometry.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, ) class BoreschRestraint(BaseHostGuestRestraints): - _EFUNC_METHOD: Callable = get_boresch_energy_function - def _verify_inputs(self) -> None: + def _verify_settings(self) -> None: if not isinstance(self.settings, BoreschRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" raise ValueError(errmsg) - if not isinstance(self.geometry, BoreschRestraintGeometry): - errmsg = f"Incorrect geometry type {self.geometry} passed through" + + def _verify_geometry(self, geometry: BoreschRestraintGeometry): + if not isinstance(geometry, BoreschRestraintGeometry): + errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) - def add_force(self, thermodynamic_state: ThermodynamicState) -> None: - force = self._get_force() + def add_force(self, thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry) -> None: + _verify_geometry(geometry) + force = self._get_force(geometry) force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system add_force_in_separate_group(system, force) thermodynamic_state.system = system - def _get_force(self) -> openmm.Force: - efunc = _EFUNC_METHOD( + def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: + efunc = get_boresch_energy_function( self.controlling_parameter_name, ) @@ -233,37 +243,38 @@ def _get_force(self) -> openmm.Force: parameter_dict = { 'K_r': self.settings.K_r, - 'r_aA0': self.geometry.r_aA0, + 'r_aA0': geometry.r_aA0, 'K_thetaA': self.settings.K_thetaA, - 'theta_A0': self.geometry.theta_A0, + 'theta_A0': geometry.theta_A0, 'K_thetaB': self.settings.K_thetaB, - 'theta_B0': self.geometry.theta_B0, + 'theta_B0': geometry.theta_B0, 'K_phiA': self.settings.K_phiA, - 'phi_A0': self.geometry.phi_A0, + 'phi_A0': geometry.phi_A0, 'K_phiB': self.settings.K_phiB, - 'phi_B0': self.geometry.phi_B0, + 'phi_B0': geometry.phi_B0, 'K_phiC': self.settings.K_phiC, - 'phi_C0': self.geometry.phi_C0, + 'phi_C0': geometry.phi_C0, } for key, val in parameter_dict.items(): param_values.append(to_openmm(val).value_in_unit_system(omm_unit.md_unit_system)) force.addPerBondParameter(key) force.addGlobalParameter(self.controlling_parameter_name, 1.0) - force.addBond(self.geometry.host_atoms + self.geometry.guest_atoms, param_values) + force.addBond(geometry.host_atoms + geometry.guest_atoms, param_values) return force def get_standard_state_correction( - self, thermodynamic_state: ThermodynamicState + self, thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry ) -> unit.Quantity: + self._verify_geometry(geometry) StandardV = 1.66053928 * unit.nanometer**3 kt = from_openmm(thermodynamic_state.kT) # distances - r_aA0 = self.geometry.r_aA0.to('nm') - sin_thetaA0 = np.sin(self.geometry.theta_A0.to('radians')) - sin_thetaB0 = np.sin(self.geometry.theta_B0.to('radians')) + r_aA0 = geometry.r_aA0.to('nm') + sin_thetaA0 = np.sin(geometry.theta_A0.to('radians')) + sin_thetaB0 = np.sin(geometry.theta_B0.to('radians')) # restraint energies K_r = self.settings.K_r.to('kilojoule_per_mole / nm ** 2') From 8f2e1e03dd613caf1f91df2d07d61e9751a03c3d Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 14:35:33 +0000 Subject: [PATCH 12/33] Add more checks to utilities --- .../openmm_utils/restraints/geometry/utils.py | 45 +++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 80b7c3372..30e81123f 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -10,9 +10,12 @@ import abc from pydantic.v1 import BaseModel, validator +import numpy as np +from scipy.stats import circvar, circmean, circstd + from openff.toolkit import Molecule as OFFMol from openff.units import unit -from openff.units.types import FloatQuantity +from openff.models.types import FloatQuantity, ArrayQuantity import networkx as nx from rdkit import Chem import MDAnalysis as mda @@ -132,7 +135,7 @@ def check_angle_energy( temperature: FloatQuantity['kelvin'] = 298.15 * unit.kelvin ) -> bool: """ - Check whether the chosen angle is less than 10 kT from 0 or 180 + Check whether the chosen angle is less than 10 kT from 0 or pi radians Parameters ---------- @@ -143,6 +146,12 @@ def check_angle_energy( temperature: unit.Quantity The system temperature in units compatible with Kelvin. + + Returns + ------- + bool + If the angle is less than 10 kT from 0 or pi radians + Note ---- We assume the temperature to be 298.15 Kelvin. @@ -167,7 +176,7 @@ def check_dihedral_bounds( dihedral: FloatQuantity['radians'] lower_cutoff: FloatQuantity['radians'] = 2.618 * unit.radians, upper_cutoff: FloatQuantity['radians'] = -2.6.18 * unit.radians, -): +) -> bool: """ Check that a dihedral does not exceed the bounds set by lower_cutoff and upper_cutoff. @@ -180,12 +189,42 @@ def check_dihedral_bounds( Dihedral lower cutoff in units compatible with radians. upper_cutoff : unit.Quantity Dihedral upper cutoff in units compatible with radians. + + Returns + ------- + bool + ``True`` if the dihedral is within the upper and lower + cutoff bounds. """ if (dihedral < lower_cutoff) or (dihedral > upper_cutoff): return False return True +def check_angular_variance( + angles: ArrayQuantity['radians'] + width: FloatQuantity['radians'] +) -> bool: + """ + Check that the variance of a list of ``angles`` does not exceed + a given ``width`` + + Parameters + ---------- + angles : ArrayLike[unit.Quantity] + An array of angles in units compatible with radians. + width : unit.Quantity + The width to check the variance against, in units compatible with radians. + + Returns + ------- + bool + ``True`` if the variance of the angles is less than the width. + + """ + array = angles.to('radians').m + variance = circvar(array) + return not (variance * unit.radians > width) def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: From 0a480aa0771f7112b9c92647877c7f02c5bc6a50 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 21:14:59 +0000 Subject: [PATCH 13/33] host finding code --- .../restraints/geometry/boresch.py | 213 +++++++++++++++- .../openmm_utils/restraints/geometry/utils.py | 240 ++++++++++-------- 2 files changed, 337 insertions(+), 116 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index 822382b9c..d6241f3d7 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -10,6 +10,8 @@ import abc from pydantic.v1 import BaseModel, validator +from rdkit import Chem + from openff.units import unit import MDAnalysis as mda from MDAnalysis.lib.distances import calc_bonds, calc_angles @@ -31,6 +33,7 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. """ + def get_bond_distance(self, topology, coordinates) -> unit.Quantity: u = mda.Universe(topology, coordinates) at1 = u.atoms[host_atoms[2]] @@ -46,8 +49,12 @@ def get_angles(self, topology, coordinates) -> unit.Quantity: at3 = u.atoms[guest_atoms[0]] at4 = u.atoms[guest_atoms[1]] - angleA = calc_angles(at1.position, at2.position, at3.position, u.atoms.dimensions) - angleB = calc_angles(at2.position, at3.position, at4.position, u.atoms.dimensions) + angleA = calc_angles( + at1.position, at2.position, at3.position, u.atoms.dimensions + ) + angleB = calc_angles( + at2.position, at3.position, at4.position, u.atoms.dimensions + ) return angleA, angleB def get_dihedrals(self, topology, coordinates) -> unit.Quantity: @@ -59,8 +66,204 @@ def get_dihedrals(self, topology, coordinates) -> unit.Quantity: at5 = u.atoms[guest_atoms[1]] at6 = u.atoms[guest_atoms[2]] - dihA = calc_dihedrals(at1.position, at2.position, at3.position, at4.position, u.atoms.dimensions) - dihB = calc_dihedrals(at2.position, at3.position, at4.position, at5.position, u.atoms.dimensions) - dihC = calc_dihedrals(at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions) + dihA = calc_dihedrals( + at1.position, at2.position, at3.position, at4.position, u.atoms.dimensions + ) + dihB = calc_dihedrals( + at2.position, at3.position, at4.position, at5.position, u.atoms.dimensions + ) + dihC = calc_dihedrals( + at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions + ) return dihA, dihB, dihC + + +def _sort_by_distance_from_atom( + rdmol: Chem.Mol, target_idx: int, atom_idxs: Iterable[int] +) -> list[int]: + """ + Sort a list of RDMol atoms by their distance from a target atom. + + Parameters + ---------- + target_idx : int + The idx of the atom to measure from. + atom_idxs : list[int] + The idx values of the atoms to sort. + rdmol : Chem.Mol + RDKit Molecule the atoms belong to + + Returns + ------- + list[int] + The input atom idxs sorted by their distance from the target atom. + """ + distances = [] + + conformer = rdmol.GetConformer() + # Get the target atom position + target_pos = conformer.GetAtomPosition(target_idx) + + for idx in atom_idxs: + pos = conformer.GetAtomPosition(idx) + distances.append(((target_pos - pos).Length(), idx)) + + return [i[1] for i in sorted(distances)] + + +def _get_bonded_angles_from_pool( + rdmol: Chem.Mol, atom_idx: int, atom_pool: list[int] +) -> list[tuple[int, int, int]]: + """ + Get all bonded angles starting from ``atom_idx`` from a pool of atoms. + + Parameters + ---------- + rdmol : Chem.Mol + The RDKit Molecule + atom_idx : int + The index of the atom to search angles from. + atom_pool : list[int] + The list of indices to pick possible angle partners from. + + Returns + ------- + list[tuple[int, int, int]] + A list of tuples containing all the angles. + """ + angles = [] + + # Get the base atom and its neighbors + at1 = rdmol.GetAtomWithIdx(atom_idx) + at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] + + # We loop at2 and at3 through the sorted atom_pool in order to get + # a list of angles in the branch that are sorted by how close the atoms + # are from the central atom + for at2 in atom_pool: + if at2 in at1_neighbors: + at2_neighbors = [ + at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + ] + for at3 in atom_pool: + if at3 != atom_idx and at3 in at2_neighbors: + angles.append((atom_idx, at2, at3)) + return angles + + +def get_small_molecule_atom_candidates( + topology: Union[str, openmm.app.Topology], + trajectory: Union[str, pathlib.Path], + rdmol: Chem.Mol, + ligand_idxs: list[int], + rmsf_cutoff: unit.Quantity = 1 * unit.angstrom, + angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, +): + """ + Get a list of potential ligand atom choices for a Boresch restraint + being applied to a given small molecule. + + TODO: remember to update the RDMol with the last frame positions + """ + if isinstance(topology, openmm.app.Topology): + topology_format = "OPENMMTOPOLOGY" + else: + topology_format = None + + u = mda.Universe(topology, trajectory, topology_format=topology_format) + ligand_ag = u.atoms[ligand_idxs] + + # 0. Get the ligand RMSF + rmsf = get_local_rmsf(ligand_ag) + u.trajectory[-1] # forward to the last frame + + # 1. Get the pool of atoms to work with + # TODO: move to a helper function to make it easier to test + # Get a list of all the aromatic rings + # Note: no need to keep track of rings because we'll filter by + # bonded terms after, so if we only keep rings then all the bonded + # atoms should be within the same ring system. + atom_pool = set() + for ring in get_aromatic_rings(rdmol): + max_rmsf = rmsf[list(ring)].max() + if max_rmsf < rmsf_cutoff: + atom_pool.update(ring) + + # if we don't have enough atoms just get all the heavy atoms + if len(atom_pool) < 3: + heavy_atoms = get_heavy_atom_idxs(rdmol) + atom_pool = set(heavy_atoms[rmsf[heavy_atoms] < rmsf_cutoff]) + if len(atom_pool) < 3: + errmsg = ( + "No suitable ligand atoms for " "the boresch restraint could be found" + ) + raise ValueError(errmsg) + + # 2. Get the central atom + center = get_central_atom_idx(rdmol) + + # 3. Sort the atom pool based on their distance from the center + sorted_anchor_pool = _sort_by_distance_from_atom(rdmol, center, anchor_pool) + + # 4. Get a list of probable angles + angles_list = [] + for atom in sorted_anchor_pool: + angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) + for angle in _angles: + angle_ag = ligand_ag.atoms[angle] + collinear = is_collinear(ligand_ag.positions, angle) + angle_value = ( + calc_angle( + angle_ag.atoms[0].position, + angle_ag.atoms[1].position, + angle_ag.atoms[2].position, + box=angle_ag.universe.dimensions, + ) + * unit.radians + ) + energy = check_angle_energy( + angle_value, angle_force_constant, 298.15 * unit.kelvin + ) + if not collinear and energy: + angles_list.append(angle) + + return angles_list + + +def get_host_atom_candidates( + topology: Union[str, openmm.app.Topology], + trajectory: Union[str, pathlib.Path], + host_idxs: list[int], + l1_idx: int, + host_selection: str, + dssp_filter: bool = False, + rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, + min_distance: unit.Quantity = 10 * unit.nanometer, + max_distance: unit.Quantity = 30 * unit.nanometer, + angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, +): + if isinstance(topology, openmm.app.Topology): + topology_format = "OPENMMTOPOLOGY" + else: + topology_format = None + + u = mda.Universe(topology, trajectory, topology_format=topology_format) + protein_ag1 = u.atoms[host_idxs] + protein_ag2 = protein_ag.select_atoms(protein_selection) + + # 0. TODO: implement DSSP filter + # Should be able to just call MDA's DSSP method, but will need to catch an exception + if dssp_filter: + raise NotImplementedError("DSSP filtering is not currently implemented") + + # 1. Get the RMSF & filter + rmsf = get_local_rmsf(sub_protein_ag) + protein_ag3 = sub_protein_ag.atoms[rmsf[heavy_atoms] < rmsf_cutoff] + + # 2. Search of atoms within the min/max cutoff + atom_finder = FindHostAtoms( + protein_ag3, u.atoms[l1_idx], min_search_distance, max_search_distance + ) + atom_finder.run() + return atom_finder.results.host_idxs diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 30e81123f..c8226af0d 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -20,8 +20,37 @@ from rdkit import Chem import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.analysis.rmsf import RMSF from MDAnalysis.lib.distances import calc_bonds, calc_angles +from openfe_analysis.transformations import Aligner, NoJump + + +def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: + """ + Get a list of tuples with the indices for each ring in an rdkit Molecule. + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molecule + + Returns + ------- + list[tuple[int]] + List of tuples for each ring. + """ + ringinfo = rdmol.GetRingInfo() + arom_idxs = get_aromatic_atom_idxs(rdmol) + + aromatic_rings = [] + + for ring in ringinfo.AtomRings(): + if all(a in aroms for a in ring): + aromatic_rings.append(ring) + + return aromatic_rings + def get_aromatic_atom_idxs(rdmol: Chem.Mol) -> list[int]: """ @@ -38,10 +67,7 @@ def get_aromatic_atom_idxs(rdmol: Chem.Mol) -> list[int]: list[int] A list of the aromatic atom idxs """ - idxs = [ - at.GetIdx() for at in rdmol.GetAtoms() - if at.GetIsAromatic() - ] + idxs = [at.GetIdx() for at in rdmol.GetAtoms() if at.GetIsAromatic()] return idxs @@ -58,10 +84,7 @@ def get_heavy_atom_idxs(rdmol: Chem.Mol) -> list[int]: list[int] A list of heavy atom idxs """ - idxs = [ - at.GetIdx() for at in rdmol.GetAtoms() - if at.GetAtomicNum() > 1 - ] + idxs = [at.GetIdx() for at in rdmol.GetAtoms() if at.GetAtomicNum() > 1] return idxs @@ -124,15 +147,19 @@ def is_collinear(positions, atoms, threshold=0.9): for i in range(len(atoms) - 2): v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] - normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + normalized_inner_product = np.dot(v1, v2) / np.sqrt( + np.dot(v1, v1) * np.dot(v2, v2) + ) result = result or (np.abs(normalized_inner_product) > threshold) return result def check_angle_energy( - angle: FloatQuantity['radians'], - force_constant: FloatQuantity['unit.kilojoule_per_mole / unit.radians**2'] = 83.68 * unit.kilojoule_per_mole / unit.radians**2, - temperature: FloatQuantity['kelvin'] = 298.15 * unit.kelvin + angle: FloatQuantity["radians"], + force_constant: FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] = 83.68 + * unit.kilojoule_per_mole + / unit.radians**2, + temperature: FloatQuantity["kelvin"] = 298.15 * unit.kelvin, ) -> bool: """ Check whether the chosen angle is less than 10 kT from 0 or pi radians @@ -157,9 +184,9 @@ def check_angle_energy( We assume the temperature to be 298.15 Kelvin. """ # Convert things - angle_rads = angle.to('radians') - frc_const = force_constant.to('unit.kilojoule_per_mole / unit.radians**2') - temp_kelvin = temperature.to('kelvin') + angle_rads = angle.to("radians") + frc_const = force_constant.to("unit.kilojoule_per_mole / unit.radians**2") + temp_kelvin = temperature.to("kelvin") RT = 8.31445985 * 0.001 * temp_kelvin # check if angle is <10kT from 0 or 180 @@ -167,15 +194,15 @@ def check_angle_energy( check2 = 0.5 * frc_const * np.power((angle - np.pi), 2) ang_check_1 = check1 / RT ang_check_2 = check2 / RT - if ang_check_1 < 10.0 or ang_check_2 < 10.0: + if ang_check_1 < 10.0 or ang_check_2 < 10.0: return False return True def check_dihedral_bounds( - dihedral: FloatQuantity['radians'] - lower_cutoff: FloatQuantity['radians'] = 2.618 * unit.radians, - upper_cutoff: FloatQuantity['radians'] = -2.6.18 * unit.radians, + dihedral: FloatQuantity["radians"], + lower_cutoff: FloatQuantity["radians"] = 2.618 * unit.radians, + upper_cutoff: FloatQuantity["radians"] = -2.618 * unit.radians, ) -> bool: """ Check that a dihedral does not exceed the bounds set by @@ -202,8 +229,7 @@ def check_dihedral_bounds( def check_angular_variance( - angles: ArrayQuantity['radians'] - width: FloatQuantity['radians'] + angles: ArrayQuantity["radians"], width: FloatQuantity["radians"] ) -> bool: """ Check that the variance of a list of ``angles`` does not exceed @@ -222,45 +248,14 @@ def check_angular_variance( ``True`` if the variance of the angles is less than the width. """ - array = angles.to('radians').m + array = angles.to("radians").m variance = circvar(array) return not (variance * unit.radians > width) -def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: - """ - Sort a list of atoms by their distance from a target atom. - - Parameters - ---------- - target_idx : int - The idx of the target atom. - atom_idxs : list[int] - The idx values of the atoms to sort. - rdmol : ??? - RDKit Molecule the atoms belong to - - Returns - ------- - list[int] - The input atom idxs sorted by their distance from the target atom. - """ - distances = [] - - conformer = rdmol.GetConformer() - # Get the target atom position - target_pos = conformer.GetAtomPosition(target_idx) - - for idx in atom_idxs: - pos = conformer.GetAtomPosition(idx) - distances.append(((target_pos - pos).Length(), idx)) - - return [i[1] for i in sorted(distances)] - - def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): angles = [] - + # Get the base atom and its neighbors at1 = rdmol.GetAtomWithIdx(atom_idx) at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] @@ -271,8 +266,7 @@ def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): for at2 in atom_pool: if at2 in at1_neighbors: at2_neighbors = [ - at.GetIdx() - for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() ] for at3 in atom_pool: if at3 != atom_idx and at3 in at2_neighbors: @@ -283,12 +277,12 @@ def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: """ Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). - + Parameters ---------- rdmol : ??? Molecule object for the ligand to apply a restraint to. - + Returns ------- angles : list[tuple[int, int, int]] @@ -304,7 +298,7 @@ def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: # If there are not enough aromatic atoms, then default to heavy atoms if len(anchor_pool) < 3: anchor_pool = _get_heavy_atoms(rdmol) - + # Raise an error if we have less than 3 anchors if len(anchor_pool) < 3: errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" @@ -316,15 +310,15 @@ def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: # Get a list of ligand anchor angle atoms angles = [] for atom in sorted_anchor_pool: - angles.extend( - _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) - ) + angles.extend(_get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool)) -def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str): +def get_host_anchors( + positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str +): """ Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. - + Parameters ---------- positions : openmm.unit.Quantity @@ -341,30 +335,35 @@ def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_ # Create an mdtraj trajectory to manipulate # First fetch the box vectors and pass them as lengths and angles vectors = from_openmm(topology.getPeriodicBoxVectors()) - a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles(vectors[0].m, vectors[1].m, vectors[2].m) - + a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles( + vectors[0].m, vectors[1].m, vectors[2].m + ) + traj = mdt.Trajectory( - positions[np.newaxis, ...], - mdt.Topology.from_openmm(topology) + positions[np.newaxis, ...], mdt.Topology.from_openmm(topology) ) - + # Get all the potential protein atoms matching the selection host_sel = traj.topology.select(selection) - + # Get residues to exclude from the selection - exclude_sel = np.array([ - at.index for at in - chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) - ]) - + exclude_sel = np.array( + [ + at.index + for at in chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) + ] + ) + # Remove exclusion anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] - + # Compute distanecs from ligand l1 anchor atom - pairs = np.vstack((anchors, np.array([lig_anchor_idx for _ in range(len(anchors))]))).T - + pairs = np.vstack( + (anchors, np.array([lig_anchor_idx for _ in range(len(anchors))])) + ).T + distances = mdt.compute_distances(traj, pairs, periodic=True) - + return np.array([pairs[i][0] for i in np.argsort(distances[0])]) @@ -395,7 +394,9 @@ def is_collinear(positions, atoms, threshold=0.9): for i in range(len(atoms) - 2): v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] - normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + normalized_inner_product = np.dot(v1, v2) / np.sqrt( + np.dot(v1, v1) * np.dot(v2, v2) + ) result = result or (np.abs(normalized_inner_product) > threshold) return result @@ -403,7 +404,7 @@ def is_collinear(positions, atoms, threshold=0.9): def check_angle(angle, force_constant=83.68): """ Check whether the chosen angle is less than 10 kT from 0 or 180 - + Parameters ---------- angle : float @@ -417,19 +418,17 @@ def check_angle(angle, force_constant=83.68): """ # TODO: convert this to unit.Quantity so we don't end up with # conversion errors - RT = 8.31445985 * 0.001 * 298.15 + RT = 8.31445985 * 0.001 * 298.15 # check if angle is <10kT from 0 or 180 check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) ang_check_1 = check1 / RT ang_check_2 = check2 / RT - if ang_check_1 < 10.0 or ang_check_2 < 10.0: + if ang_check_1 < 10.0 or ang_check_2 < 10.0: return False return True - - class FindHostAtoms(AnalysisBase): """ Class filter host atoms based on their distance @@ -441,17 +440,28 @@ class FindHostAtoms(AnalysisBase): Initial selection of host atoms to filter from. guest_atoms : MDANalysis.AtomGroup Selection of guest atoms to search around. - search_distance: unit.Quantity - Distance to filter atoms within. + min_search_distance: unit.Quantity + Minimum distance to filter atoms within. + max_search_distance: unit.Quantity + Maximum distance to filter atoms within. """ + _analysis_algorithm_is_parallelizable = False - def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): + def __init__( + self, + host_atoms, + guest_atoms, + min_search_distance, + max_search_distance, + **kwargs, + ): super().__init__(host_atoms.universe.trajectory, **kwargs) self.host_ag = host_atoms self.guest_ag = guest_atoms - self.cutoff = search_distance.to('angstrom').m + self.min_cutoff = min_search_distance.to("angstrom").m + self.max_cutoff = max_search_distance.to("angstrom").m def _prepare(self): self.results.host_idxs = set() @@ -460,19 +470,22 @@ def _single_frame(self): pairs = capped_distance( reference=self.host_ag.positions, configuration=self.guest_ag.positions, - max_cutoff=self.cutoff, - min_cutoff=None + max_cutoff=self.max_cutoff, + min_cutoff=self.min_cutoff, box=self.guest_ag.universe.dimensions, - return_distances=False) + return_distances=False, + ) - host_idxs = [self.guest_ag.atoms[p].index for p in pairs[:, 1]] + host_idxs = [self.guest_ag.atoms[p].ix for p in pairs[:, 1]] self.results.host_idxs.update(set(host_idxs)) def _conclude(self): - pass + self.results.host_idxs = np.array(self.results.host_idxs) -def find_host_atoms(topology, trajectory, host_selection, guest_selection, cutoff) -> mda.AtomGroup: +def find_host_atoms( + topology, trajectory, host_selection, guest_selection, cutoff +) -> mda.AtomGroup: """ Get an AtomGroup of the host atoms based on their distances from the guest atoms. """ @@ -487,7 +500,7 @@ def _get_selection(selection): else: ag = u.atoms[host_ag] return ag - + host_ag = _get_selection(host_selection) guest_ag = _get_selection(guest_selection) @@ -496,24 +509,29 @@ def _get_selection(selection): return u.atoms[list(finder.results.host_idxs)] -def get_molecule_center_idx(atomgroup): - offmol = Molecule(atomgroup.convert_to("RDKIT"), allow_undefined_stereo=True) - # Check if the molecule is whole, otherwise throw an error. - nx = offmol.to_networkx() +def get_local_rmsf(atomgroup: mda.AtomGroup): + """ + Get the RMSF of an AtomGroup when aligned upon itself. -def get_distance_restraint(topology, trajectory, host_atoms, guest_atoms, host_selection, guest_selection): - u = mda.Universe(topology, trajectory) + Parameters + ---------- + atomgroup : MDAnalysis.AtomGroup - if guest_atoms is None: - if guest_selection is None: - raise ValueError("one of guest_atoms or guest_selections must be defined") - guest_ag = u.select_atoms(guest_selection) - else: + Return + ------ + rmsf + ArrayQuantity of RMSF values. + """ + # First let's copy our Universe + copy_u = atomgroup.universe.copy() + ag = copy_u.atoms[atomgroup.atoms.ix] + nojump = NoJump(ag) + align = Aligner(ag) - if host_atoms is None: - if host_selection is None: - raise ValueError("one of host_atoms or host_selection must be defined") + copy_u.trajectory.add_transformations(nojump, align) - host_ag = u.select_atoms(host_selection) + rmsf = RMSF(ag) + rmsf.run() + return rmsf.results.rmsf * unit.angstrom From 7a7be903a62ca03ecfa76326e621636e78cb9537 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 21:21:30 +0000 Subject: [PATCH 14/33] fix up weird black wrapping --- .../openmm_utils/restraints/geometry/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index c8226af0d..82e7e621f 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -26,6 +26,10 @@ from openfe_analysis.transformations import Aligner, NoJump +DEFAULT_ANGLE_FRC_CONSTANT = 83.68 * unit.kilojoule_per_mole / unit.radians**2 +ANGLE_FRC_CONSTANT_TYPE = FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] + + def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: """ Get a list of tuples with the indices for each ring in an rdkit Molecule. @@ -156,9 +160,7 @@ def is_collinear(positions, atoms, threshold=0.9): def check_angle_energy( angle: FloatQuantity["radians"], - force_constant: FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] = 83.68 - * unit.kilojoule_per_mole - / unit.radians**2, + force_constant: ANGLE_FRC_CONSTANT_TYPE = DEFAULT_ANGLE_FRC_CONSTANT, temperature: FloatQuantity["kelvin"] = 298.15 * unit.kelvin, ) -> bool: """ @@ -170,10 +172,9 @@ def check_angle_energy( The angle to check in units compatible with radians. force_constant : unit.Quantity Force constant of the angle in units compatible with kilojoule_per_mole / radians ** 2. - temperature: unit.Quantity + temperature : unit.Quantity The system temperature in units compatible with Kelvin. - Returns ------- bool From 733f3b3c681a1112d081e043f7866cb198fbf3aa Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 21:47:56 +0000 Subject: [PATCH 15/33] remove old search file, add more changes to boresch search --- .../restraints/geometry/boresch.py | 111 ++++-- .../openmm_utils/restraints/search.py | 360 ------------------ 2 files changed, 83 insertions(+), 388 deletions(-) delete mode 100644 openfe/protocols/openmm_utils/restraints/search.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index d6241f3d7..a35c4f5a9 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -15,6 +15,8 @@ from openff.units import unit import MDAnalysis as mda from MDAnalysis.lib.distances import calc_bonds, calc_angles +import numpy as np +import numpy.typing as npt from .base import HostGuestRestraintGeometry @@ -152,19 +154,84 @@ def _get_bonded_angles_from_pool( return angles -def get_small_molecule_atom_candidates( +def _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) -> Optional[set[int]]: + """ + Filter atoms based on rmsf & rings, defaulting to heavy atoms if + there are not enough. + + Parameters + ---------- + rdmol : Chem.Mol + The RDKit Molecule to search through + rmsf : npt.NDArray + A 1-D array of RMSF values for each atom. + + Returns + ------- + atom_pool : Optional[set[int]] + """ + # Get a list of all the aromatic rings + # Note: no need to keep track of rings because we'll filter by + # bonded terms after, so if we only keep rings then all the bonded + # atoms should be within the same ring system. + atom_pool = set() + for ring in get_aromatic_rings(rdmol): + max_rmsf = rmsf[list(ring)].max() + if max_rmsf < rmsf_cutoff: + atom_pool.update(ring) + + # if we don't have enough atoms just get all the heavy atoms + if len(atom_pool) < 3: + heavy_atoms = get_heavy_atom_idxs(rdmol) + atom_pool = set(heavy_atoms[rmsf[heavy_atoms] < rmsf_cutoff]) + if len(atom_pool) < 3: + return None + + return atom_pool + + +def get_small_molecule_guest_atom_candidates( topology: Union[str, openmm.app.Topology], trajectory: Union[str, pathlib.Path], rdmol: Chem.Mol, ligand_idxs: list[int], rmsf_cutoff: unit.Quantity = 1 * unit.angstrom, - angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, -): + angle_force_constant: unit.Quantity = 83.68 * unit.kilojoule_per_mole / unit.radians**2, +) -> list[tuple[int]]: """ Get a list of potential ligand atom choices for a Boresch restraint being applied to a given small molecule. - TODO: remember to update the RDMol with the last frame positions + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + The topology of the system. + trajectory : Union[str, pathlib.Path] + A path to the system's coordinate trajectory. + rdmol : Chem.Mol + An RDKit Molecule representing the small molecule ordered in + the same way as it is listed in the topology. + ligand_idxs : list[int] + The ligand indices in the topology. + rmsf_cutoff : unit.Quantity + The RMSF filter cut-off. + angle_force_constant : unit.Quantity + The force constant for the l1-l2-l3 atom angle. + + Returns + ------- + angle_list : list[tuple[int]] + A list of tuples for each valid l1, l2, l3 angle. If ``None``, no + angles could be found. + + Raises + ------ + ValueError + If no suitable ligand atoms could be found. + + TODO + ---- + Remember to update the RDMol with the last frame positions. """ if isinstance(topology, openmm.app.Topology): topology_format = "OPENMMTOPOLOGY" @@ -179,26 +246,12 @@ def get_small_molecule_atom_candidates( u.trajectory[-1] # forward to the last frame # 1. Get the pool of atoms to work with - # TODO: move to a helper function to make it easier to test - # Get a list of all the aromatic rings - # Note: no need to keep track of rings because we'll filter by - # bonded terms after, so if we only keep rings then all the bonded - # atoms should be within the same ring system. - atom_pool = set() - for ring in get_aromatic_rings(rdmol): - max_rmsf = rmsf[list(ring)].max() - if max_rmsf < rmsf_cutoff: - atom_pool.update(ring) + atom_pool = _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) - # if we don't have enough atoms just get all the heavy atoms - if len(atom_pool) < 3: - heavy_atoms = get_heavy_atom_idxs(rdmol) - atom_pool = set(heavy_atoms[rmsf[heavy_atoms] < rmsf_cutoff]) - if len(atom_pool) < 3: - errmsg = ( - "No suitable ligand atoms for " "the boresch restraint could be found" - ) - raise ValueError(errmsg) + if atom_pool is None: + # We don't have enough atoms so we raise an error + errmsg = "No suitable ligand atoms were found for the restraint" + raise ValueError(errmsg) # 2. Get the central atom center = get_central_atom_idx(rdmol) @@ -211,7 +264,7 @@ def get_small_molecule_atom_candidates( for atom in sorted_anchor_pool: angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) for angle in _angles: - angle_ag = ligand_ag.atoms[angle] + angle_ag = ligand_ag.atoms[list(angle)] collinear = is_collinear(ligand_ag.positions, angle) angle_value = ( calc_angle( @@ -219,8 +272,7 @@ def get_small_molecule_atom_candidates( angle_ag.atoms[1].position, angle_ag.atoms[2].position, box=angle_ag.universe.dimensions, - ) - * unit.radians + ) * unit.radians ) energy = check_angle_energy( angle_value, angle_force_constant, 298.15 * unit.kelvin @@ -239,10 +291,13 @@ def get_host_atom_candidates( host_selection: str, dssp_filter: bool = False, rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, - min_distance: unit.Quantity = 10 * unit.nanometer, - max_distance: unit.Quantity = 30 * unit.nanometer, + min_distance: unit.Quantity = 1 * unit.nanometer, + max_distance: unit.Quantity = 3 * unit.nanometer, angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, ): + """ + + """ if isinstance(topology, openmm.app.Topology): topology_format = "OPENMMTOPOLOGY" else: diff --git a/openfe/protocols/openmm_utils/restraints/search.py b/openfe/protocols/openmm_utils/restraints/search.py deleted file mode 100644 index 6b3d94eb7..000000000 --- a/openfe/protocols/openmm_utils/restraints/search.py +++ /dev/null @@ -1,360 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -""" -Search methods for generating Geometry objects - -TODO ----- -* Add relevant duecredit entries. -""" -import abc -from pydantic.v1 import BaseModel, validator - -from openff.toolkit import Molecule as OFFMol -from openff.units import unit -import networkx as nx -import MDAnalysis as mda -from MDAnalysis.analysis.base import AnalysisBase -from MDAnalysis.lib.distances import calc_bonds, calc_angles - - -def _get_aromatic_atom_idxs(rdmol) -> list[int]: - """ - Helper method to get aromatic atoms idxs - in a RDKit Molecule - - Parameters - ---------- - rdmol : ??? - RDKit Molecule - - Returns - ------- - list[int] - A list of the aromatic atom idxs - """ - idxs = [ - at.GetIdx() for at in rdmol.GetAtoms() - if at.GetIsAromatic() - ] - return idxs - - -def _get_heavy_atom_idxs(rdmol) -> list[int]: - """ - Get idxs of heavy atoms in an RDKit Molecule - - Parameters - ---------- - rmdol : ??? - - Returns - ------- - list[int] - A list of heavy atom idxs - """ - idxs = [ - at.GetIdx() for at in rdmol.GetAtoms() - if at.GetAtomicNum() > 1 - ] - return idxs - - -def _get_central_atom_idx(rdmol) -> int: - offmol = OFFMol(rdmol, allow_undefined_stereo=True) - # We take the zero-th entry if there are multiple center - # atoms (e.g. equal likelihood centers) - center = nx.center(offmol.to_networkx())[0] - return center - - -def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: - """ - Sort a list of atoms by their distance from a target atom. - - Parameters - ---------- - target_idx : int - The idx of the target atom. - atom_idxs : list[int] - The idx values of the atoms to sort. - rdmol : ??? - RDKit Molecule the atoms belong to - - Returns - ------- - list[int] - The input atom idxs sorted by their distance from the target atom. - """ - distances = [] - - conformer = rdmol.GetConformer() - # Get the target atom position - target_pos = conformer.GetAtomPosition(target_idx) - - for idx in atom_idxs: - pos = conformer.GetAtomPosition(idx) - distances.append(((target_pos - pos).Length(), idx)) - - return [i[1] for i in sorted(distances)] - - -def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): - angles = [] - - # Get the base atom and its neighbors - at1 = rdmol.GetAtomWithIdx(atom_idx) - at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] - - # We loop at2 and at3 through the sorted atom_pool in order to get - # a list of angles in the branch that are sorted by how close the atoms - # are from the central atom - for at2 in atom_pool: - if at2 in at1_neighbors: - at2_neighbors = [ - at.GetIdx() - for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() - ] - for at3 in atom_pool: - if at3 != atom_idx and at3 in at2_neighbors: - angles.append((atom_idx, at2, at3)) - return angles - - -def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: - """ - Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). - - Parameters - ---------- - rdmol : ??? - Molecule object for the ligand to apply a restraint to. - - Returns - ------- - angles : list[tuple[int, int, int]] - A list of ligand atom triples denoting the possible l1, l2, and l3 - restraint atoms. Ordered by likelihood of restraint-ability. - """ - # Find the central atom - center = _get_central_atom_idx(rdmol) - - # Get a pool of potential anchor atoms looking for aromatic atoms - anchor_pool = _get_aromatic_atoms(rdmol) - - # If there are not enough aromatic atoms, then default to heavy atoms - if len(anchor_pool) < 3: - anchor_pool = _get_heavy_atoms(rdmol) - - # Raise an error if we have less than 3 anchors - if len(anchor_pool) < 3: - errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" - raise ValueError(errmsg) - - # Sort the pool of anchor atoms by their distance from the central atom - sorted_anchor_pool = _sort_by_distance_from_target(rdmol, center, anchor_pool) - - # Get a list of ligand anchor angle atoms - angles = [] - for atom in sorted_anchor_pool: - angles.extend( - _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) - ) - - -def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str): - """ - Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. - - Parameters - ---------- - positions : openmm.unit.Quantity - Positions of the input system - topology : openmm.app.Topology - OpenMM Topology for input system - exclude_resids : list[int] - List of residue numbers to exclude from host selection - lig_anchor_idx : int - The index of the l1 ligand anchor. - selection : str - Selection string for the host atoms. - """ - # Create an mdtraj trajectory to manipulate - # First fetch the box vectors and pass them as lengths and angles - vectors = from_openmm(topology.getPeriodicBoxVectors()) - a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles(vectors[0].m, vectors[1].m, vectors[2].m) - - traj = mdt.Trajectory( - positions[np.newaxis, ...], - mdt.Topology.from_openmm(topology) - ) - - # Get all the potential protein atoms matching the selection - host_sel = traj.topology.select(selection) - - # Get residues to exclude from the selection - exclude_sel = np.array([ - at.index for at in - chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) - ]) - - # Remove exclusion - anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] - - # Compute distanecs from ligand l1 anchor atom - pairs = np.vstack((anchors, np.array([lig_anchor_idx for _ in range(len(anchors))]))).T - - distances = mdt.compute_distances(traj, pairs, periodic=True) - - return np.array([pairs[i][0] for i in np.argsort(distances[0])]) - - -def is_collinear(positions, atoms, threshold=0.9): - """ - Check whether any sequential vectors in a sequence of atoms are collinear. - - Parameters - ---------- - positions : openmm.unit.Quantity - System positions. - atoms : list[int] - The indices of the atoms to test. - threshold : float - Atoms are not collinear if their sequential vector separation dot - products are less than ``threshold``. Default 0.9. - - Returns - ------- - result : bool - Returns True if any sequential pair of vectors is collinear; False otherwise. - - Notes - ----- - Originally from Yank, with modifications from Separated Topologies - """ - results = False - for i in range(len(atoms) - 2): - v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] - v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] - normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) - result = result or (np.abs(normalized_inner_product) > threshold) - return result - - -def check_angle(angle, force_constant=83.68): - """ - Check whether the chosen angle is less than 10 kT from 0 or 180 - - Parameters - ---------- - angle : float - The angle to check in degrees. - force_constant : float - Force constant of the angle. - - Note - ---- - We assume the temperature to be 298.15 Kelvin. - """ - # TODO: convert this to unit.Quantity so we don't end up with - # conversion errors - RT = 8.31445985 * 0.001 * 298.15 - # check if angle is <10kT from 0 or 180 - check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) - check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) - ang_check_1 = check1 / RT - ang_check_2 = check2 / RT - if ang_check_1 < 10.0 or ang_check_2 < 10.0: - return False - return True - - - - -class FindHostAtoms(AnalysisBase): - """ - Class filter host atoms based on their distance - from a set of guest atoms. - - Parameters - ---------- - host_atoms : MDAnalysis.AtomGroup - Initial selection of host atoms to filter from. - guest_atoms : MDANalysis.AtomGroup - Selection of guest atoms to search around. - search_distance: unit.Quantity - Distance to filter atoms within. - """ - _analysis_algorithm_is_parallelizable = False - - def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): - super().__init__(host_atoms.universe.trajectory, **kwargs) - - self.host_ag = host_atoms - self.guest_ag = guest_atoms - self.cutoff = search_distance.to('angstrom').m - - def _prepare(self): - self.results.host_idxs = set() - - def _single_frame(self): - pairs = capped_distance( - reference=self.host_ag.positions, - configuration=self.guest_ag.positions, - max_cutoff=self.cutoff, - min_cutoff=None - box=self.guest_ag.universe.dimensions, - return_distances=False) - - host_idxs = [self.guest_ag.atoms[p].index for p in pairs[:, 1]] - self.results.host_idxs.update(set(host_idxs)) - - def _conclude(self): - pass - - -def find_host_atoms(topology, trajectory, host_selection, guest_selection, cutoff) -> mda.AtomGroup: - """ - Get an AtomGroup of the host atoms based on their distances from the guest atoms. - """ - u = mda.Universe(topology, trajectory) - - def _get_selection(selection): - """ - If it's a str, call select_atoms, if not a list of atom idxs - """ - if isinstance(selection, str): - ag = u.select_atoms(host_selection) - else: - ag = u.atoms[host_ag] - return ag - - host_ag = _get_selection(host_selection) - guest_ag = _get_selection(guest_selection) - - finder = FindHostAtoms(host_ag, guest_ag, cutoff) - finder.run() - - return u.atoms[list(finder.results.host_idxs)] - -def get_molecule_center_idx(atomgroup): - offmol = Molecule(atomgroup.convert_to("RDKIT"), allow_undefined_stereo=True) - # Check if the molecule is whole, otherwise throw an error. - nx = offmol.to_networkx() - - -def get_distance_restraint(topology, trajectory, host_atoms, guest_atoms, host_selection, guest_selection): - u = mda.Universe(topology, trajectory) - - if guest_atoms is None: - if guest_selection is None: - raise ValueError("one of guest_atoms or guest_selections must be defined") - guest_ag = u.select_atoms(guest_selection) - else: - - - if host_atoms is None: - if host_selection is None: - raise ValueError("one of host_atoms or host_selection must be defined") - - host_ag = u.select_atoms(host_selection) From 96decfffe7036ab5cff19ad4453517957a3291dd Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 22:19:34 +0000 Subject: [PATCH 16/33] Remove duplicate methods --- .../restraints/geometry/boresch.py | 34 ++- .../openmm_utils/restraints/geometry/utils.py | 218 ++++-------------- 2 files changed, 75 insertions(+), 177 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index a35c4f5a9..cf14b73aa 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -37,6 +37,13 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): """ def get_bond_distance(self, topology, coordinates) -> unit.Quantity: + """ + Get the H2 - G0 distance + + Parameters + ---------- + topology : + """ u = mda.Universe(topology, coordinates) at1 = u.atoms[host_atoms[2]] at2 = u.atoms[guest_atoms[0]] @@ -293,10 +300,30 @@ def get_host_atom_candidates( rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, min_distance: unit.Quantity = 1 * unit.nanometer, max_distance: unit.Quantity = 3 * unit.nanometer, - angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, ): """ + Get a list of suitable host atoms. + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + The topology of the system. + trajectory : Union[str, pathlib.Path] + A path to the system's coordinate trajectory. + host_idxs : list[int] + A list of the host indices in the system topology. + l1_idx : int + The index of the proposed l1 binding atom. + host_selection : str + An MDAnalysis selection string to fileter the host by. + dssp_filter : bool + Whether or not to apply a DSSP filter on the host selection. + rmsf_cutoff : uni.Quantity + The maximum RMSF value allowwed for any candidate host atom. + min_distance : unit.Quantity + The minimum search distance around l1 for suitable candidate atoms. + max_distance : unit.Quantity + The maximum search distance around l1 for suitable candidate atoms. """ if isinstance(topology, openmm.app.Topology): topology_format = "OPENMMTOPOLOGY" @@ -322,3 +349,8 @@ def get_host_atom_candidates( ) atom_finder.run() return atom_finder.results.host_idxs + + +def select_boresch_atoms( + +): diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 82e7e621f..a74e83ed3 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -11,6 +11,7 @@ from pydantic.v1 import BaseModel, validator import numpy as np +import numpy.typing as npt from scipy.stats import circvar, circmean, circstd from openff.toolkit import Molecule as OFFMol @@ -22,6 +23,7 @@ from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.analysis.rmsf import RMSF from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDAnalysis.coordinates.memory import MemoryReader from openfe_analysis.transformations import Aligner, NoJump @@ -30,6 +32,46 @@ ANGLE_FRC_CONSTANT_TYPE = FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] +def _get_mda_coord_format(coordinates: Union[str, npt.NDArray]) -> Optional[MemoryReader]: + """ + Helper to set the coordinate format to MemoryReader + if the coordinates are an NDArray. + + Parameters + ---------- + coordinates : Union[str, npt.NDArray] + + Returns + ------- + Optional[MemoryReader] + Either the MemoryReader class or None. + """ + if isinstance(coordinates, npt.NDArray): + return MemoryReader + else: + return None + +def _get_mda_topology_format(topology: Union[str, openmm.app.Topology]) -> Optional[str]: + """ + Helper to set the topology format to OPENMMTOPOLOGY + if the topology is an openmm.app.Topology. + + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + + + Returns + ------- + Optional[str] + The string `OPENMMTOPOLOGY` or None. + """ + if isinstance(topology, openmm.app.Topology): + return "OPENMMTOPOLOGY" + else: + return None + + def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: """ Get a list of tuples with the indices for each ring in an rdkit Molecule. @@ -254,182 +296,6 @@ def check_angular_variance( return not (variance * unit.radians > width) -def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): - angles = [] - - # Get the base atom and its neighbors - at1 = rdmol.GetAtomWithIdx(atom_idx) - at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] - - # We loop at2 and at3 through the sorted atom_pool in order to get - # a list of angles in the branch that are sorted by how close the atoms - # are from the central atom - for at2 in atom_pool: - if at2 in at1_neighbors: - at2_neighbors = [ - at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() - ] - for at3 in atom_pool: - if at3 != atom_idx and at3 in at2_neighbors: - angles.append((atom_idx, at2, at3)) - return angles - - -def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: - """ - Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). - - Parameters - ---------- - rdmol : ??? - Molecule object for the ligand to apply a restraint to. - - Returns - ------- - angles : list[tuple[int, int, int]] - A list of ligand atom triples denoting the possible l1, l2, and l3 - restraint atoms. Ordered by likelihood of restraint-ability. - """ - # Find the central atom - center = _get_central_atom_idx(rdmol) - - # Get a pool of potential anchor atoms looking for aromatic atoms - anchor_pool = _get_aromatic_atoms(rdmol) - - # If there are not enough aromatic atoms, then default to heavy atoms - if len(anchor_pool) < 3: - anchor_pool = _get_heavy_atoms(rdmol) - - # Raise an error if we have less than 3 anchors - if len(anchor_pool) < 3: - errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" - raise ValueError(errmsg) - - # Sort the pool of anchor atoms by their distance from the central atom - sorted_anchor_pool = _sort_by_distance_from_target(rdmol, center, anchor_pool) - - # Get a list of ligand anchor angle atoms - angles = [] - for atom in sorted_anchor_pool: - angles.extend(_get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool)) - - -def get_host_anchors( - positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str -): - """ - Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. - - Parameters - ---------- - positions : openmm.unit.Quantity - Positions of the input system - topology : openmm.app.Topology - OpenMM Topology for input system - exclude_resids : list[int] - List of residue numbers to exclude from host selection - lig_anchor_idx : int - The index of the l1 ligand anchor. - selection : str - Selection string for the host atoms. - """ - # Create an mdtraj trajectory to manipulate - # First fetch the box vectors and pass them as lengths and angles - vectors = from_openmm(topology.getPeriodicBoxVectors()) - a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles( - vectors[0].m, vectors[1].m, vectors[2].m - ) - - traj = mdt.Trajectory( - positions[np.newaxis, ...], mdt.Topology.from_openmm(topology) - ) - - # Get all the potential protein atoms matching the selection - host_sel = traj.topology.select(selection) - - # Get residues to exclude from the selection - exclude_sel = np.array( - [ - at.index - for at in chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) - ] - ) - - # Remove exclusion - anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] - - # Compute distanecs from ligand l1 anchor atom - pairs = np.vstack( - (anchors, np.array([lig_anchor_idx for _ in range(len(anchors))])) - ).T - - distances = mdt.compute_distances(traj, pairs, periodic=True) - - return np.array([pairs[i][0] for i in np.argsort(distances[0])]) - - -def is_collinear(positions, atoms, threshold=0.9): - """ - Check whether any sequential vectors in a sequence of atoms are collinear. - - Parameters - ---------- - positions : openmm.unit.Quantity - System positions. - atoms : list[int] - The indices of the atoms to test. - threshold : float - Atoms are not collinear if their sequential vector separation dot - products are less than ``threshold``. Default 0.9. - - Returns - ------- - result : bool - Returns True if any sequential pair of vectors is collinear; False otherwise. - - Notes - ----- - Originally from Yank, with modifications from Separated Topologies - """ - results = False - for i in range(len(atoms) - 2): - v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] - v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] - normalized_inner_product = np.dot(v1, v2) / np.sqrt( - np.dot(v1, v1) * np.dot(v2, v2) - ) - result = result or (np.abs(normalized_inner_product) > threshold) - return result - - -def check_angle(angle, force_constant=83.68): - """ - Check whether the chosen angle is less than 10 kT from 0 or 180 - - Parameters - ---------- - angle : float - The angle to check in degrees. - force_constant : float - Force constant of the angle. - - Note - ---- - We assume the temperature to be 298.15 Kelvin. - """ - # TODO: convert this to unit.Quantity so we don't end up with - # conversion errors - RT = 8.31445985 * 0.001 * 298.15 - # check if angle is <10kT from 0 or 180 - check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) - check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) - ang_check_1 = check1 / RT - ang_check_2 = check2 / RT - if ang_check_1 < 10.0 or ang_check_2 < 10.0: - return False - return True - - class FindHostAtoms(AnalysisBase): """ Class filter host atoms based on their distance From 2d97de82dc762f5f294e503282efcbe7bb0f03bf Mon Sep 17 00:00:00 2001 From: Irfan Alibay Date: Thu, 12 Dec 2024 22:21:50 +0000 Subject: [PATCH 17/33] Apply suggestions from code review --- .../openmm_utils/restraints/openmm/omm_restraints.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py index e53e828d5..a3fe777d3 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py @@ -125,7 +125,7 @@ def _verify_geometry(self, geometry: BaseRestraintGeometry): super()._verify_geometry(geometry) -class BaseRadialllySymmetricRestraintForce(BaseHostGuestRestraints): +class BaseRadiallySymmetricRestraintForce(BaseHostGuestRestraints): def _verify_inputs(self) -> None: if not isinstance(self.settings, BaseDistanceRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" @@ -162,7 +162,7 @@ def _get_force(self, geometry: DistanceRestraintGeometry): raise NotImplementedError("only implemented in child classes") -class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): +class HarmonicBondRestraint(BaseRadiallySymmetricRestraintForce, SingleBondMixin): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( @@ -173,7 +173,7 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: ) -class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): +class FlatBottomBondRestraint(BaseRadiallySymmetricRestraintForce, SingleBondMixin): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) @@ -186,7 +186,7 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: ) -class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): +class CentroidHarmonicRestraint(BaseRadiallySymmetricRestraintForce): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( @@ -197,7 +197,7 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: ) -class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): +class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) From 116ba64b0cb29acf69124a7743fc74c2d24815c7 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 22:33:09 +0000 Subject: [PATCH 18/33] Add some more docstring --- .../restraints/geometry/boresch.py | 72 ++++++++++++------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index cf14b73aa..cfaf443dc 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -35,23 +35,48 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. """ - - def get_bond_distance(self, topology, coordinates) -> unit.Quantity: + def get_bond_distance( + self, + topology: Union[str, openmm.app.Topology], + coordinates: Union[str, npt.NDArray], + ) -> unit.Quantity: """ - Get the H2 - G0 distance + Get the H2 - G0 distance. Parameters ---------- - topology : + topology : Union[str, openmm.app.Topology] + coordinates : Union[str, npt.NDArray] + A coordinate file or NDArray in frame-atom-coordinate + order in Angstrom. """ - u = mda.Universe(topology, coordinates) + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) at1 = u.atoms[host_atoms[2]] at2 = u.atoms[guest_atoms[0]] bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) # convert to float so we avoid having a np.float64 return float(bond) * unit.angstrom - def get_angles(self, topology, coordinates) -> unit.Quantity: + def get_angles( + self, + topology: Union[str, openmm.app.Topology], + coordinates: Union[str, npt.NDArray], + ) -> unit.Quantity: + """ + Get the H1-H2-G0, and H2-G0-G1 angles. + + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + coordinates : Union[str, npt.NDArray] + A coordinate file or NDArray in frame-atom-coordinate + order in Angstrom. + """ u = mda.Universe(topology, coordinates) at1 = u.atoms[host_atoms[1]] at2 = u.atoms[host_atoms[2]] @@ -66,7 +91,21 @@ def get_angles(self, topology, coordinates) -> unit.Quantity: ) return angleA, angleB - def get_dihedrals(self, topology, coordinates) -> unit.Quantity: + def get_dihedrals( + self, + topology: Union[str, openmm.app.Topology], + coordinates: Union[str, npt.NDArray], + ) -> unit.Quantity: + """ + Get the H0-H1-H2-G0, H1-H2-G0-G1, and H2-G0-G1-G2 dihedrals. + + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + coordinates : Union[str, npt.NDArray] + A coordinate file or NDArray in frame-atom-coordinate + order in Angstrom. + """ u = mda.Universe(topology, coordinates) at1 = u.atoms[host_atoms[0]] at2 = u.atoms[host_atoms[1]] @@ -84,7 +123,6 @@ def get_dihedrals(self, topology, coordinates) -> unit.Quantity: dihC = calc_dihedrals( at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions ) - return dihA, dihB, dihC @@ -203,7 +241,6 @@ def get_small_molecule_guest_atom_candidates( rdmol: Chem.Mol, ligand_idxs: list[int], rmsf_cutoff: unit.Quantity = 1 * unit.angstrom, - angle_force_constant: unit.Quantity = 83.68 * unit.kilojoule_per_mole / unit.radians**2, ) -> list[tuple[int]]: """ Get a list of potential ligand atom choices for a Boresch restraint @@ -222,8 +259,6 @@ def get_small_molecule_guest_atom_candidates( The ligand indices in the topology. rmsf_cutoff : unit.Quantity The RMSF filter cut-off. - angle_force_constant : unit.Quantity - The force constant for the l1-l2-l3 atom angle. Returns ------- @@ -271,20 +306,9 @@ def get_small_molecule_guest_atom_candidates( for atom in sorted_anchor_pool: angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) for angle in _angles: + # Check that the angle is at least not collinear angle_ag = ligand_ag.atoms[list(angle)] - collinear = is_collinear(ligand_ag.positions, angle) - angle_value = ( - calc_angle( - angle_ag.atoms[0].position, - angle_ag.atoms[1].position, - angle_ag.atoms[2].position, - box=angle_ag.universe.dimensions, - ) * unit.radians - ) - energy = check_angle_energy( - angle_value, angle_force_constant, 298.15 * unit.kelvin - ) - if not collinear and energy: + if not is_collinear(ligand_ag.positions, angle): angles_list.append(angle) return angles_list From 9ae60da278e4c54eaf2f98feaf2387bef09a76a7 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 13 Dec 2024 01:20:40 +0000 Subject: [PATCH 19/33] Add minimized vectors on the collinear checks --- .../restraints/geometry/boresch.py | 181 +++++++++++++++--- .../openmm_utils/restraints/geometry/utils.py | 26 ++- 2 files changed, 168 insertions(+), 39 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index cfaf443dc..363e22c6b 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -8,13 +8,15 @@ * Add relevant duecredit entries. """ import abc +import pathlib from pydantic.v1 import BaseModel, validator from rdkit import Chem from openff.units import unit import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDANalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals import numpy as np import numpy.typing as npt @@ -37,8 +39,8 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): """ def get_bond_distance( self, - topology: Union[str, openmm.app.Topology], - coordinates: Union[str, npt.NDArray], + topology: Union[str, pathlib.Path, openmm.app.Topology], + coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ Get the H2 - G0 distance. @@ -64,8 +66,8 @@ def get_bond_distance( def get_angles( self, - topology: Union[str, openmm.app.Topology], - coordinates: Union[str, npt.NDArray], + topology: Union[str, pathlib.Path, openmm.app.Topology], + coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ Get the H1-H2-G0, and H2-G0-G1 angles. @@ -77,7 +79,12 @@ def get_angles( A coordinate file or NDArray in frame-atom-coordinate order in Angstrom. """ - u = mda.Universe(topology, coordinates) + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) at1 = u.atoms[host_atoms[1]] at2 = u.atoms[host_atoms[2]] at3 = u.atoms[guest_atoms[0]] @@ -93,8 +100,8 @@ def get_angles( def get_dihedrals( self, - topology: Union[str, openmm.app.Topology], - coordinates: Union[str, npt.NDArray], + topology: Union[str, pathlib.Path, openmm.app.Topology], + coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ Get the H0-H1-H2-G0, H1-H2-G0-G1, and H2-G0-G1-G2 dihedrals. @@ -106,7 +113,12 @@ def get_dihedrals( A coordinate file or NDArray in frame-atom-coordinate order in Angstrom. """ - u = mda.Universe(topology, coordinates) + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) at1 = u.atoms[host_atoms[0]] at2 = u.atoms[host_atoms[1]] at3 = u.atoms[host_atoms[2]] @@ -235,12 +247,12 @@ def _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) -> Optional[set[int]]: return atom_pool -def get_small_molecule_guest_atom_candidates( - topology: Union[str, openmm.app.Topology], +def get_guest_atom_candidates( + topology: Union[str, pathlib.Path, openmm.app.Topology], trajectory: Union[str, pathlib.Path], rdmol: Chem.Mol, - ligand_idxs: list[int], - rmsf_cutoff: unit.Quantity = 1 * unit.angstrom, + guest_idxs: list[int], + rmsf_cutoff: unit.Quantity = 1 * unit.nanometer, ) -> list[tuple[int]]: """ Get a list of potential ligand atom choices for a Boresch restraint @@ -255,7 +267,7 @@ def get_small_molecule_guest_atom_candidates( rdmol : Chem.Mol An RDKit Molecule representing the small molecule ordered in the same way as it is listed in the topology. - ligand_idxs : list[int] + guest_idxs : list[int] The ligand indices in the topology. rmsf_cutoff : unit.Quantity The RMSF filter cut-off. @@ -275,13 +287,14 @@ def get_small_molecule_guest_atom_candidates( ---- Remember to update the RDMol with the last frame positions. """ - if isinstance(topology, openmm.app.Topology): - topology_format = "OPENMMTOPOLOGY" - else: - topology_format = None + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) - u = mda.Universe(topology, trajectory, topology_format=topology_format) - ligand_ag = u.atoms[ligand_idxs] + ligand_ag = u.atoms[guest_idxs] # 0. Get the ligand RMSF rmsf = get_local_rmsf(ligand_ag) @@ -308,14 +321,20 @@ def get_small_molecule_guest_atom_candidates( for angle in _angles: # Check that the angle is at least not collinear angle_ag = ligand_ag.atoms[list(angle)] - if not is_collinear(ligand_ag.positions, angle): - angles_list.append(angle) + if not is_collinear(ligand_ag.positions, angle, u.dimensions): + angles_list.append( + ( + angle_ag.atoms[0].ix, + angle_ag.atoms[1].ix, + angle_ag.atoms[2].ix + ) + ) return angles_list def get_host_atom_candidates( - topology: Union[str, openmm.app.Topology], + topology: Union[str, pathlib.Path, openmm.app.Topology], trajectory: Union[str, pathlib.Path], host_idxs: list[int], l1_idx: int, @@ -349,12 +368,13 @@ def get_host_atom_candidates( max_distance : unit.Quantity The maximum search distance around l1 for suitable candidate atoms. """ - if isinstance(topology, openmm.app.Topology): - topology_format = "OPENMMTOPOLOGY" - else: - topology_format = None + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) - u = mda.Universe(topology, trajectory, topology_format=topology_format) protein_ag1 = u.atoms[host_idxs] protein_ag2 = protein_ag.select_atoms(protein_selection) @@ -375,6 +395,107 @@ def get_host_atom_candidates( return atom_finder.results.host_idxs -def select_boresch_atoms( +class EvaluateH2Atoms(AnalysisBase): + """ + Class to evaluate the suitability of a set of host atoms + as a H2 atom (i.e. bonded to the guest G0 atom). + + Parameters + ---------- + guest_atoms: MDAnalysis.AtomGroup + The guest atoms representing G0-G1-G2. + host_atom_pool: MDAnalysis.AtomGroup + The pool of atoms to pick a H2 from. + angle_force_constant : unit.Quantity + The force constant for the H2-G0-G1 angle. + """ + + +def find_boresch_restraint( + topology: Union[str, pathlib.Path, openmm.app.Topology], + trajectory: Union[str, pathlib.Path], + guest_rdmol: Chem.Mol, + guest_idxs: list[int], + host_idxs: list[int], + guest_restraint_atom_idxs: Optional[list[int]] = None, + host_restraint_atoms_idxs Optional[list[int]] = None, + host_selection: str = 'all', + dssp_filter: bool = False, + rmsf_custoff: unit.Quantity = 0.1 * unit.nanometer, + host_min_distance: unit.Quantity = 1 * unit.nanometer, + host_max_distance: unit.Quantity = 3 * unit.nanometer, +) -> BoreschRestraintGeometry: + """ + Find suitable Boresch-style restraints between a host and guest entity. + + Parameters + ---------- + ... + + Returns + ------- + ... + """ + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) + u.trajectory[-1] # Work with the final frame + + if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): + # In this case assume the picked atoms were intentional / representative + # of the input and go with it + guest_ag = u.select_atoms[guest_idxs] + guest_angle = (at.ix for at in guest_ag.atoms[guest_restraint_atom_idxs]) + host_ag = u.select_atoms[host_idxs] + host_angle = (at.ix for at in host_ag.atoms[host_restraint_atoms_idxs]) + # TODO sort out the return on this + return BoreschRestraintGeometry(...) + + if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): + # This is not an intended outcome, crash out here + errmsg = ( + "both ``guest_restraints_atoms_idxs`` and ``host_restraint_atoms_idxs`` " + "must be set or both must be None. " + f"Got {guest_restraint_atoms_idxs} and {host_atoms_restraint_atoms_idxs}" + ) + raise ValueError(errmsg) + + # Fetch the guest angles + guest_angles = get_guest_atom_candidates( + topology=topology, + trajectory=trajectory, + rdmol=guest_rdmol, + guest_idxs=guest_idxs, + rmsf_cutoff=rmsf_cutoff, + ) + + guest_angle = guest_angles[0] + + # Fetch the host atom pool + host_pool = get_host_atom_candidates( + topology=topology, + trajectory=trajectory, + host_idxs=host_idxs, + l1_idx=guest_angle[0], + host_selection=host_selection, + dssp_filter=dssp_filter, + rmsf_cutoff=rmsf_custoff, + min_distance=host_min_distance, + max_distance=host_max_distance, + ) + + # Get the guest angle atomgroup + guest_ag = u.atoms[list(guest_angle)] + + # Find all suitable H2 idxs + h2_idxs = [] + for i in host_pool: + host2_at = u.atoms[i] + pos = np.vstack((at.position, guest_ag.positions)) + angle = calc_angles(pos[0], pos[1], pos[2], box=u.dimensions) * unit.radians + dihed = calc_dihedrals(pos[0], pos[1], pos[2], pos[3], box=u.dimensions) * unit.radians + collinear = is_collinear(positions, [0, 1, 2, 3]) -): diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index a74e83ed3..a1f983621 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -22,7 +22,7 @@ import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.analysis.rmsf import RMSF -from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDAnalysis.lib.distances import calc_bonds, calc_angles, minimize_vectors from MDAnalysis.coordinates.memory import MemoryReader from openfe_analysis.transformations import Aligner, NoJump @@ -166,19 +166,21 @@ def get_central_atom_idx(rdmol: Chem.Mol) -> int: return center -def is_collinear(positions, atoms, threshold=0.9): +def is_collinear(positions, atoms, dimensions=None, threshold=0.9): """ Check whether any sequential vectors in a sequence of atoms are collinear. Parameters ---------- positions : openmm.unit.Quantity - System positions. + System positions. atoms : list[int] - The indices of the atoms to test. + The indices of the atoms to test. + dimensions : Optional[npt.NDArray] + The dimensions of the system to minimize vectors. threshold : float - Atoms are not collinear if their sequential vector separation dot - products are less than ``threshold``. Default 0.9. + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. Returns ------- @@ -187,12 +189,18 @@ def is_collinear(positions, atoms, threshold=0.9): Notes ----- - Originally from Yank, with modifications from Separated Topologies + Originally from Yank. """ results = False for i in range(len(atoms) - 2): - v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] - v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] + v1 = minimize_vectors( + positions[atoms[i + 1], :] - positions[atoms[i], :], + box=dimensions, + ) + v2 = minimize_vectors( + positions[atoms[i + 2], :] - positions[atoms[i + 1], :], + box=dimensions, + ) normalized_inner_product = np.dot(v1, v2) / np.sqrt( np.dot(v1, v1) * np.dot(v2, v2) ) From 3cce308a8548df90305d7ce11c03f128dee6cc0d Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 14 Dec 2024 00:07:26 +0000 Subject: [PATCH 20/33] add host atom finding routine --- .../restraints/geometry/boresch.py | 277 +++++++++++++++--- .../openmm_utils/restraints/geometry/utils.py | 20 +- 2 files changed, 259 insertions(+), 38 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index 363e22c6b..41bd3b3b7 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -19,6 +19,7 @@ from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals import numpy as np import numpy.typing as npt +from scipy.stats import circmean from .base import HostGuestRestraintGeometry @@ -29,10 +30,10 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): The restraint is defined by the following: - H0 G2 + H2 G2 - - - - - H1 - - H2 -- G0 - - G1 + H1 - - H0 -- G0 - - G1 Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. @@ -43,7 +44,7 @@ def get_bond_distance( coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ - Get the H2 - G0 distance. + Get the H0 - G0 distance. Parameters ---------- @@ -58,7 +59,7 @@ def get_bond_distance( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology) ) - at1 = u.atoms[host_atoms[2]] + at1 = u.atoms[host_atoms[0]] at2 = u.atoms[guest_atoms[0]] bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) # convert to float so we avoid having a np.float64 @@ -70,7 +71,7 @@ def get_angles( coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ - Get the H1-H2-G0, and H2-G0-G1 angles. + Get the H1-H0-G0, and H0-G0-G1 angles. Parameters ---------- @@ -86,7 +87,7 @@ def get_angles( topology_format=_get_mda_topology_format(topology) ) at1 = u.atoms[host_atoms[1]] - at2 = u.atoms[host_atoms[2]] + at2 = u.atoms[host_atoms[0]] at3 = u.atoms[guest_atoms[0]] at4 = u.atoms[guest_atoms[1]] @@ -104,7 +105,7 @@ def get_dihedrals( coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ - Get the H0-H1-H2-G0, H1-H2-G0-G1, and H2-G0-G1-G2 dihedrals. + Get the H2-H1-H0-G0, H1-H0-G0-G1, and H0-G0-G1-G2 dihedrals. Parameters ---------- @@ -119,9 +120,9 @@ def get_dihedrals( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology) ) - at1 = u.atoms[host_atoms[0]] + at1 = u.atoms[host_atoms[2]] at2 = u.atoms[host_atoms[1]] - at3 = u.atoms[host_atoms[2]] + at3 = u.atoms[host_atoms[0]] at4 = u.atoms[guest_atoms[0]] at5 = u.atoms[guest_atoms[1]] at6 = u.atoms[guest_atoms[2]] @@ -275,7 +276,7 @@ def get_guest_atom_candidates( Returns ------- angle_list : list[tuple[int]] - A list of tuples for each valid l1, l2, l3 angle. If ``None``, no + A list of tuples for each valid G0, G1, G2 angle. If ``None``, no angles could be found. Raises @@ -343,7 +344,7 @@ def get_host_atom_candidates( rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, min_distance: unit.Quantity = 1 * unit.nanometer, max_distance: unit.Quantity = 3 * unit.nanometer, -): +) -> npt.NDArray: """ Get a list of suitable host atoms. @@ -367,6 +368,11 @@ def get_host_atom_candidates( The minimum search distance around l1 for suitable candidate atoms. max_distance : unit.Quantity The maximum search distance around l1 for suitable candidate atoms. + + Return + ------ + NDArray + Array of host atom indexes """ u = mda.Universe( topology, @@ -395,20 +401,212 @@ def get_host_atom_candidates( return atom_finder.results.host_idxs -class EvaluateH2Atoms(AnalysisBase): +class EvaluateHostAtoms1(AnalysisBase): """ Class to evaluate the suitability of a set of host atoms - as a H2 atom (i.e. bonded to the guest G0 atom). + as H1 atoms (i.e. the second host atom). Parameters ---------- - guest_atoms: MDAnalysis.AtomGroup - The guest atoms representing G0-G1-G2. - host_atom_pool: MDAnalysis.AtomGroup - The pool of atoms to pick a H2 from. + reference : MDAnalysis.AtomGroup + The reference preceeding three atoms. + host_atom_pool : MDAnalysis.AtomGroup + The pool of atoms to pick an atom from. + minimum_distance : unit.Quantity + The minimum distance from the bound reference atom. angle_force_constant : unit.Quantity - The force constant for the H2-G0-G1 angle. + The force constant for the angle. + temperature : unit.Quantity + The system temperature in Kelvin """ + def __init__( + self, + reference, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + **kwargs + ): + super().__init__(reference.universe.trajectory, **kwargs) + + if len(reference) != 3: + errmsg = "Incorrect number of reference atoms passed" + raise ValueError(errmsg) + + self.reference = reference + self.host_atom_pool = host_atom_pool + self.minimum_distance = minimum_distance.to('angstrom').m + self.angle_force_constant = angle_force_constant + self.temperature = temperature + + def _prepare(self): + self.results.distances = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.angles = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.dihedrals = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.collinear = np.empty( + (len(self.host_atom_pool), self.n_frames), + dtype=bool, + ) + self.results.valid = np.empty( + len(self.host_atom_pool), + dtype=bool, + ) + + def _single_frame(self): + for i, at in enumerate(self.host_atom_pool): + distance = calc_bonds( + at.position, + self.reference.atoms[0].position, + box=self.reference.dimensions, + ) + angle = calc_angles( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + box=self.reference.dimensions, + ) + dihedral = calc_dihedrals( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + self.reference.atoms[2].position, + box=self.reference.dimensions + ) + collinear = is_collinear( + positions=np.vstack((at.position, self.reference.positions)), + dimensions=self.reference.dimensions, + ) + self.results.distances[i][self._frame_index] = distance + self.results.angles[i][self._frame_index] = angle + self.results.dihedrals[i][self._frame_index] = dihedral + self.results.collinear[i][self._frame_index] = collinear + + def _conclude(self): + for i, at in enumerate(self.host_atom_pool): + distance_bounds = all( + self.results.distances[i] > self.minimum_distance + ) + mean_angle = circmean(self.results.angles[i], high=np.pi, low=0) + angle_bounds = check_angle_not_flat( + angle=mean_angle * unit.radians, + force_constant=self.angle_force_constant, + temperature=self.temperature, + ) + angle_variance = check_angular_variance( + self.results.angles[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=0 * unit.radians, + width=1.745 * unit.radians, + ) + mean_dihed = circmean(self.results.dihedrals[i], high=np.pi, low=-np.pi) + dihed_bounds = check_dihedral_bounds(mean_dihed) + dihed_variance = check_angular_variance( + self.results.dihedrals[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=5.23 * unit.radians, + ) + not_collinear = not all(self.results.collinear[i]) + if all([distance_bounds, angle_bounds, angle_variance, dihed_bounds, dihed_variance, not_collinear]): + self.results.valid[i] = True + + +class EvaluateHostAtoms2(EvaluateH21Atoms): + def _prepare(self): + self.results.distances1 = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.ditances2 = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.dihedrals = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.collinear = np.empty( + (len(self.host_atom_pool), self.n_frames), + dtype=bool, + ) + self.results.valid = np.empty( + len(self.host_atom_pool), + dtype=bool, + ) + + def _single_frame(self): + for i, at in enumerate(self.host_atom_pool): + distance1 = calc_bonds( + at.position, + self.reference.atoms[0].position, + box=self.reference.dimensions, + ) + distance2 = calc_bonds( + at.position, + self.reference.atoms[1].position, + box=self.reference.dimensions, + ) + dihedral = calc_dihedrals( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + self.reference.atoms[2].position, + box=self.reference.dimensions + ) + collinear = is_collinear( + positions=np.vstack((at.position, self.reference.positions)), + dimensions=self.reference.dimensions, + ) + self.results.distances1[i][self._frame_index] = distance + self.results.distances2[i][self._frame_index] = angle + self.results.dihedrals[i][self._frame_index] = dihedral + self.results.collinear[i][self._frame_index] = collinear + + def _conclude(self): + for i, at in enumerate(self.host_atom_pool): + distance1_bounds = all( + self.results.distances1[i] > self.minimum_distance + ) + distance2_bounds = all( + self.results.distances2[i] > self.minimum_distance + ) + mean_dihed = circmean(self.results.dihedrals[i], high=np.pi, low=-np.pi) + dihed_bounds = check_dihedral_bounds(mean_dihed) + dihed_variance = check_angular_variance( + self.results.dihedrals[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=5.23 * unit.radians, + ) + not_collinear = not all(self.results.collinear[i]) + if all([distance1_bounds, distance2_bounds, dihed_bounds, dihed_variance, not_collinear]): + self.results.valid[i] = True + + +def _find_host_angle(g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature): + h0_eval = EvaluateHAtoms1(g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + h0_eval.run() + + for i, valid_h0 in enumerate(h0_eval.results.valid): + if valid_h0: + g1g2h0_atoms = g0g1g2_atoms.atoms[1:] + host_atom_pool.atoms[i] + h1_eval = EvaluateHAtoms1(g1g2h0_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + for j, valid_h1 in enumerate(h1_eval.results.valid): + g2h0h1_atoms = g1g2h0_atoms.atoms[1:] + host_atom_pool.atoms[j] + h2_eval = EvaluateHAtoms2(g2h0h1_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + + if any(h2_eval.ressults.valid): + d1_avgs = [d.mean() for d in h2_eval.results.distances1] + d2_avgs = [d.mean() for d in h2_eval.results.distances2] + dsum_avgs = d1_avgs + d2_avgs + k = dsum_avgs.argmin() + + return host_atom_pool.atoms[[i, j, k]].ix + return None def find_boresch_restraint( @@ -424,6 +622,8 @@ def find_boresch_restraint( rmsf_custoff: unit.Quantity = 0.1 * unit.nanometer, host_min_distance: unit.Quantity = 1 * unit.nanometer, host_max_distance: unit.Quantity = 3 * unit.nanometer, + angle_force_constant: unit.Quantity = 83.68 * unit.kilojoule_per_mole / unit.radians**2, + temperature: unit.Quantity = 298.15 * unit.kelvin, ) -> BoreschRestraintGeometry: """ Find suitable Boresch-style restraints between a host and guest entity. @@ -448,11 +648,11 @@ def find_boresch_restraint( # In this case assume the picked atoms were intentional / representative # of the input and go with it guest_ag = u.select_atoms[guest_idxs] - guest_angle = (at.ix for at in guest_ag.atoms[guest_restraint_atom_idxs]) + guest_angle = [at.ix for at in guest_ag.atoms[guest_restraint_atom_idxs]] host_ag = u.select_atoms[host_idxs] - host_angle = (at.ix for at in host_ag.atoms[host_restraint_atoms_idxs]) + host_angle = [at.ix for at in host_ag.atoms[host_restraint_atoms_idxs]] # TODO sort out the return on this - return BoreschRestraintGeometry(...) + return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): # This is not an intended outcome, crash out here @@ -463,7 +663,7 @@ def find_boresch_restraint( ) raise ValueError(errmsg) - # Fetch the guest angles + # 1. Fetch the guest angles guest_angles = get_guest_atom_candidates( topology=topology, trajectory=trajectory, @@ -472,9 +672,14 @@ def find_boresch_restraint( rmsf_cutoff=rmsf_cutoff, ) + if len(guest_angles) != 0: + errmsg = "No suitable ligand atoms found for the restraint." + raise ValueError(errmsg) + + # We pick the first angle / ligand atom set as the one to use guest_angle = guest_angles[0] - # Fetch the host atom pool + # 2. We next fetch the host atom pool host_pool = get_host_atom_candidates( topology=topology, trajectory=trajectory, @@ -487,15 +692,21 @@ def find_boresch_restraint( max_distance=host_max_distance, ) - # Get the guest angle atomgroup - guest_ag = u.atoms[list(guest_angle)] + # 3. We then loop through the guest angles to find suitable host atoms + for guest_angle in guest_angles: + host_angle = _find_host_angle( + g0g1g2_atoms=u.atoms[list(guest_angle)], + host_atom_pool=u.atoms[host_pool], + minimum_distance=0.5 * unit.nanometer, + angle_force_constant=angle_force_constant, + temperature=temperature, + ) + # continue if it's empty, otherwise stop + if host_angle is not None: + break - # Find all suitable H2 idxs - h2_idxs = [] - for i in host_pool: - host2_at = u.atoms[i] - pos = np.vstack((at.position, guest_ag.positions)) - angle = calc_angles(pos[0], pos[1], pos[2], box=u.dimensions) * unit.radians - dihed = calc_dihedrals(pos[0], pos[1], pos[2], pos[3], box=u.dimensions) * unit.radians - collinear = is_collinear(positions, [0, 1, 2, 3]) + if host_angle is None: + errmsg = "No suitable host atoms could be found" + raise ValueError(errmsg) + return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index a1f983621..96a665ee5 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -208,7 +208,7 @@ def is_collinear(positions, atoms, dimensions=None, threshold=0.9): return result -def check_angle_energy( +def check_angle_not_flat( angle: FloatQuantity["radians"], force_constant: ANGLE_FRC_CONSTANT_TYPE = DEFAULT_ANGLE_FRC_CONSTANT, temperature: FloatQuantity["kelvin"] = 298.15 * unit.kelvin, @@ -228,7 +228,7 @@ def check_angle_energy( Returns ------- bool - If the angle is less than 10 kT from 0 or pi radians + False if the angle is less than 10 kT from 0 or pi radians Note ---- @@ -280,7 +280,10 @@ def check_dihedral_bounds( def check_angular_variance( - angles: ArrayQuantity["radians"], width: FloatQuantity["radians"] + angles: ArrayQuantity["radians"], width: FloatQuantity["radians"], + upper_bound: FloatQuantity['radians'], + lower_bound: FloatQuantity['radians'], + width: FloatQuantity['radians'], ) -> bool: """ Check that the variance of a list of ``angles`` does not exceed @@ -290,6 +293,10 @@ def check_angular_variance( ---------- angles : ArrayLike[unit.Quantity] An array of angles in units compatible with radians. + upper_bound: FloatQuantity['radians'] + The upper bound in the angle range. + lower_bound: FloatQuantity['radians'] + The lower bound in the angle range. width : unit.Quantity The width to check the variance against, in units compatible with radians. @@ -299,8 +306,11 @@ def check_angular_variance( ``True`` if the variance of the angles is less than the width. """ - array = angles.to("radians").m - variance = circvar(array) + variance = circvar( + angles.to("radians").m, + high=upper_bound.to("radians").m, + low=lower_bound.to("radians").m + ) return not (variance * unit.radians > width) From 9171d3992e2418f1453e0888cf48ec2661c2bc25 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 14 Dec 2024 00:12:14 +0000 Subject: [PATCH 21/33] autoformatting --- .../restraints/geometry/boresch.py | 133 ++++++++++-------- .../openmm_utils/restraints/geometry/utils.py | 24 ++-- 2 files changed, 93 insertions(+), 64 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index 41bd3b3b7..844817f52 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -38,9 +38,10 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. """ + def get_bond_distance( self, - topology: Union[str, pathlib.Path, openmm.app.Topology], + topology: Union[str, pathlib.Path, openmm.app.Topology], coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ @@ -57,7 +58,7 @@ def get_bond_distance( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) at1 = u.atoms[host_atoms[0]] at2 = u.atoms[guest_atoms[0]] @@ -84,7 +85,7 @@ def get_angles( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) at1 = u.atoms[host_atoms[1]] at2 = u.atoms[host_atoms[0]] @@ -118,7 +119,7 @@ def get_dihedrals( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) at1 = u.atoms[host_atoms[2]] at2 = u.atoms[host_atoms[1]] @@ -292,7 +293,7 @@ def get_guest_atom_candidates( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) ligand_ag = u.atoms[guest_idxs] @@ -302,7 +303,7 @@ def get_guest_atom_candidates( u.trajectory[-1] # forward to the last frame # 1. Get the pool of atoms to work with - atom_pool = _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) + atom_pool = _get_atom_pool(rdmol, rmsf) if atom_pool is None: # We don't have enough atoms so we raise an error @@ -324,11 +325,7 @@ def get_guest_atom_candidates( angle_ag = ligand_ag.atoms[list(angle)] if not is_collinear(ligand_ag.positions, angle, u.dimensions): angles_list.append( - ( - angle_ag.atoms[0].ix, - angle_ag.atoms[1].ix, - angle_ag.atoms[2].ix - ) + (angle_ag.atoms[0].ix, angle_ag.atoms[1].ix, angle_ag.atoms[2].ix) ) return angles_list @@ -378,7 +375,7 @@ def get_host_atom_candidates( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) protein_ag1 = u.atoms[host_idxs] @@ -419,6 +416,7 @@ class EvaluateHostAtoms1(AnalysisBase): temperature : unit.Quantity The system temperature in Kelvin """ + def __init__( self, reference, @@ -426,7 +424,7 @@ def __init__( minimum_distance, angle_force_constant, temperature, - **kwargs + **kwargs, ): super().__init__(reference.universe.trajectory, **kwargs) @@ -436,20 +434,14 @@ def __init__( self.reference = reference self.host_atom_pool = host_atom_pool - self.minimum_distance = minimum_distance.to('angstrom').m + self.minimum_distance = minimum_distance.to("angstrom").m self.angle_force_constant = angle_force_constant self.temperature = temperature def _prepare(self): - self.results.distances = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) - self.results.angles = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) - self.results.dihedrals = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) + self.results.distances = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.angles = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.dihedrals = np.zeros((len(self.host_atom_pool), self.n_frames)) self.results.collinear = np.empty( (len(self.host_atom_pool), self.n_frames), dtype=bool, @@ -477,7 +469,7 @@ def _single_frame(self): self.reference.atoms[0].position, self.reference.atoms[1].position, self.reference.atoms[2].position, - box=self.reference.dimensions + box=self.reference.dimensions, ) collinear = is_collinear( positions=np.vstack((at.position, self.reference.positions)), @@ -490,9 +482,7 @@ def _single_frame(self): def _conclude(self): for i, at in enumerate(self.host_atom_pool): - distance_bounds = all( - self.results.distances[i] > self.minimum_distance - ) + distance_bounds = all(self.results.distances[i] > self.minimum_distance) mean_angle = circmean(self.results.angles[i], high=np.pi, low=0) angle_bounds = check_angle_not_flat( angle=mean_angle * unit.radians, @@ -514,21 +504,24 @@ def _conclude(self): width=5.23 * unit.radians, ) not_collinear = not all(self.results.collinear[i]) - if all([distance_bounds, angle_bounds, angle_variance, dihed_bounds, dihed_variance, not_collinear]): + if all( + [ + distance_bounds, + angle_bounds, + angle_variance, + dihed_bounds, + dihed_variance, + not_collinear, + ] + ): self.results.valid[i] = True class EvaluateHostAtoms2(EvaluateH21Atoms): def _prepare(self): - self.results.distances1 = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) - self.results.ditances2 = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) - self.results.dihedrals = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) + self.results.distances1 = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.ditances2 = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.dihedrals = np.zeros((len(self.host_atom_pool), self.n_frames)) self.results.collinear = np.empty( (len(self.host_atom_pool), self.n_frames), dtype=bool, @@ -555,7 +548,7 @@ def _single_frame(self): self.reference.atoms[0].position, self.reference.atoms[1].position, self.reference.atoms[2].position, - box=self.reference.dimensions + box=self.reference.dimensions, ) collinear = is_collinear( positions=np.vstack((at.position, self.reference.positions)), @@ -568,12 +561,8 @@ def _single_frame(self): def _conclude(self): for i, at in enumerate(self.host_atom_pool): - distance1_bounds = all( - self.results.distances1[i] > self.minimum_distance - ) - distance2_bounds = all( - self.results.distances2[i] > self.minimum_distance - ) + distance1_bounds = all(self.results.distances1[i] > self.minimum_distance) + distance2_bounds = all(self.results.distances2[i] > self.minimum_distance) mean_dihed = circmean(self.results.dihedrals[i], high=np.pi, low=-np.pi) dihed_bounds = check_dihedral_bounds(mean_dihed) dihed_variance = check_angular_variance( @@ -583,21 +572,49 @@ def _conclude(self): width=5.23 * unit.radians, ) not_collinear = not all(self.results.collinear[i]) - if all([distance1_bounds, distance2_bounds, dihed_bounds, dihed_variance, not_collinear]): + if all( + [ + distance1_bounds, + distance2_bounds, + dihed_bounds, + dihed_variance, + not_collinear, + ] + ): self.results.valid[i] = True -def _find_host_angle(g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature): - h0_eval = EvaluateHAtoms1(g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) +def _find_host_angle( + g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature +): + h0_eval = EvaluateHAtoms1( + g0g1g2_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) h0_eval.run() for i, valid_h0 in enumerate(h0_eval.results.valid): if valid_h0: g1g2h0_atoms = g0g1g2_atoms.atoms[1:] + host_atom_pool.atoms[i] - h1_eval = EvaluateHAtoms1(g1g2h0_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + h1_eval = EvaluateHAtoms1( + g1g2h0_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) for j, valid_h1 in enumerate(h1_eval.results.valid): g2h0h1_atoms = g1g2h0_atoms.atoms[1:] + host_atom_pool.atoms[j] - h2_eval = EvaluateHAtoms2(g2h0h1_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + h2_eval = EvaluateHAtoms2( + g2h0h1_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) if any(h2_eval.ressults.valid): d1_avgs = [d.mean() for d in h2_eval.results.distances1] @@ -616,13 +633,15 @@ def find_boresch_restraint( guest_idxs: list[int], host_idxs: list[int], guest_restraint_atom_idxs: Optional[list[int]] = None, - host_restraint_atoms_idxs Optional[list[int]] = None, - host_selection: str = 'all', + host_restraint_atoms_idxs: Optional[list[int]] = None, + host_selection: str = "all", dssp_filter: bool = False, rmsf_custoff: unit.Quantity = 0.1 * unit.nanometer, host_min_distance: unit.Quantity = 1 * unit.nanometer, host_max_distance: unit.Quantity = 3 * unit.nanometer, - angle_force_constant: unit.Quantity = 83.68 * unit.kilojoule_per_mole / unit.radians**2, + angle_force_constant: unit.Quantity = ( + 83.68 * unit.kilojoule_per_mole / unit.radians**2 + ), temperature: unit.Quantity = 298.15 * unit.kelvin, ) -> BoreschRestraintGeometry: """ @@ -640,11 +659,13 @@ def find_boresch_restraint( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) u.trajectory[-1] # Work with the final frame - if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): + if (guest_restraint_atoms_idxs is not None) and ( + host_restraint_atoms_idxs is not None + ): # In this case assume the picked atoms were intentional / representative # of the input and go with it guest_ag = u.select_atoms[guest_idxs] @@ -654,7 +675,9 @@ def find_boresch_restraint( # TODO sort out the return on this return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) - if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): + if (guest_restraint_atoms_idxs is not None) ^ ( + host_restraint_atoms_idxs is not None + ): # This is not an intended outcome, crash out here errmsg = ( "both ``guest_restraints_atoms_idxs`` and ``host_restraint_atoms_idxs`` " diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 96a665ee5..91ce61e8b 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -32,7 +32,9 @@ ANGLE_FRC_CONSTANT_TYPE = FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] -def _get_mda_coord_format(coordinates: Union[str, npt.NDArray]) -> Optional[MemoryReader]: +def _get_mda_coord_format( + coordinates: Union[str, npt.NDArray] +) -> Optional[MemoryReader]: """ Helper to set the coordinate format to MemoryReader if the coordinates are an NDArray. @@ -51,7 +53,10 @@ def _get_mda_coord_format(coordinates: Union[str, npt.NDArray]) -> Optional[Memo else: return None -def _get_mda_topology_format(topology: Union[str, openmm.app.Topology]) -> Optional[str]: + +def _get_mda_topology_format( + topology: Union[str, openmm.app.Topology] +) -> Optional[str]: """ Helper to set the topology format to OPENMMTOPOLOGY if the topology is an openmm.app.Topology. @@ -59,7 +64,7 @@ def _get_mda_topology_format(topology: Union[str, openmm.app.Topology]) -> Optio Parameters ---------- topology : Union[str, openmm.app.Topology] - + Returns ------- @@ -177,7 +182,7 @@ def is_collinear(positions, atoms, dimensions=None, threshold=0.9): atoms : list[int] The indices of the atoms to test. dimensions : Optional[npt.NDArray] - The dimensions of the system to minimize vectors. + The dimensions of the system to minimize vectors. threshold : float Atoms are not collinear if their sequential vector separation dot products are less than ``threshold``. Default 0.9. @@ -280,10 +285,11 @@ def check_dihedral_bounds( def check_angular_variance( - angles: ArrayQuantity["radians"], width: FloatQuantity["radians"], - upper_bound: FloatQuantity['radians'], - lower_bound: FloatQuantity['radians'], - width: FloatQuantity['radians'], + angles: ArrayQuantity["radians"], + width: FloatQuantity["radians"], + upper_bound: FloatQuantity["radians"], + lower_bound: FloatQuantity["radians"], + width: FloatQuantity["radians"], ) -> bool: """ Check that the variance of a list of ``angles`` does not exceed @@ -309,7 +315,7 @@ def check_angular_variance( variance = circvar( angles.to("radians").m, high=upper_bound.to("radians").m, - low=lower_bound.to("radians").m + low=lower_bound.to("radians").m, ) return not (variance * unit.radians > width) From 033a1e44aadfbf330531b0443842b6f21095e6ec Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 14 Dec 2024 01:21:39 +0000 Subject: [PATCH 22/33] various fixes --- .../restraints/geometry/__init__.py | 4 + .../openmm_utils/restraints/geometry/base.py | 5 - .../restraints/geometry/boresch.py | 170 +++++++++++------- .../restraints/geometry/flatbottom.py | 16 +- .../restraints/geometry/harmonic.py | 33 ++-- .../openmm_utils/restraints/geometry/utils.py | 56 +++--- .../restraints/openmm/omm_forces.py | 2 +- .../restraints/openmm/omm_restraints.py | 108 +++++++---- 8 files changed, 236 insertions(+), 158 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/__init__.py b/openfe/protocols/openmm_utils/restraints/geometry/__init__.py index e69de29bb..1c1b4c56a 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/__init__.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/__init__.py @@ -0,0 +1,4 @@ +from .base import BaseRestraintGeometry +from .harmonic import DistanceRestraintGeometry +from .flatbottom import FlatBottomDistanceGeometry +from .boresch import BoreschRestraintGeometry diff --git a/openfe/protocols/openmm_utils/restraints/geometry/base.py b/openfe/protocols/openmm_utils/restraints/geometry/base.py index 21a714cde..5db9225ac 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/base.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/base.py @@ -10,10 +10,6 @@ import abc from pydantic.v1 import BaseModel, validator -from openff.units import unit -import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds, calc_angles - class BaseRestraintGeometry(BaseModel, abc.ABC): class Config: @@ -47,4 +43,3 @@ def positive_idxs(cls, v): errmsg = "negative indices passed" raise ValueError(errmsg) return v - diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index 844817f52..0d6806611 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -7,21 +7,34 @@ ---- * Add relevant duecredit entries. """ -import abc import pathlib -from pydantic.v1 import BaseModel, validator +from typing import Union, Optional, Iterable from rdkit import Chem +import openmm from openff.units import unit import MDAnalysis as mda -from MDANalysis.analysis.base import AnalysisBase +from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals import numpy as np import numpy.typing as npt from scipy.stats import circmean from .base import HostGuestRestraintGeometry +from .utils import ( + _get_mda_coord_format, + _get_mda_topology_format, + get_aromatic_rings, + get_heavy_atom_idxs, + get_central_atom_idx, + is_collinear, + check_angular_variance, + check_dihedral_bounds, + check_angle_not_flat, + FindHostAtoms, + get_local_rmsf +) class BoreschRestraintGeometry(HostGuestRestraintGeometry): @@ -60,8 +73,8 @@ def get_bond_distance( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology), ) - at1 = u.atoms[host_atoms[0]] - at2 = u.atoms[guest_atoms[0]] + at1 = u.atoms[self.host_atoms[0]] + at2 = u.atoms[self.guest_atoms[0]] bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) # convert to float so we avoid having a np.float64 return float(bond) * unit.angstrom @@ -87,10 +100,10 @@ def get_angles( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology), ) - at1 = u.atoms[host_atoms[1]] - at2 = u.atoms[host_atoms[0]] - at3 = u.atoms[guest_atoms[0]] - at4 = u.atoms[guest_atoms[1]] + at1 = u.atoms[self.host_atoms[1]] + at2 = u.atoms[self.host_atoms[0]] + at3 = u.atoms[self.guest_atoms[0]] + at4 = u.atoms[self.guest_atoms[1]] angleA = calc_angles( at1.position, at2.position, at3.position, u.atoms.dimensions @@ -121,21 +134,24 @@ def get_dihedrals( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology), ) - at1 = u.atoms[host_atoms[2]] - at2 = u.atoms[host_atoms[1]] - at3 = u.atoms[host_atoms[0]] - at4 = u.atoms[guest_atoms[0]] - at5 = u.atoms[guest_atoms[1]] - at6 = u.atoms[guest_atoms[2]] + at1 = u.atoms[self.host_atoms[2]] + at2 = u.atoms[self.host_atoms[1]] + at3 = u.atoms[self.host_atoms[0]] + at4 = u.atoms[self.guest_atoms[0]] + at5 = u.atoms[self.guest_atoms[1]] + at6 = u.atoms[self.guest_atoms[2]] dihA = calc_dihedrals( - at1.position, at2.position, at3.position, at4.position, u.atoms.dimensions + at1.position, at2.position, at3.position, at4.position, + box=u.dimensions ) dihB = calc_dihedrals( - at2.position, at3.position, at4.position, at5.position, u.atoms.dimensions + at2.position, at3.position, at4.position, at5.position, + box=u.dimensions ) dihC = calc_dihedrals( - at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions + at3.position, at4.position, at5.position, at6.position, + box=u.dimensions ) return dihA, dihB, dihC @@ -213,7 +229,11 @@ def _get_bonded_angles_from_pool( return angles -def _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) -> Optional[set[int]]: +def _get_atom_pool( + rdmol: Chem.Mol, + rmsf: npt.NDArray, + rmsf_cutoff: unit.Quantity +) -> Optional[set[int]]: """ Filter atoms based on rmsf & rings, defaulting to heavy atoms if there are not enough. @@ -291,8 +311,8 @@ def get_guest_atom_candidates( """ u = mda.Universe( topology, - coordinates, - format=_get_mda_coord_format(coordinates), + trajectory, + format=_get_mda_coord_format(trajectory), topology_format=_get_mda_topology_format(topology), ) @@ -314,18 +334,22 @@ def get_guest_atom_candidates( center = get_central_atom_idx(rdmol) # 3. Sort the atom pool based on their distance from the center - sorted_anchor_pool = _sort_by_distance_from_atom(rdmol, center, anchor_pool) + sorted_atom_pool = _sort_by_distance_from_atom(rdmol, center, atom_pool) # 4. Get a list of probable angles angles_list = [] - for atom in sorted_anchor_pool: - angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) - for angle in _angles: + for atom in sorted_atom_pool: + angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_atom_pool) + for angle in angles: # Check that the angle is at least not collinear angle_ag = ligand_ag.atoms[list(angle)] if not is_collinear(ligand_ag.positions, angle, u.dimensions): angles_list.append( - (angle_ag.atoms[0].ix, angle_ag.atoms[1].ix, angle_ag.atoms[2].ix) + ( + angle_ag.atoms[0].ix, + angle_ag.atoms[1].ix, + angle_ag.atoms[2].ix + ) ) return angles_list @@ -373,26 +397,29 @@ def get_host_atom_candidates( """ u = mda.Universe( topology, - coordinates, - format=_get_mda_coord_format(coordinates), + trajectory, + format=_get_mda_coord_format(trajectory), topology_format=_get_mda_topology_format(topology), ) - protein_ag1 = u.atoms[host_idxs] - protein_ag2 = protein_ag.select_atoms(protein_selection) + host_ag1 = u.atoms[host_idxs] + host_ag2 = host_ag1.select_atoms(host_selection) # 0. TODO: implement DSSP filter - # Should be able to just call MDA's DSSP method, but will need to catch an exception + # Should be able to just call MDA's DSSP method + # but will need to catch an exception if dssp_filter: - raise NotImplementedError("DSSP filtering is not currently implemented") + raise NotImplementedError( + "DSSP filtering is not currently implemented" + ) # 1. Get the RMSF & filter - rmsf = get_local_rmsf(sub_protein_ag) - protein_ag3 = sub_protein_ag.atoms[rmsf[heavy_atoms] < rmsf_cutoff] + rmsf = get_local_rmsf(host_ag2) + protein_ag3 = host_ag2.atoms[rmsf < rmsf_cutoff] # 2. Search of atoms within the min/max cutoff atom_finder = FindHostAtoms( - protein_ag3, u.atoms[l1_idx], min_search_distance, max_search_distance + protein_ag3, u.atoms[l1_idx], min_distance, max_distance ) atom_finder.run() return atom_finder.results.host_idxs @@ -439,9 +466,15 @@ def __init__( self.temperature = temperature def _prepare(self): - self.results.distances = np.zeros((len(self.host_atom_pool), self.n_frames)) - self.results.angles = np.zeros((len(self.host_atom_pool), self.n_frames)) - self.results.dihedrals = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.distances = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.angles = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.dihedrals = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) self.results.collinear = np.empty( (len(self.host_atom_pool), self.n_frames), dtype=bool, @@ -517,7 +550,7 @@ def _conclude(self): self.results.valid[i] = True -class EvaluateHostAtoms2(EvaluateH21Atoms): +class EvaluateHostAtoms2(EvaluateHostAtoms1): def _prepare(self): self.results.distances1 = np.zeros((len(self.host_atom_pool), self.n_frames)) self.results.ditances2 = np.zeros((len(self.host_atom_pool), self.n_frames)) @@ -554,8 +587,8 @@ def _single_frame(self): positions=np.vstack((at.position, self.reference.positions)), dimensions=self.reference.dimensions, ) - self.results.distances1[i][self._frame_index] = distance - self.results.distances2[i][self._frame_index] = angle + self.results.distances1[i][self._frame_index] = distance1 + self.results.distances2[i][self._frame_index] = distance2 self.results.dihedrals[i][self._frame_index] = dihedral self.results.collinear[i][self._frame_index] = collinear @@ -585,9 +618,13 @@ def _conclude(self): def _find_host_angle( - g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature + g0g1g2_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature ): - h0_eval = EvaluateHAtoms1( + h0_eval = EvaluateHostAtoms1( g0g1g2_atoms, host_atom_pool, minimum_distance, @@ -599,7 +636,7 @@ def _find_host_angle( for i, valid_h0 in enumerate(h0_eval.results.valid): if valid_h0: g1g2h0_atoms = g0g1g2_atoms.atoms[1:] + host_atom_pool.atoms[i] - h1_eval = EvaluateHAtoms1( + h1_eval = EvaluateHostAtoms1( g1g2h0_atoms, host_atom_pool, minimum_distance, @@ -608,7 +645,7 @@ def _find_host_angle( ) for j, valid_h1 in enumerate(h1_eval.results.valid): g2h0h1_atoms = g1g2h0_atoms.atoms[1:] + host_atom_pool.atoms[j] - h2_eval = EvaluateHAtoms2( + h2_eval = EvaluateHostAtoms2( g2h0h1_atoms, host_atom_pool, minimum_distance, @@ -632,11 +669,11 @@ def find_boresch_restraint( guest_rdmol: Chem.Mol, guest_idxs: list[int], host_idxs: list[int], - guest_restraint_atom_idxs: Optional[list[int]] = None, + guest_restraint_atoms_idxs: Optional[list[int]] = None, host_restraint_atoms_idxs: Optional[list[int]] = None, host_selection: str = "all", dssp_filter: bool = False, - rmsf_custoff: unit.Quantity = 0.1 * unit.nanometer, + rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, host_min_distance: unit.Quantity = 1 * unit.nanometer, host_max_distance: unit.Quantity = 3 * unit.nanometer, angle_force_constant: unit.Quantity = ( @@ -657,32 +694,35 @@ def find_boresch_restraint( """ u = mda.Universe( topology, - coordinates, - format=_get_mda_coord_format(coordinates), + trajectory, + format=_get_mda_coord_format(trajectory), topology_format=_get_mda_topology_format(topology), ) u.trajectory[-1] # Work with the final frame - if (guest_restraint_atoms_idxs is not None) and ( - host_restraint_atoms_idxs is not None - ): - # In this case assume the picked atoms were intentional / representative - # of the input and go with it + if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): # fmt: skip + # In this case assume the picked atoms were intentional / + # representative of the input and go with it guest_ag = u.select_atoms[guest_idxs] - guest_angle = [at.ix for at in guest_ag.atoms[guest_restraint_atom_idxs]] + guest_angle = [ + at.ix for at in guest_ag.atoms[guest_restraint_atoms_idxs] + ] host_ag = u.select_atoms[host_idxs] - host_angle = [at.ix for at in host_ag.atoms[host_restraint_atoms_idxs]] + host_angle = [ + at.ix for at in host_ag.atoms[host_restraint_atoms_idxs] + ] # TODO sort out the return on this - return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) + return BoreschRestraintGeometry( + host_atoms=host_angle, guest_atoms=guest_angle + ) - if (guest_restraint_atoms_idxs is not None) ^ ( - host_restraint_atoms_idxs is not None - ): + if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): # fmt: skip # This is not an intended outcome, crash out here errmsg = ( - "both ``guest_restraints_atoms_idxs`` and ``host_restraint_atoms_idxs`` " + "both ``guest_restraints_atoms_idxs`` and " + "``host_restraint_atoms_idxs`` " "must be set or both must be None. " - f"Got {guest_restraint_atoms_idxs} and {host_atoms_restraint_atoms_idxs}" + f"Got {guest_restraint_atoms_idxs} and {host_restraint_atoms_idxs}" ) raise ValueError(errmsg) @@ -710,7 +750,7 @@ def find_boresch_restraint( l1_idx=guest_angle[0], host_selection=host_selection, dssp_filter=dssp_filter, - rmsf_cutoff=rmsf_custoff, + rmsf_cutoff=rmsf_cutoff, min_distance=host_min_distance, max_distance=host_max_distance, ) @@ -732,4 +772,6 @@ def find_boresch_restraint( errmsg = "No suitable host atoms could be found" raise ValueError(errmsg) - return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) + return BoreschRestraintGeometry( + host_atoms=host_angle, guest_atoms=guest_angle + ) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py b/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py index c7e987736..c9007dd59 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py @@ -7,14 +7,15 @@ ---- * Add relevant duecredit entries. """ -import abc -from pydantic.v1 import BaseModel, validator - +import pathlib +from typing import Union, Optional import numpy as np +from openmm import app from openff.units import unit +from openff.models.types import FloatQuantity import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase -from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDAnalysis.lib.distances import calc_bonds from .harmonic import ( DistanceRestraintGeometry, @@ -27,7 +28,6 @@ class FlatBottomDistanceGeometry(DistanceRestraintGeometry): A geometry class for a flat bottom distance restraint between two groups of atoms. """ - well_radius: FloatQuantity["nanometer"] @@ -45,8 +45,8 @@ class COMDistanceAnalysis(AnalysisBase): _analysis_algorithm_is_parallelizable = False - def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): - super().__init__(host_atoms.universe.trajectory, **kwargs) + def __init__(self, group1, group2, **kwargs): + super().__init__(group1.universe.trajectory, **kwargs) self.ag1 = group1 self.ag2 = group2 @@ -67,7 +67,7 @@ def _conclude(self): def get_flatbottom_distance_restraint( - topology: Union[str, openmm.app.Topology], + topology: Union[str, app.Topology], trajectory: pathlib.Path, topology_format: Optional[str] = None, host_atoms: Optional[list[int]] = None, diff --git a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py index 36e7a61a7..770f86bcb 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py @@ -7,12 +7,12 @@ ---- * Add relevant duecredit entries. """ -import abc -from pydantic.v1 import BaseModel, validator - +import pathlib +from typing import Union, Optional +from openmm import app from openff.units import unit import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDAnalysis.lib.distances import calc_bonds from rdkit import Chem from .base import HostGuestRestraintGeometry @@ -50,7 +50,7 @@ def _get_selection(universe, atom_list, selection): def get_distance_restraint( - topology: Union[str, openmm.app.Topology], + topology: Union[str, app.Topology], trajectory: pathlib.Path, topology_format: Optional[str] = None, host_atoms: Optional[list[int]] = None, @@ -61,34 +61,25 @@ def get_distance_restraint( u = mda.Universe(topology, trajectory, topology_format=topology_format) guest_ag = _get_selection(u, guest_atoms, guest_selection) + guest_atoms = [a.ix for a in guest_ag] host_ag = _get_selection(u, host_atoms, host_selection) + host_atoms = [a.ix for a in host_ag] - return DistanceRestraintGeometry(guest_atoms=guest_atoms, host_atoms=host_atoms) + return DistanceRestraintGeometry( + guest_atoms=guest_atoms, host_atoms=host_atoms + ) def get_molecule_centers_restraint( - topology: Union[str, openmm.app.Topology], - trajectory: pathlib.Path, molA_rdmol: Chem.Mol, molB_rdmol: Chem.Mol, molA_idxs: list[int], molB_idxs: list[int], - topology_format: Optional[str] = None, ): # We assume that the mol idxs are ordered centerA = molA_idxs[_get_central_atom_idx(molA_rdmol)] centerB = molB_idxs[_get_central_atom_idx(molB_rdmol)] - u = mda.Universe(topology, trajectory, topology_format=topology_format) - guest_ag = _get_selection( - u, - [centerA], - None, - ) - guest_ag = _get_selection( - u, - [centerB], - None, + return DistanceRestraintGeometry( + guest_atoms=[centerA], host_atoms=[centerB] ) - - return DistsanceRestraintGeometry(guest_atoms=guest_atoms, host_atoms=host_atoms) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 91ce61e8b..7d6906650 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -7,29 +7,26 @@ ---- * Add relevant duecredit entries. """ -import abc -from pydantic.v1 import BaseModel, validator - +from typing import Union, Optional import numpy as np import numpy.typing as npt -from scipy.stats import circvar, circmean, circstd +from scipy.stats import circvar +import openmm from openff.toolkit import Molecule as OFFMol from openff.units import unit -from openff.models.types import FloatQuantity, ArrayQuantity import networkx as nx from rdkit import Chem import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase -from MDAnalysis.analysis.rmsf import RMSF -from MDAnalysis.lib.distances import calc_bonds, calc_angles, minimize_vectors +from MDAnalysis.analysis.rms import RMSF +from MDAnalysis.lib.distances import minimize_vectors, capped_distance from MDAnalysis.coordinates.memory import MemoryReader from openfe_analysis.transformations import Aligner, NoJump DEFAULT_ANGLE_FRC_CONSTANT = 83.68 * unit.kilojoule_per_mole / unit.radians**2 -ANGLE_FRC_CONSTANT_TYPE = FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] def _get_mda_coord_format( @@ -97,7 +94,7 @@ def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: aromatic_rings = [] for ring in ringinfo.AtomRings(): - if all(a in aroms for a in ring): + if all(a in arom_idxs for a in ring): aromatic_rings.append(ring) return aromatic_rings @@ -190,13 +187,14 @@ def is_collinear(positions, atoms, dimensions=None, threshold=0.9): Returns ------- result : bool - Returns True if any sequential pair of vectors is collinear; False otherwise. + Returns True if any sequential pair of vectors is collinear; + False otherwise. Notes ----- Originally from Yank. """ - results = False + result = False for i in range(len(atoms) - 2): v1 = minimize_vectors( positions[atoms[i + 1], :] - positions[atoms[i], :], @@ -214,9 +212,9 @@ def is_collinear(positions, atoms, dimensions=None, threshold=0.9): def check_angle_not_flat( - angle: FloatQuantity["radians"], - force_constant: ANGLE_FRC_CONSTANT_TYPE = DEFAULT_ANGLE_FRC_CONSTANT, - temperature: FloatQuantity["kelvin"] = 298.15 * unit.kelvin, + angle: unit.Quantity, + force_constant: unit.Quantity = DEFAULT_ANGLE_FRC_CONSTANT, + temperature: unit.Quantity = 298.15 * unit.kelvin, ) -> bool: """ Check whether the chosen angle is less than 10 kT from 0 or pi radians @@ -246,8 +244,8 @@ def check_angle_not_flat( RT = 8.31445985 * 0.001 * temp_kelvin # check if angle is <10kT from 0 or 180 - check1 = 0.5 * frc_const * np.power((angle - 0.0), 2) - check2 = 0.5 * frc_const * np.power((angle - np.pi), 2) + check1 = 0.5 * frc_const * np.power((angle_rads - 0.0), 2) + check2 = 0.5 * frc_const * np.power((angle_rads - np.pi), 2) ang_check_1 = check1 / RT ang_check_2 = check2 / RT if ang_check_1 < 10.0 or ang_check_2 < 10.0: @@ -256,9 +254,9 @@ def check_angle_not_flat( def check_dihedral_bounds( - dihedral: FloatQuantity["radians"], - lower_cutoff: FloatQuantity["radians"] = 2.618 * unit.radians, - upper_cutoff: FloatQuantity["radians"] = -2.618 * unit.radians, + dihedral: unit.Quantity, + lower_cutoff: unit.Quantity = 2.618 * unit.radians, + upper_cutoff: unit.Quantity = -2.618 * unit.radians, ) -> bool: """ Check that a dihedral does not exceed the bounds set by @@ -285,11 +283,10 @@ def check_dihedral_bounds( def check_angular_variance( - angles: ArrayQuantity["radians"], - width: FloatQuantity["radians"], - upper_bound: FloatQuantity["radians"], - lower_bound: FloatQuantity["radians"], - width: FloatQuantity["radians"], + angles: unit.Quantity, + upper_bound: unit.Quantity, + lower_bound: unit.Quantity, + width: unit.Quantity, ) -> bool: """ Check that the variance of a list of ``angles`` does not exceed @@ -299,12 +296,13 @@ def check_angular_variance( ---------- angles : ArrayLike[unit.Quantity] An array of angles in units compatible with radians. - upper_bound: FloatQuantity['radians'] - The upper bound in the angle range. - lower_bound: FloatQuantity['radians'] - The lower bound in the angle range. + upper_bound: unit.Quantity + The upper bound in the angle range in radians compatible units. + lower_bound: unit.Quantity + The lower bound in the angle range in radians compatible units. width : unit.Quantity - The width to check the variance against, in units compatible with radians. + The width to check the variance against, in units compatible with + radians. Returns ------- diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py index 3ad9d0aa6..9c288515d 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py @@ -44,7 +44,7 @@ def get_periodic_boresch_energy_function( def get_custom_compound_bond_force( - n_particles: int = 6, energy_function: str = BORESCH_ENERGY_FUNCTION + energy_function: str, n_particles: int = 6, ): """ Return an OpenMM CustomCompoundForce diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py index a3fe777d3..2b8898a22 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py @@ -15,7 +15,6 @@ * Add Periodic Torsion Boresch class """ import abc -from typing import Optional, Union, Callable import numpy as np import openmm @@ -31,11 +30,16 @@ from openff.units import unit from gufe.settings.models import SettingsBaseModel -from openfe.protocols.openmm_utils.omm_forces import ( + +from openfe.protocols.openmm_utils.restraints.geometry import ( + BaseRestraintGeometry, + DistanceRestraintGeometry, + BoreschRestraintGeometry +) +from .omm_forces import ( get_custom_compound_bond_force, add_force_in_separate_group, get_boresch_energy_function, - get_periodic_boresch_energy_function, ) @@ -49,7 +53,8 @@ class RestraintParameterState(GlobalParameterState): ---------- parameters_name_suffix : Optional[str] If specified, the state will control a modified version of the parameter - ``lambda_restraints_{parameters_name_suffix}` instead of just ``lambda_restraints``. + ``lambda_restraints_{parameters_name_suffix}` instead of just + ``lambda_restraints``. lambda_restraints : Optional[float] The strength of the restraint. If defined, must be between 0 and 1. @@ -66,7 +71,8 @@ class RestraintParameterState(GlobalParameterState): def lambda_restraints(self, instance, new_value): if new_value is not None and not (0.0 <= new_value <= 1.0): errmsg = ( - "lambda_restraints must be between 0.0 and 1.0, " f"got {new_value}" + "lambda_restraints must be between 0.0 and 1.0 " + f"and got {new_value}" ) raise ValueError(errmsg) # Not crashing out on None to match upstream behaviour @@ -101,11 +107,19 @@ def _verify_geometry(self, geometry): pass @abc.abstractmethod - def add_force(self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry): + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: BaseRestraintGeometry + ): pass @abc.abstractmethod - def get_standard_state_correction(self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry): + def get_standard_state_correction( + self, + thermodynamic_state: ThermodynamicState, + geometry: BaseRestraintGeometry + ): pass @abc.abstractmethod @@ -118,8 +132,8 @@ def _verify_geometry(self, geometry: BaseRestraintGeometry): if len(geometry.host_atoms) != 1 or len(geometry.guest_atoms) != 1: errmsg = ( "host_atoms and guest_atoms must only include a single index " - f"each, got {len(host_atoms)} and " - f"{len(guest_atoms)} respectively." + f"each, got {len(geometry.host_atoms)} and " + f"{len(geometry.guest_atoms)} respectively." ) raise ValueError(errmsg) super()._verify_geometry(geometry) @@ -127,19 +141,25 @@ def _verify_geometry(self, geometry: BaseRestraintGeometry): class BaseRadiallySymmetricRestraintForce(BaseHostGuestRestraints): def _verify_inputs(self) -> None: - if not isinstance(self.settings, BaseDistanceRestraintSettings): + if not isinstance(self.settings, DistanceRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" raise ValueError(errmsg) - def _verify_geometry(self, geometry: DistanceRestraintGeometry) + def _verify_geometry(self, geometry: DistanceRestraintGeometry): if not isinstance(geometry, DistanceRestraintGeometry): errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) - def add_force(self, thermodynamic_state: ThermodynamicState, geometry: DistanceRestraintGeometry) -> None: + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: DistanceRestraintGeometry + ) -> None: self._verify_geometry(geometry) force = self._get_force(geometry) - force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) + force.setUsesPeriodicBoundaryConditions( + thermodynamic_state.is_periodic + ) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system add_force_in_separate_group(system, force) @@ -162,9 +182,13 @@ def _get_force(self, geometry: DistanceRestraintGeometry): raise NotImplementedError("only implemented in child classes") -class HarmonicBondRestraint(BaseRadiallySymmetricRestraintForce, SingleBondMixin): +class HarmonicBondRestraint( + BaseRadiallySymmetricRestraintForce, SingleBondMixin +): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: - spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( spring_constant=spring_constant, restrained_atom_index1=geometry.host_atoms[0], @@ -173,10 +197,16 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: ) -class FlatBottomBondRestraint(BaseRadiallySymmetricRestraintForce, SingleBondMixin): +class FlatBottomBondRestraint( + BaseRadiallySymmetricRestraintForce, SingleBondMixin +): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: - spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm( + geometry.well_radius + ).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, @@ -188,7 +218,9 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: class CentroidHarmonicRestraint(BaseRadiallySymmetricRestraintForce): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: - spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( spring_constant=spring_constant, restrained_atom_index1=geometry.host_atoms, @@ -199,9 +231,13 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: - spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) - return FlatBottomRestraintBondForce( + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm( + geometry.well_radius + ).value_in_unit_system(omm_unit.md_unit_system) + return FlatBottomRestraintForce( spring_constant=spring_constant, well_radius=well_radius, restrained_atom_index1=geometry.host_atoms, @@ -221,10 +257,16 @@ def _verify_geometry(self, geometry: BoreschRestraintGeometry): errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) - def add_force(self, thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry) -> None: - _verify_geometry(geometry) + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: BoreschRestraintGeometry + ) -> None: + self._verify_geometry(geometry) force = self._get_force(geometry) - force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) + force.setUsesPeriodicBoundaryConditions( + thermodynamic_state.is_periodic + ) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system add_force_in_separate_group(system, force) @@ -236,7 +278,7 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: ) force = get_custom_compound_bond_force( - n_particles=6, energy_function=efunc + energy_function=efunc, n_particles=6, ) param_values = [] @@ -256,7 +298,9 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: 'phi_C0': geometry.phi_C0, } for key, val in parameter_dict.items(): - param_values.append(to_openmm(val).value_in_unit_system(omm_unit.md_unit_system)) + param_values.append( + to_openmm(val).value_in_unit_system(omm_unit.md_unit_system) + ) force.addPerBondParameter(key) force.addGlobalParameter(self.controlling_parameter_name, 1.0) @@ -264,7 +308,9 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: return force def get_standard_state_correction( - self, thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry + self, + thermodynamic_state: ThermodynamicState, + geometry: BoreschRestraintGeometry ) -> unit.Quantity: self._verify_geometry(geometry) @@ -279,14 +325,16 @@ def get_standard_state_correction( # restraint energies K_r = self.settings.K_r.to('kilojoule_per_mole / nm ** 2') K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole / radians ** 2') - k_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole / radians ** 2') + K_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole / radians ** 2') K_phiA = self.settings.K_phiA.to('kilojoule_per_mole / radians ** 2') K_phiB = self.settings.K_phiB.to('kilojoule_per_mole / radians ** 2') K_phiC = self.settings.K_phiC.to('kilojoule_per_mole / radians ** 2') numerator1 = 8.0 * (np.pi**2) * StandardV denum1 = (r_aA0**2) * sin_thetaA0 * sin_thetaB0 - numerator2 = np.sqrt(K_r * K_thetaA * K_thetaB * K_phiA * K_phiB * K_phiC) + numerator2 = np.sqrt( + K_r * K_thetaA * K_thetaB * K_phiA * K_phiB * K_phiC + ) denum2 = (2.0 * np.pi * kt)**3 dG = -kt * np.log((numerator1/denum1) * (numerator2/denum2)) From fe1308ee4beff476e6d5e0cb967991fa538cb1c1 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 14 Dec 2024 02:00:19 +0000 Subject: [PATCH 23/33] docstring drive --- .../restraints/openmm/omm_forces.py | 13 +++++ .../restraints/openmm/omm_restraints.py | 48 ++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py index 9c288515d..52cbbec98 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py @@ -14,6 +14,19 @@ def get_boresch_energy_function( control_parameter: str, ) -> str: + """ + Return a Boresch-style energy function for a CustomCompoundForce. + + Parameters + ---------- + control_parameter : str + A string for the lambda scaling control parameter + + Returns + ------- + str + The energy function string. + """ energy_function = ( f"{control_parameter} * E; " "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py index 2b8898a22..18f9f2f34 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py @@ -96,14 +96,21 @@ def __init__( controlling_parameter_name: str = "lambda_restraints", ): self.settings = restraint_settings + self.controlling_parameter_name = controlling_parameter_name self._verify_settings() @abc.abstractmethod def _verify_settings(self): + """ + Method for validating the settings passed on object construction. + """ pass @abc.abstractmethod def _verify_geometry(self, geometry): + """ + Method for validating that the geometry object passed is correct. + """ pass @abc.abstractmethod @@ -112,6 +119,18 @@ def add_force( thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry ): + """ + Method for in-place adding a force to the System of a + ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + """ pass @abc.abstractmethod @@ -119,7 +138,24 @@ def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry - ): + ) -> unit.Quantity: + """ + Get the standard state correction for the Force. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + """ pass @abc.abstractmethod @@ -304,7 +340,15 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: force.addPerBondParameter(key) force.addGlobalParameter(self.controlling_parameter_name, 1.0) - force.addBond(geometry.host_atoms + geometry.guest_atoms, param_values) + atoms = [ + geometry.host_atoms[2], + geometry.host_atoms[1], + geometry.host_atoms[0], + geometry.guest_atoms[0], + geometry.guest_atoms[1], + geometry.guest_atoms[2], + ] + force.addBond(atoms, param_values) return force def get_standard_state_correction( From d71b9616055f95c4211bc460afead1446430115e Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sun, 15 Dec 2024 22:59:44 +0000 Subject: [PATCH 24/33] Migrate to restraint_utils --- .../restraints/geometry/harmonic.py | 85 ---- .../__init__.py | 0 .../geometry/__init__.py | 0 .../geometry/base.py | 3 + .../geometry/boresch.py | 289 +++++++++++--- .../geometry/flatbottom.py | 51 ++- .../restraint_utils/geometry/harmonic.py | 144 +++++++ .../geometry/utils.py | 83 ++-- .../openmm/__init__.py | 0 .../openmm/omm_forces.py | 22 +- .../openmm/omm_restraints.py | 371 ++++++++++++++++-- openfe/tests/protocols/restraints/__init__.py | 0 .../restraints/test_geometry_base.py | 25 ++ .../restraints/test_omm_restraints.py | 31 ++ .../restraints/test_openmm_forces.py | 115 ++++++ 15 files changed, 1000 insertions(+), 219 deletions(-) delete mode 100644 openfe/protocols/openmm_utils/restraints/geometry/harmonic.py rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/__init__.py (100%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/__init__.py (100%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/base.py (94%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/boresch.py (74%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/flatbottom.py (58%) create mode 100644 openfe/protocols/restraint_utils/geometry/harmonic.py rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/utils.py (88%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/openmm/__init__.py (100%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/openmm/omm_forces.py (85%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/openmm/omm_restraints.py (50%) create mode 100644 openfe/tests/protocols/restraints/__init__.py create mode 100644 openfe/tests/protocols/restraints/test_geometry_base.py create mode 100644 openfe/tests/protocols/restraints/test_omm_restraints.py create mode 100644 openfe/tests/protocols/restraints/test_openmm_forces.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py deleted file mode 100644 index 770f86bcb..000000000 --- a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py +++ /dev/null @@ -1,85 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -""" -Restraint Geometry classes - -TODO ----- -* Add relevant duecredit entries. -""" -import pathlib -from typing import Union, Optional -from openmm import app -from openff.units import unit -import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds -from rdkit import Chem - -from .base import HostGuestRestraintGeometry -from .utils import _get_central_atom_idx - - -class DistanceRestraintGeometry(HostGuestRestraintGeometry): - """ - A geometry class for a distance restraint between two groups of atoms. - """ - - def get_distance(self, topology, coordinates) -> unit.Quantity: - u = mda.Universe(topology, coordinates) - ag1 = u.atoms[self.host_atoms] - ag2 = u.atoms[self.guest_atoms] - bond = calc_bonds( - ag1.center_of_mass(), ag2.center_of_mass(), box=u.atoms.dimensions - ) - # convert to float so we avoid having a np.float64 - return float(bond) * unit.angstrom - - -def _get_selection(universe, atom_list, selection): - if atom_list is None: - if selection is None: - raise ValueError( - "one of either the atom lists or selections must be defined" - ) - - ag = universe.select_atoms(selection) - else: - ag = universe.atoms[atom_list] - - return ag - - -def get_distance_restraint( - topology: Union[str, app.Topology], - trajectory: pathlib.Path, - topology_format: Optional[str] = None, - host_atoms: Optional[list[int]] = None, - guest_atoms: Optional[list[int]] = None, - host_selection: Optional[str] = None, - guest_selection: Optional[str] = None, -) -> DistanceRestraintGeometry: - u = mda.Universe(topology, trajectory, topology_format=topology_format) - - guest_ag = _get_selection(u, guest_atoms, guest_selection) - guest_atoms = [a.ix for a in guest_ag] - host_ag = _get_selection(u, host_atoms, host_selection) - host_atoms = [a.ix for a in host_ag] - - return DistanceRestraintGeometry( - guest_atoms=guest_atoms, host_atoms=host_atoms - ) - - -def get_molecule_centers_restraint( - molA_rdmol: Chem.Mol, - molB_rdmol: Chem.Mol, - molA_idxs: list[int], - molB_idxs: list[int], -): - # We assume that the mol idxs are ordered - centerA = molA_idxs[_get_central_atom_idx(molA_rdmol)] - centerB = molB_idxs[_get_central_atom_idx(molB_rdmol)] - - return DistanceRestraintGeometry( - guest_atoms=[centerA], host_atoms=[centerB] - ) diff --git a/openfe/protocols/openmm_utils/restraints/__init__.py b/openfe/protocols/restraint_utils/__init__.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/__init__.py rename to openfe/protocols/restraint_utils/__init__.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry/__init__.py b/openfe/protocols/restraint_utils/geometry/__init__.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/geometry/__init__.py rename to openfe/protocols/restraint_utils/geometry/__init__.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry/base.py b/openfe/protocols/restraint_utils/geometry/base.py similarity index 94% rename from openfe/protocols/openmm_utils/restraints/geometry/base.py rename to openfe/protocols/restraint_utils/geometry/base.py index 5db9225ac..0ca6ae200 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/base.py +++ b/openfe/protocols/restraint_utils/geometry/base.py @@ -12,6 +12,9 @@ class BaseRestraintGeometry(BaseModel, abc.ABC): + """ + A base class for a restraint geometry. + """ class Config: arbitrary_types_allowed = True diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/restraint_utils/geometry/boresch.py similarity index 74% rename from openfe/protocols/openmm_utils/restraints/geometry/boresch.py rename to openfe/protocols/restraint_utils/geometry/boresch.py index 0d6806611..6e740f48d 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/restraint_utils/geometry/boresch.py @@ -14,6 +14,7 @@ import openmm from openff.units import unit +from openff.models.types import FloatQuantity import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals @@ -51,107 +52,137 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. """ + r_aA0: FloatQuantity['nanometer'] + """ + The equilibrium distance between H0 and G0. + """ + theta_A0: FloatQuantity['radians'] + """ + The equilibrium angle value between H1, H0, and G0. + """ + theta_B0: FloatQuantity['radians'] + """ + The equilibrium angle value between H0, G0, and G1. + """ + phi_A0: FloatQuantity['radians'] + """ + The equilibrium dihedral value between H2, H1, H0, and G0. + """ + phi_B0: FloatQuantity['radians'] + + """ + The equilibrium dihedral value between H1, H0, G0, and G1. + """ + phi_C0: FloatQuantity['radians'] + + """ + The equilibrium dihedral value between H0, G0, G1, and G2. + """ def get_bond_distance( self, - topology: Union[str, pathlib.Path, openmm.app.Topology], - coordinates: Union[str, pathlib.Path, npt.NDArray], + universe: mda.Universe, ) -> unit.Quantity: """ Get the H0 - G0 distance. Parameters ---------- - topology : Union[str, openmm.app.Topology] - coordinates : Union[str, npt.NDArray] - A coordinate file or NDArray in frame-atom-coordinate - order in Angstrom. + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + bond : unit.Quantity + The H0-G0 distance. """ - u = mda.Universe( - topology, - coordinates, - format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology), + at1 = universe.atoms[self.host_atoms[0]] + at2 = universe.atoms[self.guest_atoms[0]] + bond = calc_bonds( + at1.position, + at2.position, + box=universe.atoms.dimensions ) - at1 = u.atoms[self.host_atoms[0]] - at2 = u.atoms[self.guest_atoms[0]] - bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) # convert to float so we avoid having a np.float64 return float(bond) * unit.angstrom def get_angles( self, - topology: Union[str, pathlib.Path, openmm.app.Topology], - coordinates: Union[str, pathlib.Path, npt.NDArray], - ) -> unit.Quantity: + universe: mda.Universe, + ) -> tuple[unit.Quantity, unit.Quantity]: """ Get the H1-H0-G0, and H0-G0-G1 angles. Parameters ---------- - topology : Union[str, openmm.app.Topology] - coordinates : Union[str, npt.NDArray] - A coordinate file or NDArray in frame-atom-coordinate - order in Angstrom. + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + angleA : unit.Quantity + The H1-H0-G0 angle. + angleB : unit.Quantity + The H0-G0-G1 angle. """ - u = mda.Universe( - topology, - coordinates, - format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology), - ) - at1 = u.atoms[self.host_atoms[1]] - at2 = u.atoms[self.host_atoms[0]] - at3 = u.atoms[self.guest_atoms[0]] - at4 = u.atoms[self.guest_atoms[1]] + at1 = universe.atoms[self.host_atoms[1]] + at2 = universe.atoms[self.host_atoms[0]] + at3 = universe.atoms[self.guest_atoms[0]] + at4 = universe.atoms[self.guest_atoms[1]] angleA = calc_angles( - at1.position, at2.position, at3.position, u.atoms.dimensions + at1.position, + at2.position, + at3.position, + box=universe.atoms.dimensions ) angleB = calc_angles( - at2.position, at3.position, at4.position, u.atoms.dimensions + at2.position, + at3.position, + at4.position, + box=universe.atoms.dimensions ) return angleA, angleB def get_dihedrals( self, - topology: Union[str, pathlib.Path, openmm.app.Topology], - coordinates: Union[str, pathlib.Path, npt.NDArray], - ) -> unit.Quantity: + universe: mda.Universe, + ) -> tuple[unit.Quantity, unit.Quantity, unit.Quantity]: """ Get the H2-H1-H0-G0, H1-H0-G0-G1, and H0-G0-G1-G2 dihedrals. Parameters ---------- - topology : Union[str, openmm.app.Topology] - coordinates : Union[str, npt.NDArray] - A coordinate file or NDArray in frame-atom-coordinate - order in Angstrom. + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + dihA : unit.Quantity + The H2-H1-H0-G0 angle. + dihB : unit.Quantity + The H1-H0-G0-G1 angle. + dihC : unit.Quantity + The H0-G0-G1-G2 angle. """ - u = mda.Universe( - topology, - coordinates, - format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology), - ) - at1 = u.atoms[self.host_atoms[2]] - at2 = u.atoms[self.host_atoms[1]] - at3 = u.atoms[self.host_atoms[0]] - at4 = u.atoms[self.guest_atoms[0]] - at5 = u.atoms[self.guest_atoms[1]] - at6 = u.atoms[self.guest_atoms[2]] + at1 = universe.atoms[self.host_atoms[2]] + at2 = universe.atoms[self.host_atoms[1]] + at3 = universe.atoms[self.host_atoms[0]] + at4 = universe.atoms[self.guest_atoms[0]] + at5 = universe.atoms[self.guest_atoms[1]] + at6 = universe.atoms[self.guest_atoms[2]] dihA = calc_dihedrals( at1.position, at2.position, at3.position, at4.position, - box=u.dimensions + box=universe.dimensions ) dihB = calc_dihedrals( at2.position, at3.position, at4.position, at5.position, - box=u.dimensions + box=universe.dimensions ) dihC = calc_dihedrals( at3.position, at4.position, at5.position, at6.position, - box=u.dimensions + box=universe.dimensions ) return dihA, dihB, dihC @@ -307,7 +338,7 @@ def get_guest_atom_candidates( TODO ---- - Remember to update the RDMol with the last frame positions. + Should the RDMol have a specific frame position? """ u = mda.Universe( topology, @@ -663,6 +694,66 @@ def _find_host_angle( return None +def _get_restraint_distances( + atomgroup: mda.AtomGroup +) -> tuple[unit.Quantity]: + """ + Get the bond, angle, and dihedral distances for an input atomgroup + defining the six atoms for a Boresch-like restraint. + + The atoms must be in the order of H0, H1, H2, G0, G1, G2. + + Parameters + ---------- + atomgroup : mda.AtomGroup + An AtomGroup defining the restrained atoms in order. + + Returns + ------- + bond : unit.Quantity + The H0-G0 bond value. + angle1 : unit.Quantity + The H1-H0-G0 angle value. + angle2 : unit.Quantity + The H0-G0-G1 angle value. + dihed1 : unit.Quantity + The H2-H1-H0-G0 dihedral value. + dihed2 : unit.Quantity + The H1-H0-G0-G1 dihedral value. + dihed3 : unit.Quantity + The H0-G0-G1-G2 dihedral value. + """ + + bond = calc_bonds( + atomgroup.atoms[0].position, + atomgroup.atoms[3], + box=atomgroup.dimensions + ) + + angles = [] + for idx_set in [[1, 0, 3], [0, 3, 4]]: + angle = calc_angles( + atomgroup.atoms[idx_set[0]].position, + atomgroup.atoms[idx_set[1]].position, + atomgroup.atoms[idx_set[2]].position, + box=atomgroup.dimensions, + ) + angles.append(angle * unit.radians) + + dihedrals = [] + for idx_set in [[2, 1, 0, 3], [1, 0, 3, 4], [0, 3, 4, 5]]: + dihed = calc_dihedrals( + atomgroup.atoms[idx_set[0]].position, + atomgroup.atoms[idx_set[1]].position, + atomgroup.atoms[idx_set[2]].position, + atomgroup.atoms[idx_set[3]].position, + box=atomgroup.dimensions, + ) + dihedrals.append(dihed * unit.radians) + + return bond, angles[0], angles[1], dihedrals[0], dihedrals[1], dihedrals[2] + + def find_boresch_restraint( topology: Union[str, pathlib.Path, openmm.app.Topology], trajectory: Union[str, pathlib.Path], @@ -682,15 +773,60 @@ def find_boresch_restraint( temperature: unit.Quantity = 298.15 * unit.kelvin, ) -> BoreschRestraintGeometry: """ - Find suitable Boresch-style restraints between a host and guest entity. + Find suitable Boresch-style restraints between a host and guest entity + based on the approach of Baumann et al. [1] with some modifications. Parameters ---------- - ... + topology : Union[str, pathlib.Path, openmm.app.Topology] + A topology of the system. + trajectory : Union[str, pathlib.Path] + A path to a coordinate trajectory file. + guest_rdmol : Chem.Mol + An RDKit Mol for the guest molecule. + guest_idxs : list[int] + Indices in the topology for the guest molecule. + host_idxs : list[int] + Indices in the topology for the host molecule. + guest_restraint_atoms_idxs : Optional[list[int]] + User selected indices of the guest molecule itself (i.e. indexed + starting a 0 for the guest molecule). This overrides the + restraint search and a restraint using these indices will + be retruned. Must be defined alongside ``host_restraint_atoms_idxs``. + host_restraint_atoms_idxs : Optional[list[int]] + User selected indices of the host molecule itself (i.e. indexed + starting a 0 for the hosts molecule). This overrides the + restraint search and a restraint using these indices will + be returnned. Must be defined alongside ``guest_restraint_atoms_idxs``. + host_selection : str + An MDAnalysis selection string to sub-select the host atoms. + dssp_filter : bool + Whether or not to filter the host atoms by their secondary structure. + rmsf_cutoff : unit.Quantity + The cutoff value for atom root mean square fluction. Atoms with RMSF + values above this cutoff will be disregarded. + Must be in units compatible with nanometer. + host_min_distance : unit.Quantity + The minimum distance between any host atom and the guest G0 atom. + Must be in units compatible with nanometer. + host_max_distance : unit.Quantity + The maximum distance between any host atom and the guest G0 atom. + Must be in units compatible with nanometer. + angle_force_constant : unit.Quantity + The force constant for the G1-G0-H0 and G0-H0-H1 angles. Must be + in units compatible with kilojoule / mole / radians ** 2. + temperature : unit.Quantity + The system temperature in units compatible with Kelvin. Returns ------- - ... + BoreschRestraintGeometry + An object defining the parameters of the Boresch-like restraint. + + References + ---------- + [1] Baumann, Hannah M., et al. "Broadening the scope of binding free energy + calculations using a Separated Topologies approach." (2023). """ u = mda.Universe( topology, @@ -698,7 +834,6 @@ def find_boresch_restraint( format=_get_mda_coord_format(trajectory), topology_format=_get_mda_topology_format(topology), ) - u.trajectory[-1] # Work with the final frame if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): # fmt: skip # In this case assume the picked atoms were intentional / @@ -711,9 +846,23 @@ def find_boresch_restraint( host_angle = [ at.ix for at in host_ag.atoms[host_restraint_atoms_idxs] ] - # TODO sort out the return on this + + # Set the equilibrium values as those of the final frame + u.trajectory[-1] + atomgroup = u.atoms[host_angle + guest_angle] + bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( + atomgroup + ) + return BoreschRestraintGeometry( - host_atoms=host_angle, guest_atoms=guest_angle + host_atoms=host_angle, + guest_atoms=guest_angle, + r_aA0=bond, + theta_A0=ang1, + theta_B0=ang2, + phi_A0=dih1, + phi_B0=dih2, + phi_C0=dih3 ) if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): # fmt: skip @@ -772,6 +921,20 @@ def find_boresch_restraint( errmsg = "No suitable host atoms could be found" raise ValueError(errmsg) + # Set the equilibrium values as those of the final frame + u.trajectory[-1] + atomgroup = u.atoms[host_angle + guest_angle] + bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( + atomgroup + ) + return BoreschRestraintGeometry( - host_atoms=host_angle, guest_atoms=guest_angle + host_atoms=host_angle, + guest_atoms=guest_angle, + r_aA0=bond, + theta_A0=ang1, + theta_B0=ang2, + phi_A0=dih1, + phi_B0=dih2, + phi_C0=dih3 ) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py b/openfe/protocols/restraint_utils/geometry/flatbottom.py similarity index 58% rename from openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py rename to openfe/protocols/restraint_utils/geometry/flatbottom.py index c9007dd59..3b4599f56 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py +++ b/openfe/protocols/restraint_utils/geometry/flatbottom.py @@ -19,9 +19,10 @@ from .harmonic import ( DistanceRestraintGeometry, - _get_selection, ) +from .utils import _get_mda_topology_format, _get_mda_selection + class FlatBottomDistanceGeometry(DistanceRestraintGeometry): """ @@ -42,7 +43,6 @@ class COMDistanceAnalysis(AnalysisBase): group2 : MDANalysis.AtomGroup Atoms defining the second centroid. """ - _analysis_algorithm_is_parallelizable = False def __init__(self, group1, group2, **kwargs): @@ -68,18 +68,55 @@ def _conclude(self): def get_flatbottom_distance_restraint( topology: Union[str, app.Topology], - trajectory: pathlib.Path, - topology_format: Optional[str] = None, + trajectory: Union[str, pathlib.Path], host_atoms: Optional[list[int]] = None, guest_atoms: Optional[list[int]] = None, host_selection: Optional[str] = None, guest_selection: Optional[str] = None, padding: unit.Quantity = 0.5 * unit.nanometer, ) -> FlatBottomDistanceGeometry: - u = mda.Universe(topology, trajectory, topology_format=topology_format) + """ + Get a FlatBottomDistanceGeometry by analyzing the COM distance + change between two sets of atoms. + + The ``well_radius`` is defined as the maximum COM distance plus + ``padding``. + + Parameters + ---------- + topology : Union[str, app.Topology] + A topology defining the system. + trajectory : Union[str, pathlib.Path] + A coordinate trajectory for the system. + host_atoms : Optional[list[int]] + A list of host atoms indices. Either ``host_atoms`` or + ``host_selection`` must be defined. + guest_atoms : Optional[list[int]] + A list of guest atoms indices. Either ``guest_atoms`` or + ``guest_selection`` must be defined. + host_selection : Optional[str] + An MDAnalysis selection string to define the host atoms. + Either ``host_atoms`` or ``host_selection`` must be defined. + guest_selection : Optional[str] + An MDAnalysis selection string to define the guest atoms. + Either ``guest_atoms`` or ``guest_selection`` must be defined. + padding : unit.Quantity + A padding value to add to the ``well_radius`` definition. + Must be in units compatible with nanometers. + + Returns + ------- + FlatBottomDistanceGeometry + An object defining a flat bottom restraint geometry. + """ + u = mda.Universe( + topology, + trajectory, + topology_format=_get_mda_topology_format(topology) + ) - guest_ag = _get_selection(u, guest_atoms, guest_selection) - host_ag = _get_selection(u, host_atoms, host_selection) + guest_ag = _get_mda_selection(u, guest_atoms, guest_selection) + host_ag = _get_mda_selection(u, host_atoms, host_selection) com_dists = COMDistanceAnalysis(guest_ag, host_ag) com_dists.run() diff --git a/openfe/protocols/restraint_utils/geometry/harmonic.py b/openfe/protocols/restraint_utils/geometry/harmonic.py new file mode 100644 index 000000000..197a8bc44 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/harmonic.py @@ -0,0 +1,144 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import pathlib +from typing import Union, Optional +from openmm import app +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds +from rdkit import Chem + +from .base import HostGuestRestraintGeometry +from .utils import ( + get_central_atom_idx, + _get_mda_selection, + _get_mda_topology_format, +) + + +class DistanceRestraintGeometry(HostGuestRestraintGeometry): + """ + A geometry class for a distance restraint between two groups of atoms. + """ + + def get_distance(self, universe: mda.Universe) -> unit.Quantity: + """ + Get the center of mass distance between the host and guest atoms. + + Parameters + ---------- + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + bond : unit.Quantity + The center of mass distance between the two groups of atoms. + """ + ag1 = universe.atoms[self.host_atoms] + ag2 = universe.atoms[self.guest_atoms] + bond = calc_bonds( + ag1.center_of_mass(), + ag2.center_of_mass(), + box=universe.atoms.dimensions + ) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + +def get_distance_restraint( + topology: Union[str, pathlib.Path, app.Topology], + trajectory: Union[str, pathlib.Path], + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, +) -> DistanceRestraintGeometry: + """ + Get a DistanceRestraintGeometry between two groups of atoms. + + You can either select the groups by passing through a set of indices + or an MDAnalysis selection. + + Parameters + ---------- + topology : Union[str, pathlib.Path, app.Topology] + A path or object defining the system topology. + trajectory : Union[str, pathlib.Path] + Coordinates for the system. + host_atoms : Optional[list[int]] + A list of host atoms indices. Either ``host_atoms`` or + ``host_selection`` must be defined. + guest_atoms : Optional[list[int]] + A list of guest atoms indices. Either ``guest_atoms`` or + ``guest_selection`` must be defined. + host_selection : Optional[str] + An MDAnalysis selection string to define the host atoms. + Either ``host_atoms`` or ``host_selection`` must be defined. + guest_selection : Optional[str] + An MDAnalysis selection string to define the guest atoms. + Either ``guest_atoms`` or ``guest_selection`` must be defined. + + Returns + ------- + DistanceRestraintGeometry + An object that defines a distance restraint geometry. + """ + u = mda.Universe( + topology, + trajectory, + topology_format=_get_mda_topology_format(topology) + ) + + guest_ag = _get_mda_selection(u, guest_atoms, guest_selection) + guest_atoms = [a.ix for a in guest_ag] + host_ag = _get_mda_selection(u, host_atoms, host_selection) + host_atoms = [a.ix for a in host_ag] + + return DistanceRestraintGeometry( + guest_atoms=guest_atoms, host_atoms=host_atoms + ) + + +def get_molecule_centers_restraint( + molA_rdmol: Chem.Mol, + molB_rdmol: Chem.Mol, + molA_idxs: list[int], + molB_idxs: list[int], +): + """ + Get a DistanceRestraintGeometry between the central atoms of + two molecules. + + Parameters + ---------- + molA_rdmol : Chem.Mol + An RDKit Molecule for the first molecule. + molB_rdmol : Chem.Mol + An RDKit Molecule for the first molecule. + molA_idxs : list[int] + The indices of the first molecule in the system. Note we assume these + to be sorted in the same order as the input rdmol. + molB_idxs : list[int] + The indices of the first molecule in the system. Note we assume these + to be sorted in the same order as the input rdmol. + + Returns + ------- + DistanceRestraintGeometry + An object that defines a distance restraint geometry. + """ + # We assume that the mol idxs are ordered + centerA = molA_idxs[get_central_atom_idx(molA_rdmol)] + centerB = molB_idxs[get_central_atom_idx(molB_rdmol)] + + return DistanceRestraintGeometry( + guest_atoms=[centerA], host_atoms=[centerB] + ) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/restraint_utils/geometry/utils.py similarity index 88% rename from openfe/protocols/openmm_utils/restraints/geometry/utils.py rename to openfe/protocols/restraint_utils/geometry/utils.py index 7d6906650..4b734b410 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/restraint_utils/geometry/utils.py @@ -22,13 +22,59 @@ from MDAnalysis.analysis.rms import RMSF from MDAnalysis.lib.distances import minimize_vectors, capped_distance from MDAnalysis.coordinates.memory import MemoryReader +from MDAnalysis.transformations.nojump import NoJump -from openfe_analysis.transformations import Aligner, NoJump +from openfe_analysis.transformations import Aligner DEFAULT_ANGLE_FRC_CONSTANT = 83.68 * unit.kilojoule_per_mole / unit.radians**2 +def _get_mda_selection( + universe: mda.Universe, + atom_list: Optional[list[int]], + selection: Optional[str] +) -> mda.AtomGroup: + """ + Return an AtomGroup based on either a list of atom indices or an + mdanalysis string selection. + + Parameters + ---------- + universe : mda.Universe + The MDAnalysis Universe to get the AtomGroup from. + atom_list : Optional[list[int]] + A list of atom indices. + selection : Optional[str] + An MDAnalysis selection string. + + Returns + ------- + ag : mda.AtomGroup + An atom group selected from the inputs. + + Raises + ------ + ValueError + If both ``atom_list`` and ``selection`` are ``None`` + or are defined. + """ + if atom_list is None: + if selection is None: + raise ValueError( + "one of either the atom lists or selections must be defined" + ) + + ag = universe.select_atoms(selection) + else: + if selection is not None: + raise ValueError( + "both atom_list and selection cannot be defined together" + ) + ag = universe.atoms[atom_list] + return ag + + def _get_mda_coord_format( coordinates: Union[str, npt.NDArray] ) -> Optional[MemoryReader]: @@ -224,7 +270,8 @@ def check_angle_not_flat( angle : unit.Quantity The angle to check in units compatible with radians. force_constant : unit.Quantity - Force constant of the angle in units compatible with kilojoule_per_mole / radians ** 2. + Force constant of the angle in units compatible with + kilojoule_per_mole / radians ** 2. temperature : unit.Quantity The system temperature in units compatible with Kelvin. @@ -334,7 +381,6 @@ class FindHostAtoms(AnalysisBase): max_search_distance: unit.Quantity Maximum distance to filter atoms within. """ - _analysis_algorithm_is_parallelizable = False def __init__( @@ -372,34 +418,7 @@ def _conclude(self): self.results.host_idxs = np.array(self.results.host_idxs) -def find_host_atoms( - topology, trajectory, host_selection, guest_selection, cutoff -) -> mda.AtomGroup: - """ - Get an AtomGroup of the host atoms based on their distances from the guest atoms. - """ - u = mda.Universe(topology, trajectory) - - def _get_selection(selection): - """ - If it's a str, call select_atoms, if not a list of atom idxs - """ - if isinstance(selection, str): - ag = u.select_atoms(host_selection) - else: - ag = u.atoms[host_ag] - return ag - - host_ag = _get_selection(host_selection) - guest_ag = _get_selection(guest_selection) - - finder = FindHostAtoms(host_ag, guest_ag, cutoff) - finder.run() - - return u.atoms[list(finder.results.host_idxs)] - - -def get_local_rmsf(atomgroup: mda.AtomGroup): +def get_local_rmsf(atomgroup: mda.AtomGroup) -> unit.Quantity: """ Get the RMSF of an AtomGroup when aligned upon itself. @@ -416,7 +435,7 @@ def get_local_rmsf(atomgroup: mda.AtomGroup): copy_u = atomgroup.universe.copy() ag = copy_u.atoms[atomgroup.atoms.ix] - nojump = NoJump(ag) + nojump = NoJump() align = Aligner(ag) copy_u.trajectory.add_transformations(nojump, align) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/__init__.py b/openfe/protocols/restraint_utils/openmm/__init__.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/openmm/__init__.py rename to openfe/protocols/restraint_utils/openmm/__init__.py diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py b/openfe/protocols/restraint_utils/openmm/omm_forces.py similarity index 85% rename from openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py rename to openfe/protocols/restraint_utils/openmm/omm_forces.py index 52cbbec98..2947c8e03 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py +++ b/openfe/protocols/restraint_utils/openmm/omm_forces.py @@ -43,6 +43,20 @@ def get_boresch_energy_function( def get_periodic_boresch_energy_function( control_parameter: str, ) -> str: + """ + Return a Boresch-style energy function with a periodic torsion for a + CustomCompoundForce. + + Parameters + ---------- + control_parameter : str + A string for the lambda scaling control parameter + + Returns + ------- + str + The energy function string. + """ energy_function = ( f"{control_parameter} * E; " "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " @@ -104,8 +118,12 @@ def add_force_in_separate_group( Mostly reproduced from `Yank `_. """ available_force_groups = set(range(32)) - for force in system.getForces(): - available_force_groups.discard(force.getForceGroup()) + for existing_force in system.getForces(): + available_force_groups.discard(existing_force.getForceGroup()) + + if len(available_force_groups) == 0: + errmsg = "No available force groups could be found" + raise ValueError(errmsg) force.setForceGroup(min(available_force_groups)) system.addForce(force) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/restraint_utils/openmm/omm_restraints.py similarity index 50% rename from openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py rename to openfe/protocols/restraint_utils/openmm/omm_restraints.py index 18f9f2f34..c77b1cd0b 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/restraint_utils/openmm/omm_restraints.py @@ -31,7 +31,7 @@ from gufe.settings.models import SettingsBaseModel -from openfe.protocols.openmm_utils.restraints.geometry import ( +from openfe.protocols.restraint_utils.geometry import ( BaseRestraintGeometry, DistanceRestraintGeometry, BoreschRestraintGeometry @@ -87,22 +87,20 @@ class BaseHostGuestRestraints(abc.ABC): TODO ---- - Add some examples here. + Add some developer examples here. """ def __init__( self, restraint_settings: SettingsBaseModel, - controlling_parameter_name: str = "lambda_restraints", ): self.settings = restraint_settings - self.controlling_parameter_name = controlling_parameter_name - self._verify_settings() + self._verify_inputs() @abc.abstractmethod - def _verify_settings(self): + def _verify_inputs(self): """ - Method for validating the settings passed on object construction. + Method for validating that the inputs to the class are correct. """ pass @@ -117,10 +115,11 @@ def _verify_geometry(self, geometry): def add_force( self, thermodynamic_state: ThermodynamicState, - geometry: BaseRestraintGeometry + geometry: BaseRestraintGeometry, + controlling_parameter_name: str, ): """ - Method for in-place adding a force to the System of a + Method for in-place adding the Force to the System of a ThermodynamicState. Parameters @@ -130,6 +129,8 @@ def add_force( new force. geometry : BaseRestraintGeometry A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. """ pass @@ -140,7 +141,8 @@ def get_standard_state_correction( geometry: BaseRestraintGeometry ) -> unit.Quantity: """ - Get the standard state correction for the Force. + Get the standard state correction for the Force when + applied to the input ThermodynamicState. Parameters ---------- @@ -159,11 +161,23 @@ def get_standard_state_correction( pass @abc.abstractmethod - def _get_force(self, geometry: BaseRestraintGeometry): + def _get_force( + self, + geometry: BaseRestraintGeometry, + controlling_parameter_name: str, + ): + """ + Helper method to get the relevant OpenMM Force for this + class, given an input geometry. + """ pass class SingleBondMixin: + """ + A mixin to extend geometry checks for Forces that can only hold + a single atom. + """ def _verify_geometry(self, geometry: BaseRestraintGeometry): if len(geometry.host_atoms) != 1 or len(geometry.guest_atoms) != 1: errmsg = ( @@ -176,6 +190,12 @@ def _verify_geometry(self, geometry: BaseRestraintGeometry): class BaseRadiallySymmetricRestraintForce(BaseHostGuestRestraints): + """ + A base class for all radially symmetic Forces acting between + two sets of atoms. + + Must be subclassed. + """ def _verify_inputs(self) -> None: if not isinstance(self.settings, DistanceRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" @@ -189,10 +209,25 @@ def _verify_geometry(self, geometry: DistanceRestraintGeometry): def add_force( self, thermodynamic_state: ThermodynamicState, - geometry: DistanceRestraintGeometry + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str = "lambda_restraints", ) -> None: + """ + Method for in-place adding the Force to the System of the + given ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ self._verify_geometry(geometry) - force = self._get_force(geometry) + force = self._get_force(geometry, controlling_parameter_name) force.setUsesPeriodicBoundaryConditions( thermodynamic_state.is_periodic ) @@ -206,6 +241,24 @@ def get_standard_state_correction( thermodynamic_state: ThermodynamicState, geometry: DistanceRestraintGeometry, ) -> unit.Quantity: + """ + Get the standard state correction for the Force when + applied to the input ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + """ self._verify_geometry(geometry) force = self._get_force(geometry) corr = force.compute_standard_state_correction( @@ -214,14 +267,50 @@ def get_standard_state_correction( dg = corr * thermodynamic_state.kT return from_openmm(dg).to('kilojoule_per_mole') - def _get_force(self, geometry: DistanceRestraintGeometry): + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str + ): raise NotImplementedError("only implemented in child classes") class HarmonicBondRestraint( BaseRadiallySymmetricRestraintForce, SingleBondMixin ): - def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: + """ + A class to add a harmonic restraint between two atoms + in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.HarmonicRestraintBondForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the HarmonicRestraintBondForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + HarmonicRestraintBondForce + An OpenMM Force that applies a harmonic restraint between + two atoms. + """ spring_constant = to_openmm( self.settings.spring_constant ).value_in_unit_system(omm_unit.md_unit_system) @@ -229,14 +318,46 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant=spring_constant, restrained_atom_index1=geometry.host_atoms[0], restrained_atom_index2=geometry.guest_atoms[0], - controlling_parameter_name=self.controlling_parameter_name, + controlling_parameter_name=controlling_parameter_name, ) class FlatBottomBondRestraint( BaseRadiallySymmetricRestraintForce, SingleBondMixin ): - def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: + """ + A class to add a flat bottom restraint between two atoms + in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.FlatBottomRestraintBondForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the FlatBottomRestraintBondForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + FlatBottomRestraintBondForce + An OpenMM Force that applies a flat bottom restraint between + two atoms. + """ spring_constant = to_openmm( self.settings.spring_constant ).value_in_unit_system(omm_unit.md_unit_system) @@ -248,12 +369,44 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: well_radius=well_radius, restrained_atom_index1=geometry.host_atoms[0], restrained_atom_index2=geometry.guest_atoms[0], - controlling_parameter_name=self.controlling_parameter_name, + controlling_parameter_name=controlling_parameter_name, ) class CentroidHarmonicRestraint(BaseRadiallySymmetricRestraintForce): - def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: + """ + A class to add a harmonic restraint between the centroid of + two sets of atoms in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.HarmonicRestraintForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the HarmonicRestraintForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + HarmonicRestraintForce + An OpenMM Force that applies a harmonic restraint between + the centroid of two sets of atoms. + """ spring_constant = to_openmm( self.settings.spring_constant ).value_in_unit_system(omm_unit.md_unit_system) @@ -261,12 +414,44 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant=spring_constant, restrained_atom_index1=geometry.host_atoms, restrained_atom_index2=geometry.guest_atoms, - controlling_parameter_name=self.controlling_parameter_name, + controlling_parameter_name=controlling_parameter_name, ) class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): - def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: + """ + A class to add a flat bottom restraint between the centroid + of two sets of atoms in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.FlatBottomRestraintForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the FlatBottomRestraintForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + FlatBottomRestraintForce + An OpenMM Force that applies a flat bottom restraint between + the centroid of two sets of atoms. + """ spring_constant = to_openmm( self.settings.spring_constant ).value_in_unit_system(omm_unit.md_unit_system) @@ -278,17 +463,80 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: well_radius=well_radius, restrained_atom_index1=geometry.host_atoms, restrained_atom_index2=geometry.guest_atoms, - controlling_parameter_name=self.controlling_parameter_name, + controlling_parameter_name=controlling_parameter_name, ) class BoreschRestraint(BaseHostGuestRestraints): - def _verify_settings(self) -> None: + """ + A class to add a Boresch-like restraint between six atoms, + + The restraint is defined as a + :class:`openmmtools.forces.CustomCompoundForce` with the + following energy function: + + lambda_control_parameter * E; + E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 + + (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 + + (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + + (K_phiC/2)*dphi_C^2; + dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); + dA = dihedral(p1,p2,p3,p4) - phi_A0; + dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); + dB = dihedral(p2,p3,p4,p5) - phi_B0; + dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); + dC = dihedral(p3,p4,p5,p6) - phi_C0; + + Where p1, p2, p3, p4, p5, p6 represent host atoms 2, 1, 0, + and guest atoms 0, 1, 2 respectively. + + ``lambda_control_parameter`` is a control parameter for + scaling the Force. + + ``K_r`` is defined as the bond spring constant between + p3 and p4 and must be provided in the settings in units + compatible with kilojoule / mole. + + ``r_aA0`` is the equilibrium distance of the bond between + p3 and p4. This must be provided by the Geometry class in + units compatiblle with nanometer. + + ``K_thetaA`` and ``K_thetaB`` are the spring constants for the angles + formed by (p2, p3, p4) and (p3, p4, p5). They must be provided in the + settings in units compatible with kilojoule / mole / radians**2. + + ``theta_A0`` and ``theta_B0`` are the equilibrium values for angles + (p2, p3, p4) and (p3, p4, p5). They must be provided by the + Geometry class in units compatible with radians. + + ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium constants + for the dihedrals formed by (p1, p2, p3, p4), (p2, p3, p4, p5), and + (p3, p4, p5, p6). They must be provided in the settings in units + compatible with kilojoule / mole / radians ** 2. + + ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium values + for the dihedrals formed by (p1, p2, p3, p4), (p2, p3, p4, p5), and + (p3, p4, p5, p6). They must be provided in the Geometry class in + units compatible with radians. + + + Notes + ----- + * Settings must define the ``K_r`` (d) + """ + def _verify_inputs(self) -> None: + """ + Method for validating that the geometry object is correct. + """ if not isinstance(self.settings, BoreschRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" raise ValueError(errmsg) def _verify_geometry(self, geometry: BoreschRestraintGeometry): + """ + Method for validating that the geometry object is correct. + """ if not isinstance(geometry, BoreschRestraintGeometry): errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) @@ -296,10 +544,28 @@ def _verify_geometry(self, geometry: BoreschRestraintGeometry): def add_force( self, thermodynamic_state: ThermodynamicState, - geometry: BoreschRestraintGeometry + geometry: BoreschRestraintGeometry, + controlling_parameter_name: str, ) -> None: + """ + Method for in-place adding the Boresch CustomCompoundForce + to the System of the given ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ self._verify_geometry(geometry) - force = self._get_force(geometry) + force = self._get_force( + geometry, + controlling_parameter_name, + ) force.setUsesPeriodicBoundaryConditions( thermodynamic_state.is_periodic ) @@ -308,10 +574,29 @@ def add_force( add_force_in_separate_group(system, force) thermodynamic_state.system = system - def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: - efunc = get_boresch_energy_function( - self.controlling_parameter_name, - ) + def _get_force( + self, + geometry: BoreschRestraintGeometry, + controlling_parameter_name: str + ) -> openmm.CustomCompoundBondForce: + """ + Get the CustomCompoundForce with a Boresch-like energy function + given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + CustomCompoundForce + An OpenMM CustomCompoundForce that applies a Boresch-like + restraint between 6 atoms. + """ + efunc = get_boresch_energy_function(controlling_parameter_name) force = get_custom_compound_bond_force( energy_function=efunc, n_particles=6, @@ -339,7 +624,7 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: ) force.addPerBondParameter(key) - force.addGlobalParameter(self.controlling_parameter_name, 1.0) + force.addGlobalParameter(controlling_parameter_name, 1.0) atoms = [ geometry.host_atoms[2], geometry.host_atoms[1], @@ -356,6 +641,32 @@ def get_standard_state_correction( thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry ) -> unit.Quantity: + """ + Get the standard state correction for the Boresch-like + restraint when applied to the input ThermodynamicState. + + The correction is calculated using the analytical method + as defined by Boresch et al. [1] + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + + References + ---------- + [1] Boresch S, Tettinger F, Leitgeb M, Karplus M. J Phys Chem B. 107:9535, 2003. + http://dx.doi.org/10.1021/jp0217839 + """ self._verify_geometry(geometry) StandardV = 1.66053928 * unit.nanometer**3 diff --git a/openfe/tests/protocols/restraints/__init__.py b/openfe/tests/protocols/restraints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/tests/protocols/restraints/test_geometry_base.py b/openfe/tests/protocols/restraints/test_geometry_base.py new file mode 100644 index 000000000..139c57dc5 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_geometry_base.py @@ -0,0 +1,25 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe.protocols.restraint_utils.geometry.base import ( + HostGuestRestraintGeometry +) + + +def test_hostguest_geometry(): + """ + A very basic will it build test. + """ + geom = HostGuestRestraintGeometry(guest_atoms=[1, 2, 3], host_atoms=[4]) + + assert isinstance(geom, HostGuestRestraintGeometry) + + +def test_hostguest_positiveidxs_validator(): + """ + Check that the validator is working as intended. + """ + with pytest.raises(ValueError, match="negative indices passed"): + geom = HostGuestRestraintGeometry(guest_atoms=[-1, 1], host_atoms=[0]) diff --git a/openfe/tests/protocols/restraints/test_omm_restraints.py b/openfe/tests/protocols/restraints/test_omm_restraints.py new file mode 100644 index 000000000..0e346f9c5 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_omm_restraints.py @@ -0,0 +1,31 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe.protocols.restraint_utils.openmm.omm_restraints import ( + RestraintParameterState, +) + + +def test_parameter_state_default(): + param_state = RestraintParameterState() + assert param_state.lambda_restraints is None + + +@pytest.mark.parametrize('suffix', [None, 'foo']) +@pytest.mark.parametrize('lambda_var', [0, 0.5, 1.0]) +def test_parameter_state_suffix(suffix, lambda_var): + param_state = RestraintParameterState( + parameters_name_suffix=suffix, lambda_restraints=lambda_var + ) + + if suffix is not None: + param_name = f'lambda_restraints_{suffix}' + else: + param_name = 'lambda_restraints' + + assert getattr(param_state, param_name) == lambda_var + assert len(param_state._parameters.keys()) == 1 + assert param_state._parameters[param_name] == lambda_var + assert param_state._parameters_name_suffix == suffix diff --git a/openfe/tests/protocols/restraints/test_openmm_forces.py b/openfe/tests/protocols/restraints/test_openmm_forces.py new file mode 100644 index 000000000..cd2a7f21e --- /dev/null +++ b/openfe/tests/protocols/restraints/test_openmm_forces.py @@ -0,0 +1,115 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest +import numpy as np +import openmm +from openfe.protocols.restraint_utils.openmm.omm_forces import ( + get_boresch_energy_function, + get_periodic_boresch_energy_function, + get_custom_compound_bond_force, + add_force_in_separate_group, +) + + +@pytest.mark.parametrize('param', ['foo', 'bar']) +def test_boresch_energy_function(param): + """ + Base regression test for the energy function + """ + fn = get_boresch_energy_function(param) + assert fn == ( + f"{param} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " + "dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + + +@pytest.mark.parametrize('param', ['foo', 'bar']) +def test_periodic_boresch_energy_function(param): + """ + Base regression test for the energy function + """ + fn = get_periodic_boresch_energy_function(param) + assert fn == ( + f"{param} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*uphi_A + (K_phiB/2)*uphi_B + (K_phiC/2)*uphi_C; " + "uphi_A = (1-cos(dA)); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + + +@pytest.mark.parametrize('num_atoms', [6, 20]) +def test_custom_compound_force(num_atoms): + fn = get_boresch_energy_function('lambda_restraints') + force = get_custom_compound_bond_force(fn, num_atoms) + + # Check we have the right object + assert isinstance(force, openmm.CustomCompoundBondForce) + + # Check the energy function + assert force.getEnergyFunction() == fn + + # Check the number of particles + assert force.getNumParticlesPerBond() == num_atoms + + +@pytest.mark.parametrize('groups, expected', [ + [[0, 1, 2, 3, 4], 5], + [[1, 2, 3, 4, 5], 0], +]) +def test_add_force_in_separate_group(groups, expected): + # Create an empty system + system = openmm.System() + + # Create some forces with some force groups + base_forces = [ + openmm.NonbondedForce(), + openmm.HarmonicBondForce(), + openmm.HarmonicAngleForce(), + openmm.PeriodicTorsionForce(), + openmm.CMMotionRemover(), + ] + + for force, group in zip(base_forces, groups): + force.setForceGroup(group) + + [system.addForce(force) for force in base_forces] + + # Get your CustomCompoundBondForce + fn = get_boresch_energy_function('lambda_restraints') + new_force = get_custom_compound_bond_force(fn, 6) + # new_force.setForceGroup(5) + # system.addForce(new_force) + add_force_in_separate_group(system=system, force=new_force) + + # Loop through and check that we go assigned the expected force group + for force in system.getForces(): + if isinstance(force, openmm.CustomCompoundBondForce): + assert force.getForceGroup() == expected + + +def test_add_too_many_force_groups(): + # Create a system + system = openmm.System() + + # Fill it upu with 32 forces with different groups + for i in range(32): + f = openmm.HarmonicBondForce() + f.setForceGroup(i) + system.addForce(f) + + # Now try to add another force + with pytest.raises(ValueError, match="No available force group"): + add_force_in_separate_group( + system=system, force=openmm.HarmonicBondForce() + ) \ No newline at end of file From c914b18c63026524de550303f8662819c354bcd0 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 16 Dec 2024 09:22:31 +0000 Subject: [PATCH 25/33] base for restraint settings --- openfe/protocols/restraint_utils/settings.py | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 openfe/protocols/restraint_utils/settings.py diff --git a/openfe/protocols/restraint_utils/settings.py b/openfe/protocols/restraint_utils/settings.py new file mode 100644 index 000000000..0c12aef17 --- /dev/null +++ b/openfe/protocols/restraint_utils/settings.py @@ -0,0 +1,23 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Settings for adding restraints. +""" +from typing import Optional, Literal +from openff.units import unit +from openff.models.types import FloatQuantity, ArrayQuantity + +from gufe.settings import ( + SettingsBaseModel, +) + + +from pydantic.v1 import validator + + +class BaseRestraintSettings(SettingsBaseModel): + """ + Base class for RestraintSettings objects. + """ + class Config: + arbitrary_types_allowed = True From d24a5a5a6909c35b906c3c9a42d1d15f929fac0b Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 16 Dec 2024 11:20:14 +0000 Subject: [PATCH 26/33] Fix up some things --- .../restraint_utils/geometry/boresch.py | 237 ++++++------------ .../restraint_utils/geometry/harmonic.py | 26 -- 2 files changed, 82 insertions(+), 181 deletions(-) diff --git a/openfe/protocols/restraint_utils/geometry/boresch.py b/openfe/protocols/restraint_utils/geometry/boresch.py index 6e740f48d..2b0cd2313 100644 --- a/openfe/protocols/restraint_utils/geometry/boresch.py +++ b/openfe/protocols/restraint_utils/geometry/boresch.py @@ -79,113 +79,6 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): The equilibrium dihedral value between H0, G0, G1, and G2. """ - def get_bond_distance( - self, - universe: mda.Universe, - ) -> unit.Quantity: - """ - Get the H0 - G0 distance. - - Parameters - ---------- - universe : mda.Universe - A Universe representing the system of interest. - - Returns - ------- - bond : unit.Quantity - The H0-G0 distance. - """ - at1 = universe.atoms[self.host_atoms[0]] - at2 = universe.atoms[self.guest_atoms[0]] - bond = calc_bonds( - at1.position, - at2.position, - box=universe.atoms.dimensions - ) - # convert to float so we avoid having a np.float64 - return float(bond) * unit.angstrom - - def get_angles( - self, - universe: mda.Universe, - ) -> tuple[unit.Quantity, unit.Quantity]: - """ - Get the H1-H0-G0, and H0-G0-G1 angles. - - Parameters - ---------- - universe : mda.Universe - A Universe representing the system of interest. - - Returns - ------- - angleA : unit.Quantity - The H1-H0-G0 angle. - angleB : unit.Quantity - The H0-G0-G1 angle. - """ - at1 = universe.atoms[self.host_atoms[1]] - at2 = universe.atoms[self.host_atoms[0]] - at3 = universe.atoms[self.guest_atoms[0]] - at4 = universe.atoms[self.guest_atoms[1]] - - angleA = calc_angles( - at1.position, - at2.position, - at3.position, - box=universe.atoms.dimensions - ) - angleB = calc_angles( - at2.position, - at3.position, - at4.position, - box=universe.atoms.dimensions - ) - return angleA, angleB - - def get_dihedrals( - self, - universe: mda.Universe, - ) -> tuple[unit.Quantity, unit.Quantity, unit.Quantity]: - """ - Get the H2-H1-H0-G0, H1-H0-G0-G1, and H0-G0-G1-G2 dihedrals. - - Parameters - ---------- - universe : mda.Universe - A Universe representing the system of interest. - - Returns - ------- - dihA : unit.Quantity - The H2-H1-H0-G0 angle. - dihB : unit.Quantity - The H1-H0-G0-G1 angle. - dihC : unit.Quantity - The H0-G0-G1-G2 angle. - """ - at1 = universe.atoms[self.host_atoms[2]] - at2 = universe.atoms[self.host_atoms[1]] - at3 = universe.atoms[self.host_atoms[0]] - at4 = universe.atoms[self.guest_atoms[0]] - at5 = universe.atoms[self.guest_atoms[1]] - at6 = universe.atoms[self.guest_atoms[2]] - - dihA = calc_dihedrals( - at1.position, at2.position, at3.position, at4.position, - box=universe.dimensions - ) - dihB = calc_dihedrals( - at2.position, at3.position, at4.position, at5.position, - box=universe.dimensions - ) - dihC = calc_dihedrals( - at3.position, at4.position, at5.position, at6.position, - box=universe.dimensions - ) - return dihA, dihB, dihC - def _sort_by_distance_from_atom( rdmol: Chem.Mol, target_idx: int, atom_idxs: Iterable[int] @@ -220,7 +113,7 @@ def _sort_by_distance_from_atom( return [i[1] for i in sorted(distances)] -def _get_bonded_angles_from_pool( +def _bonded_angles_from_pool( rdmol: Chem.Mol, atom_idx: int, atom_pool: list[int] ) -> list[tuple[int, int, int]]: """ @@ -300,7 +193,7 @@ def _get_atom_pool( return atom_pool -def get_guest_atom_candidates( +def find_guest_atom_candidates( topology: Union[str, pathlib.Path, openmm.app.Topology], trajectory: Union[str, pathlib.Path], rdmol: Chem.Mol, @@ -370,7 +263,7 @@ def get_guest_atom_candidates( # 4. Get a list of probable angles angles_list = [] for atom in sorted_atom_pool: - angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_atom_pool) + angles = _bonded_angles_from_pool(rdmol, atom, sorted_atom_pool) for angle in angles: # Check that the angle is at least not collinear angle_ag = ligand_ag.atoms[list(angle)] @@ -386,7 +279,7 @@ def get_guest_atom_candidates( return angles_list -def get_host_atom_candidates( +def find_host_atom_candidates( topology: Union[str, pathlib.Path, openmm.app.Topology], trajectory: Union[str, pathlib.Path], host_idxs: list[int], @@ -459,7 +352,7 @@ def get_host_atom_candidates( class EvaluateHostAtoms1(AnalysisBase): """ Class to evaluate the suitability of a set of host atoms - as H1 atoms (i.e. the second host atom). + as either H0 or H1 atoms (i.e. the first and second host atoms). Parameters ---------- @@ -474,7 +367,6 @@ class EvaluateHostAtoms1(AnalysisBase): temperature : unit.Quantity The system temperature in Kelvin """ - def __init__( self, reference, @@ -582,6 +474,23 @@ def _conclude(self): class EvaluateHostAtoms2(EvaluateHostAtoms1): + """ + Class to evaluate the suitability of a set of host atoms + as H2 atoms (i.e. the third host atoms). + + Parameters + ---------- + reference : MDAnalysis.AtomGroup + The reference preceeding three atoms. + host_atom_pool : MDAnalysis.AtomGroup + The pool of atoms to pick an atom from. + minimum_distance : unit.Quantity + The minimum distance from the bound reference atom. + angle_force_constant : unit.Quantity + The force constant for the angle. + temperature : unit.Quantity + The system temperature in Kelvin + """ def _prepare(self): self.results.distances1 = np.zeros((len(self.host_atom_pool), self.n_frames)) self.results.ditances2 = np.zeros((len(self.host_atom_pool), self.n_frames)) @@ -648,15 +557,36 @@ def _conclude(self): self.results.valid[i] = True -def _find_host_angle( - g0g1g2_atoms, - host_atom_pool, - minimum_distance, - angle_force_constant, - temperature -): +def _find_host_anchor( + guest_atoms: mda.AtomGroup, + host_atom_pool: mda.AtomGroup, + minimum_distance: unit.Quantity, + angle_force_constant: unit.Quantity, + temperature: unit.Quantity +) -> Optional[list[int]]: + """ + Find suitable atoms for the H0-H1-H2 portion of the restraint. + + Parameters + ---------- + guest_atoms : mda.AtomGroup + The guest anchor atoms for G0-G1-G2 + host_atom_pool : mda.AtomGroup + The host atoms to search from. + minimum_distance : unit.Quantity + The minimum distance to pick host atoms from each other. + angle_force_constant : unit.Quantity + The force constant for the G1-G0-H0 and G0-H0-H1 angles. + temperature : unit.Quantity + The target system temperature. + + Returns + ------- + Optional[list[int]] + A list of indices for a selected combination of H0, H1, and H2. + """ h0_eval = EvaluateHostAtoms1( - g0g1g2_atoms, + guest_atoms, host_atom_pool, minimum_distance, angle_force_constant, @@ -666,7 +596,7 @@ def _find_host_angle( for i, valid_h0 in enumerate(h0_eval.results.valid): if valid_h0: - g1g2h0_atoms = g0g1g2_atoms.atoms[1:] + host_atom_pool.atoms[i] + g1g2h0_atoms = guest_atoms.atoms[1:] + host_atom_pool.atoms[i] h1_eval = EvaluateHostAtoms1( g1g2h0_atoms, host_atom_pool, @@ -690,7 +620,7 @@ def _find_host_angle( dsum_avgs = d1_avgs + d2_avgs k = dsum_avgs.argmin() - return host_atom_pool.atoms[[i, j, k]].ix + return list(host_atom_pool.atoms[[i, j, k]].ix) return None @@ -839,24 +769,24 @@ def find_boresch_restraint( # In this case assume the picked atoms were intentional / # representative of the input and go with it guest_ag = u.select_atoms[guest_idxs] - guest_angle = [ + guest_anchor = [ at.ix for at in guest_ag.atoms[guest_restraint_atoms_idxs] ] host_ag = u.select_atoms[host_idxs] - host_angle = [ + host_anchor = [ at.ix for at in host_ag.atoms[host_restraint_atoms_idxs] ] # Set the equilibrium values as those of the final frame u.trajectory[-1] - atomgroup = u.atoms[host_angle + guest_angle] + atomgroup = u.atoms[host_anchor + guest_anchor] bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( atomgroup ) return BoreschRestraintGeometry( - host_atoms=host_angle, - guest_atoms=guest_angle, + host_atoms=host_anchor, + guest_atoms=guest_anchor, r_aA0=bond, theta_A0=ang1, theta_B0=ang2, @@ -875,8 +805,8 @@ def find_boresch_restraint( ) raise ValueError(errmsg) - # 1. Fetch the guest angles - guest_angles = get_guest_atom_candidates( + # 1. Fetch the guest anchors + guest_anchors = find_guest_atom_candidates( topology=topology, trajectory=trajectory, rdmol=guest_rdmol, @@ -884,53 +814,50 @@ def find_boresch_restraint( rmsf_cutoff=rmsf_cutoff, ) - if len(guest_angles) != 0: + if len(guest_anchors) != 0: errmsg = "No suitable ligand atoms found for the restraint." raise ValueError(errmsg) - # We pick the first angle / ligand atom set as the one to use - guest_angle = guest_angles[0] - - # 2. We next fetch the host atom pool - host_pool = get_host_atom_candidates( - topology=topology, - trajectory=trajectory, - host_idxs=host_idxs, - l1_idx=guest_angle[0], - host_selection=host_selection, - dssp_filter=dssp_filter, - rmsf_cutoff=rmsf_cutoff, - min_distance=host_min_distance, - max_distance=host_max_distance, - ) + # 2. We then loop through the guest anchors to find suitable host atoms + for guest_anchor in guest_anchors: + # We next fetch the host atom pool + host_pool = find_host_atom_candidates( + topology=topology, + trajectory=trajectory, + host_idxs=host_idxs, + l1_idx=guest_anchor, + host_selection=host_selection, + dssp_filter=dssp_filter, + rmsf_cutoff=rmsf_cutoff, + min_distance=host_min_distance, + max_distance=host_max_distance, + ) - # 3. We then loop through the guest angles to find suitable host atoms - for guest_angle in guest_angles: - host_angle = _find_host_angle( - g0g1g2_atoms=u.atoms[list(guest_angle)], + host_anchor = _find_host_anchor( + guest_atoms=u.atoms[list(guest_anchor)], host_atom_pool=u.atoms[host_pool], minimum_distance=0.5 * unit.nanometer, angle_force_constant=angle_force_constant, temperature=temperature, ) # continue if it's empty, otherwise stop - if host_angle is not None: + if host_anchor is not None: break - if host_angle is None: + if host_anchor is None: errmsg = "No suitable host atoms could be found" raise ValueError(errmsg) # Set the equilibrium values as those of the final frame u.trajectory[-1] - atomgroup = u.atoms[host_angle + guest_angle] + atomgroup = u.atoms[host_anchor + guest_anchor] bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( atomgroup ) return BoreschRestraintGeometry( - host_atoms=host_angle, - guest_atoms=guest_angle, + host_atoms=host_anchor, + guest_atoms=guest_anchor, r_aA0=bond, theta_A0=ang1, theta_B0=ang2, diff --git a/openfe/protocols/restraint_utils/geometry/harmonic.py b/openfe/protocols/restraint_utils/geometry/harmonic.py index 197a8bc44..81e2f22b2 100644 --- a/openfe/protocols/restraint_utils/geometry/harmonic.py +++ b/openfe/protocols/restraint_utils/geometry/harmonic.py @@ -10,9 +10,7 @@ import pathlib from typing import Union, Optional from openmm import app -from openff.units import unit import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds from rdkit import Chem from .base import HostGuestRestraintGeometry @@ -28,30 +26,6 @@ class DistanceRestraintGeometry(HostGuestRestraintGeometry): A geometry class for a distance restraint between two groups of atoms. """ - def get_distance(self, universe: mda.Universe) -> unit.Quantity: - """ - Get the center of mass distance between the host and guest atoms. - - Parameters - ---------- - universe : mda.Universe - A Universe representing the system of interest. - - Returns - ------- - bond : unit.Quantity - The center of mass distance between the two groups of atoms. - """ - ag1 = universe.atoms[self.host_atoms] - ag2 = universe.atoms[self.guest_atoms] - bond = calc_bonds( - ag1.center_of_mass(), - ag2.center_of_mass(), - box=universe.atoms.dimensions - ) - # convert to float so we avoid having a np.float64 - return float(bond) * unit.angstrom - def get_distance_restraint( topology: Union[str, pathlib.Path, app.Topology], From 4dd16afd3b3a84a0a98c01d67ffc21b497099f1f Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 16 Dec 2024 13:55:48 +0000 Subject: [PATCH 27/33] Add restraint settings --- .../restraint_utils/openmm/omm_restraints.py | 10 +- openfe/protocols/restraint_utils/settings.py | 116 +++++++++++++++++- 2 files changed, 117 insertions(+), 9 deletions(-) diff --git a/openfe/protocols/restraint_utils/openmm/omm_restraints.py b/openfe/protocols/restraint_utils/openmm/omm_restraints.py index c77b1cd0b..0eaaa9585 100644 --- a/openfe/protocols/restraint_utils/openmm/omm_restraints.py +++ b/openfe/protocols/restraint_utils/openmm/omm_restraints.py @@ -288,7 +288,7 @@ class HarmonicBondRestraint( Notes ----- * Settings must contain a ``spring_constant`` for the - Force in units compatible with kilojoule/mole. + Force in units compatible with kilojoule/mole/nm**2. """ def _get_force( self, @@ -335,7 +335,7 @@ class FlatBottomBondRestraint( Notes ----- * Settings must contain a ``spring_constant`` for the - Force in units compatible with kilojoule/mole. + Force in units compatible with kilojoule/mole/nm**2. """ def _get_force( self, @@ -384,7 +384,7 @@ class CentroidHarmonicRestraint(BaseRadiallySymmetricRestraintForce): Notes ----- * Settings must contain a ``spring_constant`` for the - Force in units compatible with kilojoule/mole. + Force in units compatible with kilojoule/mole/nm**2. """ def _get_force( self, @@ -429,7 +429,7 @@ class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): Notes ----- * Settings must contain a ``spring_constant`` for the - Force in units compatible with kilojoule/mole. + Force in units compatible with kilojoule/mole/nm**2. """ def _get_force( self, @@ -510,7 +510,7 @@ class BoreschRestraint(BaseHostGuestRestraints): (p2, p3, p4) and (p3, p4, p5). They must be provided by the Geometry class in units compatible with radians. - ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium constants + ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium force constants for the dihedrals formed by (p1, p2, p3, p4), (p2, p3, p4, p5), and (p3, p4, p5, p6). They must be provided in the settings in units compatible with kilojoule / mole / radians ** 2. diff --git a/openfe/protocols/restraint_utils/settings.py b/openfe/protocols/restraint_utils/settings.py index 0c12aef17..66998debb 100644 --- a/openfe/protocols/restraint_utils/settings.py +++ b/openfe/protocols/restraint_utils/settings.py @@ -2,22 +2,130 @@ # For details, see https://github.com/OpenFreeEnergy/openfe """ Settings for adding restraints. + +TODO +---- +* Rename from host/guest to molA/molB? """ from typing import Optional, Literal from openff.units import unit from openff.models.types import FloatQuantity, ArrayQuantity - +from pydantic.v1 import validator from gufe.settings import ( SettingsBaseModel, ) -from pydantic.v1 import validator - - class BaseRestraintSettings(SettingsBaseModel): """ Base class for RestraintSettings objects. """ class Config: arbitrary_types_allowed = True + + +class DistanceRestraintSettings(BaseRestraintSettings): + """ + Settings defining a distance restraint between + two groups of atoms defined as ``host`` and ``guest``. + """ + spring_constant: FloatQuantity['kilojoule_per_mole / nm ** 2'] + """ + The distance restraint potential spring constant. + """ + host_atoms: Optional[list[int]] = None + """ + The indices of the host component atoms to restrain. + If defined, these will override any automatic selection. + """ + guest_atoms: Optional[list[int]] = None + """ + The indices of the guest component atoms to restraint. + If defined, these will override any automatic selection. + """ + central_atoms_only: bool = False + """ + Whether to apply the restraint solely to the central atoms + of each group. + + Note: this can only be applied if ``host`` and ``guest`` + represent small molecules. + """ + + +class FlatBottomRestraintSettings(DistanceRestraintSettings): + """ + Settings to define a flat bottom restraint between two + groups of atoms named ``host`` and ``guest``. + """ + well_radius: Optional[FloatQuantity['nm']] = None + """ + The distance at which the harmonic restraint is imposed + in units of distance. + """ + + +class BoreschRestraintSettings(BaseRestraintSettings): + """ + Settings to define a Boresch-style restraint between + two groups of atoms named ``host`` and ``guest``. + + The restraint is defined in the following manner: + + H2 G2 + - - + - - + H1 - - H0 -- G0 - - G1 + + Where HX represents the X index of ``host_atoms`` + and GX the X indexx of ``guest_atoms``. + + By default, the Boresch-like restraint will be + obtained using a modified version of the + search algorithm implemented by Baumann et al. [1]. + + If ``guest_atoms`` and ``host_atoms`` are defined, + these indices will be used instead. + + References + ---------- + [1] Baumann, Hannah M., et al. "Broadening the scope of binding free + energy calculations using a Separated Topologies approach." (2023). + """ + K_r: FloatQuantity['kilojoule_per_mole / nm ** 2'] + """ + The bond spring constant between H0 and G0. + """ + K_thetaA: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The spring constant for the angle formed by H1-H0-G0. + """ + K_thetaB: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The spring constant for the angle formed by H0-G0-G1. + """ + phi_A0: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The equilibrium force constant for the dihedral formed by + H2-H1-H0-G0. + """ + phi_B0: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The equilibrium force constant for the dihedral formed by + H1-H0-G0-G1. + """ + phi_C0: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The equilibrium force constant for the dihedral formed by + H0-G0-G1-G2. + """ + host_atoms: Optional[list[int]] = None + """ + The indices of the host component atoms to restrain. + If defined, these will override any automatic selection. + """ + guest_atoms: Optional[list[int]] = None + """ + The indices of the guest component atoms to restraint. + If defined, these will override any automatic selection. + """ \ No newline at end of file From ec0d01de87de97dd88feb4add9e360c8bc901c7c Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 16 Dec 2024 13:57:01 +0000 Subject: [PATCH 28/33] Add missing settings imports --- openfe/protocols/restraint_utils/openmm/omm_restraints.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/openfe/protocols/restraint_utils/openmm/omm_restraints.py b/openfe/protocols/restraint_utils/openmm/omm_restraints.py index 0eaaa9585..8df9cb837 100644 --- a/openfe/protocols/restraint_utils/openmm/omm_restraints.py +++ b/openfe/protocols/restraint_utils/openmm/omm_restraints.py @@ -36,6 +36,12 @@ DistanceRestraintGeometry, BoreschRestraintGeometry ) + +from openfe.protocols.restraint_utils.settings import ( + DistanceRestraintSettings, + BoreschRestraintSettings, +) + from .omm_forces import ( get_custom_compound_bond_force, add_force_in_separate_group, From e19db658d11f53722a32d936e4e57e8d18cb455a Mon Sep 17 00:00:00 2001 From: IAlibay Date: Tue, 17 Dec 2024 11:31:16 +0000 Subject: [PATCH 29/33] Settings and some tests for them --- .../restraint_utils/geometry/base.py | 2 +- openfe/protocols/restraint_utils/settings.py | 22 +++++- .../protocols/restraints/test_settings.py | 77 +++++++++++++++++++ 3 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 openfe/tests/protocols/restraints/test_settings.py diff --git a/openfe/protocols/restraint_utils/geometry/base.py b/openfe/protocols/restraint_utils/geometry/base.py index 0ca6ae200..798befd45 100644 --- a/openfe/protocols/restraint_utils/geometry/base.py +++ b/openfe/protocols/restraint_utils/geometry/base.py @@ -42,7 +42,7 @@ class HostGuestRestraintGeometry(BaseRestraintGeometry): @validator("guest_atoms", "host_atoms") def positive_idxs(cls, v): - if any([i < 0 for i in v]): + if v is not None and any([i < 0 for i in v]): errmsg = "negative indices passed" raise ValueError(errmsg) return v diff --git a/openfe/protocols/restraint_utils/settings.py b/openfe/protocols/restraint_utils/settings.py index 66998debb..0a06714c0 100644 --- a/openfe/protocols/restraint_utils/settings.py +++ b/openfe/protocols/restraint_utils/settings.py @@ -52,6 +52,13 @@ class DistanceRestraintSettings(BaseRestraintSettings): represent small molecules. """ + @validator("guest_atoms", "host_atoms") + def positive_idxs(cls, v): + if v is not None and any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v + class FlatBottomRestraintSettings(DistanceRestraintSettings): """ @@ -63,6 +70,12 @@ class FlatBottomRestraintSettings(DistanceRestraintSettings): The distance at which the harmonic restraint is imposed in units of distance. """ + @validator("well_radius") + def positive_value(cls, v): + if v is not None and v.m < 0: + errmsg = f"well radius cannot be negative {v}" + raise ValueError(errmsg) + return v class BoreschRestraintSettings(BaseRestraintSettings): @@ -128,4 +141,11 @@ class BoreschRestraintSettings(BaseRestraintSettings): """ The indices of the guest component atoms to restraint. If defined, these will override any automatic selection. - """ \ No newline at end of file + """ + + @validator("guest_atoms", "host_atoms") + def positive_idxs_list(cls, v): + if v is not None and any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v \ No newline at end of file diff --git a/openfe/tests/protocols/restraints/test_settings.py b/openfe/tests/protocols/restraints/test_settings.py new file mode 100644 index 000000000..49ed0dca0 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_settings.py @@ -0,0 +1,77 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Test the restraint settings. +""" +import pytest +import numpy as np +import openmm +from openff.units import unit +from openfe.protocols.restraint_utils.settings import ( + DistanceRestraintSettings, + FlatBottomRestraintSettings, + BoreschRestraintSettings, +) + + +def test_distance_restraint_settings_default(): + """ + Basic settings regression test + """ + settings = DistanceRestraintSettings( + spring_constant=10 * unit.kilojoule_per_mole / unit.nm ** 2, + ) + assert settings.central_atoms_only is False + assert isinstance(settings, DistanceRestraintSettings) + + +def test_distance_restraint_negative_idxs(): + """ + Check that an error is raised if you have negative + atom indices in host atoms. + """ + with pytest.raises(ValueError, match="negative indices passed"): + _ = DistanceRestraintSettings( + spring_constant=10 * unit.kilojoule_per_mole / unit.nm ** 2, + host_atoms=[-1, 0, 2], + guest_atoms=[0, 1, 2], + ) + + +def test_flatbottom_restraint_settings_default(): + """ + Basic settings regression test + """ + settings = FlatBottomRestraintSettings( + spring_constant=10 * unit.kilojoule_per_mole / unit.nm ** 2, + well_radius=1*unit.nanometer, + ) + assert isinstance(settings, FlatBottomRestraintSettings) + + +def test_flatbottom_restraint_negative_well(): + """ + Check that an error is raised if you have a negative + well radius. + """ + with pytest.raises(ValueError, match="negative indices passed"): + _ = DistanceRestraintSettings( + spring_constant=10 * unit.kilojoule_per_mole / unit.nm ** 2, + host_atoms=[-1, 0, 2], + guest_atoms=[0, 1, 2], + ) + + +def test_boresch_restraint_settings_default(): + """ + Basic settings regression test + """ + settings = BoreschRestraintSettings( + K_r=10 * unit.kilojoule_per_mole / unit.nm ** 2, + K_thetaA=10 * unit.kilojoule_per_mole / unit.radians ** 2, + K_thetaB=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_A0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_B0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_C0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + ) + assert isinstance(settings, BoreschRestraintSettings) From 854d1c6e0255ba30af6388f6322da3ce9ee4d90f Mon Sep 17 00:00:00 2001 From: IAlibay Date: Tue, 17 Dec 2024 11:40:38 +0000 Subject: [PATCH 30/33] negative idxs test --- .../protocols/restraints/test_settings.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/openfe/tests/protocols/restraints/test_settings.py b/openfe/tests/protocols/restraints/test_settings.py index 49ed0dca0..82660cfaf 100644 --- a/openfe/tests/protocols/restraints/test_settings.py +++ b/openfe/tests/protocols/restraints/test_settings.py @@ -75,3 +75,21 @@ def test_boresch_restraint_settings_default(): phi_C0=10 * unit.kilojoule_per_mole / unit.radians ** 2, ) assert isinstance(settings, BoreschRestraintSettings) + + +def test_boresch_restraint_negative_idxs(): + """ + Check that the positive_idxs_list validator is + working as expected. + """ + with pytest.raises(ValueError, match='negative indices'): + settings = BoreschRestraintSettings( + K_r=10 * unit.kilojoule_per_mole / unit.nm ** 2, + K_thetaA=10 * unit.kilojoule_per_mole / unit.radians ** 2, + K_thetaB=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_A0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_B0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_C0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + host_atoms=[-1, 0], + guest_atoms=[0, 1], + ) From 9b53cd28043537ede777c169999844b628a9e490 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Tue, 17 Dec 2024 11:56:13 +0000 Subject: [PATCH 31/33] addressing some mypy issues --- .../restraint_utils/geometry/boresch.py | 21 +++++++++++-------- .../restraint_utils/geometry/flatbottom.py | 12 ++++++++++- .../restraint_utils/geometry/utils.py | 4 ++-- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/openfe/protocols/restraint_utils/geometry/boresch.py b/openfe/protocols/restraint_utils/geometry/boresch.py index 2b0cd2313..8a4eaaceb 100644 --- a/openfe/protocols/restraint_utils/geometry/boresch.py +++ b/openfe/protocols/restraint_utils/geometry/boresch.py @@ -168,6 +168,9 @@ def _get_atom_pool( The RDKit Molecule to search through rmsf : npt.NDArray A 1-D array of RMSF values for each atom. + rmsf_cutoff : unit.Quantity + The rmsf cutoff value for selecting atoms in units compatible with + nanometer. Returns ------- @@ -177,7 +180,7 @@ def _get_atom_pool( # Note: no need to keep track of rings because we'll filter by # bonded terms after, so if we only keep rings then all the bonded # atoms should be within the same ring system. - atom_pool = set() + atom_pool: set[tuple[int]] = set() for ring in get_aromatic_rings(rdmol): max_rmsf = rmsf[list(ring)].max() if max_rmsf < rmsf_cutoff: @@ -195,7 +198,7 @@ def _get_atom_pool( def find_guest_atom_candidates( topology: Union[str, pathlib.Path, openmm.app.Topology], - trajectory: Union[str, pathlib.Path], + trajectory: Union[str, pathlib.Path, npt.NDArray], rdmol: Chem.Mol, guest_idxs: list[int], rmsf_cutoff: unit.Quantity = 1 * unit.nanometer, @@ -208,7 +211,7 @@ def find_guest_atom_candidates( ---------- topology : Union[str, openmm.app.Topology] The topology of the system. - trajectory : Union[str, pathlib.Path] + trajectory : Union[str, pathlib.Path, npt.NDArray] A path to the system's coordinate trajectory. rdmol : Chem.Mol An RDKit Molecule representing the small molecule ordered in @@ -247,7 +250,7 @@ def find_guest_atom_candidates( u.trajectory[-1] # forward to the last frame # 1. Get the pool of atoms to work with - atom_pool = _get_atom_pool(rdmol, rmsf) + atom_pool = _get_atom_pool(rdmol, rmsf, rmsf_cutoff) if atom_pool is None: # We don't have enough atoms so we raise an error @@ -281,7 +284,7 @@ def find_guest_atom_candidates( def find_host_atom_candidates( topology: Union[str, pathlib.Path, openmm.app.Topology], - trajectory: Union[str, pathlib.Path], + trajectory: Union[str, pathlib.Path, npt.NDArray], host_idxs: list[int], l1_idx: int, host_selection: str, @@ -297,8 +300,8 @@ def find_host_atom_candidates( ---------- topology : Union[str, openmm.app.Topology] The topology of the system. - trajectory : Union[str, pathlib.Path] - A path to the system's coordinate trajectory. + trajectory : Union[str, pathlib.Path, npt.NDArray] + The system's coordinate trajectory. host_idxs : list[int] A list of the host indices in the system topology. l1_idx : int @@ -615,8 +618,8 @@ def _find_host_anchor( ) if any(h2_eval.ressults.valid): - d1_avgs = [d.mean() for d in h2_eval.results.distances1] - d2_avgs = [d.mean() for d in h2_eval.results.distances2] + d1_avgs = np.array([d.mean() for d in h2_eval.results.distances1]) + d2_avgs = np.array([d.mean() for d in h2_eval.results.distances2]) dsum_avgs = d1_avgs + d2_avgs k = dsum_avgs.argmin() diff --git a/openfe/protocols/restraint_utils/geometry/flatbottom.py b/openfe/protocols/restraint_utils/geometry/flatbottom.py index 3b4599f56..c5e975401 100644 --- a/openfe/protocols/restraint_utils/geometry/flatbottom.py +++ b/openfe/protocols/restraint_utils/geometry/flatbottom.py @@ -117,11 +117,21 @@ def get_flatbottom_distance_restraint( guest_ag = _get_mda_selection(u, guest_atoms, guest_selection) host_ag = _get_mda_selection(u, host_atoms, host_selection) + guest_idxs = [a.ix for a in guest_ag] + host_idxs = [a.ix for a in host_ag] + + if len(host_idxs) == 0 or len(guest_idxs) == 0: + errmsg = ( + "no atoms found in either the host or guest atom groups" + f"host_atoms: {host_idxs}" + f"guest_atoms: {guest_idxs}" + ) + raise ValueError(errmsg) com_dists = COMDistanceAnalysis(guest_ag, host_ag) com_dists.run() well_radius = com_dists.results.distances.max() * unit.angstrom + padding return FlatBottomDistanceGeometry( - guest_atoms=guest_atoms, host_atoms=host_atoms, well_radius=well_radius + guest_atoms=guest_idxs, host_atoms=host_idxs, well_radius=well_radius ) diff --git a/openfe/protocols/restraint_utils/geometry/utils.py b/openfe/protocols/restraint_utils/geometry/utils.py index 4b734b410..988a31cfe 100644 --- a/openfe/protocols/restraint_utils/geometry/utils.py +++ b/openfe/protocols/restraint_utils/geometry/utils.py @@ -76,7 +76,7 @@ def _get_mda_selection( def _get_mda_coord_format( - coordinates: Union[str, npt.NDArray] + coordinates: Union[str, pathlib.Path, npt.NDArray] ) -> Optional[MemoryReader]: """ Helper to set the coordinate format to MemoryReader @@ -84,7 +84,7 @@ def _get_mda_coord_format( Parameters ---------- - coordinates : Union[str, npt.NDArray] + coordinates : Union[str, pathlib.Path, npt.NDArray] Returns ------- From de35e3f9445798e8d38f8c574e9e603f3050ead8 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 20 Dec 2024 02:10:28 +0000 Subject: [PATCH 32/33] Addressing reviews --- .../restraint_utils/geometry/boresch.py | 61 +++++++++++++++---- .../restraint_utils/geometry/harmonic.py | 4 +- .../restraint_utils/geometry/utils.py | 11 +++- .../restraint_utils/openmm/omm_restraints.py | 7 ++- 4 files changed, 65 insertions(+), 18 deletions(-) diff --git a/openfe/protocols/restraint_utils/geometry/boresch.py b/openfe/protocols/restraint_utils/geometry/boresch.py index 8a4eaaceb..4fb4b45b9 100644 --- a/openfe/protocols/restraint_utils/geometry/boresch.py +++ b/openfe/protocols/restraint_utils/geometry/boresch.py @@ -114,7 +114,10 @@ def _sort_by_distance_from_atom( def _bonded_angles_from_pool( - rdmol: Chem.Mol, atom_idx: int, atom_pool: list[int] + rdmol: Chem.Mol, + atom_idx: int, + atom_pool: list[int], + aromatic_only: bool, ) -> list[tuple[int, int, int]]: """ Get all bonded angles starting from ``atom_idx`` from a pool of atoms. @@ -127,11 +130,19 @@ def _bonded_angles_from_pool( The index of the atom to search angles from. atom_pool : list[int] The list of indices to pick possible angle partners from. + aromatic_only : bool + Prune any angles that include non-aromatic bonds. Returns ------- list[tuple[int, int, int]] A list of tuples containing all the angles. + + Notes + ----- + * In the original SepTop code at3 is picked as directly bonded to at1. + By comparison here we instead follow the case that at3 is bonded to + at2 but not bonded to at1. """ angles = [] @@ -150,14 +161,28 @@ def _bonded_angles_from_pool( for at3 in atom_pool: if at3 != atom_idx and at3 in at2_neighbors: angles.append((atom_idx, at2, at3)) + + if aromatic_only: + aromatic_rings = get_aromatic_rings(rdmol) + + def _belongs_to_ring(angle, aromatic_rings): + for ring in aromatic_rings: + if all(a in ring for a in angle): + return True + return False + + for angle in angles: + if not _belongs_to_ring(angle, aromatic_rings): + angles.remove(angle) + return angles -def _get_atom_pool( +def _get_guest_atom_pool( rdmol: Chem.Mol, rmsf: npt.NDArray, rmsf_cutoff: unit.Quantity -) -> Optional[set[int]]: +) -> tuple[Optional[set[int]], bool]: """ Filter atoms based on rmsf & rings, defaulting to heavy atoms if there are not enough. @@ -175,12 +200,16 @@ def _get_atom_pool( Returns ------- atom_pool : Optional[set[int]] + A pool of candidate atoms. + ring_atoms_only : bool + True if only ring atoms were selected. """ # Get a list of all the aromatic rings # Note: no need to keep track of rings because we'll filter by # bonded terms after, so if we only keep rings then all the bonded # atoms should be within the same ring system. atom_pool: set[tuple[int]] = set() + ring_atoms_only: bool = True for ring in get_aromatic_rings(rdmol): max_rmsf = rmsf[list(ring)].max() if max_rmsf < rmsf_cutoff: @@ -188,12 +217,13 @@ def _get_atom_pool( # if we don't have enough atoms just get all the heavy atoms if len(atom_pool) < 3: + ring_atoms_only = False heavy_atoms = get_heavy_atom_idxs(rdmol) atom_pool = set(heavy_atoms[rmsf[heavy_atoms] < rmsf_cutoff]) if len(atom_pool) < 3: - return None + return None, False - return atom_pool + return atom_pool, ring_atoms_only def find_guest_atom_candidates( @@ -250,7 +280,7 @@ def find_guest_atom_candidates( u.trajectory[-1] # forward to the last frame # 1. Get the pool of atoms to work with - atom_pool = _get_atom_pool(rdmol, rmsf, rmsf_cutoff) + atom_pool, rings_only = _get_guest_atom_pool(rdmol, rmsf, rmsf_cutoff) if atom_pool is None: # We don't have enough atoms so we raise an error @@ -266,7 +296,12 @@ def find_guest_atom_candidates( # 4. Get a list of probable angles angles_list = [] for atom in sorted_atom_pool: - angles = _bonded_angles_from_pool(rdmol, atom, sorted_atom_pool) + angles = _bonded_angles_from_pool( + rdmol=rdmol, + atom_idx=atom, + atom_pool=sorted_atom_pool, + aromatic_only=rings_only, + ) for angle in angles: # Check that the angle is at least not collinear angle_ag = ligand_ag.atoms[list(angle)] @@ -342,11 +377,11 @@ def find_host_atom_candidates( # 1. Get the RMSF & filter rmsf = get_local_rmsf(host_ag2) - protein_ag3 = host_ag2.atoms[rmsf < rmsf_cutoff] + host_ag3 = host_ag2.atoms[rmsf < rmsf_cutoff] # 2. Search of atoms within the min/max cutoff atom_finder = FindHostAtoms( - protein_ag3, u.atoms[l1_idx], min_distance, max_distance + host_ag3, u.atoms[l1_idx], min_distance, max_distance ) atom_finder.run() return atom_finder.results.host_idxs @@ -599,9 +634,9 @@ def _find_host_anchor( for i, valid_h0 in enumerate(h0_eval.results.valid): if valid_h0: - g1g2h0_atoms = guest_atoms.atoms[1:] + host_atom_pool.atoms[i] + g1g0h0_atoms = guest_atoms.atoms[:2] + host_atom_pool.atoms[i] h1_eval = EvaluateHostAtoms1( - g1g2h0_atoms, + g1g0h0_atoms, host_atom_pool, minimum_distance, angle_force_constant, @@ -617,7 +652,7 @@ def _find_host_anchor( temperature, ) - if any(h2_eval.ressults.valid): + if any(h2_eval.results.valid): d1_avgs = np.array([d.mean() for d in h2_eval.results.distances1]) d2_avgs = np.array([d.mean() for d in h2_eval.results.distances2]) dsum_avgs = d1_avgs + d2_avgs @@ -828,7 +863,7 @@ def find_boresch_restraint( topology=topology, trajectory=trajectory, host_idxs=host_idxs, - l1_idx=guest_anchor, + l1_idx=guest_anchor[0], host_selection=host_selection, dssp_filter=dssp_filter, rmsf_cutoff=rmsf_cutoff, diff --git a/openfe/protocols/restraint_utils/geometry/harmonic.py b/openfe/protocols/restraint_utils/geometry/harmonic.py index 81e2f22b2..1cc8fb0e0 100644 --- a/openfe/protocols/restraint_utils/geometry/harmonic.py +++ b/openfe/protocols/restraint_utils/geometry/harmonic.py @@ -96,12 +96,12 @@ def get_molecule_centers_restraint( molA_rdmol : Chem.Mol An RDKit Molecule for the first molecule. molB_rdmol : Chem.Mol - An RDKit Molecule for the first molecule. + An RDKit Molecule for the second molecule. molA_idxs : list[int] The indices of the first molecule in the system. Note we assume these to be sorted in the same order as the input rdmol. molB_idxs : list[int] - The indices of the first molecule in the system. Note we assume these + The indices of the second molecule in the system. Note we assume these to be sorted in the same order as the input rdmol. Returns diff --git a/openfe/protocols/restraint_utils/geometry/utils.py b/openfe/protocols/restraint_utils/geometry/utils.py index 988a31cfe..d27e3156e 100644 --- a/openfe/protocols/restraint_utils/geometry/utils.py +++ b/openfe/protocols/restraint_utils/geometry/utils.py @@ -8,6 +8,7 @@ * Add relevant duecredit entries. """ from typing import Union, Optional +from itertools import combinations import numpy as np import numpy.typing as npt from scipy.stats import circvar @@ -134,14 +135,22 @@ def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: list[tuple[int]] List of tuples for each ring. """ + ringinfo = rdmol.GetRingInfo() arom_idxs = get_aromatic_atom_idxs(rdmol) aromatic_rings = [] + # Add to the aromatic_rings list if all the atoms in a ring are aromatic for ring in ringinfo.AtomRings(): if all(a in arom_idxs for a in ring): - aromatic_rings.append(ring) + aromatic_rings.append(set(ring)) + + # Reduce the ring list by merging any rings that have colliding atoms + for x, y in combinations(aromatic_rings, 2): + if not x.isdisjoint(y): + x.update(y) + aromatic_rings.remove(y) return aromatic_rings diff --git a/openfe/protocols/restraint_utils/openmm/omm_restraints.py b/openfe/protocols/restraint_utils/openmm/omm_restraints.py index 8df9cb837..352bbd9bd 100644 --- a/openfe/protocols/restraint_utils/openmm/omm_restraints.py +++ b/openfe/protocols/restraint_utils/openmm/omm_restraints.py @@ -62,13 +62,16 @@ class RestraintParameterState(GlobalParameterState): ``lambda_restraints_{parameters_name_suffix}` instead of just ``lambda_restraints``. lambda_restraints : Optional[float] - The strength of the restraint. If defined, must be between 0 and 1. + The scaling parameter for the restraint. If defined, + must be between 0 and 1. In most cases, a value of 1 indicates that the + restraint is fully turned on, whilst a value of 0 indicates that it is + innactive. Acknowledgement --------------- Partially reproduced from Yank. """ - + # We set the standard system to a fully interacting restraint lambda_restraints = GlobalParameterState.GlobalParameter( "lambda_restraints", standard_value=1.0 ) From 7c502c76e03b0f5977636a94ee39e8d16f3789ec Mon Sep 17 00:00:00 2001 From: IAlibay Date: Tue, 7 Jan 2025 17:54:55 +0000 Subject: [PATCH 33/33] Add a todo from last year --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 2b80370e0..b74bd3342 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -692,7 +692,7 @@ def run(self, *, dry=False, verbose=True, # Extract relevant settings protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['protocol'].settings stateA = self._inputs['stateA'] - stateB = self._inputs['stateB'] + stateB = self._inputs['stateB'] # TODO: open an issue about this not being used. mapping = self._inputs['ligandmapping'] forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings