Skip to content

Commit

Permalink
fix: add optional force check (#780)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Bug Fixes**
- Enhanced error handling when processing forces in various data
formats.
- Added conditional checks to prevent potential runtime errors when
force data is missing.
- Improved robustness of data conversion methods across multiple
plugins.

- **Refactor**
	- Streamlined data handling for optional force and virial information.
- Implemented safer data extraction methods in ASE, PWmat, and VASP
plugins.
- Corrected a typographical error in the documentation of the driver
methods.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Jan 16, 2025
1 parent b826633 commit 4e5ab18
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
3 changes: 2 additions & 1 deletion dpdata/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def calculate(
self.results["energy"] = data["energies"][0]
# see https://gitlab.com/ase/ase/-/merge_requests/2485
self.results["free_energy"] = data["energies"][0]
self.results["forces"] = data["forces"][0]
if "forces" in data:
self.results["forces"] = data["forces"][0]
if "virials" in data:
self.results["virial"] = data["virials"][0].reshape(3, 3)

Expand Down
3 changes: 2 additions & 1 deletion dpdata/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def label(self, data: dict) -> dict:
labeled_data = lb_data.copy()
else:
labeled_data["energies"] += lb_data["energies"]
labeled_data["forces"] += lb_data["forces"]
if "forces" in labeled_data and "forces" in lb_data:
labeled_data["forces"] += lb_data["forces"]
if "virials" in labeled_data and "virials" in lb_data:
labeled_data["virials"] += lb_data["virials"]
return labeled_data
Expand Down
12 changes: 9 additions & 3 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
cell=data["cells"][ii],
)

results = {"energy": data["energies"][ii], "forces": data["forces"][ii]}
results = {"energy": data["energies"][ii]}
if "forces" in data:
results["forces"] = data["forces"][ii]
if "virials" in data:
# convert to GPa as this is ase convention
# v_pref = 1 * 1e4 / 1.602176621e6
Expand Down Expand Up @@ -296,7 +298,10 @@ def from_labeled_system(
dict_frames["energies"] = np.append(
dict_frames["energies"], tmp["energies"][0]
)
dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0])
if "forces" in tmp.keys() and "forces" in dict_frames.keys():
dict_frames["forces"] = np.append(
dict_frames["forces"], tmp["forces"][0]
)
if "virials" in tmp.keys() and "virials" in dict_frames.keys():
dict_frames["virials"] = np.append(
dict_frames["virials"], tmp["virials"][0]
Expand All @@ -305,7 +310,8 @@ def from_labeled_system(
## Correct the shape of numpy arrays
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
if "forces" in dict_frames.keys():
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
if "virials" in dict_frames.keys():
dict_frames["virials"] = dict_frames["virials"].reshape(-1, 3, 3)

Expand Down
4 changes: 3 additions & 1 deletion dpdata/plugins/pwmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ def from_labeled_system(
data["cells"],
data["coords"],
data["energies"],
data["forces"],
tmp_force,
tmp_virial,
) = dpdata.pwmat.movement.get_frames(
file_name, begin=begin, step=step, convergence_check=convergence_check
)
if tmp_force is not None:
data["forces"] = tmp_force
if tmp_virial is not None:
data["virials"] = tmp_virial
# scale virial to the unit of eV
Expand Down
4 changes: 3 additions & 1 deletion dpdata/plugins/vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def from_labeled_system(
data["cells"],
data["coords"],
data["energies"],
data["forces"],
tmp_force,
tmp_virial,
) = dpdata.vasp.outcar.get_frames(
file_name,
Expand All @@ -104,6 +104,8 @@ def from_labeled_system(
ml=ml,
convergence_check=convergence_check,
)
if tmp_force is not None:
data["forces"] = tmp_force
if tmp_virial is not None:
data["virials"] = tmp_virial
# scale virial to the unit of eV
Expand Down

0 comments on commit 4e5ab18

Please sign in to comment.