diff --git a/.github/workflows/test-scripts.yml b/.github/workflows/test-scripts.yml
index 8b2c9c32..40f91255 100644
--- a/.github/workflows/test-scripts.yml
+++ b/.github/workflows/test-scripts.yml
@@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
script:
- - scripts/model_figs/make_metrics_tables.py
+ - scripts/model_figs/metrics_tables.py
- scripts/model_figs/rolling_mae_vs_hull_dist_models.py
steps:
- name: Check out repository
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b43212bd..e1864521 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -56,7 +56,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
- repo: https://github.com/pre-commit/mirrors-eslint
- rev: v9.0.0-beta.0
+ rev: v9.0.0-beta.1
hooks:
- id: eslint
types: [file]
@@ -78,3 +78,9 @@ repos:
files: ^models/(.+)/\1.*\.yml$
args: [--schemafile, tests/model-schema.yml]
- id: check-github-actions
+
+ - repo: https://github.com/RobertCraigie/pyright-python
+ rev: v1.1.351
+ hooks:
+ - id: pyright
+ args: [--level, error]
diff --git a/data/mp/build_phase_diagram.py b/data/mp/build_phase_diagram.py
index 4b63265f..fdc775f8 100644
--- a/data/mp/build_phase_diagram.py
+++ b/data/mp/build_phase_diagram.py
@@ -17,9 +17,10 @@
from pymatgen.ext.matproj import MPRester
from tqdm import tqdm
-from matbench_discovery import MP_DIR, ROOT, Key, today
+from matbench_discovery import MP_DIR, ROOT, today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries
+from matbench_discovery.enums import Key
module_dir = os.path.dirname(__file__)
diff --git a/data/mp/eda_mp_trj.py b/data/mp/eda_mp_trj.py
index 309098a6..52dcea39 100644
--- a/data/mp/eda_mp_trj.py
+++ b/data/mp/eda_mp_trj.py
@@ -21,8 +21,9 @@
from pymatviz.utils import si_fmt
from tqdm import tqdm
-from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS, Key
+from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS
from matbench_discovery.data import DATA_FILES, df_wbm
+from matbench_discovery.enums import Key
__author__ = "Janosh Riebesell"
__date__ = "2023-11-22"
@@ -108,10 +109,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
# %% plot per-element magmom histograms
ptable_magmom_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-magmoms.json.bz2"
+srs_mp_trj_elem_magmoms = locals().get("srs_mp_trj_elem_magmoms")
if os.path.isfile(ptable_magmom_hist_path):
srs_mp_trj_elem_magmoms = pd.read_json(ptable_magmom_hist_path, typ="series")
-elif "srs_mp_trj_elem_magmoms" not in locals():
+if srs_mp_trj_elem_magmoms is None:
# project magmoms onto symbols in dict
df_mp_trj_elem_magmom = pd.DataFrame(
[
@@ -151,10 +153,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
# %% plot per-element force histograms
ptable_force_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-forces.json.bz2"
+srs_mp_trj_elem_forces = locals().get("srs_mp_trj_elem_forces")
if os.path.isfile(ptable_force_hist_path):
srs_mp_trj_elem_forces = pd.read_json(ptable_force_hist_path, typ="series")
-elif "srs_mp_trj_elem_forces" not in locals():
+if srs_mp_trj_elem_forces is None:
df_mp_trj_elem_forces = pd.DataFrame(
[
dict(zip(elems, np.abs(forces).mean(axis=1)))
@@ -193,10 +196,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
# %% plot histogram of number of sites per element
ptable_n_sites_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-n-sites.json.bz2"
+srs_mp_trj_elem_n_sites = locals().get("srs_mp_trj_elem_n_sites")
if os.path.isfile(ptable_n_sites_hist_path):
srs_mp_trj_elem_n_sites = pd.read_json(ptable_n_sites_hist_path, typ="series")
-elif "mp_trj_elem_n_sites" not in locals():
+elif srs_mp_trj_elem_n_sites is None:
# construct a series of lists of site numbers per element (i.e. how often each
# element appears in a structure with n sites)
# create all df cols as int dtype
@@ -320,8 +324,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
pdf_kwds = dict(width=500, height=300)
x_col, y_col = "Eform (eV/atom)", count_col
+df_e_form = locals().get("df_e_form")
-if "df_e_form" not in locals(): # only compute once for speed
+if df_e_form is None: # only compute once for speed
e_form_hist = np.histogram(df_mp_trj[Key.e_form], bins=300)
df_e_form = pd.DataFrame(e_form_hist, index=[y_col, x_col]).T.round(3)
@@ -340,8 +345,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
# %% plot forces distribution
# use numpy to pre-compute histogram
x_col, y_col = "|Forces| (eV/Å)", count_col
+df_forces = locals().get("df_forces")
-if "df_forces" not in locals(): # only compute once for speed
+if df_forces is None: # only compute once for speed
forces_hist = np.histogram(
df_mp_trj[Key.forces].explode().explode().abs(), bins=300
)
@@ -361,8 +367,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
# %% plot hydrostatic stress distribution
x_col, y_col = "1/3 Tr(σ) (eV/ų)", count_col # noqa: RUF001
+df_stresses = locals().get("df_stresses")
-if "df_stresses" not in locals(): # only compute once for speed
+if df_stresses is None: # only compute once for speed
stresses_hist = np.histogram(df_mp_trj[Key.stress_trace], bins=300)
df_stresses = pd.DataFrame(stresses_hist, index=[y_col, x_col]).T.round(3)
@@ -381,8 +388,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
# %% plot magmoms distribution
x_col, y_col = "Magmoms (μB)", count_col
+df_magmoms = locals().get("df_magmoms")
-if "df_magmoms" not in locals(): # only compute once for speed
+if df_magmoms is None: # only compute once for speed
magmoms_hist = np.histogram(df_mp_trj[Key.magmoms].dropna().explode(), bins=300)
df_magmoms = pd.DataFrame(magmoms_hist, index=[y_col, x_col]).T.round(3)
diff --git a/data/mp/get_mp_energies.py b/data/mp/get_mp_energies.py
index 17bbab32..b87334e6 100644
--- a/data/mp/get_mp_energies.py
+++ b/data/mp/get_mp_energies.py
@@ -8,8 +8,9 @@
from pymatviz.utils import annotate_metrics
from tqdm import tqdm
-from matbench_discovery import STABILITY_THRESHOLD, Key, today
+from matbench_discovery import STABILITY_THRESHOLD, today
from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key
"""
Download all MP formation and above hull energies on 2023-01-10.
diff --git a/data/mp/get_mp_traj.py b/data/mp/get_mp_traj.py
index 9dab8cb0..1bcab981 100644
--- a/data/mp/get_mp_traj.py
+++ b/data/mp/get_mp_traj.py
@@ -17,7 +17,8 @@
from pymongo.database import Database
from tqdm import tqdm, trange
-from matbench_discovery import ROOT, Key, today
+from matbench_discovery import ROOT, today
+from matbench_discovery.enums import Key
__author__ = "Janosh Riebesell"
__date__ = "2023-03-15"
diff --git a/data/wbm/compare_cse_vs_ce_mp_2020_corrections.py b/data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
index 564cdfa2..e7b31c62 100644
--- a/data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
+++ b/data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
@@ -10,9 +10,10 @@
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from tqdm import tqdm
-from matbench_discovery import ROOT, Key, today
+from matbench_discovery import ROOT, today
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.energy import get_e_form_per_atom
+from matbench_discovery.enums import Key
from matbench_discovery.plots import plt
"""
diff --git a/data/wbm/compile_wbm_test_set.py b/data/wbm/compile_wbm_test_set.py
index c4101408..7ac06fe1 100644
--- a/data/wbm/compile_wbm_test_set.py
+++ b/data/wbm/compile_wbm_test_set.py
@@ -19,17 +19,19 @@
from pymatviz.io import save_fig
from tqdm import tqdm
-from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, Key, today
+from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.energy import get_e_form_per_atom
+from matbench_discovery.enums import Key
try:
import gdown
-except ImportError:
- print(
+except ImportError as exc:
+ exc.add_note(
"gdown not installed. Needed for downloading WBM initial + relaxed structures "
"from Google Drive."
)
+ raise
"""
Dataset generated with DFT and published in Jan 2021 as
@@ -90,8 +92,8 @@
18198704957443186264,
)
-if "dfs_wbm_structs" not in locals():
- dfs_wbm_structs = {}
+dfs_wbm_structs = locals().get("dfs_wbm_structs", {})
+
for json_path in json_paths:
step = int(json_path.split(".json.bz2")[0][-1])
assert step in range(1, 6)
@@ -179,8 +181,8 @@ def increment_wbm_material_id(wbm_id: str) -> str:
print(f"{file_path} already exists, skipping")
continue
+ url = f"{mat_cloud_url}&filename={filename}"
try:
- url = f"{mat_cloud_url}&filename={filename}"
urllib.request.urlretrieve(url, file_path)
except urllib.error.HTTPError as exc:
print(f"failed to download {url=}: {exc}")
diff --git a/matbench_discovery/__init__.py b/matbench_discovery/__init__.py
index 41bb796a..3dacffaa 100644
--- a/matbench_discovery/__init__.py
+++ b/matbench_discovery/__init__.py
@@ -13,15 +13,7 @@
import plotly.io as pio
import pymatviz # noqa: F401
-from matbench_discovery.enums import ( # noqa: F401
- Key,
- Model,
- ModelType,
- Open,
- Quantity,
- Targets,
- Task,
-)
+from matbench_discovery.enums import Model, Quantity
PKG_NAME = "matbench-discovery"
direct_url = Distribution.from_name(PKG_NAME).read_text("direct_url.json") or "{}"
@@ -69,7 +61,7 @@
FIGSHARE_URLS = json.load(file)
# --- start global plot settings
-px.defaults.labels = Quantity.val_dict() | Model.val_dict()
+px.defaults.labels = Quantity.key_val_dict() | Model.key_val_dict()
global_layout = dict(
paper_bgcolor="rgba(0,0,0,0)",
@@ -77,9 +69,9 @@
# increase legend marker size and make background transparent
legend=dict(itemsizing="constant", bgcolor="rgba(0, 0, 0, 0)"),
)
-pio.templates["global"] = dict(layout=global_layout)
-pio.templates.default = "pymatviz_dark+global"
-px.defaults.template = "pymatviz_dark+global"
+pio.templates["mbd_global"] = dict(layout=global_layout)
+pio.templates.default = "pymatviz_dark+mbd_global"
+px.defaults.template = "pymatviz_dark+mbd_global"
# https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924
# when seeing MathJax "loading" message in exported PDFs,
diff --git a/matbench_discovery/data.py b/matbench_discovery/data.py
index 05e381f7..b0315825 100644
--- a/matbench_discovery/data.py
+++ b/matbench_discovery/data.py
@@ -14,7 +14,8 @@
from monty.json import MontyDecoder
from tqdm import tqdm
-from matbench_discovery import FIGSHARE_DIR, Key
+from matbench_discovery import FIGSHARE_DIR
+from matbench_discovery.enums import Key
if TYPE_CHECKING:
from pathlib import Path
diff --git a/matbench_discovery/enums.py b/matbench_discovery/enums.py
index ea8e42a3..82615f3c 100644
--- a/matbench_discovery/enums.py
+++ b/matbench_discovery/enums.py
@@ -29,14 +29,24 @@ def description(self) -> str:
return self.__dict__["desc"]
@classmethod
- def val_dict(cls) -> dict[str, str]:
- """Return the Enum as dictionary."""
+ def key_val_dict(cls) -> dict[str, str]:
+ """Map of keys to values."""
return {key: str(val) for key, val in cls.__members__.items()}
@classmethod
- def label_dict(cls) -> dict[str, str]:
- """Return the Enum as dictionary."""
- return {str(val): val.label for key, val in cls.__members__.items()}
+ def val_label_dict(cls) -> dict[str, str | None]:
+ """Map of values to labels."""
+ return {str(val): val.label for val in cls.__members__.values()}
+
+ @classmethod
+ def val_desc_dict(cls) -> dict[str, str | None]:
+ """Map of values to descriptions."""
+ return {str(val): val.description for val in cls.__members__.values()}
+
+ @classmethod
+ def label_desc_dict(cls) -> dict[str | None, str | None]:
+ """Map of labels to descriptions."""
+ return {str(val.label): val.description for val in cls.__members__.values()}
@unique
diff --git a/matbench_discovery/plots.py b/matbench_discovery/plots.py
index 9bc1fe5a..feba1303 100644
--- a/matbench_discovery/plots.py
+++ b/matbench_discovery/plots.py
@@ -103,8 +103,10 @@ def hist_classified_stable_vs_hull_dist(
fixed in Inkscape or similar by merging regions by color.
"""
x_col = dict(true=each_true_col, pred=each_pred_col)[which_energy]
+ clf_col, value_name = "classified", "count"
df_plot = pd.DataFrame()
+
for facet, df_group in (
df.groupby(kwargs["facet_col"]) if "facet_col" in kwargs else [(None, df)]
):
@@ -113,16 +115,16 @@ def hist_classified_stable_vs_hull_dist(
)
# switch between hist of DFT-computed and model-predicted convex hull distance
- e_above_hull = df_group[x_col]
- each_true_pos = e_above_hull[true_pos]
- each_true_neg = e_above_hull[true_neg]
- each_false_neg = e_above_hull[false_neg]
- each_false_pos = e_above_hull[false_pos]
+ srs_each = df_group[x_col]
+ each_true_pos = srs_each[true_pos]
+ each_true_neg = srs_each[true_neg]
+ each_false_neg = srs_each[false_neg]
+ each_false_pos = srs_each[false_pos]
# n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
# sum, (true_pos, false_pos, true_neg, false_neg)
# )
- df_group[(clf_col := "classified")] = np.array(clf_labels)[
+ df_group[clf_col] = np.array(clf_labels)[
true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3
]
@@ -144,7 +146,6 @@ def hist_classified_stable_vs_hull_dist(
index=clf_labels,
).T
df_hist[x_col] = bin_edges[:-1]
- value_name = "count"
df_melt = df_hist.melt(
id_vars=x_col,
value_vars=clf_labels,
@@ -714,6 +715,7 @@ def cumulative_metrics(
)
df = dfs[metric]
ax.set(ylim=(0, 1), xlim=(0, None), ylabel=metric)
+ bbox = dict(facecolor="white", alpha=0.5, edgecolor="none")
for model in df_preds:
# TODO is this really necessary?
if len(df[model].dropna()) == 0:
@@ -722,7 +724,6 @@ def cumulative_metrics(
y_end = df[model].dropna().iloc[-1]
# add some visual guidelines to the plot
intersect_kwargs = dict(linestyle=":", alpha=0.4, linewidth=2)
- bbox = dict(facecolor="white", alpha=0.5, edgecolor="none")
# place model name at the end of every line
ax.text(x_end, y_end, model, va="bottom", rotation=30, bbox=bbox)
if "x" in project_end_point:
diff --git a/matbench_discovery/preds.py b/matbench_discovery/preds.py
index 4cf52089..5570852e 100644
--- a/matbench_discovery/preds.py
+++ b/matbench_discovery/preds.py
@@ -5,8 +5,9 @@
import pandas as pd
from tqdm import tqdm
-from matbench_discovery import ROOT, STABILITY_THRESHOLD, Key, Model
+from matbench_discovery import ROOT, STABILITY_THRESHOLD, Model
from matbench_discovery.data import Files, df_wbm, glob_to_df
+from matbench_discovery.enums import Key
from matbench_discovery.metrics import stable_metrics
from matbench_discovery.plots import plotly_colors, plotly_line_styles, plotly_markers
@@ -65,7 +66,7 @@ class PredFiles(Files):
# key_map maps model keys to pretty labels
-PRED_FILES = PredFiles(root=f"{ROOT}/models", key_map=Model.val_dict())
+PRED_FILES = PredFiles(root=f"{ROOT}/models", key_map=Model.key_val_dict())
def load_df_wbm_with_preds(
@@ -101,7 +102,6 @@ def load_df_wbm_with_preds(
)
dfs: dict[str, pd.DataFrame] = {}
-
try:
for model_name in (bar := tqdm(models, disable=not pbar, desc="Loading preds")):
bar.set_postfix_str(model_name)
@@ -109,7 +109,7 @@ def load_df_wbm_with_preds(
df = df.set_index(id_col)
dfs[model_name] = df
except Exception as exc:
- raise RuntimeError(f"Failed to load {model_name=}") from exc
+ raise RuntimeError(f"Failed to load {locals().get('model_name')=}") from exc
from matbench_discovery.data import df_wbm
diff --git a/models/alignn/test_alignn.py b/models/alignn/test_alignn.py
index cb29bf90..250b0802 100644
--- a/models/alignn/test_alignn.py
+++ b/models/alignn/test_alignn.py
@@ -17,8 +17,9 @@
from sklearn.metrics import r2_score
from tqdm import tqdm
-from matbench_discovery import Key, Task, today
+from matbench_discovery import today
from matbench_discovery.data import DATA_FILES, df_wbm
+from matbench_discovery.enums import Key, Task
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
diff --git a/models/alignn/train_alignn.py b/models/alignn/train_alignn.py
index 5c5c1d81..cbc34286 100644
--- a/models/alignn/train_alignn.py
+++ b/models/alignn/train_alignn.py
@@ -18,8 +18,9 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
-from matbench_discovery import Key, today
+from matbench_discovery import today
from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key
from matbench_discovery.slurm import slurm_submit
__author__ = "Philipp Benner, Janosh Riebesell"
diff --git a/models/alignn_ff/alignn_ff_relax.py b/models/alignn_ff/alignn_ff_relax.py
index 4cb8f8c6..1fc6f125 100644
--- a/models/alignn_ff/alignn_ff_relax.py
+++ b/models/alignn_ff/alignn_ff_relax.py
@@ -9,8 +9,9 @@
from pymatgen.io.jarvis import JarvisAtomsAdaptor
from tqdm import tqdm
-from matbench_discovery import Key, Task, today
+from matbench_discovery import today
from matbench_discovery.data import DATA_FILES, df_wbm
+from matbench_discovery.enums import Key, Task
__author__ = "Janosh Riebesell, Philipp Benner"
__date__ = "2023-07-11"
diff --git a/models/alignn_ff/test_alignn_ff.py b/models/alignn_ff/test_alignn_ff.py
index 467d47b6..c1c9cfb0 100644
--- a/models/alignn_ff/test_alignn_ff.py
+++ b/models/alignn_ff/test_alignn_ff.py
@@ -18,8 +18,9 @@
from sklearn.metrics import r2_score
from tqdm import tqdm
-from matbench_discovery import Key, Task, today
+from matbench_discovery import today
from matbench_discovery.data import DATA_FILES, df_wbm
+from matbench_discovery.enums import Key, Task
from matbench_discovery.plots import wandb_scatter
__author__ = "Philipp Benner, Janosh Riebesell"
diff --git a/models/bowsr/join_bowsr_results.py b/models/bowsr/join_bowsr_results.py
index e6619e69..08d3df60 100644
--- a/models/bowsr/join_bowsr_results.py
+++ b/models/bowsr/join_bowsr_results.py
@@ -8,8 +8,8 @@
import pymatviz
from tqdm import tqdm
-from matbench_discovery import Model, Task
-from matbench_discovery.data import DATA_FILES, Key
+from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key, Model, Task
__author__ = "Janosh Riebesell"
__date__ = "2022-09-22"
diff --git a/models/bowsr/test_bowsr.py b/models/bowsr/test_bowsr.py
index 4769d213..613b38aa 100644
--- a/models/bowsr/test_bowsr.py
+++ b/models/bowsr/test_bowsr.py
@@ -14,8 +14,9 @@
from pymatgen.core import Structure
from tqdm import tqdm
-from matbench_discovery import Key, Model, Task, timestamp, today
+from matbench_discovery import Model, timestamp, today
from matbench_discovery.data import DATA_FILES, as_dict_handler
+from matbench_discovery.enums import Key, Task
from matbench_discovery.slurm import slurm_submit
__author__ = "Janosh Riebesell"
@@ -127,11 +128,7 @@
with open(os.devnull, "w") as devnull, contextlib.redirect_stdout(devnull):
optimizer.optimize(**optimize_kwargs)
- try:
- struct_bowsr, energy_bowsr = optimizer.get_optimized_structure_and_energy()
- except Exception as exc:
- print(f"Failed to relax {material_id}: {exc!r}")
-
+ struct_bowsr, energy_bowsr = optimizer.get_optimized_structure_and_energy()
results = {
f"e_form_per_atom_bowsr_{energy_model}": model.predict_energy(struct_bowsr),
"structure_bowsr": struct_bowsr,
@@ -139,6 +136,7 @@
}
relax_results[material_id] = results
+
except Exception as exc:
print(f"{material_id=} raised {exc=}")
diff --git a/models/cgcnn/test_cgcnn.py b/models/cgcnn/test_cgcnn.py
index 7203a7fb..ab35c87d 100644
--- a/models/cgcnn/test_cgcnn.py
+++ b/models/cgcnn/test_cgcnn.py
@@ -14,8 +14,9 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
-from matbench_discovery import CHECKPOINT_DIR, WANDB_PATH, WBM_DIR, Key, Task, today
+from matbench_discovery import CHECKPOINT_DIR, WANDB_PATH, WBM_DIR, today
from matbench_discovery.data import DATA_FILES, df_wbm
+from matbench_discovery.enums import Key, Task
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
diff --git a/models/cgcnn/train_cgcnn.py b/models/cgcnn/train_cgcnn.py
index 974b420e..6ed885b3 100644
--- a/models/cgcnn/train_cgcnn.py
+++ b/models/cgcnn/train_cgcnn.py
@@ -11,8 +11,9 @@
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
-from matbench_discovery import WANDB_PATH, Key, timestamp, today
+from matbench_discovery import WANDB_PATH, timestamp, today
from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key
from matbench_discovery.slurm import slurm_submit
from matbench_discovery.structure import perturb_structure
diff --git a/models/chgnet/analyze_chgnet.py b/models/chgnet/analyze_chgnet.py
index 0dde7109..a3a90b1d 100644
--- a/models/chgnet/analyze_chgnet.py
+++ b/models/chgnet/analyze_chgnet.py
@@ -10,9 +10,10 @@
from pymatviz import density_scatter, plot_structure_2d, ptable_heatmap_plotly
from pymatviz.io import save_fig
-from matbench_discovery import PDF_FIGS, Key
+from matbench_discovery import PDF_FIGS
from matbench_discovery import plots as plots
from matbench_discovery.data import DATA_FILES, df_wbm
+from matbench_discovery.enums import Key
from matbench_discovery.preds import PRED_FILES, df_preds
__author__ = "Janosh Riebesell"
diff --git a/models/chgnet/ctk_structure_viewer.py b/models/chgnet/ctk_structure_viewer.py
index 2d40a402..52fc3ead 100644
--- a/models/chgnet/ctk_structure_viewer.py
+++ b/models/chgnet/ctk_structure_viewer.py
@@ -3,7 +3,7 @@
import pandas as pd
from crystal_toolkit.helpers.utils import hook_up_fig_with_struct_viewer
-from matbench_discovery import Key
+from matbench_discovery.enums import Key
from matbench_discovery.preds import PRED_FILES
__author__ = "Janosh Riebesell"
diff --git a/models/chgnet/ctk_trajectory_viewer.py b/models/chgnet/ctk_trajectory_viewer.py
index c53a3276..83d2004b 100644
--- a/models/chgnet/ctk_trajectory_viewer.py
+++ b/models/chgnet/ctk_trajectory_viewer.py
@@ -16,8 +16,8 @@
from m3gnet.models import Relaxer as M3gnetRelaxer
from pymatgen.core import Lattice, Structure
-from matbench_discovery import Key
from matbench_discovery.data import df_wbm
+from matbench_discovery.enums import Key
if TYPE_CHECKING:
from chgnet.model.dynamics import TrajectoryObserver
diff --git a/models/chgnet/join_chgnet_results.py b/models/chgnet/join_chgnet_results.py
index 7fd971f2..5ee7495f 100644
--- a/models/chgnet/join_chgnet_results.py
+++ b/models/chgnet/join_chgnet_results.py
@@ -13,18 +13,17 @@
from pymatviz import density_scatter
from tqdm import tqdm
-from matbench_discovery import Key, Task
from matbench_discovery.data import as_dict_handler
from matbench_discovery.energy import get_e_form_per_atom
+from matbench_discovery.enums import Key, Task
from matbench_discovery.preds import df_preds
__author__ = "Janosh Riebesell"
__date__ = "2023-03-01"
-df_chgnet = pd.read_csv("2023-12-05-chgnet-0.3.0-wbm-IS2RE-static.csv.gz").set_index(
- Key.mat_id
-)
+csv_path = "2023-12-05-chgnet-0.3.0-wbm-IS2RE-static.csv.gz"
+df_chgnet = pd.read_csv(csv_path).set_index(Key.mat_id)
# %%
diff --git a/models/chgnet/test_chgnet.py b/models/chgnet/test_chgnet.py
index 1a0193eb..f53e9745 100644
--- a/models/chgnet/test_chgnet.py
+++ b/models/chgnet/test_chgnet.py
@@ -21,8 +21,9 @@
from pymatgen.core import Structure
from tqdm import tqdm
-from matbench_discovery import Key, Task, timestamp, today
+from matbench_discovery import timestamp, today
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
+from matbench_discovery.enums import Key, Task
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
diff --git a/models/gnome/test_gnome.py b/models/gnome/test_gnome.py
index 1f12d54b..f615a005 100644
--- a/models/gnome/test_gnome.py
+++ b/models/gnome/test_gnome.py
@@ -1,7 +1,7 @@
# %%
from pymatviz import density_scatter
-from matbench_discovery import Key, Model
+from matbench_discovery.enums import Key, Model
from matbench_discovery.preds import df_preds
__author__ = "Janosh Riebesell"
diff --git a/models/m3gnet/join_m3gnet_results.py b/models/m3gnet/join_m3gnet_results.py
index 68614fa6..608770e7 100644
--- a/models/m3gnet/join_m3gnet_results.py
+++ b/models/m3gnet/join_m3gnet_results.py
@@ -16,9 +16,9 @@
from pymatgen.entries.computed_entries import ComputedStructureEntry
from tqdm import tqdm
-from matbench_discovery import Key, Task
from matbench_discovery.data import DATA_FILES, as_dict_handler
from matbench_discovery.energy import get_e_form_per_atom
+from matbench_discovery.enums import Key, Task
__author__ = "Janosh Riebesell"
__date__ = "2022-08-16"
@@ -34,8 +34,7 @@
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
# prevent accidental overwrites
-if "dfs" not in locals():
- dfs: dict[str, pd.DataFrame] = {}
+dfs: dict[str, pd.DataFrame] = locals().get("dfs", {})
# %%
diff --git a/models/m3gnet/pre_vs_post_m3gnet_relaxation.py b/models/m3gnet/pre_vs_post_m3gnet_relaxation.py
index 8da7d4dd..3a611a1a 100644
--- a/models/m3gnet/pre_vs_post_m3gnet_relaxation.py
+++ b/models/m3gnet/pre_vs_post_m3gnet_relaxation.py
@@ -12,8 +12,9 @@
from pymatviz.utils import add_identity_line
from sklearn.metrics import r2_score
-from matbench_discovery import ROOT, SITE_FIGS, Key, plots
+from matbench_discovery import ROOT, SITE_FIGS, plots
from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key
__author__ = "Janosh Riebesell"
__date__ = "2022-06-18"
diff --git a/models/m3gnet/test_m3gnet.py b/models/m3gnet/test_m3gnet.py
index e39185d8..59fa5f6c 100644
--- a/models/m3gnet/test_m3gnet.py
+++ b/models/m3gnet/test_m3gnet.py
@@ -20,8 +20,9 @@
from pymatgen.core import Structure
from tqdm import tqdm
-from matbench_discovery import ROOT, Key, Task, timestamp, today
+from matbench_discovery import ROOT, timestamp, today
from matbench_discovery.data import DATA_FILES, as_dict_handler
+from matbench_discovery.enums import Key, Task
from matbench_discovery.slurm import slurm_submit
__author__ = "Janosh Riebesell"
diff --git a/models/mace/analyze_mace.py b/models/mace/analyze_mace.py
index 6fae2ca7..408b4e06 100644
--- a/models/mace/analyze_mace.py
+++ b/models/mace/analyze_mace.py
@@ -8,9 +8,9 @@
from pymatviz import density_scatter, ptable_heatmap_plotly, spacegroup_sunburst
from pymatviz.io import save_fig
-from matbench_discovery import Key
from matbench_discovery import plots as plots
from matbench_discovery.data import df_wbm
+from matbench_discovery.enums import Key
from matbench_discovery.preds import PRED_FILES
__author__ = "Janosh Riebesell"
diff --git a/models/mace/join_mace_results.py b/models/mace/join_mace_results.py
index d0fafa1e..ec37c5e1 100644
--- a/models/mace/join_mace_results.py
+++ b/models/mace/join_mace_results.py
@@ -16,9 +16,9 @@
from pymatviz import density_scatter
from tqdm import tqdm
-from matbench_discovery import Key, Task
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
from matbench_discovery.energy import get_e_form_per_atom
+from matbench_discovery.enums import Key, Task
__author__ = "Janosh Riebesell"
__date__ = "2023-03-01"
diff --git a/models/mace/json_to_extxyz.py b/models/mace/json_to_extxyz.py
index a583d957..a8a0091d 100644
--- a/models/mace/json_to_extxyz.py
+++ b/models/mace/json_to_extxyz.py
@@ -12,7 +12,8 @@
from pymatviz.io import TqdmDownload
from tqdm import tqdm
-from matbench_discovery import FIGSHARE_URLS, MP_DIR, Key
+from matbench_discovery import FIGSHARE_URLS, MP_DIR
+from matbench_discovery.enums import Key
__author__ = "Yuan Chiang"
__date__ = "2023-08-10"
diff --git a/models/mace/test_mace.py b/models/mace/test_mace.py
index c2762bd7..bf7051ab 100644
--- a/models/mace/test_mace.py
+++ b/models/mace/test_mace.py
@@ -18,8 +18,9 @@
from pymatgen.io.ase import AseAtomsAdaptor
from tqdm import tqdm
-from matbench_discovery import ROOT, Key, Task, timestamp, today
+from matbench_discovery import ROOT, timestamp, today
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
+from matbench_discovery.enums import Key, Task
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
@@ -138,7 +139,9 @@
)
relax_results[material_id] = {"structure": mace_struct, "energy": mace_energy}
- if record_traj and len(coords) > 0:
+
+ coords, lattices = (locals().get(key, []) for key in ("coords", "lattices"))
+ if record_traj and coords and lattices:
mace_traj = Trajectory(
species=structs[material_id].species,
coords=coords,
diff --git a/models/mace/train_mace.py b/models/mace/train_mace.py
index 3d051f59..0cce4e03 100644
--- a/models/mace/train_mace.py
+++ b/models/mace/train_mace.py
@@ -7,12 +7,13 @@
from typing import Any
import mace
+import mace.data
import numpy as np
import torch.distributed
import torch.nn.functional
from e3nn import o3
-from mace import data, modules, tools
-from mace.data import HDF5Dataset
+from mace import modules, tools
+from mace.mace.data import HDF5Dataset
from mace.tools import torch_geometric
from mace.tools.scripts_utils import (
LRScheduler,
@@ -87,6 +88,7 @@ def main(**kwargs: Any) -> None:
torch.distributed.init_process_group(backend="nccl")
else:
rank = 0
+ local_rank = world_size = 1
# Setup
tools.set_seeds(args.seed)
@@ -208,11 +210,11 @@ def main(**kwargs: Any) -> None:
if args.train_file.endswith(".xyz"):
train_set = [
- data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max)
+ mace.data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max)
for config in collections.train
]
valid_set = [
- data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max)
+ mace.data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max)
for config in collections.valid
]
else:
@@ -221,7 +223,7 @@ def main(**kwargs: Any) -> None:
train_sampler, valid_sampler = None, None
if args.distributed:
- train_sampler = torch.utils.data.distributed.DistributedSampler(
+ train_sampler = torch.utils.mace.data.distributed.DistributedSampler(
train_set,
num_replicas=world_size,
rank=rank,
@@ -613,7 +615,9 @@ def main(**kwargs: Any) -> None:
if args.train_file.endswith(".xyz"):
for name, subset in collections.tests:
test_sets[name] = [
- data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max)
+ mace.data.AtomicData.from_config(
+ config, z_table=z_table, cutoff=args.r_max
+ )
for config in subset
]
else:
diff --git a/models/megnet/test_megnet.py b/models/megnet/test_megnet.py
index ce040ccc..9cf352f0 100644
--- a/models/megnet/test_megnet.py
+++ b/models/megnet/test_megnet.py
@@ -20,8 +20,9 @@
from sklearn.metrics import r2_score
from tqdm import tqdm
-from matbench_discovery import Key, Task, timestamp, today
+from matbench_discovery import timestamp, today
from matbench_discovery.data import DATA_FILES, df_wbm
+from matbench_discovery.enums import Key, Task
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.preds import PRED_FILES
from matbench_discovery.slurm import slurm_submit
diff --git a/models/voronoi_rf/join_voronoi_features.py b/models/voronoi_rf/join_voronoi_features.py
index 26c153aa..9c1bc7c6 100644
--- a/models/voronoi_rf/join_voronoi_features.py
+++ b/models/voronoi_rf/join_voronoi_features.py
@@ -12,7 +12,7 @@
import pandas as pd
from tqdm import tqdm
-from matbench_discovery import Key
+from matbench_discovery.enums import Key
__author__ = "Janosh Riebesell"
__date__ = "2022-08-16"
diff --git a/models/voronoi_rf/train_test_voronoi_rf.py b/models/voronoi_rf/train_test_voronoi_rf.py
index 5e598ce4..7ceea672 100644
--- a/models/voronoi_rf/train_test_voronoi_rf.py
+++ b/models/voronoi_rf/train_test_voronoi_rf.py
@@ -13,8 +13,9 @@
from sklearn.metrics import r2_score
from sklearn.pipeline import Pipeline
-from matbench_discovery import ROOT, Key, Task, today
+from matbench_discovery import ROOT, today
from matbench_discovery.data import DATA_FILES, df_wbm, glob_to_df
+from matbench_discovery.enums import Key, Task
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
diff --git a/models/voronoi_rf/voronoi_featurize_dataset.py b/models/voronoi_rf/voronoi_featurize_dataset.py
index 17be9c1f..25530dd1 100644
--- a/models/voronoi_rf/voronoi_featurize_dataset.py
+++ b/models/voronoi_rf/voronoi_featurize_dataset.py
@@ -14,8 +14,9 @@
from pymatgen.core import Structure
from tqdm import tqdm
-from matbench_discovery import ROOT, Key, today
+from matbench_discovery import ROOT, today
from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key
from matbench_discovery.slurm import slurm_submit
sys.path.append(f"{ROOT}/models")
@@ -71,6 +72,8 @@
struct_dicts = [cse["structure"] for cse in df_in[Key.cse]]
elif data_name == "wbm" and input_col == Key.init_struct:
struct_dicts = df_in[Key.init_struct]
+else:
+ raise ValueError(f"Invalid {data_name=}, {input_col=} combo")
df_in[input_col] = [
Structure.from_dict(dct) for dct in tqdm(struct_dicts, disable=None)
diff --git a/models/wrenformer/analyze_wrenformer.py b/models/wrenformer/analyze_wrenformer.py
index ecd87723..307cfbe4 100644
--- a/models/wrenformer/analyze_wrenformer.py
+++ b/models/wrenformer/analyze_wrenformer.py
@@ -10,8 +10,9 @@
from pymatviz.ptable import ptable_heatmap_plotly
from pymatviz.utils import add_identity_line, bin_df_cols
-from matbench_discovery import PDF_FIGS, SITE_FIGS, Key, Model
+from matbench_discovery import PDF_FIGS, SITE_FIGS, Model
from matbench_discovery.data import DATA_FILES, df_wbm
+from matbench_discovery.enums import Key
from matbench_discovery.preds import df_each_pred, df_preds
__author__ = "Janosh Riebesell"
diff --git a/models/wrenformer/test_wrenformer.py b/models/wrenformer/test_wrenformer.py
index 6941c222..78a6f03b 100644
--- a/models/wrenformer/test_wrenformer.py
+++ b/models/wrenformer/test_wrenformer.py
@@ -17,8 +17,9 @@
from aviary.wrenformer.data import df_to_in_mem_dataloader
from aviary.wrenformer.model import Wrenformer
-from matbench_discovery import CHECKPOINT_DIR, WANDB_PATH, Key, Task, today
+from matbench_discovery import CHECKPOINT_DIR, WANDB_PATH, today
from matbench_discovery.data import df_wbm
+from matbench_discovery.enums import Key, Task
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
diff --git a/models/wrenformer/train_wrenformer.py b/models/wrenformer/train_wrenformer.py
index 4f6b60cb..c042b303 100644
--- a/models/wrenformer/train_wrenformer.py
+++ b/models/wrenformer/train_wrenformer.py
@@ -8,8 +8,9 @@
import pandas as pd
from aviary.train import df_train_test_split, train_wrenformer
-from matbench_discovery import WANDB_PATH, Key, timestamp, today
+from matbench_discovery import WANDB_PATH, timestamp, today
from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key
from matbench_discovery.slurm import slurm_submit
__author__ = "Janosh Riebesell"
diff --git a/pyproject.toml b/pyproject.toml
index 6a15952c..861bff66 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -78,8 +78,10 @@ universal = true
[tool.ruff]
target-version = "py39"
-lint.select = ["ALL"]
-lint.ignore = [
+
+[tool.ruff.lint]
+select = ["ALL"]
+ignore = [
"ANN101",
"ANN102",
"ANN401",
@@ -115,9 +117,9 @@ lint.ignore = [
"TRY003",
"TRY301",
]
-lint.pydocstyle.convention = "google"
-lint.isort.known-third-party = ["wandb"]
-lint.isort.split-on-trailing-comma = false
+pydocstyle.convention = "google"
+isort.known-third-party = ["wandb"]
+isort.split-on-trailing-comma = false
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D", "S101"]
@@ -145,3 +147,9 @@ markers = [
"slow: deselect slow tests with -m 'not slow'",
"very_slow: select with -m 'very_slow'",
]
+
+[tool.pyright]
+typeCheckingMode = "off"
+reportPossiblyUnboundVariable = true
+reportUnboundVariable = true
+reportMissingImports = false
diff --git a/scripts/analyze_model_failure_cases.py b/scripts/analyze_model_failure_cases.py
index 15f92fa0..e9dfa0c0 100644
--- a/scripts/analyze_model_failure_cases.py
+++ b/scripts/analyze_model_failure_cases.py
@@ -16,8 +16,9 @@
from pymatviz.io import save_fig
from tqdm import tqdm
-from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, Key
+from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR
from matbench_discovery.data import DATA_FILES, df_wbm
+from matbench_discovery.enums import Key
from matbench_discovery.metrics import classify_stable
from matbench_discovery.preds import df_each_err, df_each_pred, df_metrics, df_preds
diff --git a/scripts/compute_struct_fingerprints.py b/scripts/compute_struct_fingerprints.py
index 00322869..50d097b4 100644
--- a/scripts/compute_struct_fingerprints.py
+++ b/scripts/compute_struct_fingerprints.py
@@ -15,8 +15,9 @@
from pymatgen.core import Structure
from tqdm import tqdm
-from matbench_discovery import DATA_DIR, Key, timestamp
+from matbench_discovery import DATA_DIR, timestamp
from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key
from matbench_discovery.slurm import slurm_submit
__author__ = "Janosh Riebesell"
diff --git a/scripts/hist_classified_stable_vs_hull_dist.py b/scripts/hist_classified_stable_vs_hull_dist.py
index d32e1b82..20e3e2a0 100644
--- a/scripts/hist_classified_stable_vs_hull_dist.py
+++ b/scripts/hist_classified_stable_vs_hull_dist.py
@@ -10,8 +10,9 @@
from pymatviz.io import save_fig
-from matbench_discovery import PDF_FIGS, Key
+from matbench_discovery import PDF_FIGS
from matbench_discovery.data import df_wbm
+from matbench_discovery.enums import Key
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist
from matbench_discovery.preds import df_each_pred
diff --git a/scripts/hist_classified_stable_vs_hull_dist_batches.py b/scripts/hist_classified_stable_vs_hull_dist_batches.py
index 1c226353..84e4cc0c 100644
--- a/scripts/hist_classified_stable_vs_hull_dist_batches.py
+++ b/scripts/hist_classified_stable_vs_hull_dist_batches.py
@@ -12,7 +12,8 @@
import pandas as pd
from pymatviz.io import save_fig
-from matbench_discovery import PDF_FIGS, Key
+from matbench_discovery import PDF_FIGS
+from matbench_discovery.enums import Key
from matbench_discovery.plots import hist_classified_stable_vs_hull_dist
from matbench_discovery.preds import df_preds
diff --git a/scripts/model_figs/analyze_model_disagreement.py b/scripts/model_figs/analyze_model_disagreement.py
index 18f7d6f4..9faaaa74 100644
--- a/scripts/model_figs/analyze_model_disagreement.py
+++ b/scripts/model_figs/analyze_model_disagreement.py
@@ -46,7 +46,7 @@
"hydrides": r".*H\d.*",
"oxynitrides": r".*[ON]\d.*",
}
-n_structs = 200
+n_structs, fig = 200, None
for material_cls, pattern in material_classes.items():
df_subset = df_preds[df_preds[Key.formula].str.match(pattern)]
diff --git a/scripts/model_figs/compile_model_stats.py b/scripts/model_figs/compile_model_stats.py
index b2805009..f5dfffca 100644
--- a/scripts/model_figs/compile_model_stats.py
+++ b/scripts/model_figs/compile_model_stats.py
@@ -62,6 +62,7 @@
# trained from scratch. Their run times only indicate the time needed to predict the
# test set.
+time_col = "Run Time (h)"
for label, stats, raw_filters in (
("train", train_stats, train_run_filters),
("test", test_stats, test_run_filters),
@@ -98,7 +99,7 @@
n_gpu, n_cpu = metadata.get("gpu_count", 0), metadata.get("cpu_count", 0)
stats[model] = {
- (time_col := "Run Time (h)"): run_time_total / 3600,
+ time_col: run_time_total / 3600,
"GPU": n_gpu,
"CPU": n_cpu,
"Slurm Jobs": n_runs,
@@ -114,18 +115,16 @@
# %%
-for df_tmp, label in (
- (df_metrics, ""),
- (df_metrics_10k, "-10k"),
- (df_metrics_uniq_protos, "-uniq-protos"),
-):
- df_tmp = pd.concat(
- [
- df_tmp,
- pd.DataFrame(train_stats).add_prefix("Train ", axis="index"),
- pd.DataFrame(test_stats).add_prefix("Test ", axis="index"),
- ],
- ).T
+stats_dict = {
+ "": df_metrics,
+ "-10k": df_metrics_10k,
+ "-uniq-protos": df_metrics_uniq_protos,
+}
+for label, df_tmp in stats_dict.items():
+ df_train_stats = pd.DataFrame(train_stats).add_prefix("Train ", axis="index")
+ df_test_stats = pd.DataFrame(test_stats).add_prefix("Test ", axis="index")
+ df_tmp = pd.concat([df_tmp, df_train_stats, df_test_stats]).T
+
df_tmp[time_col] = df_tmp.filter(like=time_col).sum(axis="columns")
# write model metrics to json for website use
@@ -136,9 +135,11 @@
df_tmp.attrs["All Models Run Time"] = df_tmp[time_col].sum()
+ # write stats for different data subsets to JSON
df_tmp.round(2).to_json(f"{SITE_LIB}/model-stats{label}.json", orient="index")
- if label == "":
- df_stats = df_tmp
+ stats_dict[label] = df_tmp
+
+df_stats = stats_dict[""]
# %%
diff --git a/scripts/model_figs/make_hull_dist_box_plot.py b/scripts/model_figs/hull_dist_box_plot.py
similarity index 100%
rename from scripts/model_figs/make_hull_dist_box_plot.py
rename to scripts/model_figs/hull_dist_box_plot.py
diff --git a/scripts/model_figs/make_metrics_tables.py b/scripts/model_figs/metrics_tables.py
similarity index 100%
rename from scripts/model_figs/make_metrics_tables.py
rename to scripts/model_figs/metrics_tables.py
diff --git a/scripts/model_figs/parity_energy_models.py b/scripts/model_figs/parity_energy_models.py
index cbb70555..bdc3fdad 100644
--- a/scripts/model_figs/parity_energy_models.py
+++ b/scripts/model_figs/parity_energy_models.py
@@ -28,7 +28,8 @@
if which_energy == "each":
e_pred_col = Key.each_pred
e_true_col = Key.each_true
-if which_energy == "e-form":
+else:
+ assert which_energy == "e-form", f"Invalid {which_energy=}"
e_true_col = Key.e_form
e_pred_col = Key.e_form_pred
diff --git a/scripts/project_compositions.py b/scripts/project_compositions.py
index 9b3ad095..8f2be1cb 100644
--- a/scripts/project_compositions.py
+++ b/scripts/project_compositions.py
@@ -11,8 +11,9 @@
from pymatgen.core import Composition
from tqdm import tqdm
-from matbench_discovery import DATA_DIR, Key
+from matbench_discovery import DATA_DIR
from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key
from matbench_discovery.slurm import slurm_submit
__author__ = "Janosh Riebesell"
diff --git a/scripts/rolling_mae_vs_hull_dist.py b/scripts/rolling_mae_vs_hull_dist.py
index 6edbcc82..bb8fe230 100644
--- a/scripts/rolling_mae_vs_hull_dist.py
+++ b/scripts/rolling_mae_vs_hull_dist.py
@@ -4,7 +4,8 @@
# %%
from pymatviz.io import save_fig
-from matbench_discovery import PDF_FIGS, SITE_FIGS, Key, Model
+from matbench_discovery import PDF_FIGS, SITE_FIGS, Model
+from matbench_discovery.enums import Key
from matbench_discovery.plots import rolling_mae_vs_hull_dist
from matbench_discovery.preds import df_each_pred, df_metrics, df_wbm
diff --git a/scripts/update_wandb_runs.py b/scripts/update_wandb_runs.py
index cec11aa2..c42f7da8 100644
--- a/scripts/update_wandb_runs.py
+++ b/scripts/update_wandb_runs.py
@@ -8,7 +8,8 @@
import wandb
from wandb.wandb_run import Run
-from matbench_discovery import WANDB_PATH, Task
+from matbench_discovery import WANDB_PATH
+from matbench_discovery.enums import Task
__author__ = "Janosh Riebesell"
__date__ = "2022-09-21"
diff --git a/scripts/upload_to_figshare.py b/scripts/upload_to_figshare.py
index 6113fbec..387128f7 100644
--- a/scripts/upload_to_figshare.py
+++ b/scripts/upload_to_figshare.py
@@ -158,11 +158,12 @@ def main(pyproject: dict[str, Any], urls_json_path: str) -> int:
with open(urls_json_path, "w") as file:
json.dump(figshare_urls, file)
except Exception as exc: # prompt to delete article if something went wrong
- answer = ""
- print(f"Encountered {exc=} for {file_path=}")
- while answer not in ("y", "n"):
+ if file_path := str(locals().get("file_path", "")):
+ exc.add_note(f"{file_path=}")
+ answer, article_id = "", int(locals().get("article_id", 0))
+ while article_id and answer.lower() not in ("y", "n"):
answer = input("Delete article? [y/n] ")
- if answer == "y":
+ if answer.lower() == "y":
make_request("DELETE", f"{BASE_URL}/account/articles/{article_id}")
return 0
diff --git a/scripts/wbm_umap_projection.py b/scripts/wbm_umap_projection.py
index 6eccb4db..aef163d3 100644
--- a/scripts/wbm_umap_projection.py
+++ b/scripts/wbm_umap_projection.py
@@ -23,8 +23,9 @@
from pymatviz.io import save_fig
from tqdm import tqdm
-from matbench_discovery import MP_DIR, PDF_FIGS, WBM_DIR, Key
+from matbench_discovery import MP_DIR, PDF_FIGS, WBM_DIR
from matbench_discovery.data import DATA_FILES
+from matbench_discovery.enums import Key
__author__ = "Philipp Benner, Janosh Riebesell"
__date__ = "2023-11-28"
diff --git a/site/tsconfig.json b/site/tsconfig.json
index 80cd5a98..ed06dd89 100644
--- a/site/tsconfig.json
+++ b/site/tsconfig.json
@@ -12,6 +12,6 @@
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true,
- "allowSyntheticDefaultImports": true,
- },
+ "allowSyntheticDefaultImports": true
+ }
}
diff --git a/tests/conftest.py b/tests/conftest.py
index 589a8a7b..73345ad1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -6,7 +6,7 @@
from pymatgen.core import Lattice, Structure
from pymatgen.entries.computed_entries import ComputedStructureEntry
-from matbench_discovery import Key
+from matbench_discovery.enums import Key
@pytest.fixture()
diff --git a/tests/test_data.py b/tests/test_data.py
index 7ba8cd0b..8edc421e 100644
--- a/tests/test_data.py
+++ b/tests/test_data.py
@@ -11,7 +11,7 @@
from pymatgen.core import Lattice, Structure
from pytest import CaptureFixture
-from matbench_discovery import FIGSHARE_DIR, ROOT, Key
+from matbench_discovery import FIGSHARE_DIR, ROOT
from matbench_discovery.data import (
DATA_FILES,
as_dict_handler,
@@ -20,6 +20,7 @@
glob_to_df,
load,
)
+from matbench_discovery.enums import Key
if TYPE_CHECKING:
from pathlib import Path
diff --git a/tests/test_plots.py b/tests/test_plots.py
index 96350a66..86985a35 100644
--- a/tests/test_plots.py
+++ b/tests/test_plots.py
@@ -7,7 +7,7 @@
import plotly.graph_objects as go
import pytest
-from matbench_discovery import Key
+from matbench_discovery.enums import Key
from matbench_discovery.plots import (
Backend,
cumulative_metrics,
@@ -89,14 +89,13 @@ def test_rolling_mae_vs_hull_dist(
ax = plt.figure().gca() # new figure ensures test functions use different axes
kwargs["ax"] = ax
- for model_name in models:
- ax, df_err, df_std = rolling_mae_vs_hull_dist(
- e_above_hull_true=df_wbm[model_name],
- e_above_hull_preds=df_wbm[models],
- x_lim=x_lim,
- show_dft_acc=show_dft_acc,
- **kwargs, # type: ignore[arg-type]
- )
+ ax, df_err, df_std = rolling_mae_vs_hull_dist(
+ e_above_hull_true=df_wbm[models[0]],
+ e_above_hull_preds=df_wbm[models],
+ x_lim=x_lim,
+ show_dft_acc=show_dft_acc,
+ **kwargs, # type: ignore[arg-type]
+ )
assert isinstance(df_err, pd.DataFrame)
assert isinstance(df_std, pd.DataFrame)
diff --git a/tests/test_preds.py b/tests/test_preds.py
index e438e32f..49f73708 100644
--- a/tests/test_preds.py
+++ b/tests/test_preds.py
@@ -2,8 +2,8 @@
import pytest
-from matbench_discovery import Key
from matbench_discovery.data import df_wbm
+from matbench_discovery.enums import Key
from matbench_discovery.preds import (
PRED_FILES,
df_each_err,