forked from ehoogeboom/e3_diffusion_for_molecules
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpsi4_chain.py
162 lines (131 loc) · 5.34 KB
/
psi4_chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import psi4
import glob
import sys
import copy
import numpy as np
import ase
from ase.io import xyz
from ase.optimize import BFGS
from ase import units
from ase.calculators.psi4 import Psi4
from ase.calculators.calculator import CalculationFailed
from xtb.ase.calculator import XTB
from ase.build import molecule
from psi4.driver.p4util.exceptions import OptimizationConvergenceError, SCFConvergenceError
from qcelemental.exceptions import ValidationError
def xyz_to_mol(xyz_fn):
return ase.io.read(xyz_fn, format="xyz")
def get_xtb_calc():
return XTB(method="GFN2-xTB")
def get_psi4_calc(atoms, basis="6-31G_2df_p_", num_threads=1):
# qm9 functional/basis
try:
return Psi4(atoms=atoms, method="B3LYP", basis=basis,
memory="16GB", num_threads=num_threads)
except ValidationError as err:
return None
print("WARNING: psi4 failed on atoms:")
print(atoms.symbols)
print(atoms.positions)
print(err)
breakpoint()
return None
# faster/less accurate
#return Psi4(atoms=atoms, method="pbe", basis="6-31g", memory="16GB", num_threads=8)
# I think cc-pVDZ is more common than 6-31G(2df,p)
#return Psi4(atoms=atoms, method="B3LYP", basis="cc-pVDZ", memory="16GB", num_threads=8)
def get_ef(atoms, method="psi4", basis="6-31G_2df_p_", num_threads=1):
if method == "psi4":
calc = get_psi4_calc(atoms, basis=basis, num_threads=num_threads)
elif method == "xtb":
calc = get_xtb_calc()
if calc is None:
return 0, np.zeros_like(atoms.get_positions())
atoms.calc = calc
try:
return atoms.get_potential_energy(), atoms.get_forces()
except (CalculationFailed, SCFConvergenceError) as err:
print("WARNING: Calculation Failed on atoms:")
ase.io.write("-", atoms, format="xyz")
return np.nan, np.nan * np.ones_like(atoms.get_positions())
def relax(atoms, method="psi4", basis="6-31G_2df_p_", fmax=0.03, num_threads=1):
atoms = ase.Atoms(atoms)
if method == "psi4":
calc = get_psi4_calc(atoms, basis=basis, num_threads=num_threads)
elif method == "xtb":
calc = get_xtb_calc()
if calc is None:
return False, atoms, 0
atoms.calc = calc
opt = BFGS(atoms)
opt.run(fmax=fmax)
return True, atoms, opt.get_number_of_steps()
def chain_fn(chain_id, frame_id):
fn_pattern = "outputs/edm_qm9/eval/chain_{}/chain_{:0>3d}.txt"
return fn_pattern.format(chain_id, frame_id)
def chain_summary_fn(chain_id):
return "outputs/edm_qm9/eval/chain_{}/chain_summary.npy".format(chain_id)
def process_chain(chain_id):
mol = xyz_to_mol(chain_fn(chain_id, 999))
final_atomic_numbers = mol.get_atomic_numbers()
# compute ground state from the last frame in the chain
molgs = relax(mol)
egs, fgs = get_ef(molgs)
avg_cos_similarities = []
energy_suboptimalities = []
all_positions = [molgs.get_positions()]
all_forces = [fgs]
frame_id = 998
# we're looping backwards over frames, so keep track of the "next" mol
# rather than the "previous" one
next_mol = ase.Atoms(mol)
while frame_id >= 0:
mol = xyz_to_mol(chain_fn(chain_id, frame_id))
# only look at frames that have the same atomic numbers as the final frame
if not (mol.get_atomic_numbers() == final_atomic_numbers).all():
break
e, f = get_ef(mol)
# compare energy to ground-state energy
energy_suboptimalities.append((e - egs) / (units.kcal / units.mol))
# compare forces on current frame to the displacement btwn current & next
displacement = next_mol.get_positions() - mol.get_positions()
f_norm = np.linalg.norm(f, axis=1)
displacement_norm = np.linalg.norm(displacement, axis=1)
f_dot_displacement = (f * displacement).sum(axis=1)
cos_theta = f_dot_displacement / (f_norm * displacement_norm)
avg_cos_similarities.append(cos_theta.mean())
all_positions.append(mol.get_positions())
all_forces.append(f)
frame_id -= 1
next_mol = ase.Atoms(mol)
with open(chain_summary_fn(chain_id), "wb") as f:
np.savez(f,
energy_suboptimalities=energy_suboptimalities,
avg_cos_similarities=avg_cos_similarities,
all_positions=np.array(all_positions),
all_forces=np.array(all_forces),
gs_positions=molgs.get_positions())
return energy_suboptimalities, avg_cos_similarities
if __name__ == "__main__":
epoch = int(sys.argv[1])
psi4.set_memory("32 GB")
#psi4.core.set_output_file("psi4_output.{}.txt".format(epoch))
#psi4.set_options({'reference': 'uhf'})
#psi4.set_num_threads(2)
for chain_id in range(100):
delta_e, cos_sim = process_chain(chain_id)
print("chain", chain_id)
print(delta_e)
print(cos_sim)
exit()
fn99 = "outputs/edm_qm9/epoch_{}_0/chain/chain_099.txt".format(epoch)
fn98 = "outputs/edm_qm9/epoch_{}_0/chain/chain_098.txt".format(epoch)
mol99 = xyz_to_mol(fn99)
mol98 = xyz_to_mol(fn98)
e98, f98 = get_ef(mol98)
e99, f99 = get_ef(mol99)
molgs = relax(mol99)
egs, fgs = get_ef(molgs)
force_norm = np.linalg.norm(f99, axis=1)
print(", ".join(map(str, [epoch, force_norm.min(), force_norm.mean(), force_norm.max(), (e99 - egs) / (kcal / mol)])))
breakpoint()