Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Steric clash detection #883

Merged
merged 4 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/relative_free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from timemachine.constants import DEFAULT_FF
from timemachine.fe import atom_mapping, pdb_writer
from timemachine.fe.rbfe import HostConfig, estimate_relative_free_energy, plot_atom_mapping_grid
from timemachine.fe.rbfe import HostConfig, estimate_relative_free_energy
from timemachine.fe.single_topology import AtomMapMixin
from timemachine.fe.utils import plot_atom_mapping_grid
from timemachine.ff import Forcefield
from timemachine.md import builders
from timemachine.testsystems.relative import get_hif2a_ligand_pair_single_topology
Expand Down
18 changes: 17 additions & 1 deletion tests/test_fe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from rdkit import Chem
from rdkit.Chem import AllChem

from timemachine.fe import utils
from timemachine.constants import DEFAULT_FF
from timemachine.fe import model_utils, utils
from timemachine.fe.model_utils import image_molecule
from timemachine.ff import Forcefield

pytestmark = [pytest.mark.nogpu]

Expand Down Expand Up @@ -115,3 +117,17 @@ def test_experimental_conversions_to_kj():
)

np.testing.assert_allclose(utils.convert_uM_to_kJ_per_mole(0.15), -38.951164)


def test_get_clashing_atoms():
ff = Forcefield.load_from_file(DEFAULT_FF)
np.random.seed(2022)
mol = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1"))
AllChem.EmbedMolecule(mol)
assert model_utils.get_clashing_atoms(mol, ff) == []

# force a clash
x0 = utils.get_romol_conf(mol)
x0[0, :] = x0[1, :] + 0.05
utils.set_romol_conf(mol, x0)
assert model_utils.get_clashing_atoms(mol, ff) == [0, 5, 6]
46 changes: 7 additions & 39 deletions tests/test_interpolate_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rdkit import Chem

from timemachine.constants import BOLTZ
from timemachine.fe import pdb_writer, single_topology, utils
from timemachine.fe import atom_mapping, pdb_writer, single_topology, utils
from timemachine.fe.system import simulate_system
from timemachine.fe.utils import get_romol_conf
from timemachine.ff import Forcefield
Expand All @@ -27,40 +27,12 @@ def test_hif2a_free_energy_estimates():
mol_a = all_mols[1]
mol_b = all_mols[4]

core = np.array(
[
[0, 0],
[2, 2],
[1, 1],
[6, 6],
[5, 5],
[4, 4],
[3, 3],
[15, 16],
[16, 17],
[17, 18],
[18, 19],
[19, 20],
[20, 21],
[32, 30],
[26, 25],
[27, 26],
[7, 7],
[8, 8],
[9, 9],
[10, 10],
[29, 11],
[11, 12],
[12, 13],
[14, 15],
[31, 29],
[13, 14],
[23, 24],
[30, 28],
[28, 27],
[21, 22],
]
)
core_smarts = atom_mapping.mcs(mol_a, mol_b).smartsString
query_mol = Chem.MolFromSmarts(core_smarts)
core = atom_mapping.get_core_by_mcs(mol_a, mol_b, query_mol)
svg = utils.plot_atom_mapping_grid(mol_a, mol_b, core_smarts, core)
with open("atom_mapping.svg", "w") as fh:
fh.write(svg)

st = single_topology.SingleTopology(mol_a, mol_b, core, forcefield)

Expand All @@ -77,10 +49,6 @@ def test_hif2a_free_energy_estimates():
kT = BOLTZ * 300.0
beta = 1 / kT

svg = utils.plot_atom_mapping_grid(mol_a, mol_b, core)
with open("atom_mapping.svg", "w") as fh:
fh.write(svg)

for lambda_idx, U_fn in enumerate(U_fns):
# print("lambda", lambda_schedule[lambda_idx], "U", U_fn(x0))
# continue
Expand Down
48 changes: 46 additions & 2 deletions timemachine/fe/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,58 @@
import tempfile
from typing import List, Optional

import jax
import numpy as np
from rdkit import Chem
from simtk.openmm import app

