From e07e855b2a15586c7e92737cacf5a088ac80f439 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 7 Dec 2024 14:13:40 -0500 Subject: [PATCH] add MatterSim and SevenNet to test_ext_load --- tests/forcefields/test_utils.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/forcefields/test_utils.py b/tests/forcefields/test_utils.py index b43eb6ff0..f2fe64eaf 100644 --- a/tests/forcefields/test_utils.py +++ b/tests/forcefields/test_utils.py @@ -4,20 +4,27 @@ from atomate2.forcefields.utils import ase_calculator -@pytest.mark.parametrize(("force_field"), [mlff.value for mlff in MLFF]) -def test_mlff(force_field: str): - mlff = MLFF(force_field) +@pytest.mark.parametrize("mlff", MLFF) +def test_mlff(mlff: MLFF): assert mlff == MLFF(str(mlff)) == MLFF(str(mlff).split(".")[-1]) -@pytest.mark.parametrize(("force_field"), ["CHGNet", "MACE"]) -def test_ext_load(force_field: str): +@pytest.mark.parametrize("mlff", ["CHGNet", "MACE", MLFF.MatterSim, MLFF.SevenNet]) +def test_ext_load(mlff: str): decode_dict = { "CHGNet": {"@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator"}, "MACE": {"@module": "mace.calculators", "@callable": "mace_mp"}, - }[force_field] + MLFF.MatterSim: { + "@module": "mattersim.forcefield", + "@callable": "MatterSimCalculator", + }, + MLFF.SevenNet: { + "@module": "sevenn.sevennet_calculator", + "@callable": "SevenNetCalculator", + }, + }[mlff] calc_from_decode = ase_calculator(decode_dict) - calc_from_preset = ase_calculator(str(MLFF(force_field))) + calc_from_preset = ase_calculator(str(MLFF(mlff))) assert type(calc_from_decode) is type(calc_from_preset) assert calc_from_decode.name == calc_from_preset.name assert calc_from_decode.parameters == calc_from_preset.parameters == {}