Skip to content

Commit

Permalink
Add extra radius_graph_cutoff.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Oct 2, 2023
1 parent 60cd9bc commit 4d095b2
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions alignn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class TrainingConfig(BaseSettings):
use_canonize: bool = True
num_workers: int = 4
cutoff: float = 8.0
cutoff_extra: float = 3.0
max_neighbors: int = 12
keep_data_order: bool = True
normalize_graph_level_loss: bool = False
Expand Down
9 changes: 9 additions & 0 deletions alignn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def load_graphs(
name: str = "dft_3d",
neighbor_strategy: str = "k-nearest",
cutoff: float = 8,
cutoff_extra: float = 3,
max_neighbors: int = 12,
cachedir: Optional[Path] = None,
use_canonize: bool = False,
Expand Down Expand Up @@ -100,6 +101,7 @@ def atoms_to_graph(atoms):
return Graph.atom_dgl_multigraph(
structure,
cutoff=cutoff,
cutoff_extra=cutoff_extra,
atom_features="atomic_number",
max_neighbors=max_neighbors,
compute_line_graph=False,
Expand Down Expand Up @@ -128,6 +130,7 @@ def atoms_to_graph(atoms):
g = Graph.atom_dgl_multigraph(
structure,
cutoff=cutoff,
cutoff_extra=cutoff_extra,
atom_features="atomic_number",
max_neighbors=max_neighbors,
compute_line_graph=False,
Expand Down Expand Up @@ -229,6 +232,7 @@ def get_torch_dataset(
name="",
line_graph="",
cutoff=8.0,
cutoff_extra=3.0,
max_neighbors=12,
classification=False,
output_dir=".",
Expand All @@ -253,6 +257,7 @@ def get_torch_dataset(
neighbor_strategy=neighbor_strategy,
use_canonize=use_canonize,
cutoff=cutoff,
cutoff_extra=cutoff_extra,
max_neighbors=max_neighbors,
id_tag=id_tag,
)
Expand Down Expand Up @@ -297,6 +302,7 @@ def get_train_val_loaders(
id_tag: str = "jid",
use_canonize: bool = False,
cutoff: float = 8.0,
cutoff_extra: float = 3.0,
max_neighbors: int = 12,
classification_threshold: Optional[float] = None,
target_multiplication_factor: Optional[float] = None,
Expand Down Expand Up @@ -505,6 +511,7 @@ def get_train_val_loaders(
name=dataset,
line_graph=line_graph,
cutoff=cutoff,
cutoff_extra=cutoff_extra,
max_neighbors=max_neighbors,
classification=classification_threshold is not None,
output_dir=output_dir,
Expand All @@ -524,6 +531,7 @@ def get_train_val_loaders(
name=dataset,
line_graph=line_graph,
cutoff=cutoff,
cutoff_extra=cutoff_extra,
max_neighbors=max_neighbors,
classification=classification_threshold is not None,
output_dir=output_dir,
Expand All @@ -546,6 +554,7 @@ def get_train_val_loaders(
name=dataset,
line_graph=line_graph,
cutoff=cutoff,
cutoff_extra=cutoff_extra,
max_neighbors=max_neighbors,
classification=classification_threshold is not None,
output_dir=output_dir,
Expand Down
6 changes: 5 additions & 1 deletion alignn/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def radius_graph(
cutoff_extra=3.5,
):
"""Construct edge list for radius graph."""

def temp_graph(cutoff=5):
"""Construct edge list for radius graph."""
cart_coords = torch.tensor(atoms.cart_coords).type(
Expand Down Expand Up @@ -363,6 +364,7 @@ def atom_dgl_multigraph(
compute_line_graph: bool = True,
use_canonize: bool = False,
use_lattice_prop: bool = False,
cutoff_extra=3.5,
):
"""Obtain a DGLGraph for Atoms object."""
# print('id',id)
Expand All @@ -379,7 +381,9 @@ def atom_dgl_multigraph(
# print('HERE')
# import sys
# sys.exit()
u, v, r = radius_graph(atoms, cutoff=cutoff)
u, v, r = radius_graph(
atoms, cutoff=cutoff, cutoff_extra=cutoff_extra
)
else:
raise ValueError("Not implemented yet", neighbor_strategy)
# elif neighbor_strategy == "voronoi":
Expand Down
1 change: 1 addition & 0 deletions alignn/train_folder_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def train_for_folder(
use_canonize=config.use_canonize,
filename=config.filename,
cutoff=config.cutoff,
cutoff_extra=config.cutoff_extra,
max_neighbors=config.max_neighbors,
output_features=config.model.output_features,
classification_threshold=config.classification_threshold,
Expand Down

0 comments on commit 4d095b2

Please sign in to comment.