Skip to content

Commit

Permalink
Merge pull request #133 from sparks-baird/xtal
Browse files Browse the repository at this point in the history
refactor to accommodate structural inputs
  • Loading branch information
sgbaird authored Jan 8, 2023
2 parents c5464b5 + a50e12c commit f5d8a46
Show file tree
Hide file tree
Showing 10 changed files with 887 additions and 157 deletions.
2 changes: 1 addition & 1 deletion docs/source/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ How to `fit`/`predict`, use custom or built-in datasets, and perform adaptive de
```python
from mat_discover.mat_discover_ import Discover
disc = Discover(target_unit="GPa")
disc.fit(train_df) # DataFrames should have at minimum "formula" and "target" columns
disc.fit(train_df) # DataFrames should have at minimum ("formula" or "structure") and "target" columns
scores = disc.predict(val_df)
disc.plot()
disc.save()
Expand Down
2 changes: 1 addition & 1 deletion examples/elmd_densmap_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
x: disc.std_emb[:, 0],
y: disc.std_emb[:, 1],
"cluster ID": disc.labels,
"formula": disc.all_formula,
"formula": disc.all_inputs,
}
)
fig = pareto_plot(
Expand Down
29 changes: 29 additions & 0 deletions examples/mat_discover_xtal_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Load some data, fit Discover(), predict on validation data, make some plots, and save the model."""
# %% imports
import pandas as pd
from crabnet.data.materials_data import elasticity
from mat_discover.mat_discover_ import Discover

# %% setup
# set dummy to True for a quicker run --> small dataset, MDS instead of UMAP
dummy = False
# set gcv to False for a quicker run --> group-cross validation can take a while
gcv = False
disc = Discover(dummy_run=dummy, device="cuda", target_unit="GPa")
train_df, val_df = disc.data(elasticity, fname="train.csv", dummy=dummy)
cat_df = pd.concat((train_df, val_df), axis=0)

# %% fit
disc.fit(train_df)

# %% predict
score = disc.predict(val_df, umap_random_state=42)

# %% leave-one-cluster-out cross-validation
if gcv:
disc.group_cross_val(cat_df, umap_random_state=42)
print("scaled test error = ", disc.scaled_error)

# %% plot and save
disc.plot()
disc.save(dummy=dummy)
Loading

0 comments on commit f5d8a46

Please sign in to comment.