From a4f363136a2d1f5b18f7f8b36dd0b193545006f9 Mon Sep 17 00:00:00 2001 From: Habibur Rahman Date: Fri, 20 Jan 2023 18:20:55 -0500 Subject: [PATCH 1/8] Update train.py Currently, the train data is not correctly printed out in the prediction_results_train_set.csv file. It shows some symmetric data points along the parity line. Also, it prints out the validation dataset rather than the training dataset. --- alignn/train.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/alignn/train.py b/alignn/train.py index 4fa072e4..71b3e29b 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -844,7 +844,7 @@ def log_results(engine): # history["validation"][metric].append(vmetrics[metric]) if config.store_outputs: - history["EOS"] = eos.data + history["EOS"] = train_eos.data history["trainEOS"] = train_eos.data dumpjson( filename=os.path.join(config.output_dir, "history_val.json"), @@ -1027,10 +1027,12 @@ def es_score(engine): ) # TODO: Add IDs f.write("target,prediction\n") - for i, j in zip(x, y): - f.write("%6f, %6f\n" % (j, i)) - line = str(i) + "," + str(j) + "\n" - f.write(line) + #for i, j in zip(x, y): + #f.write("%6f, %6f\n" % (j, i)) + #line = str(i) + "," + str(j) + "\n" + # f.write(line) + for target_val, predicted_val in zip(x, y): + print(f”{target_val}, {predicted_val}”, file=f) f.close() # TODO: Fix IDs for train loader From 76ffebdf2719de3eee8eb1a24fecb45e45da74a2 Mon Sep 17 00:00:00 2001 From: Brian DeCost Date: Fri, 20 Jan 2023 21:44:00 -0500 Subject: [PATCH 2/8] fix quotes and remove dead code --- alignn/train.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/alignn/train.py b/alignn/train.py index 71b3e29b..213de52e 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -844,7 +844,7 @@ def log_results(engine): # history["validation"][metric].append(vmetrics[metric]) if config.store_outputs: - history["EOS"] = train_eos.data + history["EOS"] = eos.data history["trainEOS"] = train_eos.data dumpjson( filename=os.path.join(config.output_dir, "history_val.json"), @@ -1014,7 +1014,7 @@ def es_score(engine): if config.store_outputs and not classification: x = [] y = [] - for i in history["EOS"]: + for i in history["trainEOS"]: x.append(i[0].cpu().numpy().tolist()) y.append(i[1].cpu().numpy().tolist()) x = np.array(x, dtype="float").flatten() @@ -1027,12 +1027,8 @@ def es_score(engine): ) # TODO: Add IDs f.write("target,prediction\n") - #for i, j in zip(x, y): - #f.write("%6f, %6f\n" % (j, i)) - #line = str(i) + "," + str(j) + "\n" - # f.write(line) for target_val, predicted_val in zip(x, y): - print(f”{target_val}, {predicted_val}”, file=f) + print(f"{target_val}, {predicted_val}", file=f) f.close() # TODO: Fix IDs for train loader From a58dec32daccf4e820cdd994e6883b6bc770d5ca Mon Sep 17 00:00:00 2001 From: Brian DeCost Date: Fri, 20 Jan 2023 21:47:44 -0500 Subject: [PATCH 3/8] improve readability of code that pulls data from epoch store --- alignn/train.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/alignn/train.py b/alignn/train.py index 213de52e..bcda0d64 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -1012,24 +1012,25 @@ def es_score(engine): mean_absolute_error(np.array(targets), np.array(predictions)), ) if config.store_outputs and not classification: - x = [] - y = [] - for i in history["trainEOS"]: - x.append(i[0].cpu().numpy().tolist()) - y.append(i[1].cpu().numpy().tolist()) - x = np.array(x, dtype="float").flatten() - y = np.array(y, dtype="float").flatten() - f = open( - os.path.join( - config.output_dir, "prediction_results_train_set.csv" - ), - "w", - ) + # save training targets and predictions here # TODO: Add IDs - f.write("target,prediction\n") - for target_val, predicted_val in zip(x, y): - print(f"{target_val}, {predicted_val}", file=f) - f.close() + resultsfile = os.path.join( + config.output_dir, "prediction_results_train_set.csv" + ) + + target_vals, predictions = [], [] + + for tgt, pred in history["trainEOS"]: + target_vals.append(tgt.cpu().numpy().tolist()) + predictions.append(pred.cpu().numpy().tolist()) + + target_vals = np.array(target_vals, dtype="float").flatten() + predictions = np.array(predictions, dtype="float").flatten() + + with open(resultsfile, "w") as f: + print("target,prediction", file=f) + for target_val, predicted_val in zip(target_vals, predictions): + print(f"{target_val}, {predicted_val}", file=f) # TODO: Fix IDs for train loader """ From 1b42d8b38628d1f3298d3d6aeba9330a87d28f2f Mon Sep 17 00:00:00 2001 From: Habibur Rahman Date: Sun, 5 Mar 2023 19:23:19 -0500 Subject: [PATCH 4/8] Add files via upload --- alignn/predict.py | 72 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 alignn/predict.py diff --git a/alignn/predict.py b/alignn/predict.py new file mode 100644 index 00000000..c688df12 --- /dev/null +++ b/alignn/predict.py @@ -0,0 +1,72 @@ +from alignn.models.alignn import ALIGNN, ALIGNNConfig +import torch +output_features = 1 +#directory of checkpoint_file, basically, your optimized model +filename = 'DataSet_A_Model/checkpoint_150.pt' +device = "cpu" +if torch.cuda.is_available(): + device = torch.device("cuda") +model = ALIGNN(ALIGNNConfig(name="alignn", output_features=output_features)) +model.load_state_dict(torch.load(filename, map_location=device)["model"]) +model.eval() + +import os +import csv +from jarvis.core.atoms import Atoms +from alignn.graphs import Graph +import re + +cutoff = 8.0 +max_neighbors = 12 + +#directory where you have all the poscar/cif and you would like to apply your optimized model on them. + +sample_data_folder = '/Users/habibur/Habibur_Python_Scripts/alignn/alignn/data/' + +# id_prop.csv; a csv file where you have all the ids of the poscar/cif in the first column and corresponding properties in the second column. + +csv_file = 'id_prop.csv' +# In this output.csv file, all the ids and corresponding properties will be printed out. +output_file = 'output.csv' + +with open(os.path.join(sample_data_folder, csv_file), newline='') as f: + reader = csv.reader(f) + file_list = [row[0] for row in reader] + +atoms_list = [] +for file in file_list: + atoms = Atoms.from_cif(os.path.join(sample_data_folder, file)) + atoms_list.append(atoms) + +g_list = [] +lg_list = [] +for atoms in atoms_list: + g, lg = Graph.atom_dgl_multigraph( + atoms, cutoff=float(cutoff), max_neighbors=max_neighbors + ) + g_list.append(g) + lg_list.append(lg) + +out_data_list = [] +for g, lg in zip(g_list, lg_list): + out_data = ( + model([g.to(device), lg.to(device)]) + .detach() + .cpu() + .numpy() + .flatten() + .tolist() + ) + out_data_str = str(out_data) + # Extract data within square brackets + match = re.search(r'\[(.*)\]', out_data_str) + if match: + out_data_list.append(match.group(1)) + else: + out_data_list.append('') + +with open(os.path.join(sample_data_folder, output_file), mode='w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['Filename', 'Output']) + for i, file in enumerate(file_list): + writer.writerow([file, out_data_list[i]]) From 34ca917548cd7e51dd74e25dc8e579bcf8023c38 Mon Sep 17 00:00:00 2001 From: Habibur Rahman Date: Thu, 9 Mar 2023 12:21:26 -0500 Subject: [PATCH 5/8] Add files via upload --- move.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 move.py diff --git a/move.py b/move.py new file mode 100644 index 00000000..32a5b522 --- /dev/null +++ b/move.py @@ -0,0 +1,51 @@ +import os +import random +import shutil +import pandas as pd + +# set the path to the folder containing the .cif files and csv file +folder_path = "/Users/habibur/Habibur_Python_Scripts/alignn/alignn/DataSet_B" + +# set the path to the new folder where you want to move the selected files +new_folder_path = "/Users/habibur/Habibur_Python_Scripts/alignn/alignn/20%/" + +# set the percentage of files you want to select +percent_to_select = 10 + +# read the csv file into a pandas dataframe +csv_file_path = os.path.join(folder_path, "id_prop.csv") +df = pd.read_csv(csv_file_path, index_col=0) + +# get a list of all the files in the folder with the .cif extension +file_list = [f for f in os.listdir(folder_path) if f.endswith(".cif")] + +# calculate the number of files to select +num_to_select = int(len(file_list) * (percent_to_select / 100)) + +# randomly select the files to move +files_to_move = random.sample(file_list, num_to_select) + +# move the selected files to the new folder +moved_files = [] +for file_name in files_to_move: + file_path = os.path.join(folder_path, file_name) + new_file_path = os.path.join(new_folder_path, file_name) + shutil.move(file_path, new_file_path) + moved_files.append(file_name) + +# create a new dataframe with the moved files and their corresponding values +moved_df = df.loc[moved_files] + +# write the moved files and values to a new csv file +moved_csv_path = os.path.join(new_folder_path, "moved_files.csv") +moved_df.to_csv(moved_csv_path) + + + + + + + + + + From 87514d3ca45bc019ad84d48879faf82960bc8116 Mon Sep 17 00:00:00 2001 From: Habibur Rahman Date: Sat, 18 Mar 2023 09:20:00 -0400 Subject: [PATCH 6/8] Add files via upload --- update_id_prop.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 update_id_prop.py diff --git a/update_id_prop.py b/update_id_prop.py new file mode 100644 index 00000000..843d3f3f --- /dev/null +++ b/update_id_prop.py @@ -0,0 +1,22 @@ +import pandas as pd + +# Set the path to the id_prop.csv file +id_prop_csv_path = "/Users/habibur/Habibur_Python_Scripts/alignn/alignn/id_prop.csv" + +# Set the path to the moved_files.csv file +moved_files_csv_path = "/Users/habibur/Habibur_Python_Scripts/alignn/alignn/moved_files.csv" + +# Read in the id_prop.csv file as a pandas DataFrame +id_prop_df = pd.read_csv(id_prop_csv_path, index_col=0) + +# Read in the moved_files.csv file as a pandas DataFrame +moved_files_df = pd.read_csv(moved_files_csv_path, index_col=0) + +# Get a list of the files that were moved +moved_files = moved_files_df.index.tolist() + +# Drop the rows in id_prop_df that correspond to the moved files +id_prop_df = id_prop_df.drop(moved_files, errors="ignore") + +# Write the updated id_prop.csv file back to disk +id_prop_df.to_csv(id_prop_csv_path) From 97b20a6945dd095c6d08d99cbd0d1da7e19216ea Mon Sep 17 00:00:00 2001 From: Habibur Rahman Date: Sun, 27 Oct 2024 23:01:37 -0400 Subject: [PATCH 7/8] Created using Colab --- examples/fine_tuning.ipynb | 1074 ++++++++++++++++++++++++++++++++++++ 1 file changed, 1074 insertions(+) create mode 100644 examples/fine_tuning.ipynb diff --git a/examples/fine_tuning.ipynb b/examples/fine_tuning.ipynb new file mode 100644 index 00000000..39379fd6 --- /dev/null +++ b/examples/fine_tuning.ipynb @@ -0,0 +1,1074 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "id": "0", + "metadata": { + "id": "0" + }, + "source": [ + "# Fine-tune the pretrained CHGNet for better accuracy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1", + "metadata": { + "id": "1", + "outputId": "dbab2444-2e16-4783-e5ec-26bbb5396286", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting chgnet\n", + " Downloading chgnet-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)\n", + "Collecting ase>=3.23.0 (from chgnet)\n", + " Downloading ase-3.23.0-py3-none-any.whl.metadata (3.8 kB)\n", + "Requirement already satisfied: cython>=3 in /usr/local/lib/python3.10/dist-packages (from chgnet) (3.0.11)\n", + "Requirement already satisfied: numpy>=1.26 in /usr/local/lib/python3.10/dist-packages (from chgnet) (1.26.4)\n", + "Collecting nvidia-ml-py3>=7.352.0 (from chgnet)\n", + " Downloading nvidia-ml-py3-7.352.0.tar.gz (19 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting pymatgen>=2024.9.10 (from chgnet)\n", + " Downloading pymatgen-2024.10.27-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)\n", + "Requirement already satisfied: torch>=2.4.1 in /usr/local/lib/python3.10/dist-packages (from chgnet) (2.5.0+cu121)\n", + "Requirement already satisfied: typing-extensions>=4.12 in /usr/local/lib/python3.10/dist-packages (from chgnet) (4.12.2)\n", + "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from ase>=3.23.0->chgnet) (1.13.1)\n", + "Requirement already satisfied: matplotlib>=3.3.4 in /usr/local/lib/python3.10/dist-packages (from ase>=3.23.0->chgnet) (3.7.1)\n", + "Requirement already satisfied: joblib>=1 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (1.4.2)\n", + "Collecting matplotlib>=3.3.4 (from ase>=3.23.0->chgnet)\n", + " Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n", + "Collecting monty>=2024.7.29 (from pymatgen>=2024.9.10->chgnet)\n", + " Downloading monty-2024.10.21-py3-none-any.whl.metadata (3.6 kB)\n", + "Requirement already satisfied: networkx>=3 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (3.4.2)\n", + "Collecting palettable>=3.3.3 (from pymatgen>=2024.9.10->chgnet)\n", + " Downloading palettable-3.3.3-py2.py3-none-any.whl.metadata (3.3 kB)\n", + "Requirement already satisfied: pandas>=2 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (2.2.2)\n", + "Requirement already satisfied: plotly>=4.5.0 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (5.24.1)\n", + "Collecting pybtex>=0.24.0 (from pymatgen>=2024.9.10->chgnet)\n", + " Downloading pybtex-0.24.0-py2.py3-none-any.whl.metadata (2.0 kB)\n", + "Requirement already satisfied: requests>=2.32 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (2.32.3)\n", + "Collecting ruamel.yaml>=0.17.0 (from pymatgen>=2024.9.10->chgnet)\n", + " Downloading ruamel.yaml-0.18.6-py3-none-any.whl.metadata (23 kB)\n", + "Collecting spglib>=2.5.0 (from pymatgen>=2024.9.10->chgnet)\n", + " Downloading spglib-2.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)\n", + "Requirement already satisfied: sympy>=1.2 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (1.13.1)\n", + "Requirement already satisfied: tabulate>=0.9 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (0.9.0)\n", + "Requirement already satisfied: tqdm>=4.60 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (4.66.5)\n", + "Collecting uncertainties>=3.1.4 (from pymatgen>=2024.9.10->chgnet)\n", + " Downloading uncertainties-3.2.2-py3-none-any.whl.metadata (6.9 kB)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.4.1->chgnet) (3.16.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.4.1->chgnet) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.4.1->chgnet) (2024.6.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy>=1.2->pymatgen>=2024.9.10->chgnet) (1.3.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (1.3.0)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (4.54.1)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (1.4.7)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (24.1)\n", + "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (10.4.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (3.2.0)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=2->pymatgen>=2024.9.10->chgnet) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=2->pymatgen>=2024.9.10->chgnet) (2024.2)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly>=4.5.0->pymatgen>=2024.9.10->chgnet) (9.0.0)\n", + "Requirement already satisfied: PyYAML>=3.01 in /usr/local/lib/python3.10/dist-packages (from pybtex>=0.24.0->pymatgen>=2024.9.10->chgnet) (6.0.2)\n", + "Collecting latexcodec>=1.0.4 (from pybtex>=0.24.0->pymatgen>=2024.9.10->chgnet)\n", + " Downloading latexcodec-3.0.0-py3-none-any.whl.metadata (4.9 kB)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from pybtex>=0.24.0->pymatgen>=2024.9.10->chgnet) (1.16.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32->pymatgen>=2024.9.10->chgnet) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32->pymatgen>=2024.9.10->chgnet) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32->pymatgen>=2024.9.10->chgnet) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32->pymatgen>=2024.9.10->chgnet) (2024.8.30)\n", + "Collecting ruamel.yaml.clib>=0.2.7 (from ruamel.yaml>=0.17.0->pymatgen>=2024.9.10->chgnet)\n", + " Downloading ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.7 kB)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.4.1->chgnet) (3.0.2)\n", + "Downloading chgnet-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.2/9.2 MB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading ase-3.23.0-py3-none-any.whl (2.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.9/2.9 MB\u001b[0m \u001b[31m65.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pymatgen-2024.10.27-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m63.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.3/8.3 MB\u001b[0m \u001b[31m77.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading monty-2024.10.21-py3-none-any.whl (68 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m68.5/68.5 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading palettable-3.3.3-py2.py3-none-any.whl (332 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m332.3/332.3 kB\u001b[0m \u001b[31m25.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pybtex-0.24.0-py2.py3-none-any.whl (561 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m561.4/561.4 kB\u001b[0m \u001b[31m35.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading ruamel.yaml-0.18.6-py3-none-any.whl (117 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.8/117.8 kB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading spglib-2.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m52.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading uncertainties-3.2.2-py3-none-any.whl (58 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading latexcodec-3.0.0-py3-none-any.whl (18 kB)\n", + "Downloading ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (722 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m722.2/722.2 kB\u001b[0m \u001b[31m36.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hBuilding wheels for collected packages: nvidia-ml-py3\n", + " Building wheel for nvidia-ml-py3 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for nvidia-ml-py3: filename=nvidia_ml_py3-7.352.0-py3-none-any.whl size=19173 sha256=66b17dd83b3e7b77db77b5bea01959b7680f8981bbe285546e23fcf0b71f7d37\n", + " Stored in directory: /root/.cache/pip/wheels/5c/d8/c0/46899f8be7a75a2ffd197a23c8797700ea858b9b34819fbf9e\n", + "Successfully built nvidia-ml-py3\n", + "Installing collected packages: nvidia-ml-py3, uncertainties, spglib, ruamel.yaml.clib, palettable, latexcodec, ruamel.yaml, pybtex, matplotlib, monty, ase, pymatgen, chgnet\n", + " Attempting uninstall: matplotlib\n", + " Found existing installation: matplotlib 3.7.1\n", + " Uninstalling matplotlib-3.7.1:\n", + " Successfully uninstalled matplotlib-3.7.1\n", + "Successfully installed ase-3.23.0 chgnet-0.4.0 latexcodec-3.0.0 matplotlib-3.9.2 monty-2024.10.21 nvidia-ml-py3-7.352.0 palettable-3.3.3 pybtex-0.24.0 pymatgen-2024.10.27 ruamel.yaml-0.18.6 ruamel.yaml.clib-0.2.12 spglib-2.5.0 uncertainties-3.2.2\n" + ] + } + ], + "source": [ + "try:\n", + " from chgnet.model import CHGNet\n", + "except ImportError:\n", + " # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n", + " !pip install chgnet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": { + "id": "2" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from pymatgen.core import Structure\n", + "\n", + "# If the above line fails in Google Colab due to numpy version issue,\n", + "# please restart the runtime, and the problem will be solved" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": { + "id": "3" + }, + "source": [ + "## 0. Parse DFT outputs to CHGNet readable formats\n" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": { + "id": "4" + }, + "source": [ + "CHGNet is interfaced to [Pymatgen](https://pymatgen.org/), the training samples (normally coming from different DFTs like VASP),\n", + "need to be converted to [pymatgen.core.structure](https://pymatgen.org/pymatgen.core.html#module-pymatgen.core.structure).\n", + "\n", + "To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "id": "5" + }, + "outputs": [], + "source": [ + "from chgnet.utils import parse_vasp_dir\n", + "\n", + "# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.\n", + "dataset_dict = parse_vasp_dir(\n", + " file_root=\"./my_vasp_calc_dir\", save_path=\"./my_vasp_calc_dir/chgnet_dataset.json\"\n", + ")\n", + "print(list(dataset_dict))" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": { + "id": "6" + }, + "source": [ + "The parsed python dictionary includes information for CHGNet inputs (structures), and CHGNet prediction labels (energy, force, stress ,magmom).\n", + "\n", + "we can save the parsed structures and labels to disk, so that they can be easily reloaded during multiple rounds of training.\n", + "\n", + "The json file can be saved by providing the save_path\n" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": { + "id": "7" + }, + "source": [ + "The Pymatgen structures can be saved separately if you're interested to take a look into each structure.\n", + "\n", + "Below are the example codes to save the structures in either json, pickle, cif, or CHGNet graph.\n", + "\n", + "For super-large training dataset, like MPtrj dataset, we recommend [converting them to CHGNet graphs](https://github.com/CederGroupHub/chgnet/blob/main/examples/make_graphs.py). This will save significant memory and graph computing time.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "id": "8" + }, + "outputs": [], + "source": [ + "# Structure to json\n", + "from chgnet.utils import write_json\n", + "\n", + "dict_to_json = [struct.as_dict() for struct in dataset_dict[\"structure\"]]\n", + "write_json(dict_to_json, \"CHGNet_structures.json\")\n", + "\n", + "\n", + "# Structure to pickle\n", + "import pickle\n", + "\n", + "with open(\"CHGNet_structures.p\", \"wb\") as f:\n", + " pickle.dump(dataset_dict, f)\n", + "\n", + "\n", + "# Structure to cif\n", + "for idx, struct in enumerate(dataset_dict[\"structure\"]):\n", + " struct.to(filename=f\"{idx}.cif\")\n", + "\n", + "\n", + "# Structure to CHGNet graph\n", + "from chgnet.graph import CrystalGraphConverter\n", + "\n", + "converter = CrystalGraphConverter()\n", + "for idx, struct in enumerate(dataset_dict[\"structure\"]):\n", + " graph = converter(struct)\n", + " graph.save(fname=f\"{idx}.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": { + "id": "9" + }, + "source": [ + "For other types of DFT calculations, please refer to their interfaces\n", + "in [pymatgen.io](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io).\n", + "\n", + "see: [Quantum Espresso](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io.pwscf)\n", + "\n", + "see: [CP2K](https://pymatgen.org/pymatgen.io.cp2k.html#module-pymatgen.io.cp2k)\n", + "\n", + "see: [Gaussian](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io.gaussian)\n" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": { + "id": "10" + }, + "source": [ + "## 1. Prepare Training Data\n" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": { + "id": "11" + }, + "source": [ + "If you have parsed your VASP labels from step 0, you can reload the saved json file.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": { + "id": "12" + }, + "outputs": [], + "source": [ + "from chgnet.utils import read_json\n", + "\n", + "dataset_dict = read_json(\"./my_vasp_calc_dir/chgnet_dataset.json\")\n", + "structures = [Structure.from_dict(struct) for struct in dataset_dict[\"structure\"]]\n", + "energies = dataset_dict[\"energy_per_atom\"]\n", + "forces = dataset_dict[\"force\"]\n", + "stresses = dataset_dict.get(\"stress\") or None\n", + "magmoms = dataset_dict.get(\"magmom\") or None" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": { + "id": "13" + }, + "source": [ + "If you don't have any DFT calculations now, we can create a dummy fine-tuning dataset by using CHGNet prediction with some random noise.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": { + "id": "14" + }, + "outputs": [], + "source": [ + "try:\n", + " from chgnet import ROOT\n", + "\n", + " lmo = Structure.from_file(f\"{ROOT}/examples/mp-18767-LiMnO2.cif\")\n", + "except Exception:\n", + " from urllib.request import urlopen\n", + "\n", + " url = \"https://raw.githubusercontent.com/CederGroupHub/chgnet/main/examples/mp-18767-LiMnO2.cif\"\n", + " cif = urlopen(url).read().decode(\"utf-8\")\n", + " lmo = Structure.from_str(cif, fmt=\"cif\")\n", + "\n", + "structures, energies_per_atom, forces, stresses, magmoms = [], [], [], [], []\n", + "chgnet = CHGNet.load()\n", + "for _ in range(100):\n", + " structure = lmo.copy()\n", + " # stretch the cell by a small amount\n", + " structure.apply_strain(np.random.uniform(-0.1, 0.1, size=3))\n", + " # perturb all atom positions by a small amount\n", + " structure.perturb(0.1)\n", + "\n", + " pred = chgnet.predict_structure(structure)\n", + "\n", + " structures.append(structure)\n", + " energies_per_atom.append(pred[\"e\"] + np.random.uniform(-0.1, 0.1, size=1))\n", + " forces.append(pred[\"f\"] + np.random.uniform(-0.01, 0.01, size=pred[\"f\"].shape))\n", + " stresses.append(\n", + " pred[\"s\"] * -10 + np.random.uniform(-0.05, 0.05, size=pred[\"s\"].shape)\n", + " )\n", + " magmoms.append(pred[\"m\"] + np.random.uniform(-0.03, 0.03, size=pred[\"m\"].shape))" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": { + "id": "15" + }, + "source": [ + "Note that the stress output from CHGNet is in unit of GPa, here the -10 unit conversion\n", + "modifies it to be kbar in VASP raw unit.\n", + "If you're using stress labels from VASP, you don't need to do any unit conversions\n", + "StructureData dataset class takes in VASP units.\n" + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": { + "id": "16" + }, + "source": [ + "## 2. Define DataSet\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": { + "id": "17" + }, + "outputs": [], + "source": [ + "from chgnet.data.dataset import StructureData, get_train_val_test_loader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": { + "id": "18", + "outputId": "5ae32ebb-a111-4fc0-9d1e-df91bd2bee9b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "100 structures imported\n" + ] + } + ], + "source": [ + "dataset = StructureData(\n", + " structures=structures,\n", + " energies=energies_per_atom,\n", + " forces=forces,\n", + " stresses=stresses, # can be None\n", + " magmoms=magmoms, # can be None\n", + ")\n", + "train_loader, val_loader, test_loader = get_train_val_test_loader(\n", + " dataset, batch_size=8, train_ratio=0.9, val_ratio=0.05\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": { + "id": "19" + }, + "source": [ + "Alternatively, the dataset can be directly created from VASP calculation dir.\n", + "This function essentially parse the VASP directory first, save the labels to json file, and create the StructureData class\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": { + "id": "20" + }, + "outputs": [], + "source": [ + "dataset = StructureData.from_vasp(\n", + " file_root=\"./my_vasp_calc_dir\", save_path=\"./my_vasp_calc_dir/chgnet_dataset.json\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": { + "id": "21" + }, + "source": [ + "The training set is used to optimize the CHGNet through gradient descent, the validation set is used to see validation error at the end of each epoch, and the test set is used to see the final test error at the end of training. The test set can be optional.\n", + "\n", + "The `batch_size` is defined to be 8 for small GPU-memory. If > 10 GB memory is available, we highly recommend to increase `batch_size` for better speed.\n", + "\n", + "If you have very large numbers (>100K) of structures (which is typical for AIMD), putting them all in a python list can quickly run into memory issues. In this case we highly recommend you to pre-convert all the structures into graphs and save them as shown in `examples/make_graphs.py`. Then directly train CHGNet by loading the graphs from disk instead of memory using the `GraphData` class defined in `data/dataset.py`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "22", + "metadata": { + "id": "22" + }, + "source": [ + "## 3. Define model and trainer\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": { + "id": "23", + "outputId": "9134d4fb-a0de-45d3-d268-e856bfad62e2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CHGNet v0.3.0 initialized with 412,525 parameters\n", + "CHGNet will run on cpu\n" + ] + } + ], + "source": [ + "from chgnet.model import CHGNet\n", + "from chgnet.trainer import Trainer\n", + "\n", + "# Load pretrained CHGNet\n", + "chgnet = CHGNet.load()" + ] + }, + { + "cell_type": "markdown", + "id": "24", + "metadata": { + "id": "24" + }, + "source": [ + "It's optional to freeze the weights inside some layers. This is a common technique to retain the learned knowledge during fine-tuning in large pretrained neural networks. You can choose the layers you want to freeze.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": { + "id": "25" + }, + "outputs": [], + "source": [ + "# Optionally fix the weights of some layers\n", + "for layer in [\n", + " chgnet.atom_embedding,\n", + " chgnet.bond_embedding,\n", + " chgnet.angle_embedding,\n", + " chgnet.bond_basis_expansion,\n", + " chgnet.angle_basis_expansion,\n", + " chgnet.atom_conv_layers[:-1],\n", + " chgnet.bond_conv_layers,\n", + " chgnet.angle_layers,\n", + "]:\n", + " for param in layer.parameters():\n", + " param.requires_grad = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": { + "id": "26" + }, + "outputs": [], + "source": [ + "# Define Trainer\n", + "trainer = Trainer(\n", + " model=chgnet,\n", + " targets=\"efsm\",\n", + " optimizer=\"Adam\",\n", + " scheduler=\"CosLR\",\n", + " criterion=\"MSE\",\n", + " epochs=5,\n", + " learning_rate=1e-2,\n", + " use_device=\"cpu\",\n", + " print_freq=6,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "27", + "metadata": { + "id": "27" + }, + "source": [ + "## 4. Start training\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": { + "id": "28", + "outputId": "cc06272b-897a-49ec-90b5-8521b6523eb3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Begin Training: using cpu device\n", + "training targets: efsm\n", + "Epoch: [0][1/12]\tTime (0.476) Data (0.016) Loss 0.0033 (0.0033) MAEs: e 0.053 (0.053) f 0.004 (0.004) s 0.002 (0.002) m 0.016 (0.016) \n", + "Epoch: [0][6/12]\tTime (0.426) Data (0.015) Loss 0.0040 (0.0039) MAEs: e 0.054 (0.056) f 0.005 (0.005) s 0.002 (0.002) m 0.015 (0.015) \n", + "Epoch: [0][12/12]\tTime (0.414) Data (0.014) Loss 0.0040 (0.0038) MAEs: e 0.054 (0.054) f 0.005 (0.005) s 0.002 (0.002) m 0.015 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "Epoch: [1][1/12]\tTime (0.409) Data (0.000) Loss 0.0052 (0.0052) MAEs: e 0.064 (0.064) f 0.005 (0.005) s 0.002 (0.002) m 0.013 (0.013) \n", + "Epoch: [1][6/12]\tTime (0.393) Data (0.000) Loss 0.0036 (0.0039) MAEs: e 0.053 (0.055) f 0.005 (0.005) s 0.002 (0.002) m 0.014 (0.014) \n", + "Epoch: [1][12/12]\tTime (0.371) Data (0.000) Loss 0.0029 (0.0038) MAEs: e 0.053 (0.054) f 0.005 (0.005) s 0.003 (0.002) m 0.012 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "Epoch: [2][1/12]\tTime (0.389) Data (0.000) Loss 0.0056 (0.0056) MAEs: e 0.065 (0.065) f 0.005 (0.005) s 0.002 (0.002) m 0.015 (0.015) \n", + "Epoch: [2][6/12]\tTime (0.377) Data (0.000) Loss 0.0042 (0.0046) MAEs: e 0.059 (0.062) f 0.005 (0.005) s 0.002 (0.002) m 0.014 (0.014) \n", + "Epoch: [2][12/12]\tTime (0.350) Data (0.000) Loss 0.0025 (0.0038) MAEs: e 0.048 (0.054) f 0.005 (0.005) s 0.002 (0.002) m 0.008 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "Epoch: [3][1/12]\tTime (0.363) Data (0.000) Loss 0.0049 (0.0049) MAEs: e 0.065 (0.065) f 0.005 (0.005) s 0.002 (0.002) m 0.014 (0.014) \n", + "Epoch: [3][6/12]\tTime (0.359) Data (0.000) Loss 0.0050 (0.0042) MAEs: e 0.066 (0.057) f 0.005 (0.005) s 0.003 (0.002) m 0.014 (0.014) \n", + "Epoch: [3][12/12]\tTime (0.355) Data (0.000) Loss 0.0045 (0.0038) MAEs: e 0.059 (0.054) f 0.004 (0.005) s 0.003 (0.002) m 0.012 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "Epoch: [4][1/12]\tTime (0.384) Data (0.000) Loss 0.0033 (0.0033) MAEs: e 0.051 (0.051) f 0.005 (0.005) s 0.003 (0.003) m 0.015 (0.015) \n", + "Epoch: [4][6/12]\tTime (0.384) Data (0.000) Loss 0.0016 (0.0033) MAEs: e 0.035 (0.051) f 0.005 (0.005) s 0.002 (0.002) m 0.012 (0.014) \n", + "Epoch: [4][12/12]\tTime (0.351) Data (0.000) Loss 0.0011 (0.0038) MAEs: e 0.033 (0.054) f 0.004 (0.005) s 0.002 (0.002) m 0.014 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "---------Evaluate Model on Test Set---------------\n", + "** e_MAE (0.056) \tf_MAE (0.005) \ts_MAE (0.003) \tm_MAE (0.015) \t\n" + ] + } + ], + "source": [ + "trainer.train(train_loader, val_loader, test_loader)" + ] + }, + { + "cell_type": "markdown", + "id": "29", + "metadata": { + "id": "29" + }, + "source": [ + "After training, the trained model can be found in the directory of today's date. Or it can be accessed by:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": { + "id": "30" + }, + "outputs": [], + "source": [ + "model = trainer.model\n", + "best_model = trainer.best_model # best model based on validation energy MAE" + ] + }, + { + "cell_type": "markdown", + "id": "31", + "metadata": { + "id": "31" + }, + "source": [ + "## Extras 1: GGA / GGA+U compatibility\n" + ] + }, + { + "cell_type": "markdown", + "id": "32", + "metadata": { + "id": "32" + }, + "source": [ + "### Q: Why and when do you care about this?\n", + "\n", + "**When**: If you want to fine-tune the pretrained CHGNet with your own GGA+U VASP calculations, and you want to keep your VASP energy compatible to the pretrained dataset. In case your dataset is so large that the pretrained knowledge does not matter to you, you can ignore this.\n", + "\n", + "**Why**: CHGNet is trained on both GGA and GGA+U calculations from Materials Project. And there has been developed methods in solving the compatibility between GGA and GGA+U calculations which makes the energies universally applicable for cross-chemistry comparison and phase-diagram constructions. Please refer to:\n", + "\n", + "https://journals.aps.org/prb/abstract/10.1103/PhysRevB.84.045115\n", + "\n", + "Below we show an example to apply the compatibility.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": { + "id": "33", + "outputId": "e2ab424d-7e76-4735-c793-1105a548b5eb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The raw total energy from VASP of LMO is: -58.97 eV\n" + ] + } + ], + "source": [ + "# Imagine this is the VASP raw energy\n", + "vasp_raw_energy = -58.97\n", + "\n", + "print(f\"The raw total energy from VASP of LMO is: {vasp_raw_energy} eV\")" + ] + }, + { + "cell_type": "markdown", + "id": "34", + "metadata": { + "id": "34" + }, + "source": [ + "You can look for the energy correction applied to each element in :\n", + "\n", + "https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/entries/MP2020Compatibility.yaml\n", + "\n", + "Here LiMnO2 applies to both Mn in transition metal oxides correction and oxide correction.\n" + ] + }, + { + "cell_type": "markdown", + "id": "35", + "metadata": { + "id": "35" + }, + "source": [ + "To demystify `MaterialsProject2020Compatibility`, basically all that's happening is:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": { + "id": "36", + "outputId": "329e0ad2-10a2-4571-9b3d-ff0edf4afab9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The corrected total energy after MP2020 = -65.05 eV\n" + ] + } + ], + "source": [ + "Mn_correction_in_TMO = -1.668\n", + "oxide_correction = -0.687\n", + "_, num_Mn, num_O = lmo.composition.values()\n", + "\n", + "\n", + "corrected_energy = (\n", + " vasp_raw_energy + num_Mn * Mn_correction_in_TMO + num_O * oxide_correction\n", + ")\n", + "print(f\"The corrected total energy after MP2020 = {corrected_energy:.4} eV\")" + ] + }, + { + "cell_type": "markdown", + "id": "37", + "metadata": { + "id": "37" + }, + "source": [ + "You can also apply the `MaterialsProject2020Compatibility` through pymatgen\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": { + "id": "38", + "outputId": "7e4dc685-ce0f-4da3-a9f7-1c0cb2c3fdf4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The total energy of LMO after MP2020Compatibility correction = -62.31 eV\n" + ] + } + ], + "source": [ + "from pymatgen.entries.compatibility import MaterialsProject2020Compatibility\n", + "from pymatgen.entries.computed_entries import ComputedStructureEntry\n", + "\n", + "params = {\"hubbards\": {\"Mn\": 3.9, \"O\": 0, \"Li\": 0}, \"run_type\": \"GGA+U\"}\n", + "\n", + "cse = ComputedStructureEntry(lmo, vasp_raw_energy, parameters=params)\n", + "\n", + "MaterialsProject2020Compatibility(check_potcar=False).process_entries(cse)\n", + "print(\n", + " f\"The total energy of LMO after MP2020Compatibility correction = {cse.energy:.4} eV\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "39", + "metadata": { + "id": "39" + }, + "source": [ + "Now use this corrected energy as labels to tune CHGNet, you're good to go!\n" + ] + }, + { + "cell_type": "markdown", + "id": "40", + "metadata": { + "id": "40" + }, + "source": [ + "## Extras 2: AtomRef\n" + ] + }, + { + "cell_type": "markdown", + "id": "41", + "metadata": { + "id": "41" + }, + "source": [ + "### Q: Why and when do you care about this?\n", + "\n", + "**When**: When you fine tune CHGNet to DFT labels that are incompatible with Materials Project, like r2SCAN functional, or other DFTs like Gaussian or QE. The large shifts in elemental energy is not of our interest and should be reconciled. For example, Li has -0.95 eV/atom in GGA (https://next-gen.materialsproject.org/materials/mp-135/tasks/mp-990455) and -1.17 eV/atom in R2SCAN (https://next-gen.materialsproject.org/materials/mp-135/tasks/mp-1943895)\n", + "\n", + "**Why**: The GNN learns the interaction between the atoms and the composition model (AtomRef) in CHGNet is used to normalize the elemental energy contribution, similar to a formation-energy-like calculation. During fine-tuning, we want to keep the most of knowledge unchanged in the GNN and allow the AtomRef to shift for the elemental energy change. So that the finetuning on the graph layers can be focused on energy contribution from atom-atom interaction instead of meaningless atom reference energies.\n", + "\n", + "Below I will show an example to fit the AtomRef layer:\n" + ] + }, + { + "cell_type": "markdown", + "id": "42", + "metadata": { + "id": "42" + }, + "source": [ + "### A quick and easy way to turn on training of AtomRef in the trainer (this is by default off):\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43", + "metadata": { + "id": "43", + "outputId": "7c84164d-9e80-46e3-beec-513850a04e3e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Begin Training: using cpu device\n", + "training targets: efsm\n", + "Epoch: [0][1/12]\tTime (0.475) Data (0.001) Loss 0.0028 (0.0028) MAEs: e 0.047 (0.047) f 0.005 (0.005) s 0.003 (0.003) m 0.014 (0.014) \n", + "Epoch: [0][6/12]\tTime (0.379) Data (0.000) Loss 0.0027 (0.0037) MAEs: e 0.046 (0.053) f 0.005 (0.005) s 0.002 (0.002) m 0.015 (0.014) \n", + "Epoch: [0][12/12]\tTime (0.359) Data (0.000) Loss 0.0010 (0.0038) MAEs: e 0.030 (0.054) f 0.005 (0.005) s 0.003 (0.002) m 0.012 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "Epoch: [1][1/12]\tTime (0.417) Data (0.000) Loss 0.0011 (0.0011) MAEs: e 0.027 (0.027) f 0.004 (0.004) s 0.002 (0.002) m 0.015 (0.015) \n", + "Epoch: [1][6/12]\tTime (0.359) Data (0.000) Loss 0.0049 (0.0040) MAEs: e 0.062 (0.056) f 0.005 (0.005) s 0.003 (0.002) m 0.015 (0.015) \n", + "Epoch: [1][12/12]\tTime (0.351) Data (0.000) Loss 0.0054 (0.0038) MAEs: e 0.073 (0.054) f 0.004 (0.005) s 0.002 (0.002) m 0.013 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "Epoch: [2][1/12]\tTime (0.368) Data (0.000) Loss 0.0027 (0.0027) MAEs: e 0.043 (0.043) f 0.005 (0.005) s 0.003 (0.003) m 0.016 (0.016) \n", + "Epoch: [2][6/12]\tTime (0.388) Data (0.000) Loss 0.0042 (0.0034) MAEs: e 0.056 (0.051) f 0.005 (0.005) s 0.003 (0.003) m 0.014 (0.015) \n", + "Epoch: [2][12/12]\tTime (0.354) Data (0.000) Loss 0.0033 (0.0038) MAEs: e 0.054 (0.054) f 0.004 (0.005) s 0.003 (0.002) m 0.013 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "Epoch: [3][1/12]\tTime (0.351) Data (0.000) Loss 0.0032 (0.0032) MAEs: e 0.048 (0.048) f 0.005 (0.005) s 0.003 (0.003) m 0.014 (0.014) \n", + "Epoch: [3][6/12]\tTime (0.371) Data (0.000) Loss 0.0046 (0.0035) MAEs: e 0.064 (0.052) f 0.005 (0.005) s 0.002 (0.003) m 0.016 (0.014) \n", + "Epoch: [3][12/12]\tTime (0.351) Data (0.000) Loss 0.0088 (0.0038) MAEs: e 0.093 (0.054) f 0.005 (0.005) s 0.002 (0.002) m 0.016 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "Epoch: [4][1/12]\tTime (0.376) Data (0.000) Loss 0.0048 (0.0048) MAEs: e 0.066 (0.066) f 0.005 (0.005) s 0.002 (0.002) m 0.013 (0.013) \n", + "Epoch: [4][6/12]\tTime (0.375) Data (0.000) Loss 0.0017 (0.0036) MAEs: e 0.030 (0.053) f 0.005 (0.005) s 0.003 (0.002) m 0.016 (0.014) \n", + "Epoch: [4][12/12]\tTime (0.351) Data (0.000) Loss 0.0006 (0.0038) MAEs: e 0.020 (0.054) f 0.005 (0.005) s 0.003 (0.002) m 0.013 (0.014) \n", + "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", + "---------Evaluate Model on Test Set---------------\n", + "** e_MAE (0.056) \tf_MAE (0.005) \ts_MAE (0.003) \tm_MAE (0.015) \t\n" + ] + } + ], + "source": [ + "trainer.train(train_loader, val_loader, test_loader, train_composition_model=True)" + ] + }, + { + "cell_type": "markdown", + "id": "44", + "metadata": { + "id": "44" + }, + "source": [ + "### The more regorous way is to solve for the per-atom contribution by linear regression in your fine-tuning dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45", + "metadata": { + "id": "45", + "outputId": "12947804-631a-46e5-b356-712beecd6639" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The pretrained Atom_Ref (per atom reference energy):\n", + "Parameter containing:\n", + "tensor([[ -3.4431, -0.1279, -2.8300, -3.4737, -7.4946, -8.2354, -8.1611,\n", + " -8.3861, -5.7498, -0.0236, -1.7406, -1.6788, -4.2833, -6.2002,\n", + " -6.1315, -5.8405, -3.8795, -0.0703, -1.5668, -3.4451, -7.0549,\n", + " -9.1465, -9.2594, -9.3514, -8.9843, -8.0228, -6.4955, -5.6057,\n", + " -3.4002, -0.9217, -3.2499, -4.9164, -4.7810, -5.0191, -3.3316,\n", + " 0.5130, -1.4043, -3.2175, -7.4994, -9.3816, -10.4386, -9.9539,\n", + " -7.9555, -8.5440, -7.3245, -5.2771, -1.9014, -0.4034, -2.6002,\n", + " -4.0054, -4.1156, -3.9928, -2.7003, 2.2170, -1.9671, -3.7180,\n", + " -6.8133, -7.3502, -6.0712, -6.1699, -5.1471, -6.1925, -11.5829,\n", + " -15.8841, -5.9994, -6.0798, -5.9513, -6.0400, -5.9773, -2.5091,\n", + " -6.0767, -10.6666, -11.8761, -11.8491, -10.7397, -9.6100, -8.4755,\n", + " -6.2070, -3.0337, 0.4726, -1.6425, -3.1295, -3.3328, -0.1221,\n", + " -0.3448, -0.4364, -0.1661, -0.3680, -4.1869, -8.4233, -10.0467,\n", + " -12.0953, -12.5228, -14.2530]], requires_grad=True)\n" + ] + } + ], + "source": [ + "print(\"The pretrained Atom_Ref (per atom reference energy):\")\n", + "for param in chgnet.composition_model.parameters():\n", + " print(param)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46", + "metadata": { + "id": "46" + }, + "outputs": [], + "source": [ + "# A list of structures / graphs\n", + "structures = [\n", + " lmo,\n", + " Structure(\n", + " species=[\"Li\", \"Mn\", \"Mn\", \"O\", \"O\", \"O\"],\n", + " lattice=np.random.rand(3, 3),\n", + " coords=np.random.rand(6, 3),\n", + " ),\n", + " Structure(\n", + " species=[\"Li\", \"Li\", \"Mn\", \"O\", \"O\", \"O\"],\n", + " lattice=np.random.rand(3, 3),\n", + " coords=np.random.rand(6, 3),\n", + " ),\n", + " Structure(\n", + " species=[\"Li\", \"Mn\", \"Mn\", \"O\", \"O\", \"O\", \"O\"],\n", + " lattice=np.random.rand(3, 3),\n", + " coords=np.random.rand(7, 3),\n", + " ),\n", + "]\n", + "\n", + "# A list of energy_per_atom values (random values here)\n", + "energies_per_atom = [5.5, 6, 4.8, 5.6]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47", + "metadata": { + "id": "47", + "outputId": "44bcff3a-a208-45ce-e6c0-5d35e68ca6bf" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "We initialize another identical AtomRef layers\n", + "tensor([[-3.4431, -0.1279, -2.8300]], grad_fn=)\n" + ] + } + ], + "source": [ + "from chgnet.model.composition_model import AtomRef\n", + "\n", + "print(\"We initialize another identical AtomRef layers\")\n", + "new_atom_ref = AtomRef(is_intensive=True)\n", + "new_atom_ref.initialize_from_MPtrj()\n", + "for param in new_atom_ref.parameters():\n", + " print(param[:, :3])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48", + "metadata": { + "id": "48", + "outputId": "7af4aa6d-c839-4530-b063-6b852205f5fd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "After refitting, the AtomRef looks like:\n", + "Parameter containing:\n", + "tensor([[ 0.0000e+00, 0.0000e+00, 4.2667e+00, -3.3299e-15, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 2.9999e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1467e+01,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],\n", + " requires_grad=True)\n" + ] + } + ], + "source": [ + "# Solve linear regression to find the per atom contribution in your fine-tuning dataset\n", + "\n", + "new_atom_ref.fit(structures, energies_per_atom)\n", + "print(\"After refitting, the AtomRef looks like:\")\n", + "for param in new_atom_ref.parameters():\n", + " print(param)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + }, + "colab": { + "provenance": [], + "gpuType": "T4", + "include_colab_link": true + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file From 4332d174b78d44fd1a360a93b94852bfdd709b2f Mon Sep 17 00:00:00 2001 From: Habibur Rahman Date: Sun, 27 Oct 2024 23:02:10 -0400 Subject: [PATCH 8/8] Delete examples directory --- examples/fine_tuning.ipynb | 1074 ------------------------------------ 1 file changed, 1074 deletions(-) delete mode 100644 examples/fine_tuning.ipynb diff --git a/examples/fine_tuning.ipynb b/examples/fine_tuning.ipynb deleted file mode 100644 index 39379fd6..00000000 --- a/examples/fine_tuning.ipynb +++ /dev/null @@ -1,1074 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "id": "0", - "metadata": { - "id": "0" - }, - "source": [ - "# Fine-tune the pretrained CHGNet for better accuracy\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "1", - "metadata": { - "id": "1", - "outputId": "dbab2444-2e16-4783-e5ec-26bbb5396286", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting chgnet\n", - " Downloading chgnet-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)\n", - "Collecting ase>=3.23.0 (from chgnet)\n", - " Downloading ase-3.23.0-py3-none-any.whl.metadata (3.8 kB)\n", - "Requirement already satisfied: cython>=3 in /usr/local/lib/python3.10/dist-packages (from chgnet) (3.0.11)\n", - "Requirement already satisfied: numpy>=1.26 in /usr/local/lib/python3.10/dist-packages (from chgnet) (1.26.4)\n", - "Collecting nvidia-ml-py3>=7.352.0 (from chgnet)\n", - " Downloading nvidia-ml-py3-7.352.0.tar.gz (19 kB)\n", - " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "Collecting pymatgen>=2024.9.10 (from chgnet)\n", - " Downloading pymatgen-2024.10.27-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)\n", - "Requirement already satisfied: torch>=2.4.1 in /usr/local/lib/python3.10/dist-packages (from chgnet) (2.5.0+cu121)\n", - "Requirement already satisfied: typing-extensions>=4.12 in /usr/local/lib/python3.10/dist-packages (from chgnet) (4.12.2)\n", - "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from ase>=3.23.0->chgnet) (1.13.1)\n", - "Requirement already satisfied: matplotlib>=3.3.4 in /usr/local/lib/python3.10/dist-packages (from ase>=3.23.0->chgnet) (3.7.1)\n", - "Requirement already satisfied: joblib>=1 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (1.4.2)\n", - "Collecting matplotlib>=3.3.4 (from ase>=3.23.0->chgnet)\n", - " Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n", - "Collecting monty>=2024.7.29 (from pymatgen>=2024.9.10->chgnet)\n", - " Downloading monty-2024.10.21-py3-none-any.whl.metadata (3.6 kB)\n", - "Requirement already satisfied: networkx>=3 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (3.4.2)\n", - "Collecting palettable>=3.3.3 (from pymatgen>=2024.9.10->chgnet)\n", - " Downloading palettable-3.3.3-py2.py3-none-any.whl.metadata (3.3 kB)\n", - "Requirement already satisfied: pandas>=2 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (2.2.2)\n", - "Requirement already satisfied: plotly>=4.5.0 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (5.24.1)\n", - "Collecting pybtex>=0.24.0 (from pymatgen>=2024.9.10->chgnet)\n", - " Downloading pybtex-0.24.0-py2.py3-none-any.whl.metadata (2.0 kB)\n", - "Requirement already satisfied: requests>=2.32 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (2.32.3)\n", - "Collecting ruamel.yaml>=0.17.0 (from pymatgen>=2024.9.10->chgnet)\n", - " Downloading ruamel.yaml-0.18.6-py3-none-any.whl.metadata (23 kB)\n", - "Collecting spglib>=2.5.0 (from pymatgen>=2024.9.10->chgnet)\n", - " Downloading spglib-2.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)\n", - "Requirement already satisfied: sympy>=1.2 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (1.13.1)\n", - "Requirement already satisfied: tabulate>=0.9 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (0.9.0)\n", - "Requirement already satisfied: tqdm>=4.60 in /usr/local/lib/python3.10/dist-packages (from pymatgen>=2024.9.10->chgnet) (4.66.5)\n", - "Collecting uncertainties>=3.1.4 (from pymatgen>=2024.9.10->chgnet)\n", - " Downloading uncertainties-3.2.2-py3-none-any.whl.metadata (6.9 kB)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.4.1->chgnet) (3.16.1)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.4.1->chgnet) (3.1.4)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.4.1->chgnet) (2024.6.1)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy>=1.2->pymatgen>=2024.9.10->chgnet) (1.3.0)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (1.3.0)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (4.54.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (1.4.7)\n", - "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (24.1)\n", - "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (10.4.0)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (3.2.0)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3.4->ase>=3.23.0->chgnet) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=2->pymatgen>=2024.9.10->chgnet) (2024.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=2->pymatgen>=2024.9.10->chgnet) (2024.2)\n", - "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly>=4.5.0->pymatgen>=2024.9.10->chgnet) (9.0.0)\n", - "Requirement already satisfied: PyYAML>=3.01 in /usr/local/lib/python3.10/dist-packages (from pybtex>=0.24.0->pymatgen>=2024.9.10->chgnet) (6.0.2)\n", - "Collecting latexcodec>=1.0.4 (from pybtex>=0.24.0->pymatgen>=2024.9.10->chgnet)\n", - " Downloading latexcodec-3.0.0-py3-none-any.whl.metadata (4.9 kB)\n", - "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from pybtex>=0.24.0->pymatgen>=2024.9.10->chgnet) (1.16.0)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32->pymatgen>=2024.9.10->chgnet) (3.4.0)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32->pymatgen>=2024.9.10->chgnet) (3.10)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32->pymatgen>=2024.9.10->chgnet) (2.2.3)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32->pymatgen>=2024.9.10->chgnet) (2024.8.30)\n", - "Collecting ruamel.yaml.clib>=0.2.7 (from ruamel.yaml>=0.17.0->pymatgen>=2024.9.10->chgnet)\n", - " Downloading ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.7 kB)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.4.1->chgnet) (3.0.2)\n", - "Downloading chgnet-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.2 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.2/9.2 MB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading ase-3.23.0-py3-none-any.whl (2.9 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.9/2.9 MB\u001b[0m \u001b[31m65.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading pymatgen-2024.10.27-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.9 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m63.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.3/8.3 MB\u001b[0m \u001b[31m77.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading monty-2024.10.21-py3-none-any.whl (68 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m68.5/68.5 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading palettable-3.3.3-py2.py3-none-any.whl (332 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m332.3/332.3 kB\u001b[0m \u001b[31m25.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading pybtex-0.24.0-py2.py3-none-any.whl (561 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m561.4/561.4 kB\u001b[0m \u001b[31m35.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading ruamel.yaml-0.18.6-py3-none-any.whl (117 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.8/117.8 kB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading spglib-2.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m52.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading uncertainties-3.2.2-py3-none-any.whl (58 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading latexcodec-3.0.0-py3-none-any.whl (18 kB)\n", - "Downloading ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (722 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m722.2/722.2 kB\u001b[0m \u001b[31m36.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hBuilding wheels for collected packages: nvidia-ml-py3\n", - " Building wheel for nvidia-ml-py3 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for nvidia-ml-py3: filename=nvidia_ml_py3-7.352.0-py3-none-any.whl size=19173 sha256=66b17dd83b3e7b77db77b5bea01959b7680f8981bbe285546e23fcf0b71f7d37\n", - " Stored in directory: /root/.cache/pip/wheels/5c/d8/c0/46899f8be7a75a2ffd197a23c8797700ea858b9b34819fbf9e\n", - "Successfully built nvidia-ml-py3\n", - "Installing collected packages: nvidia-ml-py3, uncertainties, spglib, ruamel.yaml.clib, palettable, latexcodec, ruamel.yaml, pybtex, matplotlib, monty, ase, pymatgen, chgnet\n", - " Attempting uninstall: matplotlib\n", - " Found existing installation: matplotlib 3.7.1\n", - " Uninstalling matplotlib-3.7.1:\n", - " Successfully uninstalled matplotlib-3.7.1\n", - "Successfully installed ase-3.23.0 chgnet-0.4.0 latexcodec-3.0.0 matplotlib-3.9.2 monty-2024.10.21 nvidia-ml-py3-7.352.0 palettable-3.3.3 pybtex-0.24.0 pymatgen-2024.10.27 ruamel.yaml-0.18.6 ruamel.yaml.clib-0.2.12 spglib-2.5.0 uncertainties-3.2.2\n" - ] - } - ], - "source": [ - "try:\n", - " from chgnet.model import CHGNet\n", - "except ImportError:\n", - " # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n", - " !pip install chgnet" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2", - "metadata": { - "id": "2" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "from pymatgen.core import Structure\n", - "\n", - "# If the above line fails in Google Colab due to numpy version issue,\n", - "# please restart the runtime, and the problem will be solved" - ] - }, - { - "cell_type": "markdown", - "id": "3", - "metadata": { - "id": "3" - }, - "source": [ - "## 0. Parse DFT outputs to CHGNet readable formats\n" - ] - }, - { - "cell_type": "markdown", - "id": "4", - "metadata": { - "id": "4" - }, - "source": [ - "CHGNet is interfaced to [Pymatgen](https://pymatgen.org/), the training samples (normally coming from different DFTs like VASP),\n", - "need to be converted to [pymatgen.core.structure](https://pymatgen.org/pymatgen.core.html#module-pymatgen.core.structure).\n", - "\n", - "To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "5" - }, - "outputs": [], - "source": [ - "from chgnet.utils import parse_vasp_dir\n", - "\n", - "# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.\n", - "dataset_dict = parse_vasp_dir(\n", - " file_root=\"./my_vasp_calc_dir\", save_path=\"./my_vasp_calc_dir/chgnet_dataset.json\"\n", - ")\n", - "print(list(dataset_dict))" - ] - }, - { - "cell_type": "markdown", - "id": "6", - "metadata": { - "id": "6" - }, - "source": [ - "The parsed python dictionary includes information for CHGNet inputs (structures), and CHGNet prediction labels (energy, force, stress ,magmom).\n", - "\n", - "we can save the parsed structures and labels to disk, so that they can be easily reloaded during multiple rounds of training.\n", - "\n", - "The json file can be saved by providing the save_path\n" - ] - }, - { - "cell_type": "markdown", - "id": "7", - "metadata": { - "id": "7" - }, - "source": [ - "The Pymatgen structures can be saved separately if you're interested to take a look into each structure.\n", - "\n", - "Below are the example codes to save the structures in either json, pickle, cif, or CHGNet graph.\n", - "\n", - "For super-large training dataset, like MPtrj dataset, we recommend [converting them to CHGNet graphs](https://github.com/CederGroupHub/chgnet/blob/main/examples/make_graphs.py). This will save significant memory and graph computing time.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "8" - }, - "outputs": [], - "source": [ - "# Structure to json\n", - "from chgnet.utils import write_json\n", - "\n", - "dict_to_json = [struct.as_dict() for struct in dataset_dict[\"structure\"]]\n", - "write_json(dict_to_json, \"CHGNet_structures.json\")\n", - "\n", - "\n", - "# Structure to pickle\n", - "import pickle\n", - "\n", - "with open(\"CHGNet_structures.p\", \"wb\") as f:\n", - " pickle.dump(dataset_dict, f)\n", - "\n", - "\n", - "# Structure to cif\n", - "for idx, struct in enumerate(dataset_dict[\"structure\"]):\n", - " struct.to(filename=f\"{idx}.cif\")\n", - "\n", - "\n", - "# Structure to CHGNet graph\n", - "from chgnet.graph import CrystalGraphConverter\n", - "\n", - "converter = CrystalGraphConverter()\n", - "for idx, struct in enumerate(dataset_dict[\"structure\"]):\n", - " graph = converter(struct)\n", - " graph.save(fname=f\"{idx}.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "9", - "metadata": { - "id": "9" - }, - "source": [ - "For other types of DFT calculations, please refer to their interfaces\n", - "in [pymatgen.io](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io).\n", - "\n", - "see: [Quantum Espresso](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io.pwscf)\n", - "\n", - "see: [CP2K](https://pymatgen.org/pymatgen.io.cp2k.html#module-pymatgen.io.cp2k)\n", - "\n", - "see: [Gaussian](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io.gaussian)\n" - ] - }, - { - "cell_type": "markdown", - "id": "10", - "metadata": { - "id": "10" - }, - "source": [ - "## 1. Prepare Training Data\n" - ] - }, - { - "cell_type": "markdown", - "id": "11", - "metadata": { - "id": "11" - }, - "source": [ - "If you have parsed your VASP labels from step 0, you can reload the saved json file.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12", - "metadata": { - "id": "12" - }, - "outputs": [], - "source": [ - "from chgnet.utils import read_json\n", - "\n", - "dataset_dict = read_json(\"./my_vasp_calc_dir/chgnet_dataset.json\")\n", - "structures = [Structure.from_dict(struct) for struct in dataset_dict[\"structure\"]]\n", - "energies = dataset_dict[\"energy_per_atom\"]\n", - "forces = dataset_dict[\"force\"]\n", - "stresses = dataset_dict.get(\"stress\") or None\n", - "magmoms = dataset_dict.get(\"magmom\") or None" - ] - }, - { - "cell_type": "markdown", - "id": "13", - "metadata": { - "id": "13" - }, - "source": [ - "If you don't have any DFT calculations now, we can create a dummy fine-tuning dataset by using CHGNet prediction with some random noise.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "14", - "metadata": { - "id": "14" - }, - "outputs": [], - "source": [ - "try:\n", - " from chgnet import ROOT\n", - "\n", - " lmo = Structure.from_file(f\"{ROOT}/examples/mp-18767-LiMnO2.cif\")\n", - "except Exception:\n", - " from urllib.request import urlopen\n", - "\n", - " url = \"https://raw.githubusercontent.com/CederGroupHub/chgnet/main/examples/mp-18767-LiMnO2.cif\"\n", - " cif = urlopen(url).read().decode(\"utf-8\")\n", - " lmo = Structure.from_str(cif, fmt=\"cif\")\n", - "\n", - "structures, energies_per_atom, forces, stresses, magmoms = [], [], [], [], []\n", - "chgnet = CHGNet.load()\n", - "for _ in range(100):\n", - " structure = lmo.copy()\n", - " # stretch the cell by a small amount\n", - " structure.apply_strain(np.random.uniform(-0.1, 0.1, size=3))\n", - " # perturb all atom positions by a small amount\n", - " structure.perturb(0.1)\n", - "\n", - " pred = chgnet.predict_structure(structure)\n", - "\n", - " structures.append(structure)\n", - " energies_per_atom.append(pred[\"e\"] + np.random.uniform(-0.1, 0.1, size=1))\n", - " forces.append(pred[\"f\"] + np.random.uniform(-0.01, 0.01, size=pred[\"f\"].shape))\n", - " stresses.append(\n", - " pred[\"s\"] * -10 + np.random.uniform(-0.05, 0.05, size=pred[\"s\"].shape)\n", - " )\n", - " magmoms.append(pred[\"m\"] + np.random.uniform(-0.03, 0.03, size=pred[\"m\"].shape))" - ] - }, - { - "cell_type": "markdown", - "id": "15", - "metadata": { - "id": "15" - }, - "source": [ - "Note that the stress output from CHGNet is in unit of GPa, here the -10 unit conversion\n", - "modifies it to be kbar in VASP raw unit.\n", - "If you're using stress labels from VASP, you don't need to do any unit conversions\n", - "StructureData dataset class takes in VASP units.\n" - ] - }, - { - "cell_type": "markdown", - "id": "16", - "metadata": { - "id": "16" - }, - "source": [ - "## 2. Define DataSet\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": { - "id": "17" - }, - "outputs": [], - "source": [ - "from chgnet.data.dataset import StructureData, get_train_val_test_loader" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": { - "id": "18", - "outputId": "5ae32ebb-a111-4fc0-9d1e-df91bd2bee9b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "100 structures imported\n" - ] - } - ], - "source": [ - "dataset = StructureData(\n", - " structures=structures,\n", - " energies=energies_per_atom,\n", - " forces=forces,\n", - " stresses=stresses, # can be None\n", - " magmoms=magmoms, # can be None\n", - ")\n", - "train_loader, val_loader, test_loader = get_train_val_test_loader(\n", - " dataset, batch_size=8, train_ratio=0.9, val_ratio=0.05\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "19", - "metadata": { - "id": "19" - }, - "source": [ - "Alternatively, the dataset can be directly created from VASP calculation dir.\n", - "This function essentially parse the VASP directory first, save the labels to json file, and create the StructureData class\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20", - "metadata": { - "id": "20" - }, - "outputs": [], - "source": [ - "dataset = StructureData.from_vasp(\n", - " file_root=\"./my_vasp_calc_dir\", save_path=\"./my_vasp_calc_dir/chgnet_dataset.json\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "21", - "metadata": { - "id": "21" - }, - "source": [ - "The training set is used to optimize the CHGNet through gradient descent, the validation set is used to see validation error at the end of each epoch, and the test set is used to see the final test error at the end of training. The test set can be optional.\n", - "\n", - "The `batch_size` is defined to be 8 for small GPU-memory. If > 10 GB memory is available, we highly recommend to increase `batch_size` for better speed.\n", - "\n", - "If you have very large numbers (>100K) of structures (which is typical for AIMD), putting them all in a python list can quickly run into memory issues. In this case we highly recommend you to pre-convert all the structures into graphs and save them as shown in `examples/make_graphs.py`. Then directly train CHGNet by loading the graphs from disk instead of memory using the `GraphData` class defined in `data/dataset.py`.\n" - ] - }, - { - "cell_type": "markdown", - "id": "22", - "metadata": { - "id": "22" - }, - "source": [ - "## 3. Define model and trainer\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23", - "metadata": { - "id": "23", - "outputId": "9134d4fb-a0de-45d3-d268-e856bfad62e2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CHGNet v0.3.0 initialized with 412,525 parameters\n", - "CHGNet will run on cpu\n" - ] - } - ], - "source": [ - "from chgnet.model import CHGNet\n", - "from chgnet.trainer import Trainer\n", - "\n", - "# Load pretrained CHGNet\n", - "chgnet = CHGNet.load()" - ] - }, - { - "cell_type": "markdown", - "id": "24", - "metadata": { - "id": "24" - }, - "source": [ - "It's optional to freeze the weights inside some layers. This is a common technique to retain the learned knowledge during fine-tuning in large pretrained neural networks. You can choose the layers you want to freeze.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25", - "metadata": { - "id": "25" - }, - "outputs": [], - "source": [ - "# Optionally fix the weights of some layers\n", - "for layer in [\n", - " chgnet.atom_embedding,\n", - " chgnet.bond_embedding,\n", - " chgnet.angle_embedding,\n", - " chgnet.bond_basis_expansion,\n", - " chgnet.angle_basis_expansion,\n", - " chgnet.atom_conv_layers[:-1],\n", - " chgnet.bond_conv_layers,\n", - " chgnet.angle_layers,\n", - "]:\n", - " for param in layer.parameters():\n", - " param.requires_grad = False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26", - "metadata": { - "id": "26" - }, - "outputs": [], - "source": [ - "# Define Trainer\n", - "trainer = Trainer(\n", - " model=chgnet,\n", - " targets=\"efsm\",\n", - " optimizer=\"Adam\",\n", - " scheduler=\"CosLR\",\n", - " criterion=\"MSE\",\n", - " epochs=5,\n", - " learning_rate=1e-2,\n", - " use_device=\"cpu\",\n", - " print_freq=6,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "27", - "metadata": { - "id": "27" - }, - "source": [ - "## 4. Start training\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28", - "metadata": { - "id": "28", - "outputId": "cc06272b-897a-49ec-90b5-8521b6523eb3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Begin Training: using cpu device\n", - "training targets: efsm\n", - "Epoch: [0][1/12]\tTime (0.476) Data (0.016) Loss 0.0033 (0.0033) MAEs: e 0.053 (0.053) f 0.004 (0.004) s 0.002 (0.002) m 0.016 (0.016) \n", - "Epoch: [0][6/12]\tTime (0.426) Data (0.015) Loss 0.0040 (0.0039) MAEs: e 0.054 (0.056) f 0.005 (0.005) s 0.002 (0.002) m 0.015 (0.015) \n", - "Epoch: [0][12/12]\tTime (0.414) Data (0.014) Loss 0.0040 (0.0038) MAEs: e 0.054 (0.054) f 0.005 (0.005) s 0.002 (0.002) m 0.015 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "Epoch: [1][1/12]\tTime (0.409) Data (0.000) Loss 0.0052 (0.0052) MAEs: e 0.064 (0.064) f 0.005 (0.005) s 0.002 (0.002) m 0.013 (0.013) \n", - "Epoch: [1][6/12]\tTime (0.393) Data (0.000) Loss 0.0036 (0.0039) MAEs: e 0.053 (0.055) f 0.005 (0.005) s 0.002 (0.002) m 0.014 (0.014) \n", - "Epoch: [1][12/12]\tTime (0.371) Data (0.000) Loss 0.0029 (0.0038) MAEs: e 0.053 (0.054) f 0.005 (0.005) s 0.003 (0.002) m 0.012 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "Epoch: [2][1/12]\tTime (0.389) Data (0.000) Loss 0.0056 (0.0056) MAEs: e 0.065 (0.065) f 0.005 (0.005) s 0.002 (0.002) m 0.015 (0.015) \n", - "Epoch: [2][6/12]\tTime (0.377) Data (0.000) Loss 0.0042 (0.0046) MAEs: e 0.059 (0.062) f 0.005 (0.005) s 0.002 (0.002) m 0.014 (0.014) \n", - "Epoch: [2][12/12]\tTime (0.350) Data (0.000) Loss 0.0025 (0.0038) MAEs: e 0.048 (0.054) f 0.005 (0.005) s 0.002 (0.002) m 0.008 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "Epoch: [3][1/12]\tTime (0.363) Data (0.000) Loss 0.0049 (0.0049) MAEs: e 0.065 (0.065) f 0.005 (0.005) s 0.002 (0.002) m 0.014 (0.014) \n", - "Epoch: [3][6/12]\tTime (0.359) Data (0.000) Loss 0.0050 (0.0042) MAEs: e 0.066 (0.057) f 0.005 (0.005) s 0.003 (0.002) m 0.014 (0.014) \n", - "Epoch: [3][12/12]\tTime (0.355) Data (0.000) Loss 0.0045 (0.0038) MAEs: e 0.059 (0.054) f 0.004 (0.005) s 0.003 (0.002) m 0.012 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "Epoch: [4][1/12]\tTime (0.384) Data (0.000) Loss 0.0033 (0.0033) MAEs: e 0.051 (0.051) f 0.005 (0.005) s 0.003 (0.003) m 0.015 (0.015) \n", - "Epoch: [4][6/12]\tTime (0.384) Data (0.000) Loss 0.0016 (0.0033) MAEs: e 0.035 (0.051) f 0.005 (0.005) s 0.002 (0.002) m 0.012 (0.014) \n", - "Epoch: [4][12/12]\tTime (0.351) Data (0.000) Loss 0.0011 (0.0038) MAEs: e 0.033 (0.054) f 0.004 (0.005) s 0.002 (0.002) m 0.014 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "---------Evaluate Model on Test Set---------------\n", - "** e_MAE (0.056) \tf_MAE (0.005) \ts_MAE (0.003) \tm_MAE (0.015) \t\n" - ] - } - ], - "source": [ - "trainer.train(train_loader, val_loader, test_loader)" - ] - }, - { - "cell_type": "markdown", - "id": "29", - "metadata": { - "id": "29" - }, - "source": [ - "After training, the trained model can be found in the directory of today's date. Or it can be accessed by:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30", - "metadata": { - "id": "30" - }, - "outputs": [], - "source": [ - "model = trainer.model\n", - "best_model = trainer.best_model # best model based on validation energy MAE" - ] - }, - { - "cell_type": "markdown", - "id": "31", - "metadata": { - "id": "31" - }, - "source": [ - "## Extras 1: GGA / GGA+U compatibility\n" - ] - }, - { - "cell_type": "markdown", - "id": "32", - "metadata": { - "id": "32" - }, - "source": [ - "### Q: Why and when do you care about this?\n", - "\n", - "**When**: If you want to fine-tune the pretrained CHGNet with your own GGA+U VASP calculations, and you want to keep your VASP energy compatible to the pretrained dataset. In case your dataset is so large that the pretrained knowledge does not matter to you, you can ignore this.\n", - "\n", - "**Why**: CHGNet is trained on both GGA and GGA+U calculations from Materials Project. And there has been developed methods in solving the compatibility between GGA and GGA+U calculations which makes the energies universally applicable for cross-chemistry comparison and phase-diagram constructions. Please refer to:\n", - "\n", - "https://journals.aps.org/prb/abstract/10.1103/PhysRevB.84.045115\n", - "\n", - "Below we show an example to apply the compatibility.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33", - "metadata": { - "id": "33", - "outputId": "e2ab424d-7e76-4735-c793-1105a548b5eb" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The raw total energy from VASP of LMO is: -58.97 eV\n" - ] - } - ], - "source": [ - "# Imagine this is the VASP raw energy\n", - "vasp_raw_energy = -58.97\n", - "\n", - "print(f\"The raw total energy from VASP of LMO is: {vasp_raw_energy} eV\")" - ] - }, - { - "cell_type": "markdown", - "id": "34", - "metadata": { - "id": "34" - }, - "source": [ - "You can look for the energy correction applied to each element in :\n", - "\n", - "https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/entries/MP2020Compatibility.yaml\n", - "\n", - "Here LiMnO2 applies to both Mn in transition metal oxides correction and oxide correction.\n" - ] - }, - { - "cell_type": "markdown", - "id": "35", - "metadata": { - "id": "35" - }, - "source": [ - "To demystify `MaterialsProject2020Compatibility`, basically all that's happening is:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36", - "metadata": { - "id": "36", - "outputId": "329e0ad2-10a2-4571-9b3d-ff0edf4afab9" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The corrected total energy after MP2020 = -65.05 eV\n" - ] - } - ], - "source": [ - "Mn_correction_in_TMO = -1.668\n", - "oxide_correction = -0.687\n", - "_, num_Mn, num_O = lmo.composition.values()\n", - "\n", - "\n", - "corrected_energy = (\n", - " vasp_raw_energy + num_Mn * Mn_correction_in_TMO + num_O * oxide_correction\n", - ")\n", - "print(f\"The corrected total energy after MP2020 = {corrected_energy:.4} eV\")" - ] - }, - { - "cell_type": "markdown", - "id": "37", - "metadata": { - "id": "37" - }, - "source": [ - "You can also apply the `MaterialsProject2020Compatibility` through pymatgen\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "38", - "metadata": { - "id": "38", - "outputId": "7e4dc685-ce0f-4da3-a9f7-1c0cb2c3fdf4" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The total energy of LMO after MP2020Compatibility correction = -62.31 eV\n" - ] - } - ], - "source": [ - "from pymatgen.entries.compatibility import MaterialsProject2020Compatibility\n", - "from pymatgen.entries.computed_entries import ComputedStructureEntry\n", - "\n", - "params = {\"hubbards\": {\"Mn\": 3.9, \"O\": 0, \"Li\": 0}, \"run_type\": \"GGA+U\"}\n", - "\n", - "cse = ComputedStructureEntry(lmo, vasp_raw_energy, parameters=params)\n", - "\n", - "MaterialsProject2020Compatibility(check_potcar=False).process_entries(cse)\n", - "print(\n", - " f\"The total energy of LMO after MP2020Compatibility correction = {cse.energy:.4} eV\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "39", - "metadata": { - "id": "39" - }, - "source": [ - "Now use this corrected energy as labels to tune CHGNet, you're good to go!\n" - ] - }, - { - "cell_type": "markdown", - "id": "40", - "metadata": { - "id": "40" - }, - "source": [ - "## Extras 2: AtomRef\n" - ] - }, - { - "cell_type": "markdown", - "id": "41", - "metadata": { - "id": "41" - }, - "source": [ - "### Q: Why and when do you care about this?\n", - "\n", - "**When**: When you fine tune CHGNet to DFT labels that are incompatible with Materials Project, like r2SCAN functional, or other DFTs like Gaussian or QE. The large shifts in elemental energy is not of our interest and should be reconciled. For example, Li has -0.95 eV/atom in GGA (https://next-gen.materialsproject.org/materials/mp-135/tasks/mp-990455) and -1.17 eV/atom in R2SCAN (https://next-gen.materialsproject.org/materials/mp-135/tasks/mp-1943895)\n", - "\n", - "**Why**: The GNN learns the interaction between the atoms and the composition model (AtomRef) in CHGNet is used to normalize the elemental energy contribution, similar to a formation-energy-like calculation. During fine-tuning, we want to keep the most of knowledge unchanged in the GNN and allow the AtomRef to shift for the elemental energy change. So that the finetuning on the graph layers can be focused on energy contribution from atom-atom interaction instead of meaningless atom reference energies.\n", - "\n", - "Below I will show an example to fit the AtomRef layer:\n" - ] - }, - { - "cell_type": "markdown", - "id": "42", - "metadata": { - "id": "42" - }, - "source": [ - "### A quick and easy way to turn on training of AtomRef in the trainer (this is by default off):\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43", - "metadata": { - "id": "43", - "outputId": "7c84164d-9e80-46e3-beec-513850a04e3e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Begin Training: using cpu device\n", - "training targets: efsm\n", - "Epoch: [0][1/12]\tTime (0.475) Data (0.001) Loss 0.0028 (0.0028) MAEs: e 0.047 (0.047) f 0.005 (0.005) s 0.003 (0.003) m 0.014 (0.014) \n", - "Epoch: [0][6/12]\tTime (0.379) Data (0.000) Loss 0.0027 (0.0037) MAEs: e 0.046 (0.053) f 0.005 (0.005) s 0.002 (0.002) m 0.015 (0.014) \n", - "Epoch: [0][12/12]\tTime (0.359) Data (0.000) Loss 0.0010 (0.0038) MAEs: e 0.030 (0.054) f 0.005 (0.005) s 0.003 (0.002) m 0.012 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "Epoch: [1][1/12]\tTime (0.417) Data (0.000) Loss 0.0011 (0.0011) MAEs: e 0.027 (0.027) f 0.004 (0.004) s 0.002 (0.002) m 0.015 (0.015) \n", - "Epoch: [1][6/12]\tTime (0.359) Data (0.000) Loss 0.0049 (0.0040) MAEs: e 0.062 (0.056) f 0.005 (0.005) s 0.003 (0.002) m 0.015 (0.015) \n", - "Epoch: [1][12/12]\tTime (0.351) Data (0.000) Loss 0.0054 (0.0038) MAEs: e 0.073 (0.054) f 0.004 (0.005) s 0.002 (0.002) m 0.013 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "Epoch: [2][1/12]\tTime (0.368) Data (0.000) Loss 0.0027 (0.0027) MAEs: e 0.043 (0.043) f 0.005 (0.005) s 0.003 (0.003) m 0.016 (0.016) \n", - "Epoch: [2][6/12]\tTime (0.388) Data (0.000) Loss 0.0042 (0.0034) MAEs: e 0.056 (0.051) f 0.005 (0.005) s 0.003 (0.003) m 0.014 (0.015) \n", - "Epoch: [2][12/12]\tTime (0.354) Data (0.000) Loss 0.0033 (0.0038) MAEs: e 0.054 (0.054) f 0.004 (0.005) s 0.003 (0.002) m 0.013 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "Epoch: [3][1/12]\tTime (0.351) Data (0.000) Loss 0.0032 (0.0032) MAEs: e 0.048 (0.048) f 0.005 (0.005) s 0.003 (0.003) m 0.014 (0.014) \n", - "Epoch: [3][6/12]\tTime (0.371) Data (0.000) Loss 0.0046 (0.0035) MAEs: e 0.064 (0.052) f 0.005 (0.005) s 0.002 (0.003) m 0.016 (0.014) \n", - "Epoch: [3][12/12]\tTime (0.351) Data (0.000) Loss 0.0088 (0.0038) MAEs: e 0.093 (0.054) f 0.005 (0.005) s 0.002 (0.002) m 0.016 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "Epoch: [4][1/12]\tTime (0.376) Data (0.000) Loss 0.0048 (0.0048) MAEs: e 0.066 (0.066) f 0.005 (0.005) s 0.002 (0.002) m 0.013 (0.013) \n", - "Epoch: [4][6/12]\tTime (0.375) Data (0.000) Loss 0.0017 (0.0036) MAEs: e 0.030 (0.053) f 0.005 (0.005) s 0.003 (0.002) m 0.016 (0.014) \n", - "Epoch: [4][12/12]\tTime (0.351) Data (0.000) Loss 0.0006 (0.0038) MAEs: e 0.020 (0.054) f 0.005 (0.005) s 0.003 (0.002) m 0.013 (0.014) \n", - "* e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n", - "---------Evaluate Model on Test Set---------------\n", - "** e_MAE (0.056) \tf_MAE (0.005) \ts_MAE (0.003) \tm_MAE (0.015) \t\n" - ] - } - ], - "source": [ - "trainer.train(train_loader, val_loader, test_loader, train_composition_model=True)" - ] - }, - { - "cell_type": "markdown", - "id": "44", - "metadata": { - "id": "44" - }, - "source": [ - "### The more regorous way is to solve for the per-atom contribution by linear regression in your fine-tuning dataset\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "45", - "metadata": { - "id": "45", - "outputId": "12947804-631a-46e5-b356-712beecd6639" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The pretrained Atom_Ref (per atom reference energy):\n", - "Parameter containing:\n", - "tensor([[ -3.4431, -0.1279, -2.8300, -3.4737, -7.4946, -8.2354, -8.1611,\n", - " -8.3861, -5.7498, -0.0236, -1.7406, -1.6788, -4.2833, -6.2002,\n", - " -6.1315, -5.8405, -3.8795, -0.0703, -1.5668, -3.4451, -7.0549,\n", - " -9.1465, -9.2594, -9.3514, -8.9843, -8.0228, -6.4955, -5.6057,\n", - " -3.4002, -0.9217, -3.2499, -4.9164, -4.7810, -5.0191, -3.3316,\n", - " 0.5130, -1.4043, -3.2175, -7.4994, -9.3816, -10.4386, -9.9539,\n", - " -7.9555, -8.5440, -7.3245, -5.2771, -1.9014, -0.4034, -2.6002,\n", - " -4.0054, -4.1156, -3.9928, -2.7003, 2.2170, -1.9671, -3.7180,\n", - " -6.8133, -7.3502, -6.0712, -6.1699, -5.1471, -6.1925, -11.5829,\n", - " -15.8841, -5.9994, -6.0798, -5.9513, -6.0400, -5.9773, -2.5091,\n", - " -6.0767, -10.6666, -11.8761, -11.8491, -10.7397, -9.6100, -8.4755,\n", - " -6.2070, -3.0337, 0.4726, -1.6425, -3.1295, -3.3328, -0.1221,\n", - " -0.3448, -0.4364, -0.1661, -0.3680, -4.1869, -8.4233, -10.0467,\n", - " -12.0953, -12.5228, -14.2530]], requires_grad=True)\n" - ] - } - ], - "source": [ - "print(\"The pretrained Atom_Ref (per atom reference energy):\")\n", - "for param in chgnet.composition_model.parameters():\n", - " print(param)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "46", - "metadata": { - "id": "46" - }, - "outputs": [], - "source": [ - "# A list of structures / graphs\n", - "structures = [\n", - " lmo,\n", - " Structure(\n", - " species=[\"Li\", \"Mn\", \"Mn\", \"O\", \"O\", \"O\"],\n", - " lattice=np.random.rand(3, 3),\n", - " coords=np.random.rand(6, 3),\n", - " ),\n", - " Structure(\n", - " species=[\"Li\", \"Li\", \"Mn\", \"O\", \"O\", \"O\"],\n", - " lattice=np.random.rand(3, 3),\n", - " coords=np.random.rand(6, 3),\n", - " ),\n", - " Structure(\n", - " species=[\"Li\", \"Mn\", \"Mn\", \"O\", \"O\", \"O\", \"O\"],\n", - " lattice=np.random.rand(3, 3),\n", - " coords=np.random.rand(7, 3),\n", - " ),\n", - "]\n", - "\n", - "# A list of energy_per_atom values (random values here)\n", - "energies_per_atom = [5.5, 6, 4.8, 5.6]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47", - "metadata": { - "id": "47", - "outputId": "44bcff3a-a208-45ce-e6c0-5d35e68ca6bf" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "We initialize another identical AtomRef layers\n", - "tensor([[-3.4431, -0.1279, -2.8300]], grad_fn=)\n" - ] - } - ], - "source": [ - "from chgnet.model.composition_model import AtomRef\n", - "\n", - "print(\"We initialize another identical AtomRef layers\")\n", - "new_atom_ref = AtomRef(is_intensive=True)\n", - "new_atom_ref.initialize_from_MPtrj()\n", - "for param in new_atom_ref.parameters():\n", - " print(param[:, :3])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "48", - "metadata": { - "id": "48", - "outputId": "7af4aa6d-c839-4530-b063-6b852205f5fd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "After refitting, the AtomRef looks like:\n", - "Parameter containing:\n", - "tensor([[ 0.0000e+00, 0.0000e+00, 4.2667e+00, -3.3299e-15, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 2.9999e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1467e+01,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],\n", - " requires_grad=True)\n" - ] - } - ], - "source": [ - "# Solve linear regression to find the per atom contribution in your fine-tuning dataset\n", - "\n", - "new_atom_ref.fit(structures, energies_per_atom)\n", - "print(\"After refitting, the AtomRef looks like:\")\n", - "for param in new_atom_ref.parameters():\n", - " print(param)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - }, - "colab": { - "provenance": [], - "gpuType": "T4", - "include_colab_link": true - }, - "accelerator": "GPU" - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file