Skip to content

Commit

Permalink
linted all files (#332)
Browse files Browse the repository at this point in the history
* linted all files

* remove var that doesn't exist

* ignore E501 for test_template_generators.py
  • Loading branch information
mikemhenry authored May 3, 2024
1 parent 212d00e commit a0b79a6
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 40 deletions.
8 changes: 4 additions & 4 deletions openmmforcefields/generators/template_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,11 +1092,11 @@ def as_attrib(quantity):
atom_types = etree.SubElement(root, "AtomTypes")
for atom_index, atom in enumerate(molecule.atoms):
# Create a new atom type for each atom in the molecule
paricle_indices = [atom_index]
element_symbol = atom.symbol
atom_type = etree.SubElement(atom_types, "Type", name=atom.typename,
element=element_symbol, mass=as_attrib(atom.mass))
atom_type.set('class', atom.typename) # 'class' is a reserved Python keyword, so use alternative API
atom_type = etree.SubElement(
atom_types, "Type", name=atom.typename, element=element_symbol, mass=as_attrib(atom.mass)
)
atom_type.set("class", atom.typename) # 'class' is a reserved Python keyword, so use alternative API

supported_forces = {
"NonbondedForce",
Expand Down
94 changes: 58 additions & 36 deletions openmmforcefields/tests/test_template_generators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ruff: noqa: E501
import copy
import logging
import os
Expand Down Expand Up @@ -67,7 +68,6 @@ def setUp(self):
molecule = Molecule.from_smiles("C=O")
molecule.generate_conformers(n_conformers=1)


molecule.conformers[0][0, 0] += unit.Quantity(0.1, unit.angstroms)

molecules.insert(0, molecule)
Expand Down Expand Up @@ -373,10 +373,8 @@ def test_add_molecules(self):
try:
system = forcefield.createSystem(openmm_topology, nonbondedMethod=NoCutoff)
except Exception as e:

print(forcefield._atomTypes.keys())


PDBFile.writeFile(
openmm_topology,
molecule.conformers[0].to_openmm(),
Expand Down Expand Up @@ -862,19 +860,29 @@ def compute_energy(system, positions):
system = copy.deepcopy(system)
for index, force in enumerate(system.getForces()):
force.setForceGroup(index)
platform = openmm.Platform.getPlatformByName('Reference')
platform = openmm.Platform.getPlatformByName("Reference")
integrator = openmm.VerletIntegrator(0.001)
context = openmm.Context(system, integrator, platform)
context.setPositions(positions)
openmm_energy = {
'total' : context.getState(getEnergy=True).getPotentialEnergy(),
'components' : { system.getForce(index).__class__.__name__ : context.getState(getEnergy=True, groups=(1 << index)).getPotentialEnergy() for index in range(system.getNumForces()) },
}
"total": context.getState(getEnergy=True).getPotentialEnergy(),
"components": {
system.getForce(index).__class__.__name__: context.getState(
getEnergy=True, groups=(1 << index)
).getPotentialEnergy()
for index in range(system.getNumForces())
},
}

openmm_forces = {
'total' : context.getState(getForces=True).getForces(asNumpy=True),
'components' : { system.getForce(index).__class__.__name__ : context.getState(getForces=True, groups=(1 << index)).getForces(asNumpy=True) for index in range(system.getNumForces()) },
}
"total": context.getState(getForces=True).getForces(asNumpy=True),
"components": {
system.getForce(index).__class__.__name__: context.getState(
getForces=True, groups=(1 << index)
).getForces(asNumpy=True)
for index in range(system.getNumForces())
},
}

del context, integrator
return openmm_energy, openmm_forces
Expand Down Expand Up @@ -909,67 +917,81 @@ def compare_energies(cls, molecule, template_generated_system, reference_system)
from openmm import unit

def write_xml(filename, system):
with open(filename, 'w') as outfile:
print(f'Writing {filename}...')
with open(filename, "w") as outfile:
print(f"Writing {filename}...")
outfile.write(openmm.XmlSerializer.serialize(system))
# DEBUG
print(openmm.XmlSerializer.serialize(system))

# Make sure both systems contain the same energy components
reference_components = set(reference_energy['components'])
template_components = set(template_energy['components'])
reference_components = set(reference_energy["components"])
template_components = set(template_energy["components"])
if len(reference_components.difference(template_components)) > 0:
raise Exception(f'Reference system contains components {reference_components.difference(template_components)} that do not appear in template-generated system.')
raise Exception(
f"Reference system contains components {reference_components.difference(template_components)} that do not appear in template-generated system."
)
if len(template_components.difference(reference_components)) > 0:
raise Exception(f'Template-generated system contains components {template_components.difference(reference_components)} that do not appear in reference system.')
raise Exception(
f"Template-generated system contains components {template_components.difference(reference_components)} that do not appear in reference system."
)
components = reference_components

# Compare energies
ENERGY_DEVIATION_TOLERANCE = 1.0e-2 * unit.kilocalories_per_mole
delta = (template_energy['total'] - reference_energy['total'])
delta = template_energy["total"] - reference_energy["total"]
if abs(delta) > ENERGY_DEVIATION_TOLERANCE:
# Show breakdown by components
print('Energy components:')
print("Energy components:")
print(f"{'component':24} {'Template (kcal/mol)':>20} {'Reference (kcal/mol)':>20}")
for key in components:
reference_component_energy = reference_energy['components'][key]
template_component_energy = template_energy['components'][key]
print(f'{key:24} {(template_component_energy/unit.kilocalories_per_mole):20.3f} {(reference_component_energy/unit.kilocalories_per_mole):20.3f} kcal/mol')
print(f'{"TOTAL":24} {(template_energy["total"]/unit.kilocalories_per_mole):20.3f} {(reference_energy["total"]/unit.kilocalories_per_mole):20.3f} kcal/mol')
write_xml('reference_system.xml', reference_system)
write_xml('template_system.xml', template_system) # What's this? This variable does not exist
raise Exception(f'Energy deviation for {molecule.to_smiles()} ({delta/unit.kilocalories_per_mole} kcal/mol) exceeds threshold ({ENERGY_DEVIATION_TOLERANCE})')
reference_component_energy = reference_energy["components"][key]
template_component_energy = template_energy["components"][key]
print(
f"{key:24} {(template_component_energy/unit.kilocalories_per_mole):20.3f} {(reference_component_energy/unit.kilocalories_per_mole):20.3f} kcal/mol"
)
print(
f'{"TOTAL":24} {(template_energy["total"]/unit.kilocalories_per_mole):20.3f} {(reference_energy["total"]/unit.kilocalories_per_mole):20.3f} kcal/mol'
)
write_xml("reference_system.xml", reference_system)
raise Exception(
f"Energy deviation for {molecule.to_smiles()} ({delta/unit.kilocalories_per_mole} kcal/mol) exceeds threshold ({ENERGY_DEVIATION_TOLERANCE})"
)

# Compare forces
def norm(x):
N = x.shape[0]
return np.sqrt((1.0/N) * (x**2).sum())
return np.sqrt((1.0 / N) * (x**2).sum())

def relative_deviation(x, y):
FORCE_UNIT = unit.kilocalories_per_mole / unit.angstroms
if hasattr(x, 'value_in_unit'):
if hasattr(x, "value_in_unit"):
x = x / FORCE_UNIT
if hasattr(y, 'value_in_unit'):
if hasattr(y, "value_in_unit"):
y = y / FORCE_UNIT

if norm(y) > 0:
return norm(x-y) / np.sqrt(norm(x)**2 + norm(y)**2)
return norm(x - y) / np.sqrt(norm(x) ** 2 + norm(y) ** 2)
else:
return 0

RELATIVE_FORCE_DEVIATION_TOLERANCE = 1.0e-5
relative_force_deviation = relative_deviation(template_forces['total'], reference_forces['total'])
relative_force_deviation = relative_deviation(template_forces["total"], reference_forces["total"])
if relative_force_deviation > RELATIVE_FORCE_DEVIATION_TOLERANCE:
# Show breakdown by components
print('Force components:')
print("Force components:")
print(f"{'component':24} {'relative deviation':>24}")
for key in components:
print(f"{key:24} {relative_deviation(template_forces['components'][key], reference_forces['components'][key]):24.10f}")
print(
f"{key:24} {relative_deviation(template_forces['components'][key], reference_forces['components'][key]):24.10f}"
)
print(f'{"TOTAL":24} {relative_force_deviation:24.10f}')
write_xml('system-smirnoff.xml', reference_system)
write_xml('openmm-smirnoff.xml', template_generated_system)
raise Exception(f'Relative force deviation for {molecule.to_smiles()} ({relative_force_deviation}) exceeds threshold ({RELATIVE_FORCE_DEVIATION_TOLERANCE})')
write_xml("system-smirnoff.xml", reference_system)
write_xml("openmm-smirnoff.xml", template_generated_system)
raise Exception(
f"Relative force deviation for {molecule.to_smiles()} ({relative_force_deviation}) exceeds threshold ({RELATIVE_FORCE_DEVIATION_TOLERANCE})"
)



class TestSMIRNOFFTemplateGenerator(TemplateGeneratorBaseCase):
TEMPLATE_GENERATOR = SMIRNOFFTemplateGenerator

Expand Down

0 comments on commit a0b79a6

Please sign in to comment.