from timemachine.fe.topology import BaseTopology
from timemachine.fe.utils import get_romol_conf
from timemachine.ff import Forcefield

def assert_mol_has_all_hydrogens(mol: Chem.Mol):

def mol_has_all_hydrogens(mol: Chem.Mol) -> bool:
atoms = mol.GetNumAtoms()
mol_copy = Chem.AddHs(mol)
assert atoms == mol_copy.GetNumAtoms(), "Hydrogens missing for mol"
return atoms == mol_copy.GetNumAtoms()


def assert_mol_has_all_hydrogens(mol: Chem.Mol):
assert mol_has_all_hydrogens(mol), "Hydrogens missing for mol"


def get_vacuum_val_and_grad_fn(mol: Chem.Mol, ff: Forcefield):
"""
Return a function which returns the potential energy and frcs
jkausrelay marked this conversation as resolved.
Show resolved Hide resolved
at the given coordinates for the molecule in vacuum.
"""
top = BaseTopology(mol, ff)
vacuum_system = top.setup_end_state()
U = vacuum_system.get_U_fn()

grad_fn = jax.jit(jax.grad(U, argnums=(0)))

def val_and_grad_fn(x):
return U(x), grad_fn(x)

return val_and_grad_fn


def get_clashing_atoms(mol: Chem.Mol, ff: Forcefield, max_force: Optional[float] = 50000) -> List[float]:
"""
Return a list of atom indices that are clashing based on the max_force.
maxentile marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
max_force:
If the magnitude of the force on atom i is greater than max force,
consider this a clash.
"""
x0 = get_romol_conf(mol)
val_and_grad_fn = get_vacuum_val_and_grad_fn(mol, ff)
_, frcs = val_and_grad_fn(x0)
jkausrelay marked this conversation as resolved.
Show resolved Hide resolved
frcs = np.linalg.norm(frcs, axis=1)
return [int(x) for x in np.arange(x0.shape[0])[frcs > max_force]]


def apply_hmr(masses, bond_list, multiplier=2):
Expand Down
42 changes: 0 additions & 42 deletions timemachine/fe/rbfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import matplotlib.pyplot as plt
import numpy as np
import pymbar
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from scipy.spatial.distance import cdist

from timemachine.constants import BOLTZ, DEFAULT_TEMP
Expand All @@ -26,46 +24,6 @@
from timemachine.md.barostat.utils import get_bond_list, get_group_indices


def plot_atom_mapping_grid(mol_a, mol_b, core_smarts, core, show_idxs=False):
mol_a_2d = Chem.Mol(mol_a)
mol_b_2d = Chem.Mol(mol_b)
mol_q_2d = Chem.MolFromSmarts(core_smarts)

AllChem.Compute2DCoords(mol_q_2d)

q_to_a = [[int(x[0]), int(x[1])] for x in enumerate(core[:, 0])]
q_to_b = [[int(x[0]), int(x[1])] for x in enumerate(core[:, 1])]

AllChem.GenerateDepictionMatching2DStructure(mol_a_2d, mol_q_2d, atomMap=q_to_a)
AllChem.GenerateDepictionMatching2DStructure(mol_b_2d, mol_q_2d, atomMap=q_to_b)

atom_colors_a = {}
atom_colors_b = {}
atom_colors_q = {}
for c_idx, ((a_idx, b_idx), rgb) in enumerate(zip(core, np.random.random((len(core), 3)))):
atom_colors_a[int(a_idx)] = tuple(rgb.tolist())
atom_colors_b[int(b_idx)] = tuple(rgb.tolist())
atom_colors_q[int(c_idx)] = tuple(rgb.tolist())

if show_idxs:
for atom in mol_a_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
for atom in mol_b_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
for atom in mol_q_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))

return Draw.MolsToGridImage(
[mol_q_2d, mol_a_2d, mol_b_2d],
molsPerRow=3,
highlightAtomLists=[list(range(mol_q_2d.GetNumAtoms())), core[:, 0].tolist(), core[:, 1].tolist()],
highlightAtomColors=[atom_colors_q, atom_colors_a, atom_colors_b],
subImgSize=(300, 300),
legends=["core", get_mol_name(mol_a), get_mol_name(mol_b)],
useSVG=True,
)


