Skip to content

Commit

Permalink
add new class,HFSP,HFSP_noPBC, for OCP
Browse files Browse the repository at this point in the history
  • Loading branch information
tdprice-858 authored and sakim8048 committed Apr 23, 2024
1 parent 9e69b72 commit 6c736c2
Showing 1 changed file with 58 additions and 44 deletions.
102 changes: 58 additions & 44 deletions pynta/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,37 @@ def calculate(self, atoms=None, properties=None, system_changes=calculator.all_c
energy,forces = self.get_energy_forces()
self.results["energy"] += energy
self.results["free_energy"] += energy
self.results["forces"] += force

class HarmonicallyForcedOCP(OCPCalculator):
def __init__(self, config_yml, checkpoint_path,trainer, cutoff, max_neighbors, cpu, seed, atom_bond_potentials, site_bond_potentials):
super().__init__(config_yml, checkpoint_path,trainer,cutoff,max_neighbors, cpu, seed)
self.parameters["site_bond_potentials"] = site_bond_potentials
self.parameters["atom_bond_potentials"] = atom_bond_potentials
print(self.parameters)
def get_energy_forces(self):
energy = 0.0
forces = np.zeros(self.atoms.positions.shape)
if hasattr(self.parameters,"atom_bond_potentials"):
for atom_bond_potential in self.parameters.atom_bond_potentials:
E,F = get_energy_forces_atom_bond(self.atoms,**atom_bond_potential)
energy += E
forces += F

if hasattr(self.parameters,"site_bond_potentials"):
for site_bond_potential in self.parameters.site_bond_potentials:
E,F = get_energy_forces_site_bond(self.atoms,**site_bond_potential)
energy += E
print(energy)
forces += F
print(forces)
print(f"energy = {energy} and forces = {forces}")
return energy[0][0],forces

def calculate(self, atoms=None, properties=None, system_changes=calculator.all_changes):
OCPCalculator.calculate(self,atoms=atoms,properties=properties,system_changes=system_changes)
energy,forces = self.get_energy_forces()
self.results["energy"] += energy
self.results["forces"] += forces

def run_harmonically_forced_xtb(atoms,atom_bond_potentials,site_bond_potentials,nslab,
Expand Down Expand Up @@ -124,57 +155,34 @@ def run_harmonically_forced_xtb(atoms,atom_bond_potentials,site_bond_potentials,
))

atoms.set_constraint(out_constraints)

hfxtb = HarmonicallyForcedXTB(method="GFN1-xTB",
atom_bond_potentials=atom_bond_potentials,
site_bond_potentials=site_bond_potentials)

atoms.calc = hfxtb

opt = Sella(atoms,trajectory="xtbharm.traj",order=0)
#hfxtb = HarmonicallyForcedXTB(method="GFN1-xTB",
# atom_bond_potentials=atom_bond_potentials,
# site_bond_potentials=site_bond_potentials)
#hfml = HarmonicallyForcedDeepMD(model="/global/cfs/cdirs/m3548/tdprice/12_ML_DiffMod_training/iter4/train/00/graph.pb",
# atom_bond_potentials=atom_bond_potentials,
# site_bond_potentials=site_bond_potentials)
#print(help(HarmonicallyForcedOCP))
hfml = HarmonicallyForcedOCP(config_yml = None, checkpoint_path = '/global/cfs/cdirs/m4126/tdprice/13_OCP_tests/checkpoints/eq2_153M_ec4_allmd.pt', trainer=None, cutoff=6, max_neighbors=50, cpu=True, seed = None, atom_bond_potentials=atom_bond_potentials, site_bond_potentials=site_bond_potentials)
#atoms.calc = hfxtb
print(f"hfml {hfml}")
print(f"atom_bond_potentials = {atom_bond_potentials}")
atoms.calc = hfml
view(atoms)
opt = Sella(atoms,trajectory=f"xtbharm_{i}.traj",order=0)

try:
opt.run(fmax=0.02,steps=150)
except Exception as e: #no pbc fallback
return run_harmonically_forced_xtb_no_pbc(atoms,atom_bond_potentials,site_bond_potentials,nslab,
print(e)
return run_harmonically_forced_xtb_no_pbc(pbc, atoms,atom_bond_potentials,site_bond_potentials,nslab, i,
molecule_to_atom_maps=molecule_to_atom_maps,ase_to_mol_num=ase_to_mol_num,
constraints=constraints,method=method,dthresh=4.0)

Eharm,Fharm = atoms.calc.get_energy_forces()

return atoms,Eharm,Fharm

class HarmonicallyForcedOCP(OCPCalculator):
def __init__(self, config_yml, checkpoint_path,trainer, cutoff, max_neighbors, cpu, seed, atom_bond_potentials, site_bond_potentials):
super().__init__(config_yml, checkpoint_path,trainer,cutoff,max_neighbors, cpu, seed)
self.parameters["site_bond_potentials"] = site_bond_potentials
self.parameters["atom_bond_potentials"] = atom_bond_potentials
print(self.parameters)
def get_energy_forces(self):
energy = 0.0
forces = np.zeros(self.atoms.positions.shape)
if hasattr(self.parameters,"atom_bond_potentials"):
for atom_bond_potential in self.parameters.atom_bond_potentials:
E,F = get_energy_forces_atom_bond(self.atoms,**atom_bond_potential)
energy += E
forces += F

if hasattr(self.parameters,"site_bond_potentials"):
for site_bond_potential in self.parameters.site_bond_potentials:
E,F = get_energy_forces_site_bond(self.atoms,**site_bond_potential)
energy += E
print(energy)
forces += F
print(forces)
print(f"energy = {energy} and forces = {forces}")
return energy[0][0],forces

def calculate(self, atoms=None, properties=None, system_changes=calculator.all_changes):
OCPCalculator.calculate(self,atoms=atoms,properties=properties,system_changes=system_changes)
energy,forces = self.get_energy_forces()
self.results["energy"] += energy
self.results["forces"] += forces

def run_harmonically_forced_xtb_no_pbc(atoms,atom_bond_potentials,site_bond_potentials,nslab,
molecule_to_atom_maps,ase_to_mol_num=None,
constraints=[],method="GFN1-xTB",dthresh=4.0):
Expand Down Expand Up @@ -317,13 +325,19 @@ def run_harmonically_forced_xtb_no_pbc(atoms,atom_bond_potentials,site_bond_pote
indices=list(range(n))
))

hfxtb = HarmonicallyForcedXTB(method="GFN1-xTB",
atom_bond_potentials=new_atom_bond_potentials,
site_bond_potentials=new_site_potentials)
#hfml = HarmonicallyForcedDeepMD(model="/global/cfs/cdirs/m3548/tdprice/12_ML_DiffMod_training/iter4/train/00/graph.pb",
# atom_bond_potentials=new_atom_bond_potentials,
# site_bond_potentials=new_site_potentials)
hfml = HarmonicallyForcedOCP(config_yml = None, checkpoint_path = '/global/cfs/cdirs/m4126/tdprice/13_OCP_tests/checkpoints/eq2_153M_ec4_allmd.pt', trainer=None, cutoff=6, max_neighbors=50, cpu=True, seed = None, atom_bond_potentials=atom_bond_potentials, site_bond_potentials=site_bond_potentials)
#hfml = HarmonicallyForcedOCP(checkpoint_path = '/global/cfs/cdirs/m4126/tdprice/13_OCP_tests/checkpoints/eq2_153M_ec4_allmd.pt', cpu=False, atom_bond_potentials=new_atom_bond_potentials, site_bond_potentials=new_site_potentials)
#hfxtb = HarmonicallyForcedXTB(method="GFN1-xTB",
# atom_bond_potentials=new_atom_bond_potentials,
# site_bond_potentials=new_site_potentials)
bigad.set_constraint(out_constraints)
bigad.calc = hfxtb
#bigad.calc = hfxtb
bigad.calc = hfml

opt = Sella(bigad,trajectory="xtbharm.traj",order=0)
opt = Sella(bigad,trajectory=f"xtbharm_{i}.traj",order=0)

try:
opt.run(fmax=0.02,steps=150)
Expand Down

0 comments on commit 6c736c2

Please sign in to comment.