Skip to content

Commit

Permalink
add unit test workflow for sparc official files
Browse files Browse the repository at this point in the history
  • Loading branch information
alchem0x2A committed Sep 28, 2023
1 parent a6d97e0 commit f6bbdc0
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 15 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/installation_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
# mamba install -c conda-forge ase>=3.22 pymatgen flake8 pytest
mamba install -c alchem0x2a sparc
mamba install -c conda-forge sparc-x
# pip install pyfakefs
- name: Install package
run: |
Expand All @@ -45,6 +45,7 @@ jobs:
run: |
# python -m pytest -svv tests/ --cov=sparc --cov-report=json --cov-report=html
export SPARC_TESTS_DIR="./SPARC-master/tests"
export ASE_SPARC_COMMAND="mpirun -n 1 sparc"
coverage run -a -m pytest -svv tests/
coverage json --omit="tests/*.py"
coverage html --omit="tests/*.py"
Expand Down
9 changes: 6 additions & 3 deletions sparc/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,15 @@ def _convert_special_params(self, atoms=None):
converted_sparc_params["EXCHANGE_CORRELATION"] = "PBE0"
elif xc.lower() == "hf":
converted_sparc_params["EXCHANGE_CORRELATION"] = "HF"
# backward compatibility for HSE03. Note HSE06 is not supported yet
# backward compatibility for HSE03. Note HSE06 is not supported yet
elif (xc.lower() == "hse") or (xc.lower() == "hse03"):
converted_sparc_params["EXCHANGE_CORRELATION"] = "HSE"
# backward compatibility for VASP-style XCs
elif (xc.lower() == "vdwdf1") or (xc.lower() == "vdw-df") or (xc.lower() == "vdw-df1"):
elif (
(xc.lower() == "vdwdf1")
or (xc.lower() == "vdw-df")
or (xc.lower() == "vdw-df1")
):
converted_sparc_params["EXCHANGE_CORRELATION"] = "vdWDF1"
elif (xc.lower() == "vdwdf2") or (xc.lower() == "vdw-df2"):
converted_sparc_params["EXCHANGE_CORRELATION"] = "vdWDF2"
Expand All @@ -499,7 +503,6 @@ def _convert_special_params(self, atoms=None):
else:
# TODO: alternative exception
raise ValueError(f"xc keyword value {xc} is invalid!")


# h --> gpts
if "h" in params:
Expand Down
9 changes: 7 additions & 2 deletions sparc/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,9 @@ def write_sparc(filename, images, **kwargs):
return


@deprecated("Reading individual .ion is not recommended. Please use read_sparc instead.")
@deprecated(
"Reading individual .ion is not recommended. Please use read_sparc instead."
)
def read_ion(filename, **kwargs):
"""Parse an .ion file inside the SPARC bundle using a wrapper around SparcBundle
The reader works only when other files (.inpt) exist.
Expand All @@ -683,7 +685,10 @@ def read_ion(filename, **kwargs):
atoms = sb._read_ion_and_inpt()
return atoms

@deprecated("Writing individual .ion file is not recommended. Please use write_sparc instead.")

@deprecated(
"Writing individual .ion file is not recommended. Please use write_sparc instead."
)
def write_ion(filename, atoms, **kwargs):
"""Write .ion and .inpt files using the SparcBundle wrapper.
Expand Down
3 changes: 1 addition & 2 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_h_parameter():
def test_xc_parameter():
from sparc.calculator import SPARC
from ase.build import bulk

atoms = bulk("Al", cubic=True)
with tempfile.TemporaryDirectory() as tmpdir:
calc = SPARC(directory=tmpdir)
Expand Down Expand Up @@ -118,8 +119,6 @@ def test_xc_parameter():
filecontent = open(Path(tmpdir) / "SPARC.inpt", "r").read()
assert "EXCHANGE_CORRELATION: SCAN" in filecontent




def test_conflict_param():
from sparc.calculator import SPARC
Expand Down
106 changes: 99 additions & 7 deletions tests/test_read_all_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,24 @@
import tempfile
import shutil

skipped_names = ["Si2_domain_paral", "Si2_kpt_paral",
"SiH4", "SiH4_quick",
"H2O_sheet_quick", "H2O_sheet",
"CdS_bandstruct"]
skipped_names = [
"Si2_domain_paral",
"Si2_kpt_paral",
"SiH4",
"SiH4_quick",
"H2O_sheet_quick",
"H2O_sheet",
"CdS_bandstruct",
]

selected_quick_tests = [
"AlSi_orthogonal_quick_scf/standard",
"AlSi_primitive_quick_relax/standard",
"Cu_FCC/standard",
"BaTiO3_quick/standard",
"H2O_wire_quick/standard",
]


