Skip to content

Commit

Permalink
Steric clash detection (#883)
Browse files Browse the repository at this point in the history
* Code to detect steric clashes and mol plot cleanup

* update tests

* CR comments, fix test
  • Loading branch information
jkausrelay authored Oct 14, 2022
1 parent 7ff2872 commit e6233c9
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 95 deletions.
3 changes: 2 additions & 1 deletion examples/relative_free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from timemachine.constants import DEFAULT_FF
from timemachine.fe import atom_mapping, pdb_writer
from timemachine.fe.rbfe import HostConfig, estimate_relative_free_energy, plot_atom_mapping_grid
from timemachine.fe.rbfe import HostConfig, estimate_relative_free_energy
from timemachine.fe.single_topology import AtomMapMixin
from timemachine.fe.utils import plot_atom_mapping_grid
from timemachine.ff import Forcefield
from timemachine.md import builders
from timemachine.testsystems.relative import get_hif2a_ligand_pair_single_topology
Expand Down
18 changes: 17 additions & 1 deletion tests/test_fe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from rdkit import Chem
from rdkit.Chem import AllChem

from timemachine.fe import utils
from timemachine.constants import DEFAULT_FF
from timemachine.fe import model_utils, utils
from timemachine.fe.model_utils import image_molecule
from timemachine.ff import Forcefield

pytestmark = [pytest.mark.nogpu]

Expand Down Expand Up @@ -115,3 +117,17 @@ def test_experimental_conversions_to_kj():
)

np.testing.assert_allclose(utils.convert_uM_to_kJ_per_mole(0.15), -38.951164)


def test_get_strained_atoms():
ff = Forcefield.load_from_file(DEFAULT_FF)
np.random.seed(2022)
mol = Chem.AddHs(Chem.MolFromSmiles("c1ccccc1"))
AllChem.EmbedMolecule(mol)
assert model_utils.get_strained_atoms(mol, ff) == []

# force a clash
x0 = utils.get_romol_conf(mol)
x0[-2, :] = x0[-1, :] + 0.01
utils.set_romol_conf(mol, x0)
assert model_utils.get_strained_atoms(mol, ff) == [10, 11]
46 changes: 7 additions & 39 deletions tests/test_interpolate_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rdkit import Chem

from timemachine.constants import BOLTZ
from timemachine.fe import pdb_writer, single_topology, utils
from timemachine.fe import atom_mapping, pdb_writer, single_topology, utils
from timemachine.fe.system import simulate_system
from timemachine.fe.utils import get_romol_conf
from timemachine.ff import Forcefield
Expand All @@ -27,40 +27,12 @@ def test_hif2a_free_energy_estimates():
mol_a = all_mols[1]
mol_b = all_mols[4]

core = np.array(
[
[0, 0],
[2, 2],
[1, 1],
[6, 6],
[5, 5],
[4, 4],
[3, 3],
[15, 16],
[16, 17],
[17, 18],
[18, 19],
[19, 20],
[20, 21],
[32, 30],
[26, 25],
[27, 26],
[7, 7],
[8, 8],
[9, 9],
[10, 10],
[29, 11],
[11, 12],
[12, 13],
[14, 15],
[31, 29],
[13, 14],
[23, 24],
[30, 28],
[28, 27],
[21, 22],
]
)
core_smarts = atom_mapping.mcs(mol_a, mol_b).smartsString
query_mol = Chem.MolFromSmarts(core_smarts)
core = atom_mapping.get_core_by_mcs(mol_a, mol_b, query_mol)
svg = utils.plot_atom_mapping_grid(mol_a, mol_b, core_smarts, core)
with open("atom_mapping.svg", "w") as fh:
fh.write(svg)

st = single_topology.SingleTopology(mol_a, mol_b, core, forcefield)

Expand All @@ -77,10 +49,6 @@ def test_hif2a_free_energy_estimates():
kT = BOLTZ * 300.0
beta = 1 / kT

svg = utils.plot_atom_mapping_grid(mol_a, mol_b, core)
with open("atom_mapping.svg", "w") as fh:
fh.write(svg)

