Skip to content

Commit

Permalink
Add pyright pre-commit hook and fix possibly unbound variables (#93)
Browse files Browse the repository at this point in the history
* add pyright pre-commit hook

* mv scripts/model_figs/(make_metrics_tables->metrics_tables).py

mv scripts/model_figs/(make_hull_dist_box_plot->hull_dist_box_plot).py

* fix most pyright PossiblyUnboundVariable

* remove __init__.py convenience re-exports of enums

* LabelEnum add new to_dict methods val_desc_dict + label_desc_dict

renamed key_val_dict, val_label_dict
  • Loading branch information
janosh authored Mar 1, 2024
1 parent e596392 commit c9dc8ee
Show file tree
Hide file tree
Showing 62 changed files with 210 additions and 145 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
script:
- scripts/model_figs/make_metrics_tables.py
- scripts/model_figs/metrics_tables.py
- scripts/model_figs/rolling_mae_vs_hull_dist_models.py
steps:
- name: Check out repository
Expand Down
8 changes: 7 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.0.0-beta.0
rev: v9.0.0-beta.1
hooks:
- id: eslint
types: [file]
Expand All @@ -78,3 +78,9 @@ repos:
files: ^models/(.+)/\1.*\.yml$
args: [--schemafile, tests/model-schema.yml]
- id: check-github-actions

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.351
hooks:
- id: pyright
args: [--level, error]
3 changes: 2 additions & 1 deletion data/mp/build_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from pymatgen.ext.matproj import MPRester
from tqdm import tqdm

from matbench_discovery import MP_DIR, ROOT, Key, today
from matbench_discovery import MP_DIR, ROOT, today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries
from matbench_discovery.enums import Key

module_dir = os.path.dirname(__file__)

Expand Down
24 changes: 16 additions & 8 deletions data/mp/eda_mp_trj.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from pymatviz.utils import si_fmt
from tqdm import tqdm

from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS, Key
from matbench_discovery import MP_DIR, PDF_FIGS, ROOT, SITE_FIGS
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.enums import Key

__author__ = "Janosh Riebesell"
__date__ = "2023-11-22"
Expand Down Expand Up @@ -108,10 +109,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:

# %% plot per-element magmom histograms
ptable_magmom_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-magmoms.json.bz2"
srs_mp_trj_elem_magmoms = locals().get("srs_mp_trj_elem_magmoms")

if os.path.isfile(ptable_magmom_hist_path):
srs_mp_trj_elem_magmoms = pd.read_json(ptable_magmom_hist_path, typ="series")
elif "srs_mp_trj_elem_magmoms" not in locals():
if srs_mp_trj_elem_magmoms is None:
# project magmoms onto symbols in dict
df_mp_trj_elem_magmom = pd.DataFrame(
[
Expand Down Expand Up @@ -151,10 +153,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:

# %% plot per-element force histograms
ptable_force_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-forces.json.bz2"
srs_mp_trj_elem_forces = locals().get("srs_mp_trj_elem_forces")

if os.path.isfile(ptable_force_hist_path):
srs_mp_trj_elem_forces = pd.read_json(ptable_force_hist_path, typ="series")
elif "srs_mp_trj_elem_forces" not in locals():
if srs_mp_trj_elem_forces is None:
df_mp_trj_elem_forces = pd.DataFrame(
[
dict(zip(elems, np.abs(forces).mean(axis=1)))
Expand Down Expand Up @@ -193,10 +196,11 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:

# %% plot histogram of number of sites per element
ptable_n_sites_hist_path = f"{MP_DIR}/mp-trj-2022-09-elem-n-sites.json.bz2"
srs_mp_trj_elem_n_sites = locals().get("srs_mp_trj_elem_n_sites")

if os.path.isfile(ptable_n_sites_hist_path):
srs_mp_trj_elem_n_sites = pd.read_json(ptable_n_sites_hist_path, typ="series")
elif "mp_trj_elem_n_sites" not in locals():
elif srs_mp_trj_elem_n_sites is None:
# construct a series of lists of site numbers per element (i.e. how often each
# element appears in a structure with n sites)
# create all df cols as int dtype
Expand Down Expand Up @@ -320,8 +324,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
pdf_kwds = dict(width=500, height=300)

x_col, y_col = "E<sub>form</sub> (eV/atom)", count_col
df_e_form = locals().get("df_e_form")

if "df_e_form" not in locals(): # only compute once for speed
if df_e_form is None: # only compute once for speed
e_form_hist = np.histogram(df_mp_trj[Key.e_form], bins=300)
df_e_form = pd.DataFrame(e_form_hist, index=[y_col, x_col]).T.round(3)

Expand All @@ -340,8 +345,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:
# %% plot forces distribution
# use numpy to pre-compute histogram
x_col, y_col = "|Forces| (eV/Å)", count_col
df_forces = locals().get("df_forces")

if "df_forces" not in locals(): # only compute once for speed
if df_forces is None: # only compute once for speed
forces_hist = np.histogram(
df_mp_trj[Key.forces].explode().explode().abs(), bins=300
)
Expand All @@ -361,8 +367,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:

# %% plot hydrostatic stress distribution
x_col, y_col = "1/3 Tr(σ) (eV/ų)", count_col # noqa: RUF001
df_stresses = locals().get("df_stresses")

if "df_stresses" not in locals(): # only compute once for speed
if df_stresses is None: # only compute once for speed
stresses_hist = np.histogram(df_mp_trj[Key.stress_trace], bins=300)
df_stresses = pd.DataFrame(stresses_hist, index=[y_col, x_col]).T.round(3)

Expand All @@ -381,8 +388,9 @@ def tile_count_anno(hist_vals: list[Any]) -> dict[str, Any]:

# %% plot magmoms distribution
x_col, y_col = "Magmoms (μ<sub>B</sub>)", count_col
df_magmoms = locals().get("df_magmoms")

if "df_magmoms" not in locals(): # only compute once for speed
if df_magmoms is None: # only compute once for speed
magmoms_hist = np.histogram(df_mp_trj[Key.magmoms].dropna().explode(), bins=300)
df_magmoms = pd.DataFrame(magmoms_hist, index=[y_col, x_col]).T.round(3)

Expand Down
3 changes: 2 additions & 1 deletion data/mp/get_mp_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from pymatviz.utils import annotate_metrics
from tqdm import tqdm

from matbench_discovery import STABILITY_THRESHOLD, Key, today
from matbench_discovery import STABILITY_THRESHOLD, today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.enums import Key

"""
Download all MP formation and above hull energies on 2023-01-10.
Expand Down
3 changes: 2 additions & 1 deletion data/mp/get_mp_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from pymongo.database import Database
from tqdm import tqdm, trange

from matbench_discovery import ROOT, Key, today
from matbench_discovery import ROOT, today
from matbench_discovery.enums import Key

__author__ = "Janosh Riebesell"
__date__ = "2023-03-15"
Expand Down
3 changes: 2 additions & 1 deletion data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from tqdm import tqdm

from matbench_discovery import ROOT, Key, today
from matbench_discovery import ROOT, today
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.enums import Key
from matbench_discovery.plots import plt

"""
Expand Down
14 changes: 8 additions & 6 deletions data/wbm/compile_wbm_test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@
from pymatviz.io import save_fig
from tqdm import tqdm

from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, Key, today
from matbench_discovery import PDF_FIGS, SITE_FIGS, WBM_DIR, today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.energy import get_e_form_per_atom
from matbench_discovery.enums import Key

try:
import gdown
except ImportError:
print(
except ImportError as exc:
exc.add_note(
"gdown not installed. Needed for downloading WBM initial + relaxed structures "
"from Google Drive."
)
raise

"""
Dataset generated with DFT and published in Jan 2021 as
Expand Down Expand Up @@ -90,8 +92,8 @@
18198704957443186264,
)

if "dfs_wbm_structs" not in locals():
dfs_wbm_structs = {}
dfs_wbm_structs = locals().get("dfs_wbm_structs", {})

for json_path in json_paths:
step = int(json_path.split(".json.bz2")[0][-1])
assert step in range(1, 6)
Expand Down Expand Up @@ -179,8 +181,8 @@ def increment_wbm_material_id(wbm_id: str) -> str:
print(f"{file_path} already exists, skipping")
continue

url = f"{mat_cloud_url}&filename={filename}"
try:
url = f"{mat_cloud_url}&filename={filename}"
urllib.request.urlretrieve(url, file_path)
except urllib.error.HTTPError as exc:
print(f"failed to download {url=}: {exc}")
Expand Down
18 changes: 5 additions & 13 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,7 @@
import plotly.io as pio
import pymatviz # noqa: F401

from matbench_discovery.enums import ( # noqa: F401
Key,
Model,
ModelType,
Open,
Quantity,
Targets,
Task,
)
from matbench_discovery.enums import Model, Quantity

PKG_NAME = "matbench-discovery"
direct_url = Distribution.from_name(PKG_NAME).read_text("direct_url.json") or "{}"
Expand Down Expand Up @@ -69,17 +61,17 @@
FIGSHARE_URLS = json.load(file)

# --- start global plot settings
px.defaults.labels = Quantity.val_dict() | Model.val_dict()
px.defaults.labels = Quantity.key_val_dict() | Model.key_val_dict()

global_layout = dict(
paper_bgcolor="rgba(0,0,0,0)",
font_size=13,
# increase legend marker size and make background transparent
legend=dict(itemsizing="constant", bgcolor="rgba(0, 0, 0, 0)"),
)
pio.templates["global"] = dict(layout=global_layout)
pio.templates.default = "pymatviz_dark+global"
px.defaults.template = "pymatviz_dark+global"
pio.templates["mbd_global"] = dict(layout=global_layout)
pio.templates.default = "pymatviz_dark+mbd_global"
px.defaults.template = "pymatviz_dark+mbd_global"

# https://github.com/plotly/Kaleido/issues/122#issuecomment-994906924
# when seeing MathJax "loading" message in exported PDFs,
Expand Down
3 changes: 2 additions & 1 deletion matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from monty.json import MontyDecoder
from tqdm import tqdm

from matbench_discovery import FIGSHARE_DIR, Key
from matbench_discovery import FIGSHARE_DIR
from matbench_discovery.enums import Key

if TYPE_CHECKING:
from pathlib import Path
Expand Down
20 changes: 15 additions & 5 deletions matbench_discovery/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,24 @@ def description(self) -> str:
return self.__dict__["desc"]

@classmethod
def val_dict(cls) -> dict[str, str]:
"""Return the Enum as dictionary."""
def key_val_dict(cls) -> dict[str, str]:
"""Map of keys to values."""
return {key: str(val) for key, val in cls.__members__.items()}

@classmethod
def label_dict(cls) -> dict[str, str]:
"""Return the Enum as dictionary."""
return {str(val): val.label for key, val in cls.__members__.items()}
def val_label_dict(cls) -> dict[str, str | None]:
"""Map of values to labels."""
return {str(val): val.label for val in cls.__members__.values()}

@classmethod
def val_desc_dict(cls) -> dict[str, str | None]:
"""Map of values to descriptions."""
return {str(val): val.description for val in cls.__members__.values()}

@classmethod
def label_desc_dict(cls) -> dict[str | None, str | None]:
"""Map of labels to descriptions."""
return {str(val.label): val.description for val in cls.__members__.values()}


@unique
Expand Down
17 changes: 9 additions & 8 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def hist_classified_stable_vs_hull_dist(
fixed in Inkscape or similar by merging regions by color.
"""
x_col = dict(true=each_true_col, pred=each_pred_col)[which_energy]
clf_col, value_name = "classified", "count"

df_plot = pd.DataFrame()

for facet, df_group in (
df.groupby(kwargs["facet_col"]) if "facet_col" in kwargs else [(None, df)]
):
Expand All @@ -113,16 +115,16 @@ def hist_classified_stable_vs_hull_dist(
)

# switch between hist of DFT-computed and model-predicted convex hull distance
e_above_hull = df_group[x_col]
each_true_pos = e_above_hull[true_pos]
each_true_neg = e_above_hull[true_neg]
each_false_neg = e_above_hull[false_neg]
each_false_pos = e_above_hull[false_pos]
srs_each = df_group[x_col]
each_true_pos = srs_each[true_pos]
each_true_neg = srs_each[true_neg]
each_false_neg = srs_each[false_neg]
each_false_pos = srs_each[false_pos]
# n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
# sum, (true_pos, false_pos, true_neg, false_neg)
# )

df_group[(clf_col := "classified")] = np.array(clf_labels)[
df_group[clf_col] = np.array(clf_labels)[
true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3
]

Expand All @@ -144,7 +146,6 @@ def hist_classified_stable_vs_hull_dist(
index=clf_labels,
).T
df_hist[x_col] = bin_edges[:-1]
value_name = "count"
df_melt = df_hist.melt(
id_vars=x_col,
value_vars=clf_labels,
Expand Down Expand Up @@ -714,6 +715,7 @@ def cumulative_metrics(
)
df = dfs[metric]
ax.set(ylim=(0, 1), xlim=(0, None), ylabel=metric)
bbox = dict(facecolor="white", alpha=0.5, edgecolor="none")
for model in df_preds:
# TODO is this really necessary?
if len(df[model].dropna()) == 0:
Expand All @@ -722,7 +724,6 @@ def cumulative_metrics(
y_end = df[model].dropna().iloc[-1]
# add some visual guidelines to the plot
intersect_kwargs = dict(linestyle=":", alpha=0.4, linewidth=2)
bbox = dict(facecolor="white", alpha=0.5, edgecolor="none")
# place model name at the end of every line
ax.text(x_end, y_end, model, va="bottom", rotation=30, bbox=bbox)
if "x" in project_end_point:
Expand Down
8 changes: 4 additions & 4 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pandas as pd
from tqdm import tqdm

from matbench_discovery import ROOT, STABILITY_THRESHOLD, Key, Model
from matbench_discovery import ROOT, STABILITY_THRESHOLD, Model
from matbench_discovery.data import Files, df_wbm, glob_to_df
from matbench_discovery.enums import Key
from matbench_discovery.metrics import stable_metrics
from matbench_discovery.plots import plotly_colors, plotly_line_styles, plotly_markers

Expand Down Expand Up @@ -65,7 +66,7 @@ class PredFiles(Files):


# key_map maps model keys to pretty labels
PRED_FILES = PredFiles(root=f"{ROOT}/models", key_map=Model.val_dict())
PRED_FILES = PredFiles(root=f"{ROOT}/models", key_map=Model.key_val_dict())


def load_df_wbm_with_preds(
Expand Down Expand Up @@ -101,15 +102,14 @@ def load_df_wbm_with_preds(
)

dfs: dict[str, pd.DataFrame] = {}

try:
for model_name in (bar := tqdm(models, disable=not pbar, desc="Loading preds")):
bar.set_postfix_str(model_name)
df = glob_to_df(PRED_FILES[model_name], pbar=False, **kwargs)
df = df.set_index(id_col)
dfs[model_name] = df
except Exception as exc:
raise RuntimeError(f"Failed to load {model_name=}") from exc
raise RuntimeError(f"Failed to load {locals().get('model_name')=}") from exc

from matbench_discovery.data import df_wbm

Expand Down
3 changes: 2 additions & 1 deletion models/alignn/test_alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from sklearn.metrics import r2_score
from tqdm import tqdm

from matbench_discovery import Key, Task, today
from matbench_discovery import today
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.enums import Key, Task
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit

Expand Down
Loading

0 comments on commit c9dc8ee

Please sign in to comment.