Skip to content

Commit

Permalink
add MatterSim and SevenNet to test_ext_load
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Dec 7, 2024
1 parent e5945b4 commit e07e855
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions tests/forcefields/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,27 @@
from atomate2.forcefields.utils import ase_calculator


@pytest.mark.parametrize(("force_field"), [mlff.value for mlff in MLFF])
def test_mlff(force_field: str):
mlff = MLFF(force_field)
@pytest.mark.parametrize("mlff", MLFF)
def test_mlff(mlff: MLFF):
assert mlff == MLFF(str(mlff)) == MLFF(str(mlff).split(".")[-1])


@pytest.mark.parametrize(("force_field"), ["CHGNet", "MACE"])
def test_ext_load(force_field: str):
@pytest.mark.parametrize("mlff", ["CHGNet", "MACE", MLFF.MatterSim, MLFF.SevenNet])
def test_ext_load(mlff: str):
decode_dict = {
"CHGNet": {"@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator"},
"MACE": {"@module": "mace.calculators", "@callable": "mace_mp"},
}[force_field]
MLFF.MatterSim: {
"@module": "mattersim.forcefield",
"@callable": "MatterSimCalculator",
},
MLFF.SevenNet: {
"@module": "sevenn.sevennet_calculator",
"@callable": "SevenNetCalculator",
},
}[mlff]
calc_from_decode = ase_calculator(decode_dict)
calc_from_preset = ase_calculator(str(MLFF(force_field)))
calc_from_preset = ase_calculator(str(MLFF(mlff)))
assert type(calc_from_decode) is type(calc_from_preset)
assert calc_from_decode.name == calc_from_preset.name
assert calc_from_decode.parameters == calc_from_preset.parameters == {}
Expand Down

0 comments on commit e07e855

Please sign in to comment.