for lambda_idx, U_fn in enumerate(U_fns):
# print("lambda", lambda_schedule[lambda_idx], "U", U_fn(x0))
# continue
Expand Down
48 changes: 46 additions & 2 deletions timemachine/fe/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,58 @@
import tempfile
from typing import List, Optional

import jax
import numpy as np
from rdkit import Chem
from simtk.openmm import app

from timemachine.fe.topology import BaseTopology
from timemachine.fe.utils import get_romol_conf
from timemachine.ff import Forcefield

def assert_mol_has_all_hydrogens(mol: Chem.Mol):

def mol_has_all_hydrogens(mol: Chem.Mol) -> bool:
atoms = mol.GetNumAtoms()
mol_copy = Chem.AddHs(mol)
assert atoms == mol_copy.GetNumAtoms(), "Hydrogens missing for mol"
return atoms == mol_copy.GetNumAtoms()


def assert_mol_has_all_hydrogens(mol: Chem.Mol):
assert mol_has_all_hydrogens(mol), "Hydrogens missing for mol"


def get_vacuum_val_and_grad_fn(mol: Chem.Mol, ff: Forcefield):
"""
Return a function which returns the potential energy and gradients
at the coordinates for the molecule in vacuum.
"""
top = BaseTopology(mol, ff)
vacuum_system = top.setup_end_state()
U = vacuum_system.get_U_fn()

grad_fn = jax.jit(jax.grad(U, argnums=(0)))

def val_and_grad_fn(x):
return U(x), grad_fn(x)

return val_and_grad_fn


def get_strained_atoms(mol: Chem.Mol, ff: Forcefield, max_force: Optional[float] = 50000) -> List[float]:
"""
Return a list of atom indices that are strained based on the max_force.
Parameters
----------
max_force:
If the magnitude of the force on atom i is greater than max force,
consider this a clash.
"""
x0 = get_romol_conf(mol)
val_and_grad_fn = get_vacuum_val_and_grad_fn(mol, ff)
_, grads = val_and_grad_fn(x0)
norm_grads = np.linalg.norm(grads, axis=1)
return [int(x) for x in np.arange(x0.shape[0])[norm_grads > max_force]]


def apply_hmr(masses, bond_list, multiplier=2):
Expand Down
42 changes: 0 additions & 42 deletions timemachine/fe/rbfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import matplotlib.pyplot as plt
import numpy as np
import pymbar
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from scipy.spatial.distance import cdist

from timemachine.constants import BOLTZ, DEFAULT_TEMP
Expand All @@ -26,46 +24,6 @@
from timemachine.md.barostat.utils import get_bond_list, get_group_indices


def plot_atom_mapping_grid(mol_a, mol_b, core_smarts, core, show_idxs=False):
mol_a_2d = Chem.Mol(mol_a)
mol_b_2d = Chem.Mol(mol_b)
mol_q_2d = Chem.MolFromSmarts(core_smarts)

AllChem.Compute2DCoords(mol_q_2d)

q_to_a = [[int(x[0]), int(x[1])] for x in enumerate(core[:, 0])]
q_to_b = [[int(x[0]), int(x[1])] for x in enumerate(core[:, 1])]

AllChem.GenerateDepictionMatching2DStructure(mol_a_2d, mol_q_2d, atomMap=q_to_a)
AllChem.GenerateDepictionMatching2DStructure(mol_b_2d, mol_q_2d, atomMap=q_to_b)

atom_colors_a = {}
atom_colors_b = {}
atom_colors_q = {}
for c_idx, ((a_idx, b_idx), rgb) in enumerate(zip(core, np.random.random((len(core), 3)))):
atom_colors_a[int(a_idx)] = tuple(rgb.tolist())
atom_colors_b[int(b_idx)] = tuple(rgb.tolist())
atom_colors_q[int(c_idx)] = tuple(rgb.tolist())

if show_idxs:
for atom in mol_a_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
for atom in mol_b_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
for atom in mol_q_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))

return Draw.MolsToGridImage(
[mol_q_2d, mol_a_2d, mol_b_2d],
molsPerRow=3,
highlightAtomLists=[list(range(mol_q_2d.GetNumAtoms())), core[:, 0].tolist(), core[:, 1].tolist()],
highlightAtomColors=[atom_colors_q, atom_colors_a, atom_colors_b],
subImgSize=(300, 300),
legends=["core", get_mol_name(mol_a), get_mol_name(mol_b)],
useSVG=True,
)


