Skip to content

Commit

Permalink
Add ALIGNN_FF2, radius_graph_jarvis.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Oct 27, 2024
1 parent d599de3 commit 77837a3
Show file tree
Hide file tree
Showing 9 changed files with 765 additions and 99 deletions.
16 changes: 4 additions & 12 deletions alignn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,10 @@
from typing import Literal
from alignn.utils import BaseSettings
from alignn.models.alignn import ALIGNNConfig
from alignn.models.alignn_ff2 import ALIGNNFF2Config
from alignn.models.alignn_atomwise import ALIGNNAtomWiseConfig

# from alignn.models.modified_cgcnn import CGCNNConfig
# from alignn.models.icgcnn import ICGCNNConfig
# from alignn.models.gcn import SimpleGCNConfig
# from alignn.models.densegcn import DenseGCNConfig
# from pydantic import model_validator
# from alignn.models.dense_alignn import DenseALIGNNConfig
# from alignn.models.alignn_cgcnn import ACGCNNConfig
# from alignn.models.alignn_layernorm import ALIGNNConfig as ALIGNN_LN_Config

# from typing import List
# import torch

try:
VERSION = (
Expand Down Expand Up @@ -167,9 +159,8 @@ class TrainingConfig(BaseSettings):
] = "k-nearest"
id_tag: Literal["jid", "id", "_oqmd_entry_id"] = "jid"

# logging configuration

