Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MatterSim, Allegro and OCP models (fairchem-core) to ase_calculators #1079

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ forcefields = [
"matgl>=1.1.3",
"torchdata<=0.7.1",
# quippy-ase support for py3.12 tracked in https://github.com/libAtoms/QUIP/issues/645
"mattersim>=1.0.1rc1",
"quippy-ase>=0.9.14; python_version < '3.12'",
"sevenn>=0.9.3",
"torchdata<=0.7.1", # TODO: remove when issue fixed
Expand Down Expand Up @@ -123,6 +124,7 @@ strict-forcefields = [
"chgnet==0.4.0",
"mace-torch>=0.3.6",
"matgl==1.1.3",
"mattersim==1.0.1rc1",
"quippy-ase==0.9.14; python_version < '3.12'",
"sevenn==0.10.1",
"torch==2.5.1",
Expand Down
2 changes: 1 addition & 1 deletion src/atomate2/common/jobs/qha.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Jobs for running qha calculations."""
"""Jobs for running QHA calculations."""

from __future__ import annotations

Expand Down
52 changes: 20 additions & 32 deletions src/atomate2/common/schemas/qha.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Schemas for qha documents."""
"""Schemas for QHA documents."""

import logging
from typing import Optional, Union
Expand All @@ -15,7 +15,7 @@


class PhononQHADoc(StructureMetadata, extra="allow"): # type: ignore[call-arg]
"""Collection of all data produced by the qha workflow."""
"""Collection of all data produced by the QHA workflow."""

structure: Optional[Structure] = Field(
None, description="Structure of Materials Project."
Expand Down Expand Up @@ -64,7 +64,7 @@ class PhononQHADoc(StructureMetadata, extra="allow"): # type: ignore[call-arg]
description="Gruneisen parameters at temperatures.Shape: (temperatures, )",
)
pressure: Optional[float] = Field(
None, description="Pressure in GPA at which Gibb's energy was computed"
None, description="Pressure in GPA at which Gibbs energy was computed"
)
t_max: Optional[float] = Field(
None,
Expand All @@ -75,17 +75,17 @@ class PhononQHADoc(StructureMetadata, extra="allow"): # type: ignore[call-arg]
free_energies: Optional[list[list[float]]] = Field(
None,
description="List of free energies in J/mol for per formula unit. "
"Shape: (temperatuers, volumes)",
"Shape: (temperatures, volumes)",
)
heat_capacities: Optional[list[list[float]]] = Field(
None,
description="List of heat capacities in J/K/mol per formula unit. "
"Shape: (temperatuers, volumes)",
"Shape: (temperatures, volumes)",
)
entropies: Optional[list[list[float]]] = Field(
None,
description="List of entropies in J/(K*mol) per formula unit. "
"Shape: (temperatuers, volumes) ",
"Shape: (temperatures, volumes) ",
)
formula_units: Optional[int] = Field(None, description="Formula units")

Expand All @@ -108,7 +108,7 @@ def from_phonon_runs(
eos_type: str = "vinet",
**kwargs,
) -> Self:
"""Generate qha results.
"""Generate QHA results.

Parameters
----------
Expand Down Expand Up @@ -151,35 +151,28 @@ def from_phonon_runs(

# create some plots here
# add kwargs to change the names and file types
fig_ext = kwargs.get("plot_type", "pdf")
qha.plot_helmholtz_volume().savefig(
f"{kwargs.get('helmholtz_volume_filename', 'helmholtz_volume')}"
f".{kwargs.get('plot_type', 'pdf')}"
f"{kwargs.get('helmholtz_volume_filename', 'helmholtz_volume')}.{fig_ext}"
)
qha.plot_volume_temperature().savefig(
f"{kwargs.get('volume_temperature_plot', 'volume_temperature')}"
f".{kwargs.get('plot_type', 'pdf')}"
f"{kwargs.get('volume_temperature_plot', 'volume_temperature')}.{fig_ext}"
)
qha.plot_thermal_expansion().savefig(
f"{kwargs.get('thermal_expansion_plot', 'thermal_expansion')}"
f".{kwargs.get('plot_type', 'pdf')}"
f"{kwargs.get('thermal_expansion_plot', 'thermal_expansion')}.{fig_ext}"
)
qha.plot_gibbs_temperature().savefig(
f"{kwargs.get('gibbs_temperature_plot', 'gibbs_temperature')}"
f".{kwargs.get('plot_type', 'pdf')}"
f"{kwargs.get('gibbs_temperature_plot', 'gibbs_temperature')}.{fig_ext}"
)
qha.plot_bulk_modulus_temperature().savefig(
f"{kwargs.get('bulk_modulus_plot', 'bulk_modulus_temperature')}"
f".{kwargs.get('plot_type', 'pdf')}"
f"{kwargs.get('bulk_modulus_plot', 'bulk_modulus_temperature')}.{fig_ext}"
)
qha.plot_heat_capacity_P_numerical().savefig(
f"{kwargs.get('heat_capacity_plot', 'heat_capacity_P_numerical')}"
f".{kwargs.get('plot_type', 'pdf')}"
f"{kwargs.get('heat_capacity_plot', 'heat_capacity_P_numerical')}.{fig_ext}"
)
# qha.plot_heat_capacity_P_polyfit().savefig("heat_capacity_P_polyfit.eps")
qha.plot_gruneisen_temperature().savefig(
f"{kwargs.get('gruneisen_temperature_plot', 'gruneisen_temperature')}"
f".{kwargs.get('plot_type', 'pdf')}"
)
ge_temp_plot = kwargs.get("gruneisen_temperature_plot", "gruneisen_temperature")
qha.plot_gruneisen_temperature().savefig(f"{ge_temp_plot}.{fig_ext}")

qha.write_helmholtz_volume(
filename=kwargs.get("helmholtz_volume_datafile", "helmholtz_volume.dat")
Expand All @@ -199,21 +192,16 @@ def from_phonon_runs(
qha.write_gibbs_temperature(
filename=kwargs.get("gibbs_temperature_datafile", "gibbs_temperature.dat")
)
qha.write_gruneisen_temperature(
filename=kwargs.get(
"gruneisen_temperature_datafile", "gruneisen_temperature.dat"
)
ge_temp_file = kwargs.get(
"gruneisen_temperature_datafile", "gruneisen_temperature.dat"
)
qha.write_gruneisen_temperature(filename=ge_temp_file)
qha.write_heat_capacity_P_numerical(
filename=kwargs.get(
"heat_capacity_datafile", "heat_capacity_P_numerical.dat"
)
)
qha.write_gruneisen_temperature(
filename=kwargs.get(
"gruneisen_temperature_datafile", "gruneisen_temperature.dat"
)
)
qha.write_gruneisen_temperature(filename=ge_temp_file)

# write files as well - might be easier for plotting

Expand Down
5 changes: 4 additions & 1 deletion src/atomate2/forcefields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
NEP = "NEP"
Nequip = "Nequip"
SevenNet = "SevenNet"
Allegro = "Allegro"
OCP = "OCP" # for loading model checkpoint with fairchem.core.OCPCalculator
MatterSim = "MatterSim"

@classmethod
def _missing_(cls, value: Any) -> Any:
"""Allow input of str(MLFF) as valid enum."""
"""Allow feeding output of str(MLFF.<model>) back into MLFF(...)."""
if isinstance(value, str):
value = value.split("MLFF.")[-1]
for member in cls:
Expand Down
17 changes: 17 additions & 0 deletions src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,28 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N

calculator = NequIPCalculator.from_deployed_model(**kwargs)

elif calculator_name == MLFF.Allegro:
from allegro.ase import AllegroCalculator

calculator = AllegroCalculator.from_deployed_model(**kwargs)

elif calculator_name == MLFF.SevenNet:
from sevenn.sevennet_calculator import SevenNetCalculator

calculator = SevenNetCalculator(**{"model": "7net-0"} | kwargs)

elif calculator_name == MLFF.OCP:
# this package is not available on PyPI, needs to be installed from source
# see https://github.com/FAIR-Chem/fairchem?tab=readme-ov-file#installation
from fairchem.core import OCPCalculator

calculator = OCPCalculator(**kwargs)

elif calculator_name == MLFF.MatterSim:
from mattersim.forcefield import MatterSimCalculator

calculator = MatterSimCalculator(**kwargs)

elif isinstance(calculator_meta, dict):
calc_cls = MontyDecoder().process_decoded(calculator_meta)
calculator = calc_cls(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/atomate2/vasp/flows/qha.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class QhaMaker(CommonQhaMaker):
First relax a structure using relax_maker.
Then perform a series of deformations on the relaxed structure, and
then compute harmonic phonons for each deformed structure.
Finally, compute Gibb's free energy.
Finally, compute Gibbs free energy.

Parameters
----------
Expand Down
16 changes: 8 additions & 8 deletions src/atomate2/vasp/sets/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class EosSetGenerator(VaspInputGenerator):
force_gamma: bool = True
auto_ismear: bool = False
auto_kspacing: bool = False
inherit_incar: bool = False
inherit_incar: bool | list[str] = False

@property
def incar_updates(self) -> dict:
Expand Down Expand Up @@ -60,7 +60,7 @@ class MPLegacyEosRelaxSetGenerator(VaspInputGenerator):
config_dict: dict = field(default_factory=lambda: MPRelaxSet.CONFIG)
auto_ismear: bool = False
auto_kspacing: bool = False
inherit_incar: bool = False
inherit_incar: bool | list[str] = False

@property
def incar_updates(self) -> dict:
Expand Down Expand Up @@ -103,7 +103,7 @@ class MPLegacyEosStaticSetGenerator(EosSetGenerator):
config_dict: dict = field(default_factory=lambda: MPRelaxSet.CONFIG)
auto_ismear: bool = False
auto_kspacing: bool = False
inherit_incar: bool = False
inherit_incar: bool | list[str] = False

@property
def incar_updates(self) -> dict:
Expand Down Expand Up @@ -138,7 +138,7 @@ class MPGGAEosRelaxSetGenerator(VaspInputGenerator):
config_dict: dict = field(default_factory=lambda: MPScanRelaxSet.CONFIG)
auto_ismear: bool = False
auto_kspacing: bool = False
inherit_incar: bool = False
inherit_incar: bool | list[str] = False

@property
def incar_updates(self) -> dict:
Expand Down Expand Up @@ -173,7 +173,7 @@ class MPGGAEosStaticSetGenerator(EosSetGenerator):
config_dict: dict = field(default_factory=lambda: MPScanRelaxSet.CONFIG)
auto_ismear: bool = False
auto_kspacing: bool = False
inherit_incar: bool = False
inherit_incar: bool | list[str] = False

@property
def incar_updates(self) -> dict:
Expand Down Expand Up @@ -207,7 +207,7 @@ class MPMetaGGAEosStaticSetGenerator(VaspInputGenerator):
config_dict: dict = field(default_factory=lambda: MPScanRelaxSet.CONFIG)
auto_ismear: bool = False
auto_kspacing: bool = False
inherit_incar: bool = False
inherit_incar: bool | list[str] = False

@property
def incar_updates(self) -> dict:
Expand Down Expand Up @@ -250,7 +250,7 @@ class MPMetaGGAEosRelaxSetGenerator(VaspInputGenerator):
bandgap_tol: float = 1e-4
auto_ismear: bool = False
auto_kspacing: bool = False
inherit_incar: bool = False
inherit_incar: bool | list[str] = False

@property
def incar_updates(self) -> dict:
Expand Down Expand Up @@ -295,7 +295,7 @@ class MPMetaGGAEosPreRelaxSetGenerator(VaspInputGenerator):
bandgap_tol: float = 1e-4
auto_ismear: bool = False
auto_kspacing: bool = False
inherit_incar: bool = False
inherit_incar: bool | list[str] = False

@property
def incar_updates(self) -> dict:
Expand Down
35 changes: 28 additions & 7 deletions tests/forcefields/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,48 @@
import numpy as np
import pytest

from atomate2.forcefields import MLFF
from atomate2.forcefields.utils import ase_calculator


@pytest.mark.parametrize(("force_field"), [mlff.value for mlff in MLFF])
def test_mlff(force_field: str):
mlff = MLFF(force_field)
@pytest.mark.parametrize("mlff", MLFF)
def test_mlff(mlff: MLFF):
assert mlff == MLFF(str(mlff)) == MLFF(str(mlff).split(".")[-1])


@pytest.mark.parametrize(("force_field"), ["CHGNet", "MACE"])
def test_ext_load(force_field: str):
@pytest.mark.parametrize("mlff", ["CHGNet", "MACE", MLFF.MatterSim, MLFF.SevenNet])
def test_ext_load(mlff: str):
from ase.build import bulk

decode_dict = {
"CHGNet": {"@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator"},
"MACE": {"@module": "mace.calculators", "@callable": "mace_mp"},
}[force_field]
MLFF.MatterSim: {
"@module": "mattersim.forcefield",
"@callable": "MatterSimCalculator",
},
MLFF.SevenNet: {
"@module": "sevenn.sevennet_calculator",
"@callable": "SevenNetCalculator",
},
}[mlff]
calc_from_decode = ase_calculator(decode_dict)
calc_from_preset = ase_calculator(str(MLFF(force_field)))
calc_from_preset = ase_calculator(str(MLFF(mlff)))
assert type(calc_from_decode) is type(calc_from_preset)
assert calc_from_decode.name == calc_from_preset.name
assert calc_from_decode.parameters == calc_from_preset.parameters == {}

atoms = bulk("Si", "diamond", a=5.43)

atoms.calc = calc_from_preset
energy = atoms.get_potential_energy()
forces = atoms.get_forces()

assert isinstance(energy, float | np.floating)
assert energy < 0
assert forces.shape == (2, 3)
assert abs(forces.sum()) < 1e-6, f"unexpectedly large net {forces=}"


def test_raises_error():
with pytest.raises(ValueError, match="Could not create"):
Expand Down
Loading