Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing calculator_kwargs and remove outdated model/model_kwargs in ForceFieldRelaxMaker doc strings #830

Merged
merged 4 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions src/atomate2/forcefields/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class ForceFieldStaticMaker(ForceFieldRelaxMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -209,6 +211,8 @@ class CHGNetRelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -236,6 +240,8 @@ class CHGNetStaticMaker(ForceFieldStaticMaker):
----------
name : str
The job name.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -272,6 +278,8 @@ class M3GNetRelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -314,6 +322,8 @@ class NequipRelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand All @@ -340,6 +350,8 @@ class NequipStaticMaker(ForceFieldStaticMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand All @@ -360,6 +372,8 @@ class M3GNetStaticMaker(ForceFieldStaticMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -396,16 +410,14 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator. E.g. the "model"
key configures which checkpoint to load with mace.calculators.MACECalculator().
Can be a URL starting with https://. If not set, loads the universal MACE-MP
trained for Matbench Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
model: str | Path | None
Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL
starting with https://. If None, loads the universal MACE trained for Matbench
Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
model_kwargs: dict[str, Any]
Further keywords (e.g. device, default_dtype, model) for
:obj:`mace.calculators.MACECalculator()'`.
"""

name: str = f"{MLFF.MACE} relax"
Expand All @@ -430,16 +442,14 @@ class MACEStaticMaker(ForceFieldStaticMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator. E.g. the "model"
key configures which checkpoint to load with mace.calculators.MACECalculator().
Can be a URL starting with https://. If not set, loads the universal MACE-MP
trained for Matbench Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
model: str | Path | None
Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL
starting with https://. If None, loads the universal MACE trained for Matbench
Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
model_kwargs: dict[str, Any]
Further keywords (e.g. device, default_dtype, model) for
:obj:`mace.calculators.MACECalculator()'`.
"""

name: str = f"{MLFF.MACE} static"
Expand Down Expand Up @@ -471,6 +481,8 @@ class GAPRelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -503,6 +515,8 @@ class GAPStaticMaker(ForceFieldStaticMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down
4 changes: 1 addition & 3 deletions src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,7 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
"""
calculator = None

if isinstance(calculator_meta, str) and calculator_meta in [
f"{name}" for name in MLFF
]:
if isinstance(calculator_meta, str) and calculator_meta in map(str, MLFF):
calculator_name = MLFF(calculator_meta.split("MLFF.")[-1])

if calculator_name == MLFF.CHGNet:
Expand Down
10 changes: 5 additions & 5 deletions src/atomate2/vasp/flows/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def make(
"""
md_job = None
md_jobs = []
for idx, maker in enumerate(self.md_makers, 1):
if md_job is None:
md_structure = structure
md_prev_dir = prev_dir
else:
md_structure = structure
md_prev_dir = prev_dir

for idx, maker in enumerate(self.md_makers, start=1):
if md_job is not None:
md_structure = md_job.output.structure
md_prev_dir = md_job.output.dir_name
md_job = maker.make(md_structure, prev_dir=md_prev_dir)
Expand Down
10 changes: 5 additions & 5 deletions src/atomate2/vasp/jobs/lobster.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def get_basis_infos(
address_basis_file_min=address_min_basis,
)

nband_list = []
n_band_list: list[int] = []
for dict_for_basis in list_basis_dict:
basis = [f"{key} {value}" for key, value in dict_for_basis.items()]
lobsterin = Lobsterin(settingsdict={"basisfunctions": basis})
nbands = lobsterin._get_nbands(structure=structure)
nband_list.append(nbands)
n_bands = lobsterin._get_nbands(structure=structure)
n_band_list.append(n_bands)

return {"nbands": max(nband_list), "basis_dict": list_basis_dict}
return {"nbands": max(n_band_list), "basis_dict": list_basis_dict}


@job
Expand All @@ -143,7 +143,7 @@ def update_user_incar_settings_maker(
Parameters
----------
vasp_maker : .BaseVaspMaker
A maker for the static run with all parammeters
A maker for the static run with all parameters
relevant for Lobster.
nbands : int
integer indicating the correct number of bands
Expand Down
16 changes: 9 additions & 7 deletions tests/forcefields/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,17 @@ def test_relaxer(si_structure, test_dir, tmp_dir, optimizer, traj_file):
assert os.path.isfile(traj_file)


def test_ext_load():
force_field_to_callable = {
@pytest.mark.parametrize(("force_field"), ["CHGNet", "MACE"])
def test_ext_load(force_field: str):
decode_dict = {
"CHGNet": {"@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator"},
"MACE": {"@module": "mace.calculators", "@callable": "mace_mp"},
}
for force_field in ("CHGNet", "MACE"):
calc_from_decode = ase_calculator(force_field_to_callable[force_field])
calc_from_preset = ase_calculator(f"{MLFF(force_field)}")
assert isinstance(calc_from_decode, type(calc_from_preset))
}[force_field]
calc_from_decode = ase_calculator(decode_dict)
calc_from_preset = ase_calculator(str(MLFF(force_field)))
assert type(calc_from_decode) == type(calc_from_preset)
assert calc_from_decode.name == calc_from_preset.name
assert calc_from_decode.parameters == calc_from_preset.parameters == {}


@pytest.mark.parametrize(("fix_symmetry"), [True, False])
Expand Down
Loading