# training configuration
dtype: str = "float32"
random_seed: Optional[int] = 123
classification_threshold: Optional[float] = None
# target_range: Optional[List] = None
Expand Down Expand Up @@ -219,6 +210,7 @@ class TrainingConfig(BaseSettings):
# model configuration
model: Union[
ALIGNNConfig,
ALIGNNFF2Config,
ALIGNNAtomWiseConfig,
# CGCNNConfig,
# ICGCNNConfig,
Expand Down
4 changes: 4 additions & 0 deletions alignn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def get_train_val_loaders(
world_size=0,
rank=0,
use_lmdb: bool = True,
dtype="float32",
):
"""Help function to set up JARVIS train and val dataloaders."""
if use_lmdb:
Expand Down Expand Up @@ -383,6 +384,7 @@ def get_train_val_loaders(
output_dir=output_dir,
sampler=train_sampler,
tmp_name=tmp_name,
dtype=dtype,
# tmp_name="train_data",
)
tmp_name = filename + "val_data"
Expand All @@ -406,6 +408,7 @@ def get_train_val_loaders(
classification=classification_threshold is not None,
output_dir=output_dir,
tmp_name=tmp_name,
dtype=dtype,
# tmp_name="val_data",
)
if len(dataset_val) > 0
Expand All @@ -431,6 +434,7 @@ def get_train_val_loaders(
classification=classification_threshold is not None,
output_dir=output_dir,
tmp_name=tmp_name,
dtype=dtype,
# tmp_name="test_data",
)
if len(dataset_test) > 0
Expand Down
6 changes: 6 additions & 0 deletions alignn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def load_graphs(
id_tag="jid",
# extra_feats_json=None,
map_size=1e12,
dtype="float32",
):
"""Construct crystal graphs.
Expand Down Expand Up @@ -54,6 +55,7 @@ def atoms_to_graph(atoms):
compute_line_graph=False,
use_canonize=use_canonize,
neighbor_strategy=neighbor_strategy,
dtype=dtype,
)

if cachedir is not None:
Expand Down Expand Up @@ -84,6 +86,7 @@ def atoms_to_graph(atoms):
use_canonize=use_canonize,
neighbor_strategy=neighbor_strategy,
id=i[id_tag],
dtype=dtype,
)
# print ('ii',ii)
if "extra_features" in i:
Expand Down Expand Up @@ -124,6 +127,7 @@ def get_torch_dataset(
output_dir=".",
tmp_name="dataset",
sampler=None,
dtype="float32",
):
"""Get Torch Dataset."""
df = pd.DataFrame(dataset)
Expand All @@ -147,6 +151,7 @@ def get_torch_dataset(
cutoff_extra=cutoff_extra,
max_neighbors=max_neighbors,
id_tag=id_tag,
dtype=dtype,
)
data = StructureDataset(
df,
Expand All @@ -160,5 +165,6 @@ def get_torch_dataset(
id_tag=id_tag,
classification=classification,
sampler=sampler,
dtype=dtype,
)
return data
5 changes: 3 additions & 2 deletions alignn/examples/sample_data_ff/config_example_atomwise.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"dataset": "user_data",
"target": "target",
"atom_features": "cgcnn",
"neighbor_strategy": "radius_graph",
"neighbor_strategy": "radius_graph_jarvis",
"id_tag": "jid",
"dtype": "float32",
"random_seed": 123,
"classification_threshold": null,
"n_val": null,
Expand Down Expand Up @@ -39,7 +40,7 @@
"distributed":false,
"use_lmdb": true,
"model": {
"name": "alignn_atomwise",
"name": "alignn_ff2",
"atom_input_features": 92,
"calculate_gradient":true,
"atomwise_output_features":0,
Expand Down
145 changes: 91 additions & 54 deletions alignn/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,100 @@
from jarvis.analysis.structure.neighbors import NeighborsAnalysis
from jarvis.core.specie import chem_data, get_node_attributes
import math

# from jarvis.core.atoms import Atoms
from collections import defaultdict
from typing import List, Tuple, Sequence, Optional
from dgl.data import DGLDataset

import torch
import dgl
from tqdm import tqdm


def temp_graph(atoms=None, cutoff=4.0, atom_features="cgcnn", dtype="float32"):
"""Helper function to construct a graph for a given cutoff."""
TORCH_DTYPES = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat": torch.bfloat16,
}
dtype = TORCH_DTYPES[dtype]
u, v, r, d, images, atom_feats = [], [], [], [], [], []
elements = atoms.elements

# Loop over each atom in the structure
for ii, i in enumerate(atoms.cart_coords):
# Get neighbors within the cutoff distance
neighs = atoms.lattice.get_points_in_sphere(
atoms.frac_coords, i, cutoff, distance_vector=True
)

# Filter out self-loops (exclude cases where atom is bonded to itself)
valid_indices = neighs[2] != ii

u.extend([ii] * np.sum(valid_indices))
d.extend(neighs[1][valid_indices])
v.extend(neighs[2][valid_indices])
images.extend(neighs[3][valid_indices])
r.extend(neighs[4][valid_indices])

feat = list(
get_node_attributes(elements[ii], atom_features=atom_features)
)
atom_feats.append(feat)

# Create DGL graph
g = dgl.graph((np.array(u), np.array(v)))

# Add data to the graph with the specified dtype
g.ndata["atom_features"] = torch.tensor(atom_feats, dtype=dtype)
g.edata["r"] = torch.tensor(r, dtype=dtype)
g.edata["d"] = torch.tensor(d, dtype=dtype)
g.edata["images"] = torch.tensor(images, dtype=dtype)
g.ndata["coords"] = torch.tensor(atoms.cart_coords, dtype=dtype)
g.ndata["V"] = torch.tensor([atoms.volume] * atoms.num_atoms, dtype=dtype)

return g, u, v, r


def radius_graph_jarvis(
atoms,
cutoff_extra=0.5,
cutoff=4.0,
atom_features="cgcnn",
line_graph=True,
dtype="float32",
):
"""Construct radius graph with dynamic cutoff."""

while True:
# try:
# Attempt to create the graph
g, u, v, r = temp_graph(
atoms=atoms,
cutoff=cutoff,
atom_features=atom_features,
dtype=dtype,
)
# Check if all atoms are included as nodes
if g.num_nodes() == len(atoms.elements):
# print(f"Graph constructed with cutoff: {cutoff}")
break # Exit the loop when successful
# Increment the cutoff if the graph is incomplete
cutoff += cutoff_extra
# print(f"Increasing cutoff to: {cutoff}")

# except Exception as exp:
# # Handle exceptions and try again
# print(f"Graph construction failed: {exp}")
# cutoff += cutoff_extra # Try with a larger cutoff

# Optional: Create a line graph if requested
if line_graph:
lg = g.line_graph(shared=True)
lg.apply_edges(compute_bond_cosines)
return g, lg

try:
from tqdm import tqdm
except Exception as exp:
print("tqdm is not installed.", exp)
pass
return g


def canonize_edge(
Expand Down Expand Up @@ -320,52 +400,6 @@ def radius_graph_old(


###
def radius_graph_jarvis(
atoms, cutoff=4, atom_features="cgcnn", line_graph=True
):
"""Construct edge list for radius graph."""
u, v, r, atom_feats = [], [], [], []
elements = atoms.elements

# Loop over each atom in the structure
for ii, i in enumerate(atoms.cart_coords):
# Get neighbors within the cutoff distance
neighs = atoms.lattice.get_points_in_sphere(
atoms.frac_coords, i, cutoff, distance_vector=True
)

# Filter out self-loops (where the neighbor is the same as the source atom)
valid_indices = neighs[2] != ii # Exclude self-loops

# Store source (u), destination (v), and distances (r) only for valid neighbors
u.extend(
[ii] * np.sum(valid_indices)
) # Add the source atom multiple times
v.extend(neighs[2][valid_indices]) # Add valid neighbors only
r.extend(neighs[-1][valid_indices]) # Add distances of valid neighbors

# Store atom features for the current atom
feat = list(
get_node_attributes(elements[ii], atom_features=atom_features)
)
atom_feats.append(feat)

# Create DGL graph
g = dgl.graph((np.array(u), np.array(v)))
g.ndata["atom_features"] = torch.tensor(atom_feats, dtype=torch.float32)
g.edata["r"] = torch.tensor(r, dtype=torch.float32)
g.ndata["coords"] = torch.tensor(atoms.cart_coords, dtype=torch.float32)
g.ndata["V"] = torch.tensor(
[atoms.volume] * atoms.num_atoms, dtype=torch.float32
)

# Optional: Create a line graph if requested
if line_graph:
lg = g.line_graph(shared=True)
lg.apply_edges(compute_bond_cosines)
return g, lg

return g


class Graph(object):
Expand Down Expand Up @@ -415,6 +449,7 @@ def atom_dgl_multigraph(
# use_canonize: bool = False,
use_lattice_prop: bool = False,
cutoff_extra=3.5,
dtype=torch.float32,
):
"""Obtain a DGLGraph for Atoms object."""
# print('id',id)
Expand All @@ -441,6 +476,7 @@ def atom_dgl_multigraph(
cutoff=cutoff,
atom_features=atom_features,
line_graph=compute_line_graph,
dtype=dtype,
)
return g, lg
else:
Expand Down Expand Up @@ -784,6 +820,7 @@ def __init__(
classification=False,
id_tag="jid",
sampler=None,
dtype="float32",
):
"""Pytorch Dataset for atomistic graphs.
Expand Down
2 changes: 2 additions & 0 deletions alignn/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_torch_dataset(
tmp_name="dataset",
map_size=1e12,
read_existing=True,
dtype="float32",
):
"""Get Torch Dataset with LMDB."""
vals = np.array([ii[target] for ii in dataset]) # df[target].values
Expand Down Expand Up @@ -151,6 +152,7 @@ def get_torch_dataset(
use_canonize=use_canonize,
cutoff_extra=cutoff_extra,
neighbor_strategy=neighbor_strategy,
dtype=dtype,
)
if line_graph:
g, lg = g
Expand Down
Loading

0 comments on commit 77837a3

Please sign in to comment.