def get_batch_U_fns(bps, lamb):
# return a function that takes in coords, boxes, lambda
all_U_fns = []
Expand Down
61 changes: 51 additions & 10 deletions timemachine/fe/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Optional

import numpy as np
import simtk.unit
from numpy.typing import NDArray
Expand Down Expand Up @@ -74,6 +76,31 @@ def draw_mol(mol, highlightAtoms, highlightColors):
# display(SVG(svg))


def draw_mol_idx(mol, highlight: Optional[List[int]] = None, scale_factor=None):
"""
Draw mol with atom indices labeled.

Pararmeters
-----------
highlight: List of int or None
If specified, highlight the given atom idxs.
"""
mol2d = Chem.Mol(mol)
AllChem.Compute2DCoords(mol2d)
if scale_factor:
AllChem.NormalizeDepiction(mol2d, scaleFactor=scale_factor)
for atom in mol2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
return Draw.MolsToGridImage(
[mol2d],
molsPerRow=1,
highlightAtomLists=[highlight] if highlight is not None else None,
subImgSize=(500, 500),
legends=[get_mol_name(mol2d)],
useSVG=True,
)


def get_atom_map_colors(core, seed=2022):
rng = np.random.default_rng(seed)

Expand All @@ -96,28 +123,42 @@ def plot_atom_mapping(mol_a, mol_b, core, seed=2022):
draw_mol(mol_b, core[:, 1].tolist(), atom_colors_b)


def plot_atom_mapping_grid(mol_a, mol_b, core, show_idxs=False, seed=2022):
def plot_atom_mapping_grid(mol_a, mol_b, core_smarts, core, show_idxs=False):
mol_a_2d = Chem.Mol(mol_a)
mol_b_2d = Chem.Mol(mol_b)
mol_q_2d = Chem.MolFromSmarts(core_smarts)

AllChem.Compute2DCoords(mol_q_2d)

AllChem.Compute2DCoords(mol_a_2d)
AllChem.GenerateDepictionMatching2DStructure(mol_b_2d, mol_a_2d, atomMap=core.tolist())
q_to_a = [[int(x[0]), int(x[1])] for x in enumerate(core[:, 0])]
q_to_b = [[int(x[0]), int(x[1])] for x in enumerate(core[:, 1])]

atom_colors_a, atom_colors_b = get_atom_map_colors(core, seed=seed)
AllChem.GenerateDepictionMatching2DStructure(mol_a_2d, mol_q_2d, atomMap=q_to_a)
AllChem.GenerateDepictionMatching2DStructure(mol_b_2d, mol_q_2d, atomMap=q_to_b)

atom_colors_a = {}
atom_colors_b = {}
atom_colors_q = {}
for c_idx, ((a_idx, b_idx), rgb) in enumerate(zip(core, np.random.random((len(core), 3)))):
atom_colors_a[int(a_idx)] = tuple(rgb.tolist())
atom_colors_b[int(b_idx)] = tuple(rgb.tolist())
atom_colors_q[int(c_idx)] = tuple(rgb.tolist())

if show_idxs:
for atom in mol_a_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
for atom in mol_b_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
for atom in mol_q_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))

return Draw.MolsToGridImage(
[mol_a_2d, mol_b_2d],
molsPerRow=2,
highlightAtomLists=[core[:, 0].tolist(), core[:, 1].tolist()],
highlightAtomColors=[atom_colors_a, atom_colors_b],
subImgSize=(400, 400),
legends=[mol_a.GetProp("_Name"), mol_b.GetProp("_Name")],
[mol_q_2d, mol_a_2d, mol_b_2d],
molsPerRow=3,
highlightAtomLists=[list(range(mol_q_2d.GetNumAtoms())), core[:, 0].tolist(), core[:, 1].tolist()],
highlightAtomColors=[atom_colors_q, atom_colors_a, atom_colors_b],
subImgSize=(300, 300),
legends=["core", get_mol_name(mol_a), get_mol_name(mol_b)],
useSVG=True,
)

Expand Down