def test_read_all_tests():
"""Search all .inpt files within the tests dir."""
Expand All @@ -39,7 +53,7 @@ def test_read_all_tests():
tmpdir = Path(tmpdir)
for ext in [".ion", ".inpt"]:
shutil.copy(workdir / f"{parent_name}{ext}", tmpdir)
for ext in [".refout", "refstatic", "refgeopt", "refaimd"]:
for ext in [".refout", ".refstatic", ".refgeopt", ".refaimd"]:
origin_file = workdir / f"{parent_name}{ext}"
new_ext = ext.replace("ref", "")
new_file = tmpdir / f"{parent_name}{new_ext}"
Expand All @@ -55,6 +69,7 @@ def test_read_all_tests():
if failed_counts > 0:
raise RuntimeError("More than 1 test in output read test failed")


def test_write_all_inputs():
"""Search all .inpt files within the tests dir."""

Expand Down Expand Up @@ -83,7 +98,11 @@ def test_write_all_inputs():
origin_inpt_dict["inpt"]["params"].pop(key, None)
# Re-write the ion and inpt files
try:
write_ion(tmpdir / "test.ion", origin_atoms, **origin_inpt_dict["inpt"]["params"])
write_ion(
tmpdir / "test.ion",
origin_atoms,
**origin_inpt_dict["inpt"]["params"],
)
new_atoms = read_ion(tmpdir / "test.ion")
new_inpt_dict = _read_inpt(tmpdir / "test.inpt")
assert np.all(origin_atoms.pbc == new_atoms.pbc)
Expand All @@ -99,10 +118,83 @@ def test_write_all_inputs():
# Vector types can be list compared
elif isinstance(origin_val, (list, np.ndarray)):
assert np.all(origin_val == new_val)

except Exception as e:
print("Failed: ", parent_name, workdir)
print("\t: Error is ", e)
failed_counts += 1
if failed_counts > 0:
raise RuntimeError("More than 1 test in inpt write test failed")


def test_quick_examples():
"""Perform quick tests on selected ref examples"""
from sparc.calculator import SPARC
from ase.build import molecule
from pathlib import Path
from sparc.io import read_sparc, read_ion, write_ion
from sparc.sparc_parsers.inpt import _read_inpt

dummy_calc = SPARC()
try:
cmd = dummy_calc._make_command()
except EnvironmentError:
print("Skip test since no sparc command found")
pytest.skip()

# Skipped tests are to avoid unwanted keywords
tests_dir = os.environ.get("SPARC_TESTS_DIR", "")
failed_counts = 0
print(f"Current test dir is {tests_dir}")
if len(tests_dir) == 0:
pytest.skip(allow_module_level=True)

tests_dir = Path(tests_dir)
for test_name in selected_quick_tests:
bundle = tests_dir / test_name
parent_name = test_name.split("/")[0]

atoms = read_sparc(bundle, index=0)
inpt_file = bundle / f"{parent_name}.inpt"
params = _read_inpt(inpt_file)["inpt"]["params"]
for key in ["CELL", "LATVEC_SCALE", "LATVEC", "BC"]:
params.pop(key, None)

with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
os.makedirs(tmpdir, exist_ok=True)
for ext in [".ion", ".inpt"]:
shutil.copy(bundle / f"{parent_name}{ext}", tmpdir)
for ext in [".refout", ".refstatic", ".refgeopt", ".refaimd"]:
origin_file = bundle / f"{parent_name}{ext}"
new_ext = ext.replace("ref", "")
new_file = tmpdir / f"{parent_name}{new_ext}"
if origin_file.is_file():
shutil.copy(origin_file, new_file)
old_atoms = read_sparc(tmpdir, index=-1, include_all_files=True)

with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
calc = SPARC(directory=tmpdir, label=parent_name, **params)
calc.calculate(atoms=atoms)
new_atoms = read_sparc(tmpdir, index=-1, include_all_files=True)

print("old atoms", old_atoms)
print("new atoms", new_atoms)
assert len(old_atoms) == len(new_atoms)
assert np.all(old_atoms.pbc) == np.all(new_atoms.pbc)
if "energy" in old_atoms.calc.results:
assert np.isclose(
old_atoms.get_potential_energy(),
new_atoms.get_potential_energy(),
rtol=1.0e-6,
atol=1.0e-3,
)
if "forces" in old_atoms.calc.results:
assert np.isclose(
old_atoms.get_forces(), new_atoms.get_forces(), rtol=1.0e-3, atol=1.0e-2
).all()
if "stress" in old_atoms.calc.results:
assert np.isclose(
old_atoms.get_stress(), new_atoms.get_stress(), rtol=1.0e-3, atol=1.0e-2
).all()

0 comments on commit f6bbdc0

Please sign in to comment.