From c9dc8ee6623ed121bc671fa1bffb9219ae8ba532 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 1 Mar 2024 14:22:50 +0100 Subject: [PATCH] Add `pyright` pre-commit hook and fix possibly unbound variables (#93) * add pyright pre-commit hook * mv scripts/model_figs/(make_metrics_tables->metrics_tables).py mv scripts/model_figs/(make_hull_dist_box_plot->hull_dist_box_plot).py * fix most pyright PossiblyUnboundVariable * remove __init__.py convenience re-exports of enums * LabelEnum add new to_dict methods val_desc_dict + label_desc_dict renamed key_val_dict, val_label_dict --- .github/workflows/test-scripts.yml | 2 +- .pre-commit-config.yaml | 8 ++++- data/mp/build_phase_diagram.py | 3 +- data/mp/eda_mp_trj.py | 24 +++++++++----- data/mp/get_mp_energies.py | 3 +- data/mp/get_mp_traj.py | 3 +- .../compare_cse_vs_ce_mp_2020_corrections.py | 3 +- data/wbm/compile_wbm_test_set.py | 14 +++++---- matbench_discovery/__init__.py | 18 +++-------- matbench_discovery/data.py | 3 +- matbench_discovery/enums.py | 20 +++++++++--- matbench_discovery/plots.py | 17 +++++----- matbench_discovery/preds.py | 8 ++--- models/alignn/test_alignn.py | 3 +- models/alignn/train_alignn.py | 3 +- models/alignn_ff/alignn_ff_relax.py | 3 +- models/alignn_ff/test_alignn_ff.py | 3 +- models/bowsr/join_bowsr_results.py | 4 +-- models/bowsr/test_bowsr.py | 10 +++--- models/cgcnn/test_cgcnn.py | 3 +- models/cgcnn/train_cgcnn.py | 3 +- models/chgnet/analyze_chgnet.py | 3 +- models/chgnet/ctk_structure_viewer.py | 2 +- models/chgnet/ctk_trajectory_viewer.py | 2 +- models/chgnet/join_chgnet_results.py | 7 ++--- models/chgnet/test_chgnet.py | 3 +- models/gnome/test_gnome.py | 2 +- models/m3gnet/join_m3gnet_results.py | 5 ++- .../m3gnet/pre_vs_post_m3gnet_relaxation.py | 3 +- models/m3gnet/test_m3gnet.py | 3 +- models/mace/analyze_mace.py | 2 +- models/mace/join_mace_results.py | 2 +- models/mace/json_to_extxyz.py | 3 +- models/mace/test_mace.py | 7 +++-- models/mace/train_mace.py | 16 ++++++---- models/megnet/test_megnet.py | 3 +- models/voronoi_rf/join_voronoi_features.py | 2 +- models/voronoi_rf/train_test_voronoi_rf.py | 3 +- .../voronoi_rf/voronoi_featurize_dataset.py | 5 ++- models/wrenformer/analyze_wrenformer.py | 3 +- models/wrenformer/test_wrenformer.py | 3 +- models/wrenformer/train_wrenformer.py | 3 +- pyproject.toml | 18 ++++++++--- scripts/analyze_model_failure_cases.py | 3 +- scripts/compute_struct_fingerprints.py | 3 +- .../hist_classified_stable_vs_hull_dist.py | 3 +- ..._classified_stable_vs_hull_dist_batches.py | 3 +- .../model_figs/analyze_model_disagreement.py | 2 +- scripts/model_figs/compile_model_stats.py | 31 ++++++++++--------- ...dist_box_plot.py => hull_dist_box_plot.py} | 0 ...ke_metrics_tables.py => metrics_tables.py} | 0 scripts/model_figs/parity_energy_models.py | 3 +- scripts/project_compositions.py | 3 +- scripts/rolling_mae_vs_hull_dist.py | 3 +- scripts/update_wandb_runs.py | 3 +- scripts/upload_to_figshare.py | 9 +++--- scripts/wbm_umap_projection.py | 3 +- site/tsconfig.json | 4 +-- tests/conftest.py | 2 +- tests/test_data.py | 3 +- tests/test_plots.py | 17 +++++----- tests/test_preds.py | 2 +- 62 files changed, 210 insertions(+), 145 deletions(-) rename scripts/model_figs/{make_hull_dist_box_plot.py => hull_dist_box_plot.py} (100%) rename scripts/model_figs/{make_metrics_tables.py => metrics_tables.py} (100%) diff --git a/.github/workflows/test-scripts.yml b/.github/workflows/test-scripts.yml index 8b2c9c32..40f91255 100644 --- a/.github/workflows/test-scripts.yml +++ b/.github/workflows/test-scripts.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: script: - - scripts/model_figs/make_metrics_tables.py + - scripts/model_figs/metrics_tables.py - scripts/model_figs/rolling_mae_vs_hull_dist_models.py steps: - name: Check out repository diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b43212bd..e1864521 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -56,7 +56,7 @@ repos: exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$ - repo: https://github.com/pre-commit/mirrors-eslint - rev: v9.0.0-beta.0 + rev: v9.0.0-beta.1 hooks: - id: eslint types: [file] @@ -78,3 +78,9 @@ repos: files: ^models/(.+)/\1.*\.yml$ args: [--schemafile, tests/model-schema.yml] - id: check-github-actions + + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.351 + hooks: + - id: pyright + args: [--level, error] diff --git a/data/mp/build_phase_diagram.py b/data/mp/build_phase_diagram.py index 4b63265f..fdc775f8 100644 --- a/data/mp/build_phase_diagram.py +++ b/data/mp/build_phase_diagram.py @@ -17,9 +17,10 @@ from pymatgen.ext.matproj import MPRester from tqdm import tqdm -from matbench_discovery import MP_DIR, ROOT, Key, today +from matbench_discovery import MP_DIR, ROOT, today from matbench_discovery.data import DATA_FILES from matbench_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries +from matbench_discovery.enums import Key module_dir = os.path.dirname(__file__) diff --git a/data/mp/eda_mp_trj.py b/data/mp/eda_mp_trj.py index 309098a6..52dcea39 100644 --- a/data/mp/eda_mp_trj.py +++ b/data/mp/eda_mp_trj.py @@ -21,8 +21,9 @@ from pymatviz.utils import si_fmt from tqdm import tqdm -from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS, Key +from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS from matbench_discovery.data import DATA_FILES, df_wbm +from matbench_discovery.enums import Key __author__ = "Janosh Riebesell" __date__ = "2023-11-22" @@ -108,10 +109,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]: # %% plot per-element magmom histograms ptable_magmom_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-magmoms.json.bz2" +srs_mp_trj_elem_magmoms = locals().get("srs_mp_trj_elem_magmoms") if os.path.isfile(ptable_magmom_hist_path): srs_mp_trj_elem_magmoms = pd.read_json(ptable_magmom_hist_path, typ="series") -elif "srs_mp_trj_elem_magmoms" not in locals(): +if srs_mp_trj_elem_magmoms is None: # project magmoms onto symbols in dict df_mp_trj_elem_magmom = pd.DataFrame( [ @@ -151,10 +153,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]: # %% plot per-element force histograms ptable_force_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-forces.json.bz2" +srs_mp_trj_elem_forces = locals().get("srs_mp_trj_elem_forces") if os.path.isfile(ptable_force_hist_path): srs_mp_trj_elem_forces = pd.read_json(ptable_force_hist_path, typ="series") -elif "srs_mp_trj_elem_forces" not in locals(): +if srs_mp_trj_elem_forces is None: df_mp_trj_elem_forces = pd.DataFrame( [ dict(zip(elems, np.abs(forces).mean(axis=1))) @@ -193,10 +196,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]: # %% plot histogram of number of sites per element ptable_n_sites_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-n-sites.json.bz2" +srs_mp_trj_elem_n_sites = locals().get("srs_mp_trj_elem_n_sites") if os.path.isfile(ptable_n_sites_hist_path): srs_mp_trj_elem_n_sites = pd.read_json(ptable_n_sites_hist_path, typ="series") -elif "mp_trj_elem_n_sites" not in locals(): +elif srs_mp_trj_elem_n_sites is None: # construct a series of lists of site numbers per element (i.e. how often each # element appears in a structure with n sites) # create all df cols as int dtype @@ -320,8 +324,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]: pdf_kwds = dict(width=500, height=300) x_col, y_col = "Eform (eV/atom)", count_col +df_e_form = locals().get("df_e_form") -if "df_e_form" not in locals(): # only compute once for speed +if df_e_form is None: # only compute once for speed e_form_hist = np.histogram(df_mp_trj[Key.e_form], bins=300) df_e_form = pd.DataFrame(e_form_hist, index=[y_col, x_col]).T.round(3) @@ -340,8 +345,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]: # %% plot forces distribution # use numpy to pre-compute histogram x_col, y_col = "|Forces| (eV/Å)", count_col +df_forces = locals().get("df_forces") -if "df_forces" not in locals(): # only compute once for speed +if df_forces is None: # only compute once for speed forces_hist = np.histogram( df_mp_trj[Key.forces].explode().explode().abs(), bins=300 ) @@ -361,8 +367,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]: # %% plot hydrostatic stress distribution x_col, y_col = "1/3 Tr(σ) (eV/ų)", count_col # noqa: RUF001 +df_stresses = locals().get("df_stresses") -if "df_stresses" not in locals(): # only compute once for speed +if df_stresses is None: # only compute once for speed stresses_hist = np.histogram(df_mp_trj[Key.stress_trace], bins=300) df_stresses = pd.DataFrame(stresses_hist, index=[y_col, x_col]).T.round(3) @@ -381,8 +388,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]: # %% plot magmoms distribution x_col, y_col = "Magmoms (μB)", count_col +df_magmoms = locals().get("df_magmoms") -if "df_magmoms" not in locals(): # only compute once for speed +if df_magmoms is None: # only compute once for speed magmoms_hist = np.histogram(df_mp_trj[Key.magmoms].dropna().explode(), bins=300) df_magmoms = pd.DataFrame(magmoms_hist, index=[y_col, x_col]).T.round(3) diff --git a/data/mp/get_mp_energies.py b/data/mp/get_mp_energies.py index 17bbab32..b87334e6 100644 --- a/data/mp/get_mp_energies.py +++ b/data/mp/get_mp_energies.py @@ -8,8 +8,9 @@ from pymatviz.utils import annotate_metrics from tqdm import tqdm -from matbench_discovery import STABILITY_THRESHOLD, Key, today +from matbench_discovery import STABILITY_THRESHOLD, today from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key """ Download all MP formation and above hull energies on 2023-01-10. diff --git a/data/mp/get_mp_traj.py b/data/mp/get_mp_traj.py index 9dab8cb0..1bcab981 100644 --- a/data/mp/get_mp_traj.py +++ b/data/mp/get_mp_traj.py @@ -17,7 +17,8 @@ from pymongo.database import Database from tqdm import tqdm, trange -from matbench_discovery import ROOT, Key, today +from matbench_discovery import ROOT, today +from matbench_discovery.enums import Key __author__ = "Janosh Riebesell" __date__ = "2023-03-15" diff --git a/data/wbm/compare_cse_vs_ce_mp_2020_corrections.py b/data/wbm/compare_cse_vs_ce_mp_2020_corrections.py index 564cdfa2..e7b31c62 100644 --- a/data/wbm/compare_cse_vs_ce_mp_2020_corrections.py +++ b/data/wbm/compare_cse_vs_ce_mp_2020_corrections.py @@ -10,9 +10,10 @@ from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry from tqdm import tqdm -from matbench_discovery import ROOT, Key, today +from matbench_discovery import ROOT, today from matbench_discovery.data import DATA_FILES, df_wbm from matbench_discovery.energy import get_e_form_per_atom +from matbench_discovery.enums import Key from matbench_discovery.plots import plt """ diff --git a/data/wbm/compile_wbm_test_set.py b/data/wbm/compile_wbm_test_set.py index c4101408..7ac06fe1 100644 --- a/data/wbm/compile_wbm_test_set.py +++ b/data/wbm/compile_wbm_test_set.py @@ -19,17 +19,19 @@ from pymatviz.io import save_fig from tqdm import tqdm -from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, Key, today +from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, today from matbench_discovery.data import DATA_FILES from matbench_discovery.energy import get_e_form_per_atom +from matbench_discovery.enums import Key try: import gdown -except ImportError: - print( +except ImportError as exc: + exc.add_note( "gdown not installed. Needed for downloading WBM initial + relaxed structures " "from Google Drive." ) + raise """ Dataset generated with DFT and published in Jan 2021 as @@ -90,8 +92,8 @@ 18198704957443186264, ) -if "dfs_wbm_structs" not in locals(): - dfs_wbm_structs = {} +dfs_wbm_structs = locals().get("dfs_wbm_structs", {}) + for json_path in json_paths: step = int(json_path.split(".json.bz2")[0][-1]) assert step in range(1, 6) @@ -179,8 +181,8 @@ def increment_wbm_material_id(wbm_id: str) -> str: print(f"{file_path} already exists, skipping") continue + url = f"{mat_cloud_url}&filename={filename}" try: - url = f"{mat_cloud_url}&filename={filename}" urllib.request.urlretrieve(url, file_path) except urllib.error.HTTPError as exc: print(f"failed to download {url=}: {exc}") diff --git a/matbench_discovery/__init__.py b/matbench_discovery/__init__.py index 41bb796a..3dacffaa 100644 --- a/matbench_discovery/__init__.py +++ b/matbench_discovery/__init__.py @@ -13,15 +13,7 @@ import plotly.io as pio import pymatviz # noqa: F401 -from matbench_discovery.enums import ( # noqa: F401 - Key, - Model, - ModelType, - Open, - Quantity, - Targets, - Task, -) +from matbench_discovery.enums import Model, Quantity PKG_NAME = "matbench-discovery" direct_url = Distribution.from_name(PKG_NAME).read_text("direct_url.json") or "{}" @@ -69,7 +61,7 @@ FIGSHARE_URLS = json.load(file) # --- start global plot settings -px.defaults.labels = Quantity.val_dict() | Model.val_dict() +px.defaults.labels = Quantity.key_val_dict() | Model.key_val_dict() global_layout = dict( paper_bgcolor="rgba(0,0,0,0)", @@ -77,9 +69,9 @@ # increase legend marker size and make background transparent legend=dict(itemsizing="constant", bgcolor="rgba(0, 0, 0, 0)"), ) -pio.templates["global"] = dict(layout=global_layout) -pio.templates.default = "pymatviz_dark+global" -px.defaults.template = "pymatviz_dark+global" +pio.templates["mbd_global"] = dict(layout=global_layout) +pio.templates.default = "pymatviz_dark+mbd_global" +px.defaults.template = "pymatviz_dark+mbd_global" # https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924 # when seeing MathJax "loading" message in exported PDFs, diff --git a/matbench_discovery/data.py b/matbench_discovery/data.py index 05e381f7..b0315825 100644 --- a/matbench_discovery/data.py +++ b/matbench_discovery/data.py @@ -14,7 +14,8 @@ from monty.json import MontyDecoder from tqdm import tqdm -from matbench_discovery import FIGSHARE_DIR, Key +from matbench_discovery import FIGSHARE_DIR +from matbench_discovery.enums import Key if TYPE_CHECKING: from pathlib import Path diff --git a/matbench_discovery/enums.py b/matbench_discovery/enums.py index ea8e42a3..82615f3c 100644 --- a/matbench_discovery/enums.py +++ b/matbench_discovery/enums.py @@ -29,14 +29,24 @@ def description(self) -> str: return self.__dict__["desc"] @classmethod - def val_dict(cls) -> dict[str, str]: - """Return the Enum as dictionary.""" + def key_val_dict(cls) -> dict[str, str]: + """Map of keys to values.""" return {key: str(val) for key, val in cls.__members__.items()} @classmethod - def label_dict(cls) -> dict[str, str]: - """Return the Enum as dictionary.""" - return {str(val): val.label for key, val in cls.__members__.items()} + def val_label_dict(cls) -> dict[str, str | None]: + """Map of values to labels.""" + return {str(val): val.label for val in cls.__members__.values()} + + @classmethod + def val_desc_dict(cls) -> dict[str, str | None]: + """Map of values to descriptions.""" + return {str(val): val.description for val in cls.__members__.values()} + + @classmethod + def label_desc_dict(cls) -> dict[str | None, str | None]: + """Map of labels to descriptions.""" + return {str(val.label): val.description for val in cls.__members__.values()} @unique diff --git a/matbench_discovery/plots.py b/matbench_discovery/plots.py index 9bc1fe5a..feba1303 100644 --- a/matbench_discovery/plots.py +++ b/matbench_discovery/plots.py @@ -103,8 +103,10 @@ def hist_classified_stable_vs_hull_dist( fixed in Inkscape or similar by merging regions by color. """ x_col = dict(true=each_true_col, pred=each_pred_col)[which_energy] + clf_col, value_name = "classified", "count" df_plot = pd.DataFrame() + for facet, df_group in ( df.groupby(kwargs["facet_col"]) if "facet_col" in kwargs else [(None, df)] ): @@ -113,16 +115,16 @@ def hist_classified_stable_vs_hull_dist( ) # switch between hist of DFT-computed and model-predicted convex hull distance - e_above_hull = df_group[x_col] - each_true_pos = e_above_hull[true_pos] - each_true_neg = e_above_hull[true_neg] - each_false_neg = e_above_hull[false_neg] - each_false_pos = e_above_hull[false_pos] + srs_each = df_group[x_col] + each_true_pos = srs_each[true_pos] + each_true_neg = srs_each[true_neg] + each_false_neg = srs_each[false_neg] + each_false_pos = srs_each[false_pos] # n_true_pos, n_false_pos, n_true_neg, n_false_neg = map( # sum, (true_pos, false_pos, true_neg, false_neg) # ) - df_group[(clf_col := "classified")] = np.array(clf_labels)[ + df_group[clf_col] = np.array(clf_labels)[ true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3 ] @@ -144,7 +146,6 @@ def hist_classified_stable_vs_hull_dist( index=clf_labels, ).T df_hist[x_col] = bin_edges[:-1] - value_name = "count" df_melt = df_hist.melt( id_vars=x_col, value_vars=clf_labels, @@ -714,6 +715,7 @@ def cumulative_metrics( ) df = dfs[metric] ax.set(ylim=(0, 1), xlim=(0, None), ylabel=metric) + bbox = dict(facecolor="white", alpha=0.5, edgecolor="none") for model in df_preds: # TODO is this really necessary? if len(df[model].dropna()) == 0: @@ -722,7 +724,6 @@ def cumulative_metrics( y_end = df[model].dropna().iloc[-1] # add some visual guidelines to the plot intersect_kwargs = dict(linestyle=":", alpha=0.4, linewidth=2) - bbox = dict(facecolor="white", alpha=0.5, edgecolor="none") # place model name at the end of every line ax.text(x_end, y_end, model, va="bottom", rotation=30, bbox=bbox) if "x" in project_end_point: diff --git a/matbench_discovery/preds.py b/matbench_discovery/preds.py index 4cf52089..5570852e 100644 --- a/matbench_discovery/preds.py +++ b/matbench_discovery/preds.py @@ -5,8 +5,9 @@ import pandas as pd from tqdm import tqdm -from matbench_discovery import ROOT, STABILITY_THRESHOLD, Key, Model +from matbench_discovery import ROOT, STABILITY_THRESHOLD, Model from matbench_discovery.data import Files, df_wbm, glob_to_df +from matbench_discovery.enums import Key from matbench_discovery.metrics import stable_metrics from matbench_discovery.plots import plotly_colors, plotly_line_styles, plotly_markers @@ -65,7 +66,7 @@ class PredFiles(Files): # key_map maps model keys to pretty labels -PRED_FILES = PredFiles(root=f"{ROOT}/models", key_map=Model.val_dict()) +PRED_FILES = PredFiles(root=f"{ROOT}/models", key_map=Model.key_val_dict()) def load_df_wbm_with_preds( @@ -101,7 +102,6 @@ def load_df_wbm_with_preds( ) dfs: dict[str, pd.DataFrame] = {} - try: for model_name in (bar := tqdm(models, disable=not pbar, desc="Loading preds")): bar.set_postfix_str(model_name) @@ -109,7 +109,7 @@ def load_df_wbm_with_preds( df = df.set_index(id_col) dfs[model_name] = df except Exception as exc: - raise RuntimeError(f"Failed to load {model_name=}") from exc + raise RuntimeError(f"Failed to load {locals().get('model_name')=}") from exc from matbench_discovery.data import df_wbm diff --git a/models/alignn/test_alignn.py b/models/alignn/test_alignn.py index cb29bf90..250b0802 100644 --- a/models/alignn/test_alignn.py +++ b/models/alignn/test_alignn.py @@ -17,8 +17,9 @@ from sklearn.metrics import r2_score from tqdm import tqdm -from matbench_discovery import Key, Task, today +from matbench_discovery import today from matbench_discovery.data import DATA_FILES, df_wbm +from matbench_discovery.enums import Key, Task from matbench_discovery.plots import wandb_scatter from matbench_discovery.slurm import slurm_submit diff --git a/models/alignn/train_alignn.py b/models/alignn/train_alignn.py index 5c5c1d81..cbc34286 100644 --- a/models/alignn/train_alignn.py +++ b/models/alignn/train_alignn.py @@ -18,8 +18,9 @@ from torch.utils.data import DataLoader from tqdm import tqdm -from matbench_discovery import Key, today +from matbench_discovery import today from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key from matbench_discovery.slurm import slurm_submit __author__ = "Philipp Benner, Janosh Riebesell" diff --git a/models/alignn_ff/alignn_ff_relax.py b/models/alignn_ff/alignn_ff_relax.py index 4cb8f8c6..1fc6f125 100644 --- a/models/alignn_ff/alignn_ff_relax.py +++ b/models/alignn_ff/alignn_ff_relax.py @@ -9,8 +9,9 @@ from pymatgen.io.jarvis import JarvisAtomsAdaptor from tqdm import tqdm -from matbench_discovery import Key, Task, today +from matbench_discovery import today from matbench_discovery.data import DATA_FILES, df_wbm +from matbench_discovery.enums import Key, Task __author__ = "Janosh Riebesell, Philipp Benner" __date__ = "2023-07-11" diff --git a/models/alignn_ff/test_alignn_ff.py b/models/alignn_ff/test_alignn_ff.py index 467d47b6..c1c9cfb0 100644 --- a/models/alignn_ff/test_alignn_ff.py +++ b/models/alignn_ff/test_alignn_ff.py @@ -18,8 +18,9 @@ from sklearn.metrics import r2_score from tqdm import tqdm -from matbench_discovery import Key, Task, today +from matbench_discovery import today from matbench_discovery.data import DATA_FILES, df_wbm +from matbench_discovery.enums import Key, Task from matbench_discovery.plots import wandb_scatter __author__ = "Philipp Benner, Janosh Riebesell" diff --git a/models/bowsr/join_bowsr_results.py b/models/bowsr/join_bowsr_results.py index e6619e69..08d3df60 100644 --- a/models/bowsr/join_bowsr_results.py +++ b/models/bowsr/join_bowsr_results.py @@ -8,8 +8,8 @@ import pymatviz from tqdm import tqdm -from matbench_discovery import Model, Task -from matbench_discovery.data import DATA_FILES, Key +from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key, Model, Task __author__ = "Janosh Riebesell" __date__ = "2022-09-22" diff --git a/models/bowsr/test_bowsr.py b/models/bowsr/test_bowsr.py index 4769d213..613b38aa 100644 --- a/models/bowsr/test_bowsr.py +++ b/models/bowsr/test_bowsr.py @@ -14,8 +14,9 @@ from pymatgen.core import Structure from tqdm import tqdm -from matbench_discovery import Key, Model, Task, timestamp, today +from matbench_discovery import Model, timestamp, today from matbench_discovery.data import DATA_FILES, as_dict_handler +from matbench_discovery.enums import Key, Task from matbench_discovery.slurm import slurm_submit __author__ = "Janosh Riebesell" @@ -127,11 +128,7 @@ with open(os.devnull, "w") as devnull, contextlib.redirect_stdout(devnull): optimizer.optimize(**optimize_kwargs) - try: - struct_bowsr, energy_bowsr = optimizer.get_optimized_structure_and_energy() - except Exception as exc: - print(f"Failed to relax {material_id}: {exc!r}") - + struct_bowsr, energy_bowsr = optimizer.get_optimized_structure_and_energy() results = { f"e_form_per_atom_bowsr_{energy_model}": model.predict_energy(struct_bowsr), "structure_bowsr": struct_bowsr, @@ -139,6 +136,7 @@ } relax_results[material_id] = results + except Exception as exc: print(f"{material_id=} raised {exc=}") diff --git a/models/cgcnn/test_cgcnn.py b/models/cgcnn/test_cgcnn.py index 7203a7fb..ab35c87d 100644 --- a/models/cgcnn/test_cgcnn.py +++ b/models/cgcnn/test_cgcnn.py @@ -14,8 +14,9 @@ from torch.utils.data import DataLoader from tqdm import tqdm -from matbench_discovery import CHECKPOINT_DIR, WANDB_PATH, WBM_DIR, Key, Task, today +from matbench_discovery import CHECKPOINT_DIR, WANDB_PATH, WBM_DIR, today from matbench_discovery.data import DATA_FILES, df_wbm +from matbench_discovery.enums import Key, Task from matbench_discovery.plots import wandb_scatter from matbench_discovery.slurm import slurm_submit diff --git a/models/cgcnn/train_cgcnn.py b/models/cgcnn/train_cgcnn.py index 974b420e..6ed885b3 100644 --- a/models/cgcnn/train_cgcnn.py +++ b/models/cgcnn/train_cgcnn.py @@ -11,8 +11,9 @@ from torch.utils.data import DataLoader from tqdm import tqdm, trange -from matbench_discovery import WANDB_PATH, Key, timestamp, today +from matbench_discovery import WANDB_PATH, timestamp, today from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key from matbench_discovery.slurm import slurm_submit from matbench_discovery.structure import perturb_structure diff --git a/models/chgnet/analyze_chgnet.py b/models/chgnet/analyze_chgnet.py index 0dde7109..a3a90b1d 100644 --- a/models/chgnet/analyze_chgnet.py +++ b/models/chgnet/analyze_chgnet.py @@ -10,9 +10,10 @@ from pymatviz import density_scatter, plot_structure_2d, ptable_heatmap_plotly from pymatviz.io import save_fig -from matbench_discovery import PDF_FIGS, Key +from matbench_discovery import PDF_FIGS from matbench_discovery import plots as plots from matbench_discovery.data import DATA_FILES, df_wbm +from matbench_discovery.enums import Key from matbench_discovery.preds import PRED_FILES, df_preds __author__ = "Janosh Riebesell" diff --git a/models/chgnet/ctk_structure_viewer.py b/models/chgnet/ctk_structure_viewer.py index 2d40a402..52fc3ead 100644 --- a/models/chgnet/ctk_structure_viewer.py +++ b/models/chgnet/ctk_structure_viewer.py @@ -3,7 +3,7 @@ import pandas as pd from crystal_toolkit.helpers.utils import hook_up_fig_with_struct_viewer -from matbench_discovery import Key +from matbench_discovery.enums import Key from matbench_discovery.preds import PRED_FILES __author__ = "Janosh Riebesell" diff --git a/models/chgnet/ctk_trajectory_viewer.py b/models/chgnet/ctk_trajectory_viewer.py index c53a3276..83d2004b 100644 --- a/models/chgnet/ctk_trajectory_viewer.py +++ b/models/chgnet/ctk_trajectory_viewer.py @@ -16,8 +16,8 @@ from m3gnet.models import Relaxer as M3gnetRelaxer from pymatgen.core import Lattice, Structure -from matbench_discovery import Key from matbench_discovery.data import df_wbm +from matbench_discovery.enums import Key if TYPE_CHECKING: from chgnet.model.dynamics import TrajectoryObserver diff --git a/models/chgnet/join_chgnet_results.py b/models/chgnet/join_chgnet_results.py index 7fd971f2..5ee7495f 100644 --- a/models/chgnet/join_chgnet_results.py +++ b/models/chgnet/join_chgnet_results.py @@ -13,18 +13,17 @@ from pymatviz import density_scatter from tqdm import tqdm -from matbench_discovery import Key, Task from matbench_discovery.data import as_dict_handler from matbench_discovery.energy import get_e_form_per_atom +from matbench_discovery.enums import Key, Task from matbench_discovery.preds import df_preds __author__ = "Janosh Riebesell" __date__ = "2023-03-01" -df_chgnet = pd.read_csv("2023-12-05-chgnet-0.3.0-wbm-IS2RE-static.csv.gz").set_index( - Key.mat_id -) +csv_path = "2023-12-05-chgnet-0.3.0-wbm-IS2RE-static.csv.gz" +df_chgnet = pd.read_csv(csv_path).set_index(Key.mat_id) # %% diff --git a/models/chgnet/test_chgnet.py b/models/chgnet/test_chgnet.py index 1a0193eb..f53e9745 100644 --- a/models/chgnet/test_chgnet.py +++ b/models/chgnet/test_chgnet.py @@ -21,8 +21,9 @@ from pymatgen.core import Structure from tqdm import tqdm -from matbench_discovery import Key, Task, timestamp, today +from matbench_discovery import timestamp, today from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm +from matbench_discovery.enums import Key, Task from matbench_discovery.plots import wandb_scatter from matbench_discovery.slurm import slurm_submit diff --git a/models/gnome/test_gnome.py b/models/gnome/test_gnome.py index 1f12d54b..f615a005 100644 --- a/models/gnome/test_gnome.py +++ b/models/gnome/test_gnome.py @@ -1,7 +1,7 @@ # %% from pymatviz import density_scatter -from matbench_discovery import Key, Model +from matbench_discovery.enums import Key, Model from matbench_discovery.preds import df_preds __author__ = "Janosh Riebesell" diff --git a/models/m3gnet/join_m3gnet_results.py b/models/m3gnet/join_m3gnet_results.py index 68614fa6..608770e7 100644 --- a/models/m3gnet/join_m3gnet_results.py +++ b/models/m3gnet/join_m3gnet_results.py @@ -16,9 +16,9 @@ from pymatgen.entries.computed_entries import ComputedStructureEntry from tqdm import tqdm -from matbench_discovery import Key, Task from matbench_discovery.data import DATA_FILES, as_dict_handler from matbench_discovery.energy import get_e_form_per_atom +from matbench_discovery.enums import Key, Task __author__ = "Janosh Riebesell" __date__ = "2022-08-16" @@ -34,8 +34,7 @@ print(f"Found {len(file_paths):,} files for {glob_pattern = }") # prevent accidental overwrites -if "dfs" not in locals(): - dfs: dict[str, pd.DataFrame] = {} +dfs: dict[str, pd.DataFrame] = locals().get("dfs", {}) # %% diff --git a/models/m3gnet/pre_vs_post_m3gnet_relaxation.py b/models/m3gnet/pre_vs_post_m3gnet_relaxation.py index 8da7d4dd..3a611a1a 100644 --- a/models/m3gnet/pre_vs_post_m3gnet_relaxation.py +++ b/models/m3gnet/pre_vs_post_m3gnet_relaxation.py @@ -12,8 +12,9 @@ from pymatviz.utils import add_identity_line from sklearn.metrics import r2_score -from matbench_discovery import ROOT, SITE_FIGS, Key, plots +from matbench_discovery import ROOT, SITE_FIGS, plots from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key __author__ = "Janosh Riebesell" __date__ = "2022-06-18" diff --git a/models/m3gnet/test_m3gnet.py b/models/m3gnet/test_m3gnet.py index e39185d8..59fa5f6c 100644 --- a/models/m3gnet/test_m3gnet.py +++ b/models/m3gnet/test_m3gnet.py @@ -20,8 +20,9 @@ from pymatgen.core import Structure from tqdm import tqdm -from matbench_discovery import ROOT, Key, Task, timestamp, today +from matbench_discovery import ROOT, timestamp, today from matbench_discovery.data import DATA_FILES, as_dict_handler +from matbench_discovery.enums import Key, Task from matbench_discovery.slurm import slurm_submit __author__ = "Janosh Riebesell" diff --git a/models/mace/analyze_mace.py b/models/mace/analyze_mace.py index 6fae2ca7..408b4e06 100644 --- a/models/mace/analyze_mace.py +++ b/models/mace/analyze_mace.py @@ -8,9 +8,9 @@ from pymatviz import density_scatter, ptable_heatmap_plotly, spacegroup_sunburst from pymatviz.io import save_fig -from matbench_discovery import Key from matbench_discovery import plots as plots from matbench_discovery.data import df_wbm +from matbench_discovery.enums import Key from matbench_discovery.preds import PRED_FILES __author__ = "Janosh Riebesell" diff --git a/models/mace/join_mace_results.py b/models/mace/join_mace_results.py index d0fafa1e..ec37c5e1 100644 --- a/models/mace/join_mace_results.py +++ b/models/mace/join_mace_results.py @@ -16,9 +16,9 @@ from pymatviz import density_scatter from tqdm import tqdm -from matbench_discovery import Key, Task from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm from matbench_discovery.energy import get_e_form_per_atom +from matbench_discovery.enums import Key, Task __author__ = "Janosh Riebesell" __date__ = "2023-03-01" diff --git a/models/mace/json_to_extxyz.py b/models/mace/json_to_extxyz.py index a583d957..a8a0091d 100644 --- a/models/mace/json_to_extxyz.py +++ b/models/mace/json_to_extxyz.py @@ -12,7 +12,8 @@ from pymatviz.io import TqdmDownload from tqdm import tqdm -from matbench_discovery import FIGSHARE_URLS, MP_DIR, Key +from matbench_discovery import FIGSHARE_URLS, MP_DIR +from matbench_discovery.enums import Key __author__ = "Yuan Chiang" __date__ = "2023-08-10" diff --git a/models/mace/test_mace.py b/models/mace/test_mace.py index c2762bd7..bf7051ab 100644 --- a/models/mace/test_mace.py +++ b/models/mace/test_mace.py @@ -18,8 +18,9 @@ from pymatgen.io.ase import AseAtomsAdaptor from tqdm import tqdm -from matbench_discovery import ROOT, Key, Task, timestamp, today +from matbench_discovery import ROOT, timestamp, today from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm +from matbench_discovery.enums import Key, Task from matbench_discovery.plots import wandb_scatter from matbench_discovery.slurm import slurm_submit @@ -138,7 +139,9 @@ ) relax_results[material_id] = {"structure": mace_struct, "energy": mace_energy} - if record_traj and len(coords) > 0: + + coords, lattices = (locals().get(key, []) for key in ("coords", "lattices")) + if record_traj and coords and lattices: mace_traj = Trajectory( species=structs[material_id].species, coords=coords, diff --git a/models/mace/train_mace.py b/models/mace/train_mace.py index 3d051f59..0cce4e03 100644 --- a/models/mace/train_mace.py +++ b/models/mace/train_mace.py @@ -7,12 +7,13 @@ from typing import Any import mace +import mace.data import numpy as np import torch.distributed import torch.nn.functional from e3nn import o3 -from mace import data, modules, tools -from mace.data import HDF5Dataset +from mace import modules, tools +from mace.mace.data import HDF5Dataset from mace.tools import torch_geometric from mace.tools.scripts_utils import ( LRScheduler, @@ -87,6 +88,7 @@ def main(**kwargs: Any) -> None: torch.distributed.init_process_group(backend="nccl") else: rank = 0 + local_rank = world_size = 1 # Setup tools.set_seeds(args.seed) @@ -208,11 +210,11 @@ def main(**kwargs: Any) -> None: if args.train_file.endswith(".xyz"): train_set = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + mace.data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) for config in collections.train ] valid_set = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + mace.data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) for config in collections.valid ] else: @@ -221,7 +223,7 @@ def main(**kwargs: Any) -> None: train_sampler, valid_sampler = None, None if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler( + train_sampler = torch.utils.mace.data.distributed.DistributedSampler( train_set, num_replicas=world_size, rank=rank, @@ -613,7 +615,9 @@ def main(**kwargs: Any) -> None: if args.train_file.endswith(".xyz"): for name, subset in collections.tests: test_sets[name] = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + mace.data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max + ) for config in subset ] else: diff --git a/models/megnet/test_megnet.py b/models/megnet/test_megnet.py index ce040ccc..9cf352f0 100644 --- a/models/megnet/test_megnet.py +++ b/models/megnet/test_megnet.py @@ -20,8 +20,9 @@ from sklearn.metrics import r2_score from tqdm import tqdm -from matbench_discovery import Key, Task, timestamp, today +from matbench_discovery import timestamp, today from matbench_discovery.data import DATA_FILES, df_wbm +from matbench_discovery.enums import Key, Task from matbench_discovery.plots import wandb_scatter from matbench_discovery.preds import PRED_FILES from matbench_discovery.slurm import slurm_submit diff --git a/models/voronoi_rf/join_voronoi_features.py b/models/voronoi_rf/join_voronoi_features.py index 26c153aa..9c1bc7c6 100644 --- a/models/voronoi_rf/join_voronoi_features.py +++ b/models/voronoi_rf/join_voronoi_features.py @@ -12,7 +12,7 @@ import pandas as pd from tqdm import tqdm -from matbench_discovery import Key +from matbench_discovery.enums import Key __author__ = "Janosh Riebesell" __date__ = "2022-08-16" diff --git a/models/voronoi_rf/train_test_voronoi_rf.py b/models/voronoi_rf/train_test_voronoi_rf.py index 5e598ce4..7ceea672 100644 --- a/models/voronoi_rf/train_test_voronoi_rf.py +++ b/models/voronoi_rf/train_test_voronoi_rf.py @@ -13,8 +13,9 @@ from sklearn.metrics import r2_score from sklearn.pipeline import Pipeline -from matbench_discovery import ROOT, Key, Task, today +from matbench_discovery import ROOT, today from matbench_discovery.data import DATA_FILES, df_wbm, glob_to_df +from matbench_discovery.enums import Key, Task from matbench_discovery.plots import wandb_scatter from matbench_discovery.slurm import slurm_submit diff --git a/models/voronoi_rf/voronoi_featurize_dataset.py b/models/voronoi_rf/voronoi_featurize_dataset.py index 17be9c1f..25530dd1 100644 --- a/models/voronoi_rf/voronoi_featurize_dataset.py +++ b/models/voronoi_rf/voronoi_featurize_dataset.py @@ -14,8 +14,9 @@ from pymatgen.core import Structure from tqdm import tqdm -from matbench_discovery import ROOT, Key, today +from matbench_discovery import ROOT, today from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key from matbench_discovery.slurm import slurm_submit sys.path.append(f"{ROOT}/models") @@ -71,6 +72,8 @@ struct_dicts = [cse["structure"] for cse in df_in[Key.cse]] elif data_name == "wbm" and input_col == Key.init_struct: struct_dicts = df_in[Key.init_struct] +else: + raise ValueError(f"Invalid {data_name=}, {input_col=} combo") df_in[input_col] = [ Structure.from_dict(dct) for dct in tqdm(struct_dicts, disable=None) diff --git a/models/wrenformer/analyze_wrenformer.py b/models/wrenformer/analyze_wrenformer.py index ecd87723..307cfbe4 100644 --- a/models/wrenformer/analyze_wrenformer.py +++ b/models/wrenformer/analyze_wrenformer.py @@ -10,8 +10,9 @@ from pymatviz.ptable import ptable_heatmap_plotly from pymatviz.utils import add_identity_line, bin_df_cols -from matbench_discovery import PDF_FIGS, SITE_FIGS, Key, Model +from matbench_discovery import PDF_FIGS, SITE_FIGS, Model from matbench_discovery.data import DATA_FILES, df_wbm +from matbench_discovery.enums import Key from matbench_discovery.preds import df_each_pred, df_preds __author__ = "Janosh Riebesell" diff --git a/models/wrenformer/test_wrenformer.py b/models/wrenformer/test_wrenformer.py index 6941c222..78a6f03b 100644 --- a/models/wrenformer/test_wrenformer.py +++ b/models/wrenformer/test_wrenformer.py @@ -17,8 +17,9 @@ from aviary.wrenformer.data import df_to_in_mem_dataloader from aviary.wrenformer.model import Wrenformer -from matbench_discovery import CHECKPOINT_DIR, WANDB_PATH, Key, Task, today +from matbench_discovery import CHECKPOINT_DIR, WANDB_PATH, today from matbench_discovery.data import df_wbm +from matbench_discovery.enums import Key, Task from matbench_discovery.plots import wandb_scatter from matbench_discovery.slurm import slurm_submit diff --git a/models/wrenformer/train_wrenformer.py b/models/wrenformer/train_wrenformer.py index 4f6b60cb..c042b303 100644 --- a/models/wrenformer/train_wrenformer.py +++ b/models/wrenformer/train_wrenformer.py @@ -8,8 +8,9 @@ import pandas as pd from aviary.train import df_train_test_split, train_wrenformer -from matbench_discovery import WANDB_PATH, Key, timestamp, today +from matbench_discovery import WANDB_PATH, timestamp, today from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key from matbench_discovery.slurm import slurm_submit __author__ = "Janosh Riebesell" diff --git a/pyproject.toml b/pyproject.toml index 6a15952c..861bff66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,8 +78,10 @@ universal = true [tool.ruff] target-version = "py39" -lint.select = ["ALL"] -lint.ignore = [ + +[tool.ruff.lint] +select = ["ALL"] +ignore = [ "ANN101", "ANN102", "ANN401", @@ -115,9 +117,9 @@ lint.ignore = [ "TRY003", "TRY301", ] -lint.pydocstyle.convention = "google" -lint.isort.known-third-party = ["wandb"] -lint.isort.split-on-trailing-comma = false +pydocstyle.convention = "google" +isort.known-third-party = ["wandb"] +isort.split-on-trailing-comma = false [tool.ruff.lint.per-file-ignores] "tests/*" = ["D", "S101"] @@ -145,3 +147,9 @@ markers = [ "slow: deselect slow tests with -m 'not slow'", "very_slow: select with -m 'very_slow'", ] + +[tool.pyright] +typeCheckingMode = "off" +reportPossiblyUnboundVariable = true +reportUnboundVariable = true +reportMissingImports = false diff --git a/scripts/analyze_model_failure_cases.py b/scripts/analyze_model_failure_cases.py index 15f92fa0..e9dfa0c0 100644 --- a/scripts/analyze_model_failure_cases.py +++ b/scripts/analyze_model_failure_cases.py @@ -16,8 +16,9 @@ from pymatviz.io import save_fig from tqdm import tqdm -from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, Key +from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR from matbench_discovery.data import DATA_FILES, df_wbm +from matbench_discovery.enums import Key from matbench_discovery.metrics import classify_stable from matbench_discovery.preds import df_each_err, df_each_pred, df_metrics, df_preds diff --git a/scripts/compute_struct_fingerprints.py b/scripts/compute_struct_fingerprints.py index 00322869..50d097b4 100644 --- a/scripts/compute_struct_fingerprints.py +++ b/scripts/compute_struct_fingerprints.py @@ -15,8 +15,9 @@ from pymatgen.core import Structure from tqdm import tqdm -from matbench_discovery import DATA_DIR, Key, timestamp +from matbench_discovery import DATA_DIR, timestamp from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key from matbench_discovery.slurm import slurm_submit __author__ = "Janosh Riebesell" diff --git a/scripts/hist_classified_stable_vs_hull_dist.py b/scripts/hist_classified_stable_vs_hull_dist.py index d32e1b82..20e3e2a0 100644 --- a/scripts/hist_classified_stable_vs_hull_dist.py +++ b/scripts/hist_classified_stable_vs_hull_dist.py @@ -10,8 +10,9 @@ from pymatviz.io import save_fig -from matbench_discovery import PDF_FIGS, Key +from matbench_discovery import PDF_FIGS from matbench_discovery.data import df_wbm +from matbench_discovery.enums import Key from matbench_discovery.plots import hist_classified_stable_vs_hull_dist from matbench_discovery.preds import df_each_pred diff --git a/scripts/hist_classified_stable_vs_hull_dist_batches.py b/scripts/hist_classified_stable_vs_hull_dist_batches.py index 1c226353..84e4cc0c 100644 --- a/scripts/hist_classified_stable_vs_hull_dist_batches.py +++ b/scripts/hist_classified_stable_vs_hull_dist_batches.py @@ -12,7 +12,8 @@ import pandas as pd from pymatviz.io import save_fig -from matbench_discovery import PDF_FIGS, Key +from matbench_discovery import PDF_FIGS +from matbench_discovery.enums import Key from matbench_discovery.plots import hist_classified_stable_vs_hull_dist from matbench_discovery.preds import df_preds diff --git a/scripts/model_figs/analyze_model_disagreement.py b/scripts/model_figs/analyze_model_disagreement.py index 18f7d6f4..9faaaa74 100644 --- a/scripts/model_figs/analyze_model_disagreement.py +++ b/scripts/model_figs/analyze_model_disagreement.py @@ -46,7 +46,7 @@ "hydrides": r".*H\d.*", "oxynitrides": r".*[ON]\d.*", } -n_structs = 200 +n_structs, fig = 200, None for material_cls, pattern in material_classes.items(): df_subset = df_preds[df_preds[Key.formula].str.match(pattern)] diff --git a/scripts/model_figs/compile_model_stats.py b/scripts/model_figs/compile_model_stats.py index b2805009..f5dfffca 100644 --- a/scripts/model_figs/compile_model_stats.py +++ b/scripts/model_figs/compile_model_stats.py @@ -62,6 +62,7 @@ # trained from scratch. Their run times only indicate the time needed to predict the # test set. +time_col = "Run Time (h)" for label, stats, raw_filters in ( ("train", train_stats, train_run_filters), ("test", test_stats, test_run_filters), @@ -98,7 +99,7 @@ n_gpu, n_cpu = metadata.get("gpu_count", 0), metadata.get("cpu_count", 0) stats[model] = { - (time_col := "Run Time (h)"): run_time_total / 3600, + time_col: run_time_total / 3600, "GPU": n_gpu, "CPU": n_cpu, "Slurm Jobs": n_runs, @@ -114,18 +115,16 @@ # %% -for df_tmp, label in ( - (df_metrics, ""), - (df_metrics_10k, "-10k"), - (df_metrics_uniq_protos, "-uniq-protos"), -): - df_tmp = pd.concat( - [ - df_tmp, - pd.DataFrame(train_stats).add_prefix("Train ", axis="index"), - pd.DataFrame(test_stats).add_prefix("Test ", axis="index"), - ], - ).T +stats_dict = { + "": df_metrics, + "-10k": df_metrics_10k, + "-uniq-protos": df_metrics_uniq_protos, +} +for label, df_tmp in stats_dict.items(): + df_train_stats = pd.DataFrame(train_stats).add_prefix("Train ", axis="index") + df_test_stats = pd.DataFrame(test_stats).add_prefix("Test ", axis="index") + df_tmp = pd.concat([df_tmp, df_train_stats, df_test_stats]).T + df_tmp[time_col] = df_tmp.filter(like=time_col).sum(axis="columns") # write model metrics to json for website use @@ -136,9 +135,11 @@ df_tmp.attrs["All Models Run Time"] = df_tmp[time_col].sum() + # write stats for different data subsets to JSON df_tmp.round(2).to_json(f"{SITE_LIB}/model-stats{label}.json", orient="index") - if label == "": - df_stats = df_tmp + stats_dict[label] = df_tmp + +df_stats = stats_dict[""] # %% diff --git a/scripts/model_figs/make_hull_dist_box_plot.py b/scripts/model_figs/hull_dist_box_plot.py similarity index 100% rename from scripts/model_figs/make_hull_dist_box_plot.py rename to scripts/model_figs/hull_dist_box_plot.py diff --git a/scripts/model_figs/make_metrics_tables.py b/scripts/model_figs/metrics_tables.py similarity index 100% rename from scripts/model_figs/make_metrics_tables.py rename to scripts/model_figs/metrics_tables.py diff --git a/scripts/model_figs/parity_energy_models.py b/scripts/model_figs/parity_energy_models.py index cbb70555..bdc3fdad 100644 --- a/scripts/model_figs/parity_energy_models.py +++ b/scripts/model_figs/parity_energy_models.py @@ -28,7 +28,8 @@ if which_energy == "each": e_pred_col = Key.each_pred e_true_col = Key.each_true -if which_energy == "e-form": +else: + assert which_energy == "e-form", f"Invalid {which_energy=}" e_true_col = Key.e_form e_pred_col = Key.e_form_pred diff --git a/scripts/project_compositions.py b/scripts/project_compositions.py index 9b3ad095..8f2be1cb 100644 --- a/scripts/project_compositions.py +++ b/scripts/project_compositions.py @@ -11,8 +11,9 @@ from pymatgen.core import Composition from tqdm import tqdm -from matbench_discovery import DATA_DIR, Key +from matbench_discovery import DATA_DIR from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key from matbench_discovery.slurm import slurm_submit __author__ = "Janosh Riebesell" diff --git a/scripts/rolling_mae_vs_hull_dist.py b/scripts/rolling_mae_vs_hull_dist.py index 6edbcc82..bb8fe230 100644 --- a/scripts/rolling_mae_vs_hull_dist.py +++ b/scripts/rolling_mae_vs_hull_dist.py @@ -4,7 +4,8 @@ # %% from pymatviz.io import save_fig -from matbench_discovery import PDF_FIGS, SITE_FIGS, Key, Model +from matbench_discovery import PDF_FIGS, SITE_FIGS, Model +from matbench_discovery.enums import Key from matbench_discovery.plots import rolling_mae_vs_hull_dist from matbench_discovery.preds import df_each_pred, df_metrics, df_wbm diff --git a/scripts/update_wandb_runs.py b/scripts/update_wandb_runs.py index cec11aa2..c42f7da8 100644 --- a/scripts/update_wandb_runs.py +++ b/scripts/update_wandb_runs.py @@ -8,7 +8,8 @@ import wandb from wandb.wandb_run import Run -from matbench_discovery import WANDB_PATH, Task +from matbench_discovery import WANDB_PATH +from matbench_discovery.enums import Task __author__ = "Janosh Riebesell" __date__ = "2022-09-21" diff --git a/scripts/upload_to_figshare.py b/scripts/upload_to_figshare.py index 6113fbec..387128f7 100644 --- a/scripts/upload_to_figshare.py +++ b/scripts/upload_to_figshare.py @@ -158,11 +158,12 @@ def main(pyproject: dict[str, Any], urls_json_path: str) -> int: with open(urls_json_path, "w") as file: json.dump(figshare_urls, file) except Exception as exc: # prompt to delete article if something went wrong - answer = "" - print(f"Encountered {exc=} for {file_path=}") - while answer not in ("y", "n"): + if file_path := str(locals().get("file_path", "")): + exc.add_note(f"{file_path=}") + answer, article_id = "", int(locals().get("article_id", 0)) + while article_id and answer.lower() not in ("y", "n"): answer = input("Delete article? [y/n] ") - if answer == "y": + if answer.lower() == "y": make_request("DELETE", f"{BASE_URL}/account/articles/{article_id}") return 0 diff --git a/scripts/wbm_umap_projection.py b/scripts/wbm_umap_projection.py index 6eccb4db..aef163d3 100644 --- a/scripts/wbm_umap_projection.py +++ b/scripts/wbm_umap_projection.py @@ -23,8 +23,9 @@ from pymatviz.io import save_fig from tqdm import tqdm -from matbench_discovery import MP_DIR, PDF_FIGS, WBM_DIR, Key +from matbench_discovery import MP_DIR, PDF_FIGS, WBM_DIR from matbench_discovery.data import DATA_FILES +from matbench_discovery.enums import Key __author__ = "Philipp Benner, Janosh Riebesell" __date__ = "2023-11-28" diff --git a/site/tsconfig.json b/site/tsconfig.json index 80cd5a98..ed06dd89 100644 --- a/site/tsconfig.json +++ b/site/tsconfig.json @@ -12,6 +12,6 @@ "forceConsistentCasingInFileNames": true, "resolveJsonModule": true, - "allowSyntheticDefaultImports": true, - }, + "allowSyntheticDefaultImports": true + } } diff --git a/tests/conftest.py b/tests/conftest.py index 589a8a7b..73345ad1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from pymatgen.core import Lattice, Structure from pymatgen.entries.computed_entries import ComputedStructureEntry -from matbench_discovery import Key +from matbench_discovery.enums import Key @pytest.fixture() diff --git a/tests/test_data.py b/tests/test_data.py index 7ba8cd0b..8edc421e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -11,7 +11,7 @@ from pymatgen.core import Lattice, Structure from pytest import CaptureFixture -from matbench_discovery import FIGSHARE_DIR, ROOT, Key +from matbench_discovery import FIGSHARE_DIR, ROOT from matbench_discovery.data import ( DATA_FILES, as_dict_handler, @@ -20,6 +20,7 @@ glob_to_df, load, ) +from matbench_discovery.enums import Key if TYPE_CHECKING: from pathlib import Path diff --git a/tests/test_plots.py b/tests/test_plots.py index 96350a66..86985a35 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -7,7 +7,7 @@ import plotly.graph_objects as go import pytest -from matbench_discovery import Key +from matbench_discovery.enums import Key from matbench_discovery.plots import ( Backend, cumulative_metrics, @@ -89,14 +89,13 @@ def test_rolling_mae_vs_hull_dist( ax = plt.figure().gca() # new figure ensures test functions use different axes kwargs["ax"] = ax - for model_name in models: - ax, df_err, df_std = rolling_mae_vs_hull_dist( - e_above_hull_true=df_wbm[model_name], - e_above_hull_preds=df_wbm[models], - x_lim=x_lim, - show_dft_acc=show_dft_acc, - **kwargs, # type: ignore[arg-type] - ) + ax, df_err, df_std = rolling_mae_vs_hull_dist( + e_above_hull_true=df_wbm[models[0]], + e_above_hull_preds=df_wbm[models], + x_lim=x_lim, + show_dft_acc=show_dft_acc, + **kwargs, # type: ignore[arg-type] + ) assert isinstance(df_err, pd.DataFrame) assert isinstance(df_std, pd.DataFrame) diff --git a/tests/test_preds.py b/tests/test_preds.py index e438e32f..49f73708 100644 --- a/tests/test_preds.py +++ b/tests/test_preds.py @@ -2,8 +2,8 @@ import pytest -from matbench_discovery import Key from matbench_discovery.data import df_wbm +from matbench_discovery.enums import Key from matbench_discovery.preds import ( PRED_FILES, df_each_err,