Skip to content

Commit

Permalink
add Allegro, OCP and MatterSim to ase_calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Dec 7, 2024
1 parent 78ba4f0 commit e5945b4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ forcefields = [
"matgl>=1.1.3",
"torchdata<=0.7.1",
# quippy-ase support for py3.12 tracked in https://github.com/libAtoms/QUIP/issues/645
"mattersim>=1.0.0rc10.dev1",
"quippy-ase>=0.9.14; python_version < '3.12'",
"sevenn>=0.9.3",
"torchdata<=0.7.1", # TODO: remove when issue fixed
Expand Down Expand Up @@ -123,6 +124,7 @@ strict-forcefields = [
"chgnet==0.4.0",
"mace-torch>=0.3.6",
"matgl==1.1.3",
"mattersim==1.0.0rc10.dev1",
"quippy-ase==0.9.14; python_version < '3.12'",
"sevenn==0.10.1",
"torch==2.5.1",
Expand Down
5 changes: 4 additions & 1 deletion src/atomate2/forcefields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
NEP = "NEP"
Nequip = "Nequip"
SevenNet = "SevenNet"
Allegro = "Allegro"
OCP = "OCP" # for loading model checkpoint with fairchem.core.OCPCalculator
MatterSim = "MatterSim"

@classmethod
def _missing_(cls, value: Any) -> Any:
"""Allow input of str(MLFF) as valid enum."""
"""Allow feeding output of str(MLFF.<model>) back into MLFF(...)."""
if isinstance(value, str):
value = value.split("MLFF.")[-1]
for member in cls:
Expand Down
17 changes: 17 additions & 0 deletions src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,28 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N

calculator = NequIPCalculator.from_deployed_model(**kwargs)

elif calculator_name == MLFF.Allegro:
from allegro.ase import AllegroCalculator

calculator = AllegroCalculator.from_deployed_model(**kwargs)

elif calculator_name == MLFF.SevenNet:
from sevenn.sevennet_calculator import SevenNetCalculator

calculator = SevenNetCalculator(**{"model": "7net-0"} | kwargs)

elif calculator_name == MLFF.OCP:
# this package is not available on PyPI, needs to be installed from source
# see https://github.com/FAIR-Chem/fairchem?tab=readme-ov-file#installation
from fairchem.core import OCPCalculator

calculator = OCPCalculator(**kwargs)

elif calculator_name == MLFF.MatterSim:
from mattersim.forcefield import MatterSimCalculator

calculator = MatterSimCalculator(**kwargs)

elif isinstance(calculator_meta, dict):
calc_cls = MontyDecoder().process_decoded(calculator_meta)
calculator = calc_cls(**kwargs)
Expand Down

0 comments on commit e5945b4

Please sign in to comment.