Skip to content

Commit

Permalink
add tests for inpt rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
alchem0x2A committed Sep 28, 2023
1 parent c568b42 commit a6d97e0
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -771,3 +771,5 @@ ex0-*/
al-eos-sparc.traj
*/ex1-sparc/
examples/ex1-ase/
/SPARC-master/
/master.zip
2 changes: 1 addition & 1 deletion sparc/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def write_ion(filename, atoms, **kwargs):
"""
label = Path(filename).with_suffix("").name
parent_dir = Path(filename).parent
sb = SparcBundle(directory=parent_dir, label=label)
sb = SparcBundle(directory=parent_dir, label=label, mode="w")
sb._write_ion_and_inpt(atoms, **kwargs)
return atoms

Expand Down
61 changes: 61 additions & 0 deletions tests/test_read_all_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
The ref
"""
import pytest
import numpy as np
from pathlib import Path
import os
import tempfile
import shutil

skipped_names = ["Si2_domain_paral", "Si2_kpt_paral",
"SiH4", "SiH4_quick",
"H2O_sheet_quick", "H2O_sheet",
"CdS_bandstruct"]

def test_read_all_tests():
"""Search all .inpt files within the tests dir."""
Expand All @@ -24,6 +29,7 @@ def test_read_all_tests():
pytest.skip(allow_module_level=True)

tests_dir = Path(tests_dir)
failed_counts = 0
for inpt_file in tests_dir.glob("**/*.inpt"):
workdir = inpt_file.parent
parent_name = inpt_file.parents[1].name
Expand All @@ -45,3 +51,58 @@ def test_read_all_tests():
except Exception as e:
print("Failed: ", parent_name, workdir)
print("\t: Error is ", e)
failed_counts += 1
if failed_counts > 0:
raise RuntimeError("More than 1 test in output read test failed")

def test_write_all_inputs():
"""Search all .inpt files within the tests dir."""

from sparc.io import read_sparc, read_ion, write_ion
from sparc.sparc_parsers.inpt import _read_inpt

# Skipped tests are to avoid unwanted keywords
tests_dir = os.environ.get("SPARC_TESTS_DIR", "")
failed_counts = 0
print(f"Current test dir is {tests_dir}")
if len(tests_dir) == 0:
pytest.skip(allow_module_level=True)

tests_dir = Path(tests_dir)
for inpt_file in tests_dir.glob("**/*.inpt"):
workdir = inpt_file.parent
parent_name = inpt_file.parents[1].name
ion_file = inpt_file.with_suffix(".ion")
if parent_name in skipped_names:
continue
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
origin_atoms = read_ion(ion_file)
origin_inpt_dict = _read_inpt(inpt_file)
for key in ["CELL", "LATVEC_SCALE", "LATVEC", "BC"]:
origin_inpt_dict["inpt"]["params"].pop(key, None)
# Re-write the ion and inpt files
try:
write_ion(tmpdir / "test.ion", origin_atoms, **origin_inpt_dict["inpt"]["params"])
new_atoms = read_ion(tmpdir / "test.ion")
new_inpt_dict = _read_inpt(tmpdir / "test.inpt")
assert np.all(origin_atoms.pbc == new_atoms.pbc)
for key in origin_inpt_dict["inpt"]["params"].keys():
origin_val = origin_inpt_dict["inpt"]["params"][key]
new_val = new_inpt_dict["inpt"]["params"][key]
if isinstance(origin_val, (int, bool)):
assert origin_val == new_val
elif isinstance(origin_val, float):
assert np.isclose(origin_val, new_val, 1e-6)
elif isinstance(origin_val, str):
assert origin_val == new_val
# Vector types can be list compared
elif isinstance(origin_val, (list, np.ndarray)):
assert np.all(origin_val == new_val)

except Exception as e:
print("Failed: ", parent_name, workdir)
print("\t: Error is ", e)
failed_counts += 1
if failed_counts > 0:
raise RuntimeError("More than 1 test in inpt write test failed")

0 comments on commit a6d97e0

Please sign in to comment.