diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 7bf92050e..4e3cccb08 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -702,7 +702,7 @@ def run(self, *, dry=False, verbose=True, # Extract relevant settings protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['protocol'].settings stateA = self._inputs['stateA'] - stateB = self._inputs['stateB'] + stateB = self._inputs['stateB'] # TODO: open an issue about this not being used. mapping = self._inputs['ligandmapping'] forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings diff --git a/openfe/protocols/restraint_utils/__init__.py b/openfe/protocols/restraint_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/restraint_utils/geometry/__init__.py b/openfe/protocols/restraint_utils/geometry/__init__.py new file mode 100644 index 000000000..1c1b4c56a --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/__init__.py @@ -0,0 +1,4 @@ +from .base import BaseRestraintGeometry +from .harmonic import DistanceRestraintGeometry +from .flatbottom import FlatBottomDistanceGeometry +from .boresch import BoreschRestraintGeometry diff --git a/openfe/protocols/restraint_utils/geometry/base.py b/openfe/protocols/restraint_utils/geometry/base.py new file mode 100644 index 000000000..798befd45 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/base.py @@ -0,0 +1,48 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + + +class BaseRestraintGeometry(BaseModel, abc.ABC): + """ + A base class for a restraint geometry. + """ + class Config: + arbitrary_types_allowed = True + + +class HostGuestRestraintGeometry(BaseRestraintGeometry): + """ + An ordered list of guest atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + + guest_atoms: list[int] + """ + An ordered list of host atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + host_atoms: list[int] + + @validator("guest_atoms", "host_atoms") + def positive_idxs(cls, v): + if v is not None and any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v diff --git a/openfe/protocols/restraint_utils/geometry/boresch/__init__.py b/openfe/protocols/restraint_utils/geometry/boresch/__init__.py new file mode 100644 index 000000000..57c306c58 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/boresch/__init__.py @@ -0,0 +1,4 @@ +from .geometry import ( + BoreschRestraintGeometry, + find_boresch_restraint, +) diff --git a/openfe/protocols/restraint_utils/geometry/boresch/geometry.py b/openfe/protocols/restraint_utils/geometry/boresch/geometry.py new file mode 100644 index 000000000..0640ecf21 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/boresch/geometry.py @@ -0,0 +1,298 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +from typing import Optional + +from rdkit import Chem + +from openff.units import unit +from openff.models.types import FloatQuantity +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals + +from openfe.protocols.restraint_utils.geometry.base import ( + HostGuestRestraintGeometry +) +from .guest import find_guest_atom_candidates +from .host import find_host_atom_candidates, find_host_anchor + + +class BoreschRestraintGeometry(HostGuestRestraintGeometry): + """ + A class that defines the restraint geometry for a Boresch restraint. + + The restraint is defined by the following: + + H2 G2 + - - + - - + H1 - - H0 -- G0 - - G1 + + Where HX represents the X index of ``host_atoms`` and GX + the X index of ``guest_atoms``. + """ + r_aA0: FloatQuantity['nanometer'] + """ + The equilibrium distance between H0 and G0. + """ + theta_A0: FloatQuantity['radians'] + """ + The equilibrium angle value between H1, H0, and G0. + """ + theta_B0: FloatQuantity['radians'] + """ + The equilibrium angle value between H0, G0, and G1. + """ + phi_A0: FloatQuantity['radians'] + """ + The equilibrium dihedral value between H2, H1, H0, and G0. + """ + phi_B0: FloatQuantity['radians'] + + """ + The equilibrium dihedral value between H1, H0, G0, and G1. + """ + phi_C0: FloatQuantity['radians'] + + """ + The equilibrium dihedral value between H0, G0, G1, and G2. + """ + + +def _get_restraint_distances( + atomgroup: mda.AtomGroup +) -> tuple[unit.Quantity]: + """ + Get the bond, angle, and dihedral distances for an input atomgroup + defining the six atoms for a Boresch-like restraint. + + The atoms must be in the order of H0, H1, H2, G0, G1, G2. + + Parameters + ---------- + atomgroup : mda.AtomGroup + An AtomGroup defining the restrained atoms in order. + + Returns + ------- + bond : unit.Quantity + The H0-G0 bond value. + angle1 : unit.Quantity + The H1-H0-G0 angle value. + angle2 : unit.Quantity + The H0-G0-G1 angle value. + dihed1 : unit.Quantity + The H2-H1-H0-G0 dihedral value. + dihed2 : unit.Quantity + The H1-H0-G0-G1 dihedral value. + dihed3 : unit.Quantity + The H0-G0-G1-G2 dihedral value. + """ + bond = calc_bonds( + atomgroup.atoms[0].position, + atomgroup.atoms[3].position, + box=atomgroup.dimensions + ) * unit.angstroms + + angles = [] + for idx_set in [[1, 0, 3], [0, 3, 4]]: + angle = calc_angles( + atomgroup.atoms[idx_set[0]].position, + atomgroup.atoms[idx_set[1]].position, + atomgroup.atoms[idx_set[2]].position, + box=atomgroup.dimensions, + ) + angles.append(angle * unit.radians) + + dihedrals = [] + for idx_set in [[2, 1, 0, 3], [1, 0, 3, 4], [0, 3, 4, 5]]: + dihed = calc_dihedrals( + atomgroup.atoms[idx_set[0]].position, + atomgroup.atoms[idx_set[1]].position, + atomgroup.atoms[idx_set[2]].position, + atomgroup.atoms[idx_set[3]].position, + box=atomgroup.dimensions, + ) + dihedrals.append(dihed * unit.radians) + + return bond, angles[0], angles[1], dihedrals[0], dihedrals[1], dihedrals[2] + + +def find_boresch_restraint( + universe: mda.Universe, + guest_rdmol: Chem.Mol, + guest_idxs: list[int], + host_idxs: list[int], + guest_restraint_atoms_idxs: Optional[list[int]] = None, + host_restraint_atoms_idxs: Optional[list[int]] = None, + host_selection: str = "all", + dssp_filter: bool = False, + rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, + host_min_distance: unit.Quantity = 1 * unit.nanometer, + host_max_distance: unit.Quantity = 3 * unit.nanometer, + angle_force_constant: unit.Quantity = ( + 83.68 * unit.kilojoule_per_mole / unit.radians**2 + ), + temperature: unit.Quantity = 298.15 * unit.kelvin, +) -> BoreschRestraintGeometry: + """ + Find suitable Boresch-style restraints between a host and guest entity + based on the approach of Baumann et al. [1] with some modifications. + + Parameters + ---------- + universe : mda.Universe + An MDAnalysis Universe defining the system and its coordinates. + guest_rdmol : Chem.Mol + An RDKit Mol for the guest molecule. + guest_idxs : list[int] + Indices in the topology for the guest molecule. + host_idxs : list[int] + Indices in the topology for the host molecule. + guest_restraint_atoms_idxs : Optional[list[int]] + User selected indices of the guest molecule itself (i.e. indexed + starting a 0 for the guest molecule). This overrides the + restraint search and a restraint using these indices will + be retruned. Must be defined alongside ``host_restraint_atoms_idxs``. + host_restraint_atoms_idxs : Optional[list[int]] + User selected indices of the host molecule itself (i.e. indexed + starting a 0 for the hosts molecule). This overrides the + restraint search and a restraint using these indices will + be returnned. Must be defined alongside ``guest_restraint_atoms_idxs``. + host_selection : str + An MDAnalysis selection string to sub-select the host atoms. + dssp_filter : bool + Whether or not to filter the host atoms by their secondary structure. + rmsf_cutoff : unit.Quantity + The cutoff value for atom root mean square fluction. Atoms with RMSF + values above this cutoff will be disregarded. + Must be in units compatible with nanometer. + host_min_distance : unit.Quantity + The minimum distance between any host atom and the guest G0 atom. + Must be in units compatible with nanometer. + host_max_distance : unit.Quantity + The maximum distance between any host atom and the guest G0 atom. + Must be in units compatible with nanometer. + angle_force_constant : unit.Quantity + The force constant for the G1-G0-H0 and G0-H0-H1 angles. Must be + in units compatible with kilojoule / mole / radians ** 2. + temperature : unit.Quantity + The system temperature in units compatible with Kelvin. + + Returns + ------- + BoreschRestraintGeometry + An object defining the parameters of the Boresch-like restraint. + + References + ---------- + [1] Baumann, Hannah M., et al. "Broadening the scope of binding free energy + calculations using a Separated Topologies approach." (2023). + """ + if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): # fmt: skip + # In this case assume the picked atoms were intentional / + # representative of the input and go with it + guest_ag = universe.select_atoms[guest_idxs] + guest_anchor = [ + at.ix for at in guest_ag.atoms[guest_restraint_atoms_idxs] + ] + host_ag = universe.select_atoms[host_idxs] + host_anchor = [ + at.ix for at in host_ag.atoms[host_restraint_atoms_idxs] + ] + + # Set the equilibrium values as those of the final frame + universe.trajectory[-1] + atomgroup = universe.atoms[host_anchor + guest_anchor] + bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( + atomgroup + ) + + # TODO: add checks to warn if this is a badly picked + # set of atoms. + return BoreschRestraintGeometry( + host_atoms=host_anchor, + guest_atoms=guest_anchor, + r_aA0=bond, + theta_A0=ang1, + theta_B0=ang2, + phi_A0=dih1, + phi_B0=dih2, + phi_C0=dih3 + ) + + if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): # fmt: skip + # This is not an intended outcome, crash out here + errmsg = ( + "both ``guest_restraints_atoms_idxs`` and " + "``host_restraint_atoms_idxs`` " + "must be set or both must be None. " + f"Got {guest_restraint_atoms_idxs} and {host_restraint_atoms_idxs}" + ) + raise ValueError(errmsg) + + # 1. Fetch the guest anchors + guest_anchors = find_guest_atom_candidates( + universe=universe, + rdmol=guest_rdmol, + guest_idxs=guest_idxs, + rmsf_cutoff=rmsf_cutoff, + ) + + if len(guest_anchors) == 0: + errmsg = "No suitable ligand atoms found for the restraint." + raise ValueError(errmsg) + + # 2. We then loop through the guest anchors to find suitable host atoms + for guest_anchor in guest_anchors: + # We next fetch the host atom pool + # Note: return is a set, so need to convert it later on + host_pool = find_host_atom_candidates( + universe=universe, + host_idxs=host_idxs, + l1_idx=guest_anchor[0], + host_selection=host_selection, + dssp_filter=dssp_filter, + rmsf_cutoff=rmsf_cutoff, + min_distance=host_min_distance, + max_distance=host_max_distance, + ) + + host_anchor = find_host_anchor( + guest_atoms=universe.atoms[list(guest_anchor)], + host_atom_pool=universe.atoms[list(host_pool)], + minimum_distance=0.5 * unit.nanometer, + angle_force_constant=angle_force_constant, + temperature=temperature, + ) + # continue if it's empty, otherwise stop + if host_anchor is not None: + break + + if host_anchor is None: + errmsg = "No suitable host atoms could be found" + raise ValueError(errmsg) + + # Set the equilibrium values as those of the final frame + universe.trajectory[-1] + atomgroup = universe.atoms[list(host_anchor) + list(guest_anchor)] + bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( + atomgroup + ) + + return BoreschRestraintGeometry( + host_atoms=host_anchor, + guest_atoms=guest_anchor, + r_aA0=bond, + theta_A0=ang1, + theta_B0=ang2, + phi_A0=dih1, + phi_B0=dih2, + phi_C0=dih3 + ) diff --git a/openfe/protocols/restraint_utils/geometry/boresch/guest.py b/openfe/protocols/restraint_utils/geometry/boresch/guest.py new file mode 100644 index 000000000..8c3490dd2 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/boresch/guest.py @@ -0,0 +1,252 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +from typing import Optional, Iterable + +from rdkit import Chem + +from openff.units import unit +import MDAnalysis as mda +import numpy as np +import numpy.typing as npt + +from openfe.protocols.restraint_utils.geometry.utils import ( + get_aromatic_rings, + get_heavy_atom_idxs, + get_central_atom_idx, + is_collinear, + get_local_rmsf, +) + + +def _sort_by_distance_from_atom( + rdmol: Chem.Mol, target_idx: int, atom_idxs: Iterable[int] +) -> list[int]: + """ + Sort a list of RDMol atoms by their distance from a target atom. + + Parameters + ---------- + target_idx : int + The idx of the atom to measure from. + atom_idxs : list[int] + The idx values of the atoms to sort. + rdmol : Chem.Mol + RDKit Molecule the atoms belong to + + Returns + ------- + list[int] + The input atom idxs sorted by their distance from the target atom. + """ + distances = [] + + conformer = rdmol.GetConformer() + # Get the target atom position + target_pos = conformer.GetAtomPosition(target_idx) + + for idx in atom_idxs: + pos = conformer.GetAtomPosition(idx) + distances.append(((target_pos - pos).Length(), idx)) + + return [i[1] for i in sorted(distances)] + + +def _bonded_angles_from_pool( + rdmol: Chem.Mol, + atom_idx: int, + atom_pool: list[int], + aromatic_only: bool, +) -> list[tuple[int, int, int]]: + """ + Get all bonded angles starting from ``atom_idx`` from a pool of atoms. + + Parameters + ---------- + rdmol : Chem.Mol + The RDKit Molecule + atom_idx : int + The index of the atom to search angles from. + atom_pool : list[int] + The list of indices to pick possible angle partners from. + aromatic_only : bool + Prune any angles that include non-aromatic bonds. + + Returns + ------- + list[tuple[int, int, int]] + A list of tuples containing all the angles. + + Notes + ----- + * In the original SepTop code at3 is picked as directly bonded to at1. + By comparison here we instead follow the case that at3 is bonded to + at2 but not bonded to at1. + """ + angles = [] + + # Get the base atom and its neighbors + at1 = rdmol.GetAtomWithIdx(atom_idx) + at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] + + # We loop at2 and at3 through the sorted atom_pool in order to get + # a list of angles in the branch that are sorted by how close the atoms + # are from the central atom + for at2 in atom_pool: + if at2 in at1_neighbors: + at2_neighbors = [ + at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + ] + for at3 in atom_pool: + if at3 != atom_idx and at3 in at2_neighbors: + angles.append((atom_idx, at2, at3)) + + if aromatic_only: # TODO: move this to its own method? + aromatic_rings = get_aromatic_rings(rdmol) + + def _belongs_to_ring(angle, aromatic_rings): + for ring in aromatic_rings: + if all(a in ring for a in angle): + return True + return False + + for angle in angles: + if not _belongs_to_ring(angle, aromatic_rings): + angles.remove(angle) + + return angles + + +def _get_guest_atom_pool( + rdmol: Chem.Mol, + rmsf: npt.NDArray, + rmsf_cutoff: unit.Quantity +) -> tuple[Optional[set[int]], bool]: + """ + Filter atoms based on rmsf & rings, defaulting to heavy atoms if + there are not enough. + + Parameters + ---------- + rdmol : Chem.Mol + The RDKit Molecule to search through + rmsf : npt.NDArray + A 1-D array of RMSF values for each atom. + rmsf_cutoff : unit.Quantity + The rmsf cutoff value for selecting atoms in units compatible with + nanometer. + + Returns + ------- + atom_pool : Optional[set[int]] + A pool of candidate atoms. + ring_atoms_only : bool + True if only ring atoms were selected. + """ + # Get a list of all the aromatic rings + # Note: no need to keep track of rings because we'll filter by + # bonded terms after, so if we only keep rings then all the bonded + # atoms should be within the same ring system. + atom_pool: set[int] = set() + ring_atoms_only: bool = True + for ring in get_aromatic_rings(rdmol): + max_rmsf = rmsf[list(ring)].max() + if max_rmsf < rmsf_cutoff: + atom_pool.update(ring) + + # if we don't have enough atoms just get all the heavy atoms + if len(atom_pool) < 3: + ring_atoms_only = False + heavy_atoms = np.array(get_heavy_atom_idxs(rdmol)) + atom_pool = set(heavy_atoms[rmsf[heavy_atoms] < rmsf_cutoff]) + if len(atom_pool) < 3: + return None, False + + return atom_pool, ring_atoms_only + + +def find_guest_atom_candidates( + universe: mda.Universe, + rdmol: Chem.Mol, + guest_idxs: list[int], + rmsf_cutoff: unit.Quantity = 1 * unit.nanometer, +) -> list[tuple[int]]: + """ + Get a list of potential ligand atom choices for a Boresch restraint + being applied to a given small molecule. + + Parameters + ---------- + universe : mda.Universe + An MDAnalysis Universe defining the system and its coordinates. + rdmol : Chem.Mol + An RDKit Molecule representing the small molecule ordered in + the same way as it is listed in the topology. + guest_idxs : list[int] + The ligand indices in the topology. + rmsf_cutoff : unit.Quantity + The RMSF filter cut-off. + + Returns + ------- + angle_list : list[tuple[int]] + A list of tuples for each valid G0, G1, G2 angle. If ``None``, no + angles could be found. + + Raises + ------ + ValueError + If no suitable ligand atoms could be found. + + TODO + ---- + Should the RDMol have a specific frame position? + """ + ligand_ag = universe.atoms[guest_idxs] + + # 0. Get the ligand RMSF + rmsf = get_local_rmsf(ligand_ag) + universe.trajectory[-1] # forward to the last frame + + # 1. Get the pool of atoms to work with + atom_pool, rings_only = _get_guest_atom_pool(rdmol, rmsf, rmsf_cutoff) + + if atom_pool is None: + # We don't have enough atoms so we raise an error + errmsg = "No suitable ligand atoms were found for the restraint" + raise ValueError(errmsg) + + # 2. Get the central atom + center = get_central_atom_idx(rdmol) + + # 3. Sort the atom pool based on their distance from the center + sorted_atom_pool = _sort_by_distance_from_atom(rdmol, center, atom_pool) + + # 4. Get a list of probable angles + angles_list = [] + for atom in sorted_atom_pool: + angles = _bonded_angles_from_pool( + rdmol=rdmol, + atom_idx=atom, + atom_pool=sorted_atom_pool, + aromatic_only=rings_only, + ) + for angle in angles: + # Check that the angle is at least not collinear + angle_ag = ligand_ag.atoms[list(angle)] + if not is_collinear(ligand_ag.positions, angle, universe.dimensions): + angles_list.append( + ( + angle_ag.atoms[0].ix, + angle_ag.atoms[1].ix, + angle_ag.atoms[2].ix + ) + ) + + return angles_list diff --git a/openfe/protocols/restraint_utils/geometry/boresch/host.py b/openfe/protocols/restraint_utils/geometry/boresch/host.py new file mode 100644 index 000000000..56f015241 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/boresch/host.py @@ -0,0 +1,430 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +from typing import Optional +import warnings + +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals +import numpy as np +import numpy.typing as npt + +from openfe.protocols.restraint_utils.geometry.utils import ( + is_collinear, + check_angular_variance, + check_dihedral_bounds, + check_angle_not_flat, + FindHostAtoms, + get_local_rmsf, + stable_secondary_structure_selection +) + + +def find_host_atom_candidates( + universe: mda.Universe, + host_idxs: list[int], + l1_idx: int, + host_selection: str, + dssp_filter: bool = False, + rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, + min_distance: unit.Quantity = 1 * unit.nanometer, + max_distance: unit.Quantity = 3 * unit.nanometer, +) -> npt.NDArray: + """ + Get a list of suitable host atoms. + + Parameters + ---------- + universe : mda.Universe + An MDAnalysis Universe defining the system and its coordinates. + host_idxs : list[int] + A list of the host indices in the system topology. + l1_idx : int + The index of the proposed l1 binding atom. + host_selection : str + An MDAnalysis selection string to filter the host by. + dssp_filter : bool + Whether or not to apply a DSSP filter on the host selection. + rmsf_cutoff : uni.Quantity + The maximum RMSF value allowwed for any candidate host atom. + min_distance : unit.Quantity + The minimum search distance around l1 for suitable candidate atoms. + max_distance : unit.Quantity + The maximum search distance around l1 for suitable candidate atoms. + + Return + ------ + NDArray + Array of host atom indexes + """ + # Get an AtomGroup for the host based on the input host indices + host_ag = universe.atoms[host_idxs] + + # Filter the host AtomGroup based on ``host_selection` + selected_host_ag = host_ag.select_atoms(host_selection) + + # If requested, filter the host atoms based on if their residues exist + # within stable secondary structures. + if dssp_filter: + # TODO: allow user-supplied kwargs here + stable_ag = stable_secondary_structure_selection(selected_host_ag) + + if len(stable_ag) < 20: + wmsg = ( + "Secondary structure filtering: " + "Too few atoms found via secondary strcuture filtering will " + "try to only select all residues in protein chains instead." + ) + warnings.warn(wmsg) + stable_ag = protein_chain_selection(selected_host_ag) + + if len(stable_ag) < 20: + wmsg = ( + "Secondary structure filtering: " + "Too few atoms found in protein residue chains, will just " + "use all atoms." + ) + warnings.warn(wmsg) + else: + selected_host_ag = stable_ag + + # 1. Get the RMSF & filter to create a new AtomGroup + rmsf = get_local_rmsf(selected_host_ag) + filtered_host_ag = selected_host_ag.atoms[rmsf < rmsf_cutoff] + + # 2. Search of atoms within the min/max cutoff + atom_finder = FindHostAtoms( + host_atoms=filtered_host_ag, + guest_atoms=universe.atoms[l1_idx], + min_search_distance=min_distance, + max_search_distance=max_distance, + ) + atom_finder.run() + return atom_finder.results.host_idxs + + +class EvaluateHostAtoms1(AnalysisBase): + """ + Class to evaluate the suitability of a set of host atoms + as either H0 or H1 atoms (i.e. the first and second host atoms). + + Parameters + ---------- + reference : MDAnalysis.AtomGroup + The reference preceeding three atoms. + host_atom_pool : MDAnalysis.AtomGroup + The pool of atoms to pick an atom from. + minimum_distance : unit.Quantity + The minimum distance from the bound reference atom. + angle_force_constant : unit.Quantity + The force constant for the angle. + temperature : unit.Quantity + The system temperature in Kelvin + """ + def __init__( + self, + reference, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + **kwargs, + ): + super().__init__(reference.universe.trajectory, **kwargs) + + if len(reference) != 3: + errmsg = "Incorrect number of reference atoms passed" + raise ValueError(errmsg) + + self.reference = reference + self.host_atom_pool = host_atom_pool + self.minimum_distance = minimum_distance.to("angstrom").m + self.angle_force_constant = angle_force_constant + self.temperature = temperature + + def _prepare(self): + self.results.distances = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.angles = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.dihedrals = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.collinear = np.empty( + (len(self.host_atom_pool), self.n_frames), + dtype=bool, + ) + self.results.valid = np.empty( + len(self.host_atom_pool), + dtype=bool, + ) + # Set everything to False to begin with + self.results.valid[:] = False + + def _single_frame(self): + for i, at in enumerate(self.host_atom_pool): + distance = calc_bonds( + at.position, + self.reference.atoms[0].position, + box=self.reference.dimensions, + ) + angle = calc_angles( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + box=self.reference.dimensions, + ) + dihedral = calc_dihedrals( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + self.reference.atoms[2].position, + box=self.reference.dimensions, + ) + collinear = is_collinear( + positions=np.vstack((at.position, self.reference.positions)), + atoms=[0, 1, 2, 3], + dimensions=self.reference.dimensions, + ) + self.results.distances[i][self._frame_index] = distance + self.results.angles[i][self._frame_index] = angle + self.results.dihedrals[i][self._frame_index] = dihedral + self.results.collinear[i][self._frame_index] = collinear + + def _conclude(self): + for i, at in enumerate(self.host_atom_pool): + # Check distances + distance_bounds = all( + self.results.distances[i] > self.minimum_distance + ) + # Check angles + angle_bounds = all( + check_angle_not_flat( + angle=angle * unit.radians, + force_constant=self.angle_force_constant, + temperature=self.temperature + ) + for angle in self.results.angles[i] + ) + angle_variance = check_angular_variance( + self.results.angles[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=0 * unit.radians, + width=1.745 * unit.radians, + ) + # Check dihedrals + dihed_bounds = all( + check_dihedral_bounds(dihed * unit.radians) + for dihed in self.results.dihedrals[i] + ) + dihed_variance = check_angular_variance( + self.results.dihedrals[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=5.23 * unit.radians, + ) + not_collinear = not all(self.results.collinear[i]) + if all( + [ + distance_bounds, + angle_bounds, + angle_variance, + dihed_bounds, + dihed_variance, + not_collinear, + ] + ): + self.results.valid[i] = True + + +class EvaluateHostAtoms2(EvaluateHostAtoms1): + """ + Class to evaluate the suitability of a set of host atoms + as H2 atoms (i.e. the third host atoms). + + Parameters + ---------- + reference : MDAnalysis.AtomGroup + The reference preceeding three atoms. + host_atom_pool : MDAnalysis.AtomGroup + The pool of atoms to pick an atom from. + minimum_distance : unit.Quantity + The minimum distance from the bound reference atom. + angle_force_constant : unit.Quantity + The force constant for the angle. + temperature : unit.Quantity + The system temperature in Kelvin + """ + def _prepare(self): + self.results.distances1 = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.distances2 = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.dihedrals = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.collinear = np.empty( + (len(self.host_atom_pool), self.n_frames), + dtype=bool, + ) + self.results.valid = np.empty( + len(self.host_atom_pool), + dtype=bool, + ) + # Default to valid == False + self.results.valid[:] = False + + def _single_frame(self): + for i, at in enumerate(self.host_atom_pool): + distance1 = calc_bonds( + at.position, + self.reference.atoms[0].position, + box=self.reference.dimensions, + ) + distance2 = calc_bonds( + at.position, + self.reference.atoms[1].position, + box=self.reference.dimensions, + ) + dihedral = calc_dihedrals( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + self.reference.atoms[2].position, + box=self.reference.dimensions, + ) + collinear = is_collinear( + positions=np.vstack((at.position, self.reference.positions)), + atoms=[0, 1, 2, 3], + dimensions=self.reference.dimensions, + ) + self.results.distances1[i][self._frame_index] = distance1 + self.results.distances2[i][self._frame_index] = distance2 + self.results.dihedrals[i][self._frame_index] = dihedral + self.results.collinear[i][self._frame_index] = collinear + + def _conclude(self): + for i, at in enumerate(self.host_atom_pool): + distance1_bounds = all( + self.results.distances1[i] > self.minimum_distance + ) + distance2_bounds = all( + self.results.distances2[i] > self.minimum_distance + ) + dihed_bounds = all( + check_dihedral_bounds(dihed * unit.radians) + for dihed in self.results.dihedrals[i] + ) + dihed_variance = check_angular_variance( + self.results.dihedrals[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=5.23 * unit.radians, + ) + not_collinear = not all(self.results.collinear[i]) + if all( + [ + distance1_bounds, + distance2_bounds, + dihed_bounds, + dihed_variance, + not_collinear, + ] + ): + self.results.valid[i] = True + + +def find_host_anchor( + guest_atoms: mda.AtomGroup, + host_atom_pool: mda.AtomGroup, + minimum_distance: unit.Quantity, + angle_force_constant: unit.Quantity, + temperature: unit.Quantity +) -> Optional[list[int]]: + """ + Find suitable atoms for the H0-H1-H2 portion of the restraint. + + Parameters + ---------- + guest_atoms : mda.AtomGroup + The guest anchor atoms for G0-G1-G2 + host_atom_pool : mda.AtomGroup + The host atoms to search from. + minimum_distance : unit.Quantity + The minimum distance to pick host atoms from each other. + angle_force_constant : unit.Quantity + The force constant for the G1-G0-H0 and G0-H0-H1 angles. + temperature : unit.Quantity + The target system temperature. + + Returns + ------- + Optional[list[int]] + A list of indices for a selected combination of H0, H1, and H2. + """ + # Evalulate the host_atom_pool for suitability as H0 atoms + h0_eval = EvaluateHostAtoms1( + guest_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) + h0_eval.run() + + for i, valid_h0 in enumerate(h0_eval.results.valid): + # If valid H0 atom, evaluate rest of host_atom_pool for suitability + # as H1 atoms. + if valid_h0: + h0g0g1_atoms = host_atom_pool.atoms[i] + guest_atoms.atoms[:2] + h1_eval = EvaluateHostAtoms1( + h0g0g1_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) + h1_eval.run() + for j, valid_h1 in enumerate(h1_eval.results.valid): + # If valid H1 atom, evaluate rest of host_atom_pool for + # suitability as H2 atoms + if valid_h1: + h1h0g0_atoms = host_atom_pool.atoms[j] + h0g0g1_atoms.atoms[:2] + h2_eval = EvaluateHostAtoms2( + h1h0g0_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) + h2_eval.run() + + if any(h2_eval.results.valid): + # Get the sum of the average distances (dsum_avgs) + # for all the host_atom_pool atoms + distance1_avgs = np.array( + [d.mean() for d in h2_eval.results.distances1] + ) + distance2_avgs = np.array( + [d.mean() for d in h2_eval.results.distances2] + ) + dsum_avgs = distance1_avgs + distance2_avgs + + # Now filter by validity as H2 atom + h2_dsum_avgs = [ + (idx, val) for idx, val in enumerate(dsum_avgs) + if h2_eval.results.valid[idx] + ] + + # Get the index of the H2 atom with the lowest + # average distance + k = sorted(h2_dsum_avgs, key=lambda x: x[1])[0][0] + + return list(host_atom_pool.atoms[[i, j, k]].ix) + return None diff --git a/openfe/protocols/restraint_utils/geometry/flatbottom.py b/openfe/protocols/restraint_utils/geometry/flatbottom.py new file mode 100644 index 000000000..1f88fbf59 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/flatbottom.py @@ -0,0 +1,126 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +from typing import Optional +import numpy as np +from openff.units import unit +from openff.models.types import FloatQuantity +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds + +from .harmonic import ( + DistanceRestraintGeometry, +) + +from .utils import _get_mda_selection + + +class FlatBottomDistanceGeometry(DistanceRestraintGeometry): + """ + A geometry class for a flat bottom distance restraint between two groups + of atoms. + """ + well_radius: FloatQuantity["nanometer"] + + +class COMDistanceAnalysis(AnalysisBase): + """ + Get a timeseries of COM distances between two AtomGroups + + Parameters + ---------- + group1 : MDAnalysis.AtomGroup + Atoms defining the first centroid. + group2 : MDANalysis.AtomGroup + Atoms defining the second centroid. + """ + _analysis_algorithm_is_parallelizable = False + + def __init__(self, group1, group2, **kwargs): + super().__init__(group1.universe.trajectory, **kwargs) + + self.ag1 = group1 + self.ag2 = group2 + + def _prepare(self): + self.results.distances = np.zeros(self.n_frames) + + def _single_frame(self): + com_dist = calc_bonds( + self.ag1.center_of_mass(), + self.ag2.center_of_mass(), + box=self.ag1.universe.dimensions, + ) + self.results.distances[self._frame_index] = com_dist + + def _conclude(self): + pass + + +def get_flatbottom_distance_restraint( + universe: mda.Universe, + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, + padding: unit.Quantity = 0.5 * unit.nanometer, +) -> FlatBottomDistanceGeometry: + """ + Get a FlatBottomDistanceGeometry by analyzing the COM distance + change between two sets of atoms. + + The ``well_radius`` is defined as the maximum COM distance plus + ``padding``. + + Parameters + ---------- + universe : mda.Universe + An MDAnalysis Universe defining the system and its coordinates. + host_atoms : Optional[list[int]] + A list of host atoms indices. Either ``host_atoms`` or + ``host_selection`` must be defined. + guest_atoms : Optional[list[int]] + A list of guest atoms indices. Either ``guest_atoms`` or + ``guest_selection`` must be defined. + host_selection : Optional[str] + An MDAnalysis selection string to define the host atoms. + Either ``host_atoms`` or ``host_selection`` must be defined. + guest_selection : Optional[str] + An MDAnalysis selection string to define the guest atoms. + Either ``guest_atoms`` or ``guest_selection`` must be defined. + padding : unit.Quantity + A padding value to add to the ``well_radius`` definition. + Must be in units compatible with nanometers. + + Returns + ------- + FlatBottomDistanceGeometry + An object defining a flat bottom restraint geometry. + """ + guest_ag = _get_mda_selection(universe, guest_atoms, guest_selection) + host_ag = _get_mda_selection(universe, host_atoms, host_selection) + guest_idxs = [a.ix for a in guest_ag] + host_idxs = [a.ix for a in host_ag] + + if len(host_idxs) == 0 or len(guest_idxs) == 0: + errmsg = ( + "no atoms found in either the host or guest atom groups" + f"host_atoms: {host_idxs}" + f"guest_atoms: {guest_idxs}" + ) + raise ValueError(errmsg) + + com_dists = COMDistanceAnalysis(guest_ag, host_ag) + com_dists.run() + + well_radius = com_dists.results.distances.max() * unit.angstrom + padding + return FlatBottomDistanceGeometry( + guest_atoms=guest_idxs, host_atoms=host_idxs, well_radius=well_radius + ) diff --git a/openfe/protocols/restraint_utils/geometry/harmonic.py b/openfe/protocols/restraint_utils/geometry/harmonic.py new file mode 100644 index 000000000..838724deb --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/harmonic.py @@ -0,0 +1,106 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +from typing import Optional +import MDAnalysis as mda +from rdkit import Chem + +from .base import HostGuestRestraintGeometry +from .utils import ( + get_central_atom_idx, + _get_mda_selection, +) + + +class DistanceRestraintGeometry(HostGuestRestraintGeometry): + """ + A geometry class for a distance restraint between two groups of atoms. + """ + + +def get_distance_restraint( + universe: mda.Universe, + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, +) -> DistanceRestraintGeometry: + """ + Get a DistanceRestraintGeometry between two groups of atoms. + + You can either select the groups by passing through a set of indices + or an MDAnalysis selection. + + Parameters + ---------- + universe : mda.Universe + An MDAnalysis Universe defining the system and its coordinates. + host_atoms : Optional[list[int]] + A list of host atoms indices. Either ``host_atoms`` or + ``host_selection`` must be defined. + guest_atoms : Optional[list[int]] + A list of guest atoms indices. Either ``guest_atoms`` or + ``guest_selection`` must be defined. + host_selection : Optional[str] + An MDAnalysis selection string to define the host atoms. + Either ``host_atoms`` or ``host_selection`` must be defined. + guest_selection : Optional[str] + An MDAnalysis selection string to define the guest atoms. + Either ``guest_atoms`` or ``guest_selection`` must be defined. + + Returns + ------- + DistanceRestraintGeometry + An object that defines a distance restraint geometry. + """ + guest_ag = _get_mda_selection(universe, guest_atoms, guest_selection) + guest_atoms = [a.ix for a in guest_ag] + host_ag = _get_mda_selection(universe, host_atoms, host_selection) + host_atoms = [a.ix for a in host_ag] + + return DistanceRestraintGeometry( + guest_atoms=guest_atoms, host_atoms=host_atoms + ) + + +def get_molecule_centers_restraint( + molA_rdmol: Chem.Mol, + molB_rdmol: Chem.Mol, + molA_idxs: list[int], + molB_idxs: list[int], +): + """ + Get a DistanceRestraintGeometry between the central atoms of + two molecules. + + Parameters + ---------- + molA_rdmol : Chem.Mol + An RDKit Molecule for the first molecule. + molB_rdmol : Chem.Mol + An RDKit Molecule for the second molecule. + molA_idxs : list[int] + The indices of the first molecule in the system. Note we assume these + to be sorted in the same order as the input rdmol. + molB_idxs : list[int] + The indices of the second molecule in the system. Note we assume these + to be sorted in the same order as the input rdmol. + + Returns + ------- + DistanceRestraintGeometry + An object that defines a distance restraint geometry. + """ + # We assume that the mol idxs are ordered + centerA = molA_idxs[get_central_atom_idx(molA_rdmol)] + centerB = molB_idxs[get_central_atom_idx(molB_rdmol)] + + return DistanceRestraintGeometry( + guest_atoms=[centerA], host_atoms=[centerB] + ) diff --git a/openfe/protocols/restraint_utils/geometry/utils.py b/openfe/protocols/restraint_utils/geometry/utils.py new file mode 100644 index 000000000..4ca5de7a4 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/utils.py @@ -0,0 +1,720 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Search methods for generating Geometry objects + +TODO +---- +* Add relevant duecredit entries. +""" +from typing import Union, Optional +from itertools import combinations, groupby +import numpy as np +import numpy.typing as npt +from scipy.stats import circvar +import warnings + +from openff.toolkit import Molecule as OFFMol +from openff.units import unit +import networkx as nx +from rdkit import Chem +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.analysis.rms import RMSF +from MDAnalysis.analysis.dssp import DSSP +from MDAnalysis.lib.distances import minimize_vectors, capped_distance +from MDAnalysis.transformations.nojump import NoJump + +from openfe_analysis.transformations import Aligner + + +DEFAULT_ANGLE_FRC_CONSTANT = 83.68 * unit.kilojoule_per_mole / unit.radians**2 + + +def _get_mda_selection( + universe: Union[mda.Universe, mda.AtomGroup], + atom_list: Optional[list[int]] = None, + selection: Optional[str] = None, +) -> mda.AtomGroup: + """ + Return an AtomGroup based on either a list of atom indices or an + mdanalysis string selection. + + Parameters + ---------- + universe : Union[mda.Universe, mda.AtomGroup] + The MDAnalysis Universe or AtomGroup to get the AtomGroup from. + atom_list : Optional[list[int]] + A list of atom indices. + selection : Optional[str] + An MDAnalysis selection string. + + Returns + ------- + ag : mda.AtomGroup + An atom group selected from the inputs. + + Raises + ------ + ValueError + If both ``atom_list`` and ``selection`` are ``None`` + or are defined. + """ + if atom_list is None: + if selection is None: + raise ValueError( + "one of either the atom lists or selections must be defined" + ) + + ag = universe.select_atoms(selection) + else: + if selection is not None: + raise ValueError( + "both atom_list and selection cannot be defined together" + ) + ag = universe.atoms[atom_list] + return ag + + +def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: + """ + Get a list of tuples with the indices for each ring in an rdkit Molecule. + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molecule + + Returns + ------- + list[tuple[int]] + List of tuples for each ring. + """ + + ringinfo = rdmol.GetRingInfo() + arom_idxs = get_aromatic_atom_idxs(rdmol) + + aromatic_rings = [] + + # Add to the aromatic_rings list if all the atoms in a ring are aromatic + for ring in ringinfo.AtomRings(): + if all(a in arom_idxs for a in ring): + aromatic_rings.append(set(ring)) + + # Reduce the ring list by merging any rings that have colliding atoms + for x, y in combinations(aromatic_rings, 2): + if not x.isdisjoint(y): + x.update(y) + aromatic_rings.remove(y) + + return aromatic_rings + + +def get_aromatic_atom_idxs(rdmol: Chem.Mol) -> list[int]: + """ + Helper method to get aromatic atoms idxs + in a RDKit Molecule + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molecule + + Returns + ------- + list[int] + A list of the aromatic atom idxs + """ + idxs = [at.GetIdx() for at in rdmol.GetAtoms() if at.GetIsAromatic()] + return idxs + + +def get_heavy_atom_idxs(rdmol: Chem.Mol) -> list[int]: + """ + Get idxs of heavy atoms in an RDKit Molecule + + Parameters + ---------- + rmdol : Chem.Mol + + Returns + ------- + list[int] + A list of heavy atom idxs + """ + idxs = [at.GetIdx() for at in rdmol.GetAtoms() if at.GetAtomicNum() > 1] + return idxs + + +def get_central_atom_idx(rdmol: Chem.Mol) -> int: + """ + Get the central atom in an rdkit Molecule. + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molcule to query + + Returns + ------- + int + Index of central atom in Molecule + + Note + ---- + If there are equal likelihood centers, will return + the first entry. + """ + # TODO: switch to a manual conversion to avoid an OpenFF dependency + offmol = OFFMol(rdmol, allow_undefined_stereo=True) + nx_mol = offmol.to_networkx() + + if not nx.is_weakly_connected(nx_mol.to_directed()): + errmsg = "A disconnected molecule was passed, cannot find the center" + raise ValueError(errmsg) + + # Get a list of all shortest paths + shortest_paths = [ + path + for node_paths in nx.shortest_path(nx_mol).values() + for path in node_paths.values() + ] + + # Get the longest of these paths (returns first instance) + longest_path = max(shortest_paths, key=len) + + # Return the index of the central atom + return longest_path[len(longest_path) // 2] + + +def is_collinear( + positions: npt.ArrayLike, + atoms: list[int], + dimensions=None, + threshold=0.9 +): + """ + Check whether any sequential vectors in a sequence of atoms are collinear. + + Approach: for each sequential set of 3 atoms (defined as A, B, and C), + calculates the nomralized inner product (i.e. cos^-1(angle)) between + vectors AB adn BC. If the absolute value of this inner product is + close to 1 (i.e. an angle of 0 radians), then the three atoms are + considered as colinear. You can use ``threshold`` to define how close + to 1 is considered "flat". + + Parameters + ---------- + positions : npt.ArrayLike + System positions. + atoms : list[int] + The indices of the atoms to test. + dimensions : Optional[npt.NDArray] + The dimensions of the system to minimize vectors. + threshold : float + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. + + Returns + ------- + result : bool + Returns True if any sequential pair of vectors is collinear; + False otherwise. + + Notes + ----- + Originally from Yank. + """ + if len(atoms) < 3: + raise ValueError("Too few atoms passed for co-linearity test") + if len(positions) < len(atoms) or len(positions) < max(atoms) + 1: + errmsg = "atoms indices do not match the positions array passed" + raise ValueError(errmsg) + if not all(isinstance(x, int) for x in atoms): + errmsg = "atoms is not a list of index integers" + raise ValueError(errmsg) + + result = False + for i in range(len(atoms) - 2): + v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] + v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] + if dimensions is not None: + v1 = minimize_vectors(v1, box=dimensions) + v2 = minimize_vectors(v2, box=dimensions) + + normalized_inner_product = np.dot(v1, v2) / np.sqrt( + np.dot(v1, v1) * np.dot(v2, v2) + ) + result = result or (np.abs(normalized_inner_product) > threshold) + return result + + +def _wrap_angle(angle: unit.Quantity) -> unit.Quantity: + """ + Wrap an angle to -pi to pi radians. + + Parameters + ---------- + angle : unit.Quantity + An angle in radians compatible units. + + Returns + ------- + unit.Quantity + The angle in units of radians wrapped. + + Notes + ----- + Print automatically converts the angle to radians + as it passes it through arctan2. + """ + return np.arctan2(np.sin(angle), np.cos(angle)) + + +def check_angle_not_flat( + angle: unit.Quantity, + force_constant: unit.Quantity = DEFAULT_ANGLE_FRC_CONSTANT, + temperature: unit.Quantity = 298.15 * unit.kelvin, +) -> bool: + """ + Check whether the chosen angle is less than 10 kT from 0 or pi radians + + Parameters + ---------- + angle : unit.Quantity + The angle to check in units compatible with radians. + force_constant : unit.Quantity + Force constant of the angle in units compatible with + kilojoule_per_mole / radians ** 2. + temperature : unit.Quantity + The system temperature in units compatible with Kelvin. + + Returns + ------- + bool + False if the angle is less than 10 kT from 0 or pi radians + + Note + ---- + We assume the temperature to be 298.15 Kelvin. + + Acknowledgements + ---------------- + This code was initially contributed by Vytautas Gapsys. + """ + # Convert things + angle_rads = _wrap_angle(angle) + frc_const = force_constant.to("unit.kilojoule_per_mole / unit.radians**2") + temp_kelvin = temperature.to("kelvin") + RT = 8.31445985 * 0.001 * temp_kelvin + + # check if angle is <10kT from 0 or 180 + check1 = 0.5 * frc_const * np.power((angle_rads - 0.0), 2) + check2 = 0.5 * frc_const * np.power((angle_rads - np.pi), 2) + ang_check_1 = check1 / RT + ang_check_2 = check2 / RT + if ang_check_1.m < 10.0 or ang_check_2.m < 10.0: + return False + return True + + +def check_dihedral_bounds( + dihedral: unit.Quantity, + lower_cutoff: unit.Quantity = -2.618 * unit.radians, + upper_cutoff: unit.Quantity = 2.618 * unit.radians, +) -> bool: + """ + Check that a dihedral does not exceed the bounds set by + lower_cutoff and upper_cutoff on a -pi to pi range. + + All angles and cutoffs are wrapped to -pi to pi before + applying the check. + + Parameters + ---------- + dihedral : unit.Quantity + Dihedral in units compatible with radians. + lower_cutoff : unit.Quantity + Dihedral lower cutoff in units compatible with radians. + upper_cutoff : unit.Quantity + Dihedral upper cutoff in units compatible with radians. + + Returns + ------- + bool + ``True`` if the dihedral is within the upper and lower + cutoff bounds. + """ + dihed = _wrap_angle(dihedral) + lower = _wrap_angle(lower_cutoff) + upper = _wrap_angle(upper_cutoff) + if (dihed < lower) or (dihed > upper): + return False + return True + + +def check_angular_variance( + angles: unit.Quantity, + upper_bound: unit.Quantity, + lower_bound: unit.Quantity, + width: unit.Quantity, +) -> bool: + """ + Check that the variance of a list of ``angles`` does not exceed + a given ``width`` + + Parameters + ---------- + angles : ArrayLike unit.Quantity + An array of angles in units compatible with radians. + upper_bound: unit.Quantity + The upper bound in the angle range in radians compatible units. + lower_bound: unit.Quantity + The lower bound in the angle range in radians compatible units. + width : unit.Quantity + The width to check the variance against, in units compatible with + radians. + + Returns + ------- + bool + ``True`` if the variance of the angles is less than the width. + + """ + # scipy circ methods already recasts internally so we shouldn't + # need to wrap the angles + variance = circvar( + angles.to("radians").m, + high=upper_bound.to("radians").m, + low=lower_bound.to("radians").m, + ) + return not (variance * unit.radians > width) + + +class FindHostAtoms(AnalysisBase): + """ + Class filter host atoms based on their distance + from a set of guest atoms. + + Parameters + ---------- + host_atoms : MDAnalysis.AtomGroup + Initial selection of host atoms to filter from. + guest_atoms : MDANalysis.AtomGroup + Selection of guest atoms to search around. + min_search_distance: unit.Quantity + Minimum distance to filter atoms within. + max_search_distance: unit.Quantity + Maximum distance to filter atoms within. + """ + _analysis_algorithm_is_parallelizable = False + + def __init__( + self, + host_atoms, + guest_atoms, + min_search_distance, + max_search_distance, + **kwargs, + ): + super().__init__(host_atoms.universe.trajectory, **kwargs) + + def get_atomgroup(ag): + if ag._is_group: + return ag + return mda.AtomGroup([ag]) + + self.host_ag = get_atomgroup(host_atoms) + self.guest_ag = get_atomgroup(guest_atoms) + self.min_cutoff = min_search_distance.to("angstrom").m + self.max_cutoff = max_search_distance.to("angstrom").m + + def _prepare(self): + self.results.host_idxs = set(self.host_ag.atoms.ix) + + def _single_frame(self): + pairs = capped_distance( + reference=self.guest_ag.positions, + configuration=self.host_ag.positions, + max_cutoff=self.max_cutoff, + min_cutoff=self.min_cutoff, + box=self.guest_ag.universe.dimensions, + return_distances=False, + ) + + host_idxs = set(self.host_ag.atoms[p].ix for p in pairs[:, 1]) + + # We do an intersection as we go along to prune atoms that don't pass + # the distance selection criteria + self.results.host_idxs = self.results.host_idxs.intersection( + host_idxs + ) + + def _conclude(self): + self.results.host_idxs = np.array(list(self.results.host_idxs)) + + +def get_local_rmsf(atomgroup: mda.AtomGroup) -> unit.Quantity: + """ + Get the RMSF of an AtomGroup when aligned upon itself. + + Parameters + ---------- + atomgroup : MDAnalysis.AtomGroup + + Return + ------ + rmsf + ArrayQuantity of RMSF values. + """ + # First let's copy our Universe + copy_u = atomgroup.universe.copy() + ag = copy_u.atoms[atomgroup.atoms.ix] + + nojump = NoJump() + align = Aligner(ag) + + copy_u.trajectory.add_transformations(nojump, align) + + rmsf = RMSF(ag) + rmsf.run() + return rmsf.results.rmsf * unit.angstrom + + +def _atomgroup_has_bonds( + atomgroup: Union[mda.AtomGroup, mda.Universe] +) -> bool: + """ + Check if all residues in an AtomGroup or Univese has bonds. + + Parameters + ---------- + atomgroup : Union[mda.Atomgroup, mda.Universe] + Either an MDAnalysis AtomGroup or Universe to check for bonds. + + Returns + ------- + bool + True if all residues contain at least one bond, False otherwise. + """ + if not hasattr(atomgroup, 'bonds'): + return False + + if not all(len(r.atoms.bonds) > 0 for r in atomgroup.residues): + return False + + return True + + +def stable_secondary_structure_selection( + atomgroup: mda.AtomGroup, + trim_chain_start: int = 10, + trim_chain_end: int = 10, + min_structure_size: int = 6, + trim_structure_ends: int = 2, +) -> mda.AtomGroup: + """ + Select all atoms in a given AtomGroup which belong to residues with a + stable secondary structure as defined by Baumann et al.[1] + + The selection algorithm works in the following manner: + 1. Protein residues are selected from the ``atomgroup``. + 2. If there are fewer than 30 protein residues, raise an error. + 3. Split the protein residues by fragment, guessing bonds if necessary. + 4. Discard the first ``trim_chain_start`` and the last + ``trim_chain_end`` residues per fragment. + 5. Run DSSP using the last trajectory frame on the remaining + fragment residues. + 6. Extract all contiguous structure units that are longer than + ``min_structure_size``, removing ``trim_structure_ends`` + residues from each end of the structure. + 7. For all extract structures, if there are more beta-sheet + residues than there are alpha-helix residues, then allow + residues to be selected from either structure type. If not, + then only allow alpha-helix residues. + 8. Select all atoms in the ``atomgroup`` that belong to residues + from extracted structure units of the selected structure type. + + Parameters + ---------- + atomgroup : mda.AtomgGroup + The AtomGroup to select atoms from. + trim_chain_start: int + The number of residues to trim from the start of each + protein chain. Default 10. + trim_chain_end : int + The number of residues to trim from the end of each + protein chain. Default 10. + min_structure_size : int + The minimum number of residues needed in a given + secondary structure unit to be considered stable. Default 8. + trim_structure_ends : int + The number of residues to trim from the end of each + secondary structure units. Default 3. + + Returns + ------- + AtomGroup : mda.AtomGroup + An AtomGroup containing all the atoms from the input AtomGroup + which belong to stable secondary structure residues. + + Raises + ------ + UserWarning + If there are no bonds for the protein atoms in the input + host residue. In this case, the bonds will be guessed + using a simple distance metric. + + Notes + ----- + * This selection algorithm assumes contiguous & ordered residues. + * We recommend always trimming at least one residue at the ends of + each chain using ``trim_chain_start`` and ``trim_chain_end`` to + avoid issues with capping residues. + * DSSP assignement is done on the final frame of the trajectory. + + References + ---------- + [1] Baumann, Hannah M., et al. "Broadening the scope of binding free energy + calculations using a Separated Topologies approach." (2023). + """ + # First let's copy our Universe so we don't overwrite its current state + copy_u = atomgroup.universe.copy() + + # Create an AtomGroup that contains all the protein residues in the + # input Universe - we will filter by what matches in the atomgroup later + copy_protein_ag = copy_u.select_atoms('protein').atoms + + # We need to split by fragments to account for multiple chains + # To do this, we need bonds! + if not _atomgroup_has_bonds(copy_protein_ag, 'bonds'): + wmsg = "No bonds found in input Universe, will attept to guess them." + warnings.warn(wmsg) + protein_ag.guess_bonds() + + structures = [] # container for all contiguous secondary structure units + # Counter for each residue type found + structure_residue_counts = {'H': 0, 'E': 0, '-': 0} + # THe minimum length any chain must have + min_chain_length = trim_chain_start + trim_chain_end + min_structure_size + + # Loop over each continually bonded section (i.e. chain) of the protein + for frag in copy_protein_ag.fragments: + # If this fragment is too small, skip processing it + if len(frag.residues) < min_chain_length: + continue + + # Trim the chain ends + chain = frag.residues[trim_chain_start:-trim_chain_end].atoms + + try: + # Run on the last frame + # TODO: maybe filter out any residue that changes secondary + # structure during the trajectory + dssp = DSSP(chain).run(start=-1) + except ValueError: + # DSSP may fail if it doesn't recognise the system's atom names + # or non-canonical residues are included, in this case just skip + continue + + # Tag each residue structure by its resindex + dssp_results = [ + (structure, resid) for structure, resid in + zip(dssp.results.dssp[0], chain.residues.resindices) + ] + + # Group by contiguous secondary structure + for _, group_iter in groupby(dssp_results, lambda x: x[0]): + group = list(group_iter) + if len(group) >= min_structure_size: + structures.append( + group[trim_structure_ends:-trim_structure_ends] + ) + num_residues = len(group) - (2 * trim_structure_ends) + structure_residue_counts[group[0][0]] += num_residues + + # If we have fewer alpha-helix residues than beta-sheet residues + # then we allow picking from beta-sheets too. + allowed_structures = ['H'] + if structure_residue_counts['H'] < structure_residue_counts['E']: + allowed_structures.append('E') + + allowed_residxs = [] + for structure in structures: + if structure[0][0] in allowed_structures: + allowed_residxs.extend([residue[1] for residue in structure]) + + # Resindexes are keyed at the Universe scale not AtomGroup + allowed_atoms = atomgroup.universe.residues[allowed_residxs].atoms + + # Pick up all the atoms that intersect the initial selection and + # those allowed. + return atomgroup.intersection(allowed_atoms) + + +def protein_chain_selection( + atomgroup: mda.AtomGroup, + trim_chain_start: int = 10, + trim_chain_end: int = 10, +) -> mda.AtomGroup: + """ + Return a sub-selection of the input AtomGroup which belongs to protein + chains trimmed by ``trim_chain_start`` and ``trim_chain_end``. + + Protein chains are defined as any continuously bonded part of system with + at least 30 residues which match the ``protein`` selection of MDAnalysis. + + Parameters + ---------- + atomgroup : mda.AtomgGroup + The AtomGroup to select atoms from. + trim_chain_start: int + The number of residues to trim from the start of each + protein chain. Default 10. + trim_chain_end : int + The number of residues to trim from the end of each + protein chain. Default 10. + + Returns + ------- + atomgroup : mda.AtomGroup + An AtomGroup containing all the atoms from the input AtomGroup + which belong to protein chains. + """ + # First let's copy our Universe so we don't overwrite its current state + copy_u = atomgroup.universe.copy() + + # Create an AtomGroup that contains all the protein residues in the + # input Universe - we will filter by what matches in the atomgroup later + copy_protein_ag = copy_u.select_atoms('protein').atoms + + # We need to split by fragments to account for multiple chains + # To do this, we need bonds! + if not _atomgroup_has_bonds(copy_protein_ag, 'bonds'): + wmsg = ( + "No bonds found in input Universe, will attept to guess them." + ) + warnings.warn(wmsg) + copy_protein_ag.guess_bonds() + + copy_chains_ags_list = [] + + # Loop over each continually bonded section (i.e. chain) of the protein + for frag in copy_protein_ag.fragments: + # If this chain is less than 30 residues, it's probably a peptide + if len(frag.residues) < 30: + continue + + chain = frag.residues[trim_chain_start:-trim_chain_end].atoms + copy_chains_ags_list.append(chain) + + # Create a single atomgroup from all chains + copy_chains_ag = sum(copy_chains_ags_list) + + # Now get a list of all the chain atoms in the original Universe + # Resindexes are keyed at the Universe scale not AtomGroup + chain_atoms = atomgroup.universe.atoms[copy_chains_ag.atoms.ix] + + # Return all atoms at the intersection of the input atomgroup and + # the chains atomgroup + return atomgroup.intersection(chain_atoms) diff --git a/openfe/protocols/restraint_utils/openmm/__init__.py b/openfe/protocols/restraint_utils/openmm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/restraint_utils/openmm/omm_forces.py b/openfe/protocols/restraint_utils/openmm/omm_forces.py new file mode 100644 index 000000000..2947c8e03 --- /dev/null +++ b/openfe/protocols/restraint_utils/openmm/omm_forces.py @@ -0,0 +1,129 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Custom OpenMM Forces + +TODO +---- +* Add relevant duecredit entries. +""" +import numpy as np +import openmm + + +def get_boresch_energy_function( + control_parameter: str, +) -> str: + """ + Return a Boresch-style energy function for a CustomCompoundForce. + + Parameters + ---------- + control_parameter : str + A string for the lambda scaling control parameter + + Returns + ------- + str + The energy function string. + """ + energy_function = ( + f"{control_parameter} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " + "dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + return energy_function + + +def get_periodic_boresch_energy_function( + control_parameter: str, +) -> str: + """ + Return a Boresch-style energy function with a periodic torsion for a + CustomCompoundForce. + + Parameters + ---------- + control_parameter : str + A string for the lambda scaling control parameter + + Returns + ------- + str + The energy function string. + """ + energy_function = ( + f"{control_parameter} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*uphi_A + (K_phiB/2)*uphi_B + (K_phiC/2)*uphi_C; " + "uphi_A = (1-cos(dA)); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + return energy_function + + +def get_custom_compound_bond_force( + energy_function: str, n_particles: int = 6, +): + """ + Return an OpenMM CustomCompoundForce + + TODO + ---- + Change this to a direct subclass like openmmtools.force. + + Acknowledgements + ---------------- + Boresch-like energy functions are reproduced from `Yank `_ + """ + return openmm.CustomCompoundBondForce(n_particles, energy_function) + + +def add_force_in_separate_group( + system: openmm.System, + force: openmm.Force, +): + """ + Add force to a System in a separate force group. + + Parameters + ---------- + system : openmm.System + System to add the Force to. + force : openmm.Force + The Force to add to the System. + + Raises + ------ + ValueError + If all 32 force groups are occupied. + + + TODO + ---- + Unlike the original Yank implementation, we assume that + all 32 force groups will not be filled. Should this be an issue + we can consider just separating it from NonbondedForce. + + Acknowledgements + ---------------- + Mostly reproduced from `Yank `_. + """ + available_force_groups = set(range(32)) + for existing_force in system.getForces(): + available_force_groups.discard(existing_force.getForceGroup()) + + if len(available_force_groups) == 0: + errmsg = "No available force groups could be found" + raise ValueError(errmsg) + + force.setForceGroup(min(available_force_groups)) + system.addForce(force) diff --git a/openfe/protocols/restraint_utils/openmm/omm_restraints.py b/openfe/protocols/restraint_utils/openmm/omm_restraints.py new file mode 100644 index 000000000..352bbd9bd --- /dev/null +++ b/openfe/protocols/restraint_utils/openmm/omm_restraints.py @@ -0,0 +1,706 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Classes for applying restraints to OpenMM Systems. + +Acknowledgements +---------------- +Many of the classes here are at least in part inspired from +`Yank `_ and +`OpenMMTools `_. + +TODO +---- +* Add relevant duecredit entries. +* Add Periodic Torsion Boresch class +""" +import abc + +import numpy as np +import openmm +from openmm import unit as omm_unit +from openmmtools.forces import ( + HarmonicRestraintForce, + HarmonicRestraintBondForce, + FlatBottomRestraintForce, + FlatBottomRestraintBondForce, +) +from openmmtools.states import GlobalParameterState, ThermodynamicState +from openff.units.openmm import to_openmm, from_openmm +from openff.units import unit + +from gufe.settings.models import SettingsBaseModel + +from openfe.protocols.restraint_utils.geometry import ( + BaseRestraintGeometry, + DistanceRestraintGeometry, + BoreschRestraintGeometry +) + +from openfe.protocols.restraint_utils.settings import ( + DistanceRestraintSettings, + BoreschRestraintSettings, +) + +from .omm_forces import ( + get_custom_compound_bond_force, + add_force_in_separate_group, + get_boresch_energy_function, +) + + +class RestraintParameterState(GlobalParameterState): + """ + Composable state to control `lambda_restraints` OpenMM Force parameters. + + See :class:`openmmtools.states.GlobalParameterState` for more details. + + Parameters + ---------- + parameters_name_suffix : Optional[str] + If specified, the state will control a modified version of the parameter + ``lambda_restraints_{parameters_name_suffix}` instead of just + ``lambda_restraints``. + lambda_restraints : Optional[float] + The scaling parameter for the restraint. If defined, + must be between 0 and 1. In most cases, a value of 1 indicates that the + restraint is fully turned on, whilst a value of 0 indicates that it is + innactive. + + Acknowledgement + --------------- + Partially reproduced from Yank. + """ + # We set the standard system to a fully interacting restraint + lambda_restraints = GlobalParameterState.GlobalParameter( + "lambda_restraints", standard_value=1.0 + ) + + @lambda_restraints.validator + def lambda_restraints(self, instance, new_value): + if new_value is not None and not (0.0 <= new_value <= 1.0): + errmsg = ( + "lambda_restraints must be between 0.0 and 1.0 " + f"and got {new_value}" + ) + raise ValueError(errmsg) + # Not crashing out on None to match upstream behaviour + return new_value + + +class BaseHostGuestRestraints(abc.ABC): + """ + An abstract base class for defining objects that apply a restraint between + two entities (referred to as a Host and a Guest). + + + TODO + ---- + Add some developer examples here. + """ + + def __init__( + self, + restraint_settings: SettingsBaseModel, + ): + self.settings = restraint_settings + self._verify_inputs() + + @abc.abstractmethod + def _verify_inputs(self): + """ + Method for validating that the inputs to the class are correct. + """ + pass + + @abc.abstractmethod + def _verify_geometry(self, geometry): + """ + Method for validating that the geometry object passed is correct. + """ + pass + + @abc.abstractmethod + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: BaseRestraintGeometry, + controlling_parameter_name: str, + ): + """ + Method for in-place adding the Force to the System of a + ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ + pass + + @abc.abstractmethod + def get_standard_state_correction( + self, + thermodynamic_state: ThermodynamicState, + geometry: BaseRestraintGeometry + ) -> unit.Quantity: + """ + Get the standard state correction for the Force when + applied to the input ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + """ + pass + + @abc.abstractmethod + def _get_force( + self, + geometry: BaseRestraintGeometry, + controlling_parameter_name: str, + ): + """ + Helper method to get the relevant OpenMM Force for this + class, given an input geometry. + """ + pass + + +class SingleBondMixin: + """ + A mixin to extend geometry checks for Forces that can only hold + a single atom. + """ + def _verify_geometry(self, geometry: BaseRestraintGeometry): + if len(geometry.host_atoms) != 1 or len(geometry.guest_atoms) != 1: + errmsg = ( + "host_atoms and guest_atoms must only include a single index " + f"each, got {len(geometry.host_atoms)} and " + f"{len(geometry.guest_atoms)} respectively." + ) + raise ValueError(errmsg) + super()._verify_geometry(geometry) + + +class BaseRadiallySymmetricRestraintForce(BaseHostGuestRestraints): + """ + A base class for all radially symmetic Forces acting between + two sets of atoms. + + Must be subclassed. + """ + def _verify_inputs(self) -> None: + if not isinstance(self.settings, DistanceRestraintSettings): + errmsg = f"Incorrect settings type {self.settings} passed through" + raise ValueError(errmsg) + + def _verify_geometry(self, geometry: DistanceRestraintGeometry): + if not isinstance(geometry, DistanceRestraintGeometry): + errmsg = f"Incorrect geometry class type {geometry} passed through" + raise ValueError(errmsg) + + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str = "lambda_restraints", + ) -> None: + """ + Method for in-place adding the Force to the System of the + given ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ + self._verify_geometry(geometry) + force = self._get_force(geometry, controlling_parameter_name) + force.setUsesPeriodicBoundaryConditions( + thermodynamic_state.is_periodic + ) + # Note .system is a call to get_system() so it's returning a copy + system = thermodynamic_state.system + add_force_in_separate_group(system, force) + thermodynamic_state.system = system + + def get_standard_state_correction( + self, + thermodynamic_state: ThermodynamicState, + geometry: DistanceRestraintGeometry, + ) -> unit.Quantity: + """ + Get the standard state correction for the Force when + applied to the input ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + """ + self._verify_geometry(geometry) + force = self._get_force(geometry) + corr = force.compute_standard_state_correction( + thermodynamic_state, volume="system" + ) + dg = corr * thermodynamic_state.kT + return from_openmm(dg).to('kilojoule_per_mole') + + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str + ): + raise NotImplementedError("only implemented in child classes") + + +class HarmonicBondRestraint( + BaseRadiallySymmetricRestraintForce, SingleBondMixin +): + """ + A class to add a harmonic restraint between two atoms + in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.HarmonicRestraintBondForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole/nm**2. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the HarmonicRestraintBondForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + HarmonicRestraintBondForce + An OpenMM Force that applies a harmonic restraint between + two atoms. + """ + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + return HarmonicRestraintBondForce( + spring_constant=spring_constant, + restrained_atom_index1=geometry.host_atoms[0], + restrained_atom_index2=geometry.guest_atoms[0], + controlling_parameter_name=controlling_parameter_name, + ) + + +class FlatBottomBondRestraint( + BaseRadiallySymmetricRestraintForce, SingleBondMixin +): + """ + A class to add a flat bottom restraint between two atoms + in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.FlatBottomRestraintBondForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole/nm**2. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the FlatBottomRestraintBondForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + FlatBottomRestraintBondForce + An OpenMM Force that applies a flat bottom restraint between + two atoms. + """ + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm( + geometry.well_radius + ).value_in_unit_system(omm_unit.md_unit_system) + return FlatBottomRestraintBondForce( + spring_constant=spring_constant, + well_radius=well_radius, + restrained_atom_index1=geometry.host_atoms[0], + restrained_atom_index2=geometry.guest_atoms[0], + controlling_parameter_name=controlling_parameter_name, + ) + + +class CentroidHarmonicRestraint(BaseRadiallySymmetricRestraintForce): + """ + A class to add a harmonic restraint between the centroid of + two sets of atoms in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.HarmonicRestraintForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole/nm**2. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the HarmonicRestraintForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + HarmonicRestraintForce + An OpenMM Force that applies a harmonic restraint between + the centroid of two sets of atoms. + """ + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + return HarmonicRestraintForce( + spring_constant=spring_constant, + restrained_atom_index1=geometry.host_atoms, + restrained_atom_index2=geometry.guest_atoms, + controlling_parameter_name=controlling_parameter_name, + ) + + +class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): + """ + A class to add a flat bottom restraint between the centroid + of two sets of atoms in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.FlatBottomRestraintForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole/nm**2. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the FlatBottomRestraintForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + FlatBottomRestraintForce + An OpenMM Force that applies a flat bottom restraint between + the centroid of two sets of atoms. + """ + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm( + geometry.well_radius + ).value_in_unit_system(omm_unit.md_unit_system) + return FlatBottomRestraintForce( + spring_constant=spring_constant, + well_radius=well_radius, + restrained_atom_index1=geometry.host_atoms, + restrained_atom_index2=geometry.guest_atoms, + controlling_parameter_name=controlling_parameter_name, + ) + + +class BoreschRestraint(BaseHostGuestRestraints): + """ + A class to add a Boresch-like restraint between six atoms, + + The restraint is defined as a + :class:`openmmtools.forces.CustomCompoundForce` with the + following energy function: + + lambda_control_parameter * E; + E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 + + (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 + + (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + + (K_phiC/2)*dphi_C^2; + dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); + dA = dihedral(p1,p2,p3,p4) - phi_A0; + dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); + dB = dihedral(p2,p3,p4,p5) - phi_B0; + dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); + dC = dihedral(p3,p4,p5,p6) - phi_C0; + + Where p1, p2, p3, p4, p5, p6 represent host atoms 2, 1, 0, + and guest atoms 0, 1, 2 respectively. + + ``lambda_control_parameter`` is a control parameter for + scaling the Force. + + ``K_r`` is defined as the bond spring constant between + p3 and p4 and must be provided in the settings in units + compatible with kilojoule / mole. + + ``r_aA0`` is the equilibrium distance of the bond between + p3 and p4. This must be provided by the Geometry class in + units compatiblle with nanometer. + + ``K_thetaA`` and ``K_thetaB`` are the spring constants for the angles + formed by (p2, p3, p4) and (p3, p4, p5). They must be provided in the + settings in units compatible with kilojoule / mole / radians**2. + + ``theta_A0`` and ``theta_B0`` are the equilibrium values for angles + (p2, p3, p4) and (p3, p4, p5). They must be provided by the + Geometry class in units compatible with radians. + + ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium force constants + for the dihedrals formed by (p1, p2, p3, p4), (p2, p3, p4, p5), and + (p3, p4, p5, p6). They must be provided in the settings in units + compatible with kilojoule / mole / radians ** 2. + + ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium values + for the dihedrals formed by (p1, p2, p3, p4), (p2, p3, p4, p5), and + (p3, p4, p5, p6). They must be provided in the Geometry class in + units compatible with radians. + + + Notes + ----- + * Settings must define the ``K_r`` (d) + """ + def _verify_inputs(self) -> None: + """ + Method for validating that the geometry object is correct. + """ + if not isinstance(self.settings, BoreschRestraintSettings): + errmsg = f"Incorrect settings type {self.settings} passed through" + raise ValueError(errmsg) + + def _verify_geometry(self, geometry: BoreschRestraintGeometry): + """ + Method for validating that the geometry object is correct. + """ + if not isinstance(geometry, BoreschRestraintGeometry): + errmsg = f"Incorrect geometry class type {geometry} passed through" + raise ValueError(errmsg) + + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: BoreschRestraintGeometry, + controlling_parameter_name: str, + ) -> None: + """ + Method for in-place adding the Boresch CustomCompoundForce + to the System of the given ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ + self._verify_geometry(geometry) + force = self._get_force( + geometry, + controlling_parameter_name, + ) + force.setUsesPeriodicBoundaryConditions( + thermodynamic_state.is_periodic + ) + # Note .system is a call to get_system() so it's returning a copy + system = thermodynamic_state.system + add_force_in_separate_group(system, force) + thermodynamic_state.system = system + + def _get_force( + self, + geometry: BoreschRestraintGeometry, + controlling_parameter_name: str + ) -> openmm.CustomCompoundBondForce: + """ + Get the CustomCompoundForce with a Boresch-like energy function + given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + CustomCompoundForce + An OpenMM CustomCompoundForce that applies a Boresch-like + restraint between 6 atoms. + """ + efunc = get_boresch_energy_function(controlling_parameter_name) + + force = get_custom_compound_bond_force( + energy_function=efunc, n_particles=6, + ) + + param_values = [] + + parameter_dict = { + 'K_r': self.settings.K_r, + 'r_aA0': geometry.r_aA0, + 'K_thetaA': self.settings.K_thetaA, + 'theta_A0': geometry.theta_A0, + 'K_thetaB': self.settings.K_thetaB, + 'theta_B0': geometry.theta_B0, + 'K_phiA': self.settings.K_phiA, + 'phi_A0': geometry.phi_A0, + 'K_phiB': self.settings.K_phiB, + 'phi_B0': geometry.phi_B0, + 'K_phiC': self.settings.K_phiC, + 'phi_C0': geometry.phi_C0, + } + for key, val in parameter_dict.items(): + param_values.append( + to_openmm(val).value_in_unit_system(omm_unit.md_unit_system) + ) + force.addPerBondParameter(key) + + force.addGlobalParameter(controlling_parameter_name, 1.0) + atoms = [ + geometry.host_atoms[2], + geometry.host_atoms[1], + geometry.host_atoms[0], + geometry.guest_atoms[0], + geometry.guest_atoms[1], + geometry.guest_atoms[2], + ] + force.addBond(atoms, param_values) + return force + + def get_standard_state_correction( + self, + thermodynamic_state: ThermodynamicState, + geometry: BoreschRestraintGeometry + ) -> unit.Quantity: + """ + Get the standard state correction for the Boresch-like + restraint when applied to the input ThermodynamicState. + + The correction is calculated using the analytical method + as defined by Boresch et al. [1] + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + + References + ---------- + [1] Boresch S, Tettinger F, Leitgeb M, Karplus M. J Phys Chem B. 107:9535, 2003. + http://dx.doi.org/10.1021/jp0217839 + """ + self._verify_geometry(geometry) + + StandardV = 1.66053928 * unit.nanometer**3 + kt = from_openmm(thermodynamic_state.kT) + + # distances + r_aA0 = geometry.r_aA0.to('nm') + sin_thetaA0 = np.sin(geometry.theta_A0.to('radians')) + sin_thetaB0 = np.sin(geometry.theta_B0.to('radians')) + + # restraint energies + K_r = self.settings.K_r.to('kilojoule_per_mole / nm ** 2') + K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole / radians ** 2') + K_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole / radians ** 2') + K_phiA = self.settings.K_phiA.to('kilojoule_per_mole / radians ** 2') + K_phiB = self.settings.K_phiB.to('kilojoule_per_mole / radians ** 2') + K_phiC = self.settings.K_phiC.to('kilojoule_per_mole / radians ** 2') + + numerator1 = 8.0 * (np.pi**2) * StandardV + denum1 = (r_aA0**2) * sin_thetaA0 * sin_thetaB0 + numerator2 = np.sqrt( + K_r * K_thetaA * K_thetaB * K_phiA * K_phiB * K_phiC + ) + denum2 = (2.0 * np.pi * kt)**3 + + dG = -kt * np.log((numerator1/denum1) * (numerator2/denum2)) + + return dG diff --git a/openfe/protocols/restraint_utils/settings.py b/openfe/protocols/restraint_utils/settings.py new file mode 100644 index 000000000..efe9c33f6 --- /dev/null +++ b/openfe/protocols/restraint_utils/settings.py @@ -0,0 +1,150 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Settings for adding restraints. + +TODO +---- +* Rename from host/guest to molA/molB? +""" +from typing import Optional +from openff.models.types import FloatQuantity +from pydantic.v1 import validator +from gufe.settings import ( + SettingsBaseModel, +) + + +class BaseRestraintSettings(SettingsBaseModel): + """ + Base class for RestraintSettings objects. + """ + class Config: + arbitrary_types_allowed = True + + +class DistanceRestraintSettings(BaseRestraintSettings): + """ + Settings defining a distance restraint between + two groups of atoms defined as ``host`` and ``guest``. + """ + spring_constant: FloatQuantity['kilojoule_per_mole / nm ** 2'] + """ + The distance restraint potential spring constant. + """ + host_atoms: Optional[list[int]] = None + """ + The indices of the host component atoms to restrain. + If defined, these will override any automatic selection. + """ + guest_atoms: Optional[list[int]] = None + """ + The indices of the guest component atoms to restraint. + If defined, these will override any automatic selection. + """ + central_atoms_only: bool = False + """ + Whether to apply the restraint solely to the central atoms + of each group. + + Note: this can only be applied if ``host`` and ``guest`` + represent small molecules. + """ + + @validator("guest_atoms", "host_atoms") + def positive_idxs(cls, v): + if v is not None and any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v + + +class FlatBottomRestraintSettings(DistanceRestraintSettings): + """ + Settings to define a flat bottom restraint between two + groups of atoms named ``host`` and ``guest``. + """ + well_radius: Optional[FloatQuantity['nm']] = None + """ + The distance at which the harmonic restraint is imposed + in units of distance. + """ + @validator("well_radius") + def positive_value(cls, v): + if v is not None and v.m < 0: + errmsg = f"well radius cannot be negative {v}" + raise ValueError(errmsg) + return v + + +class BoreschRestraintSettings(BaseRestraintSettings): + """ + Settings to define a Boresch-style restraint between + two groups of atoms named ``host`` and ``guest``. + + The restraint is defined in the following manner: + + H2 G2 + - - + - - + H1 - - H0 -- G0 - - G1 + + Where HX represents the X index of ``host_atoms`` + and GX the X indexx of ``guest_atoms``. + + By default, the Boresch-like restraint will be + obtained using a modified version of the + search algorithm implemented by Baumann et al. [1]. + + If ``guest_atoms`` and ``host_atoms`` are defined, + these indices will be used instead. + + References + ---------- + [1] Baumann, Hannah M., et al. "Broadening the scope of binding free + energy calculations using a Separated Topologies approach." (2023). + """ + K_r: FloatQuantity['kilojoule_per_mole / nm ** 2'] + """ + The bond spring constant between H0 and G0. + """ + K_thetaA: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The spring constant for the angle formed by H1-H0-G0. + """ + K_thetaB: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The spring constant for the angle formed by H0-G0-G1. + """ + phi_A0: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The equilibrium force constant for the dihedral formed by + H2-H1-H0-G0. + """ + phi_B0: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The equilibrium force constant for the dihedral formed by + H1-H0-G0-G1. + """ + phi_C0: FloatQuantity['kilojoule_per_mole / radians ** 2'] + """ + The equilibrium force constant for the dihedral formed by + H0-G0-G1-G2. + """ + host_atoms: Optional[list[int]] = None + """ + The indices of the host component atoms to restrain. + If defined, these will override any automatic selection. + """ + guest_atoms: Optional[list[int]] = None + """ + The indices of the guest component atoms to restraint. + If defined, these will override any automatic selection. + """ + + @validator("guest_atoms", "host_atoms") + def positive_idxs_list(cls, v): + if v is not None and any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v \ No newline at end of file diff --git a/openfe/tests/conftest.py b/openfe/tests/conftest.py index 51cfb598b..4ca8933f1 100644 --- a/openfe/tests/conftest.py +++ b/openfe/tests/conftest.py @@ -223,7 +223,7 @@ def T4_protein_component(): return comp -@pytest.fixture() +@pytest.fixture(scope='session') def eg5_protein_pdb(): with resources.files('openfe.tests.data.eg5') as d: yield str(d / 'eg5_protein.pdb') diff --git a/openfe/tests/protocols/restraints/__init__.py b/openfe/tests/protocols/restraints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/tests/protocols/restraints/test_geometry_base.py b/openfe/tests/protocols/restraints/test_geometry_base.py new file mode 100644 index 000000000..139c57dc5 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_geometry_base.py @@ -0,0 +1,25 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe.protocols.restraint_utils.geometry.base import ( + HostGuestRestraintGeometry +) + + +def test_hostguest_geometry(): + """ + A very basic will it build test. + """ + geom = HostGuestRestraintGeometry(guest_atoms=[1, 2, 3], host_atoms=[4]) + + assert isinstance(geom, HostGuestRestraintGeometry) + + +def test_hostguest_positiveidxs_validator(): + """ + Check that the validator is working as intended. + """ + with pytest.raises(ValueError, match="negative indices passed"): + geom = HostGuestRestraintGeometry(guest_atoms=[-1, 1], host_atoms=[0]) diff --git a/openfe/tests/protocols/restraints/test_geometry_harmonic.py b/openfe/tests/protocols/restraints/test_geometry_harmonic.py new file mode 100644 index 000000000..80c0e8134 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_geometry_harmonic.py @@ -0,0 +1,23 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe.protocols.restraint_utils.geometry.harmonic import ( + DistanceRestraintGeometry +) + + +def test_hostguest_geometry(): + """ + A very basic will it build test. + """ + geom = DistanceRestraintGeometry(guest_atoms=[1, 2, 3], host_atoms=[4]) + + assert isinstance(geom, DistanceRestraintGeometry) + + +def test_get_distance_restraint(): + """ + Check that you get a distance restraint. + """ diff --git a/openfe/tests/protocols/restraints/test_geometry_utils.py b/openfe/tests/protocols/restraints/test_geometry_utils.py new file mode 100644 index 000000000..4f8047ea0 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_geometry_utils.py @@ -0,0 +1,299 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +import itertools +from rdkit import Chem +import MDAnalysis as mda +from openff.units import unit +import numpy as np + +from openfe.protocols.restraint_utils.geometry.utils import ( + _get_mda_selection, + get_aromatic_rings, + get_aromatic_atom_idxs, + get_heavy_atom_idxs, + get_central_atom_idx, + is_collinear, + check_angle_not_flat, + _wrap_angle, + check_dihedral_bounds, + check_angular_variance, + _atomgroup_has_bonds, +) + + + +@pytest.fixture(scope='module') +def eg5_pdb_universe(eg5_protein_pdb): + return mda.Universe(eg5_protein_pdb) + + +def test_mda_selection_none_error(eg5_pdb_universe): + with pytest.raises(ValueError, match="one of either"): + _ = _get_mda_selection(eg5_pdb_universe) + + +def test_mda_selection_both_args_error(eg5_pdb_universe): + with pytest.raises(ValueError, match="both atom_list and"): + _ = _get_mda_selection( + eg5_pdb_universe, + atom_list=[0, 1, 2, 3], + selection="all" + ) + + +def test_mda_selection_universe_atom_list(eg5_pdb_universe): + test_ag = _get_mda_selection(eg5_pdb_universe, atom_list=[0, 1, 2]) + assert eg5_pdb_universe.atoms[[0, 1, 2]] == test_ag + + +def test_mda_selection_atomgroup_string(eg5_pdb_universe): + test_ag = _get_mda_selection(eg5_pdb_universe.atoms, selection='all') + assert test_ag == eg5_pdb_universe.atoms + + +@pytest.mark.parametrize('smiles, expected', [ + ['C1CCCCC1', []], + ['[C@@H]1([C@@H]([C@@H](OC([C@@H]1O)O)C(=O)O)O)O', []], + ['C1=CC=CC=C1', [6]], + ['C1=CC2C=CC1C=C2', [8]], + ['C1CC2=CC=CC=C2C1', [6]], + ['C1=COC=C1', [5]], + ['C1=CC=C2C=CC=CC2=C1', [10]], + ['C1=CC=C(C=C1)C2=CC=CC=C2', [6, 6]], + ['C1=CC=C(C=C1)C(C2=CC=CC=C2)(C3=CC=CC=C3Cl)N4C=CN=C4', [6, 6, 6, 5]] +]) +def test_aromatic_rings(smiles, expected): + mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) + + # get the rings + rings = get_aromatic_rings(mol) + + # check we have the right number of rings & their size + for i, r in enumerate(rings): + assert len(r) == expected[i] + + # check that there is no overlap in atom between each ring + for x, y in itertools.combinations(rings, 2): + assert x.isdisjoint(y) + + # get the aromatic idx + arom_idxs = get_aromatic_atom_idxs(mol) + + # Check that all the ring indices are aromatic + assert all(idx in arom_idxs for idx in itertools.chain(*rings)) + + # Also check the lengths match + assert sum(len(r) for r in rings) == len(arom_idxs) + + # Finallly check that all the arom_idxs are actually aromatic + for idx in arom_idxs: + at = mol.GetAtomWithIdx(idx) + assert at.GetIsAromatic() + +@pytest.mark.parametrize('smiles, nheavy, nlight', [ + ['C1CCCCC1', 6, 12], + ['[C@@H]1([C@@H]([C@@H](OC([C@@H]1O)O)C(=O)O)O)O', 13, 10], + ['C1=CC=CC=C1', 6, 6], + ['C1=CC2C=CC1C=C2', 8, 8], + ['C1CC2=CC=CC=C2C1', 9, 10], + ['C1=COC=C1', 5, 4], + ['C1=CC=C2C=CC=CC2=C1', 10, 8], + ['C1=CC=C(C=C1)C2=CC=CC=C2', 12, 10], + ['C1=CC=C(C=C1)C(C2=CC=CC=C2)(C3=CC=CC=C3Cl)N4C=CN=C4', 25, 17] +]) +def test_heavy_atoms(smiles, nheavy, nlight): + mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) + + n_atoms = len(list(mol.GetAtoms())) + + heavy_atoms = get_heavy_atom_idxs(mol) + + # check all the heavy atoms are indeed heavy + for idx in heavy_atoms: + at = mol.GetAtomWithIdx(idx) + assert at.GetAtomicNum() > 1 + + assert len(heavy_atoms) == nheavy + assert n_atoms == nheavy + nlight + + +@pytest.mark.parametrize('smiles, idx', [ + ['C1CCCCC1', 2], + ['[C@@H]1([C@@H]([C@@H](OC([C@@H]1O)O)C(=O)O)O)O', 3], + ['C1=CC=CC=C1', 2], + ['C1=CC2C=CC1C=C2', 2], + ['C1CC2=CC=CC=C2C1', 2], + ['C1=COC=C1', 4], + ['C1=CC=C2C=CC=CC2=C1', 3], + ['C1=CC=C(C=C1)C2=CC=CC=C2', 6], + ['C1=CC=C(C=C1)C(C2=CC=CC=C2)(C3=CC=CC=C3Cl)N4C=CN=C4', 6], + ['OC(COc1ccc(cc1)CC(=O)N)CNC(C)C', 3], +]) +def test_central_idx(smiles, idx): + """ + Regression tests for getting central atom idx. + """ + rdmol = Chem.AddHs(Chem.MolFromSmiles(smiles)) + assert get_central_atom_idx(rdmol) == idx + + +def test_central_atom_disconnected(): + mol = Chem.AddHs(Chem.MolFromSmiles('C.C')) + + with pytest.raises(ValueError, match='disconnected molecule'): + _ = get_central_atom_idx(mol) + + +def test_collinear_too_few_atoms(): + with pytest.raises(ValueError, match='Too few atoms passed'): + _ = is_collinear(None, [1, 2], None) + + +def test_collinear_index_match_error_length(): + with pytest.raises(ValueError, match='indices do not match'): + _ = is_collinear( + positions=np.zeros((3, 3)), + atoms=[0, 1, 2, 3], + ) + + +def test_collinear_index_match_error_index(): + with pytest.raises(ValueError, match='indices do not match'): + _ = is_collinear( + positions=np.zeros((3, 3)), + atoms=[1, 2, 3], + ) + + +@pytest.mark.parametrize('arr, truth', [ + [[[0, 0, -1], [1, 0, 0], [2, 0, 2]], True], + [[[0, 1, -1], [1, 0, 0], [2, 0, 2]], False], + [[[0, 1, -1], [1, 1, 0], [2, 1, 2]], True], + [[[0, 0, -1], [1, 1, 0], [2, 2, 2]], True], + [[[0, 0, -1], [1, 0, 0], [2, 0, 2]], True], + [[[2, 0, -1], [1, 0, 0], [0, 0, 2]], True], + [[[0, 0, 1], [0, 0, 0], [0, 0, 2]], True], + [[[1, 1, 1], [0, 0, 0], [2, 2, 2]], True] +]) +def test_is_collinear_three_atoms(arr, truth): + assert is_collinear(np.array(arr), [0, 1, 2]) == truth + + +@pytest.mark.parametrize('arr, truth', [ + [[[0, 0, -1], [1, 0, 0], [2, 0, 2], [3, 0, 4]], True], + [[[0, 0, -1], [1, 0, 0], [2, 0, 2], [3, 0, 2]], True], + [[[0, 0, 1], [1, 0, 0], [2, 0, 2], [3, 0, 4]], True], + [[[0, 1, -1], [1, 0, 0], [2, 0, 2], [3, 0, 2]], False], +]) +def test_is_collinear_four_atoms(arr, truth): + assert is_collinear(np.array(arr), [0, 1, 2, 3]) == truth + + +def test_wrap_angle_degrees(): + for i in range(0, 361, 1): + angle = _wrap_angle(i * unit.degrees) + if i > 180: + expected = ((i - 360) * unit.degrees).to('radians').m + else: + expected = (i * unit.degrees).to('radians').m + + assert angle.m == pytest.approx(expected) + + +@pytest.mark.parametrize('angle, expected', [ + [0 * unit.radians, 0 * unit.radians], + [1 * unit.radians, 1 * unit.radians], + [4 * unit.radians, 4 - (2 * np.pi) * unit.radians], + [-4 * unit.radians, -4 + (2 * np.pi) * unit.radians], +]) +def test_wrap_angle_radians(angle, expected): + assert _wrap_angle(angle) == pytest.approx(expected) + + +@pytest.mark.parametrize('limit, force, temperature', [ + [0.7695366605411506, 83.68, 298.15], + [0.8339791717799163, 83.68, 350.0], + [0.5441445910402979, 167.36, 298.15] +]) +def test_angle_not_flat(limit, force, temperature): + limit = limit * unit.radians + force = force * unit.kilojoule_per_mole / unit.radians ** 2 + temperature = temperature * unit.kelvin + + # test upper + assert check_angle_not_flat(limit + 0.01, force, temperature) + assert not check_angle_not_flat(limit - 0.01, force, temperature) + + # test lower + limit = np.pi - limit + assert check_angle_not_flat(limit - 0.01, force, temperature) + assert not check_angle_not_flat(limit + 0.01, force, temperature) + + +@pytest.mark.parametrize('dihed, expected', [ + [3 * unit.radians, False], + [0 * unit.radians, True], + [-3 * unit.radians, False], + [300 * unit.degrees, True], + [181 * unit.degrees, False], +]) +def test_check_dihedral_bounds(dihed, expected): + ret = check_dihedral_bounds(dihed) + assert ret == expected + + +@pytest.mark.parametrize('dihed, lower, upper, expected', [ + [3 * unit.radians, -3.1 * unit.radians, 3.1 * unit.radians, True], + [300 * unit.degrees, -61 * unit.degrees, 301 * unit.degrees, True], + [300 * unit.degrees, 299 * unit.degrees, -61 * unit.degrees, False] +]) +def test_check_dihedral_bounds_defined(dihed, lower, upper, expected): + ret = check_dihedral_bounds( + dihed, lower_cutoff=lower, upper_cutoff=upper + ) + assert ret == expected + + +def test_angular_variance(): + """ + Manual check with for an input number of angles with + a known variance of 0.36216 + """ + angles = [0, 1, 2, 6] + + assert check_angular_variance( + angles=angles * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=0.37 * unit.radians + ) + + assert not check_angular_variance( + angles=angles * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=0.35 * unit.radians + ) + + +def test_atomgroup_has_bonds(eg5_protein_pdb): + # Creating a new universe because we'll modify this one + u = mda.Universe(eg5_protein_pdb) + + # PDB has water bonds + assert len(u.bonds) == 14 + assert _atomgroup_has_bonds(u) is False + assert _atomgroup_has_bonds(u.select_atoms('resname HOH')) is True + + # Delete the topoplogy attr and everything is false + u.del_TopologyAttr('bonds') + assert _atomgroup_has_bonds(u) is False + assert _atomgroup_has_bonds(u.select_atoms('resname HOH')) is False + + # Guess some bonds back + ag = u.atoms[:100] + ag.guess_bonds() + assert _atomgroup_has_bonds(ag) is True diff --git a/openfe/tests/protocols/restraints/test_omm_restraints.py b/openfe/tests/protocols/restraints/test_omm_restraints.py new file mode 100644 index 000000000..0315d5b51 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_omm_restraints.py @@ -0,0 +1,31 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe.protocols.restraint_utils.openmm.omm_restraints import ( + RestraintParameterState, +) + + +def test_parameter_state_default(): + param_state = RestraintParameterState() + assert param_state.lambda_restraints is None + + +@pytest.mark.parametrize('suffix', [None, 'foo']) +@pytest.mark.parametrize('lambda_var', [0, 0.5, 1.0]) +def test_parameter_state_suffix(suffix, lambda_var): + param_state = RestraintParameterState( + parameters_name_suffix=suffix, lambda_restraints=lambda_var + ) + + if suffix is not None: + param_name = f'lambda_restraints_{suffix}' + else: + param_name = 'lambda_restraints' + + assert getattr(param_state, param_name) == lambda_var + assert len(param_state._parameters.keys()) == 1 + assert param_state._parameters[param_name] == lambda_var + assert param_state._parameters_name_suffix == suffix \ No newline at end of file diff --git a/openfe/tests/protocols/restraints/test_openmm_forces.py b/openfe/tests/protocols/restraints/test_openmm_forces.py new file mode 100644 index 000000000..cd2a7f21e --- /dev/null +++ b/openfe/tests/protocols/restraints/test_openmm_forces.py @@ -0,0 +1,115 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest +import numpy as np +import openmm +from openfe.protocols.restraint_utils.openmm.omm_forces import ( + get_boresch_energy_function, + get_periodic_boresch_energy_function, + get_custom_compound_bond_force, + add_force_in_separate_group, +) + + +@pytest.mark.parametrize('param', ['foo', 'bar']) +def test_boresch_energy_function(param): + """ + Base regression test for the energy function + """ + fn = get_boresch_energy_function(param) + assert fn == ( + f"{param} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " + "dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + + +@pytest.mark.parametrize('param', ['foo', 'bar']) +def test_periodic_boresch_energy_function(param): + """ + Base regression test for the energy function + """ + fn = get_periodic_boresch_energy_function(param) + assert fn == ( + f"{param} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*uphi_A + (K_phiB/2)*uphi_B + (K_phiC/2)*uphi_C; " + "uphi_A = (1-cos(dA)); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + + +@pytest.mark.parametrize('num_atoms', [6, 20]) +def test_custom_compound_force(num_atoms): + fn = get_boresch_energy_function('lambda_restraints') + force = get_custom_compound_bond_force(fn, num_atoms) + + # Check we have the right object + assert isinstance(force, openmm.CustomCompoundBondForce) + + # Check the energy function + assert force.getEnergyFunction() == fn + + # Check the number of particles + assert force.getNumParticlesPerBond() == num_atoms + + +@pytest.mark.parametrize('groups, expected', [ + [[0, 1, 2, 3, 4], 5], + [[1, 2, 3, 4, 5], 0], +]) +def test_add_force_in_separate_group(groups, expected): + # Create an empty system + system = openmm.System() + + # Create some forces with some force groups + base_forces = [ + openmm.NonbondedForce(), + openmm.HarmonicBondForce(), + openmm.HarmonicAngleForce(), + openmm.PeriodicTorsionForce(), + openmm.CMMotionRemover(), + ] + + for force, group in zip(base_forces, groups): + force.setForceGroup(group) + + [system.addForce(force) for force in base_forces] + + # Get your CustomCompoundBondForce + fn = get_boresch_energy_function('lambda_restraints') + new_force = get_custom_compound_bond_force(fn, 6) + # new_force.setForceGroup(5) + # system.addForce(new_force) + add_force_in_separate_group(system=system, force=new_force) + + # Loop through and check that we go assigned the expected force group + for force in system.getForces(): + if isinstance(force, openmm.CustomCompoundBondForce): + assert force.getForceGroup() == expected + + +def test_add_too_many_force_groups(): + # Create a system + system = openmm.System() + + # Fill it upu with 32 forces with different groups + for i in range(32): + f = openmm.HarmonicBondForce() + f.setForceGroup(i) + system.addForce(f) + + # Now try to add another force + with pytest.raises(ValueError, match="No available force group"): + add_force_in_separate_group( + system=system, force=openmm.HarmonicBondForce() + ) \ No newline at end of file diff --git a/openfe/tests/protocols/restraints/test_settings.py b/openfe/tests/protocols/restraints/test_settings.py new file mode 100644 index 000000000..7c027df79 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_settings.py @@ -0,0 +1,93 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Test the restraint settings. +""" +import pytest +from openff.units import unit +from openfe.protocols.restraint_utils.settings import ( + DistanceRestraintSettings, + FlatBottomRestraintSettings, + BoreschRestraintSettings, +) + + +def test_distance_restraint_settings_default(): + """ + Basic settings regression test + """ + settings = DistanceRestraintSettings( + spring_constant=10 * unit.kilojoule_per_mole / unit.nm ** 2, + ) + assert settings.central_atoms_only is False + assert isinstance(settings, DistanceRestraintSettings) + + +def test_distance_restraint_negative_idxs(): + """ + Check that an error is raised if you have negative + atom indices in host atoms. + """ + with pytest.raises(ValueError, match="negative indices passed"): + _ = DistanceRestraintSettings( + spring_constant=10 * unit.kilojoule_per_mole / unit.nm ** 2, + host_atoms=[-1, 0, 2], + guest_atoms=[0, 1, 2], + ) + + +def test_flatbottom_restraint_settings_default(): + """ + Basic settings regression test + """ + settings = FlatBottomRestraintSettings( + spring_constant=10 * unit.kilojoule_per_mole / unit.nm ** 2, + well_radius=1*unit.nanometer, + ) + assert isinstance(settings, FlatBottomRestraintSettings) + + +def test_flatbottom_restraint_negative_well(): + """ + Check that an error is raised if you have a negative + well radius. + """ + with pytest.raises(ValueError, match="negative indices passed"): + _ = DistanceRestraintSettings( + spring_constant=10 * unit.kilojoule_per_mole / unit.nm ** 2, + host_atoms=[-1, 0, 2], + guest_atoms=[0, 1, 2], + ) + + +def test_boresch_restraint_settings_default(): + """ + Basic settings regression test + """ + settings = BoreschRestraintSettings( + K_r=10 * unit.kilojoule_per_mole / unit.nm ** 2, + K_thetaA=10 * unit.kilojoule_per_mole / unit.radians ** 2, + K_thetaB=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_A0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_B0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_C0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + ) + assert isinstance(settings, BoreschRestraintSettings) + + +def test_boresch_restraint_negative_idxs(): + """ + Check that the positive_idxs_list validator is + working as expected. + """ + with pytest.raises(ValueError, match='negative indices'): + _ = BoreschRestraintSettings( + K_r=10 * unit.kilojoule_per_mole / unit.nm ** 2, + K_thetaA=10 * unit.kilojoule_per_mole / unit.radians ** 2, + K_thetaB=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_A0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_B0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + phi_C0=10 * unit.kilojoule_per_mole / unit.radians ** 2, + host_atoms=[-1, 0], + guest_atoms=[0, 1], + )