def get_batch_U_fns(bps, lamb):
# return a function that takes in coords, boxes, lambda
all_U_fns = []
Expand Down
61 changes: 51 additions & 10 deletions timemachine/fe/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Optional

import numpy as np
import simtk.unit
from numpy.typing import NDArray
Expand Down Expand Up @@ -74,6 +76,31 @@ def draw_mol(mol, highlightAtoms, highlightColors):
# display(SVG(svg))


def draw_mol_idx(mol, highlight: Optional[List[int]] = None, scale_factor=None):
"""
Draw mol with atom indices labeled.
Pararmeters
-----------
highlight: List of int or None
If specified, highlight the given atom idxs.
"""
mol2d = Chem.Mol(mol)
AllChem.Compute2DCoords(mol2d)
if scale_factor:
AllChem.NormalizeDepiction(mol2d, scaleFactor=scale_factor)
for atom in mol2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
return Draw.MolsToGridImage(
[mol2d],
molsPerRow=1,
highlightAtomLists=[highlight] if highlight is not None else None,
subImgSize=(500, 500),
legends=[get_mol_name(mol2d)],
useSVG=True,
)


def get_atom_map_colors(core, seed=2022):
rng = np.random.default_rng(seed)

Expand All @@ -96,28 +123,42 @@ def plot_atom_mapping(mol_a, mol_b, core, seed=2022):
draw_mol(mol_b, core[:, 1].tolist(), atom_colors_b)


def plot_atom_mapping_grid(mol_a, mol_b, core, show_idxs=False, seed=2022):
def plot_atom_mapping_grid(mol_a, mol_b, core_smarts, core, show_idxs=False):
mol_a_2d = Chem.Mol(mol_a)
mol_b_2d = Chem.Mol(mol_b)
mol_q_2d = Chem.MolFromSmarts(core_smarts)

AllChem.Compute2DCoords(mol_q_2d)

AllChem.Compute2DCoords(mol_a_2d)
AllChem.GenerateDepictionMatching2DStructure(mol_b_2d, mol_a_2d, atomMap=core.tolist())
q_to_a = [[int(x[0]), int(x[1])] for x in enumerate(core[:, 0])]
q_to_b = [[int(x[0]), int(x[1])] for x in enumerate(core[:, 1])]

atom_colors_a, atom_colors_b = get_atom_map_colors(core, seed=seed)
AllChem.GenerateDepictionMatching2DStructure(mol_a_2d, mol_q_2d, atomMap=q_to_a)
AllChem.GenerateDepictionMatching2DStructure(mol_b_2d, mol_q_2d, atomMap=q_to_b)

atom_colors_a = {}
atom_colors_b = {}
atom_colors_q = {}
for c_idx, ((a_idx, b_idx), rgb) in enumerate(zip(core, np.random.random((len(core), 3)))):
atom_colors_a[int(a_idx)] = tuple(rgb.tolist())
atom_colors_b[int(b_idx)] = tuple(rgb.tolist())
atom_colors_q[int(c_idx)] = tuple(rgb.tolist())

if show_idxs:
for atom in mol_a_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
for atom in mol_b_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
for atom in mol_q_2d.GetAtoms():
atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))

return Draw.MolsToGridImage(
[mol_a_2d, mol_b_2d],
molsPerRow=2,
highlightAtomLists=[core[:, 0].tolist(), core[:, 1].tolist()],
highlightAtomColors=[atom_colors_a, atom_colors_b],
subImgSize=(400, 400),
legends=[mol_a.GetProp("_Name"), mol_b.GetProp("_Name")],
[mol_q_2d, mol_a_2d, mol_b_2d],
molsPerRow=3,
highlightAtomLists=[list(range(mol_q_2d.GetNumAtoms())), core[:, 0].tolist(), core[:, 1].tolist()],
highlightAtomColors=[atom_colors_q, atom_colors_a, atom_colors_b],
subImgSize=(300, 300),
legends=["core", get_mol_name(mol_a), get_mol_name(mol_b)],
useSVG=True,
)

Expand Down

0 comments on commit e6233c9

Please sign in to comment.