Skip to content

Commit

Permalink
Add use_lmdb config option.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Apr 30, 2024
1 parent 4fa6acc commit 4f6f871
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
3 changes: 2 additions & 1 deletion alignn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ class TrainingConfig(BaseSettings):
distributed: bool = False
data_parallel: bool = False
n_early_stopping: Optional[int] = None # typically 50
output_dir: str = os.path.abspath(".") # typically 50
output_dir: str = os.path.abspath(".")
use_lmdb: bool = True
# alignn_layers: int = 4
# gcn_layers: int =4
# edge_input_features: int= 80
Expand Down
4 changes: 2 additions & 2 deletions alignn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_train_val_loaders(
save_dataloader: bool = False,
filename: str = "sample",
id_tag: str = "jid",
use_canonize: bool = False,
use_canonize: bool = True,
# use_ddp: bool = False,
cutoff: float = 8.0,
cutoff_extra: float = 3.0,
Expand All @@ -152,7 +152,7 @@ def get_train_val_loaders(
output_dir=None,
world_size=0,
rank=0,
use_lmdb: bool = False,
use_lmdb: bool = True,
):
"""Help function to set up JARVIS train and val dataloaders."""
if use_lmdb:
Expand Down
10 changes: 7 additions & 3 deletions alignn/examples/sample_data_ff/config_example_atomwise.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
"progress": true,
"log_tensorboard": false,
"standard_scalar_and_pca": false,
"use_canonize": false,
"use_canonize": true,
"num_workers": 0,
"cutoff": 8.0,
"max_neighbors": 12,
"keep_data_order": true,
"distributed":false,
"use_lmdb": true,
"model": {
"name": "alignn_atomwise",
"atom_input_features": 92,
Expand All @@ -48,7 +49,10 @@
"graphwise_weight":0.85,
"gradwise_weight":0.05,
"atomwise_weight":0.0,
"stresswise_weight":0.05

"stresswise_weight":0.05,
"add_reverse_forces":true,
"lg_on_fly":true


}
}
10 changes: 7 additions & 3 deletions alignn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def train_dgl(
# print("rank", rank)
# setup(rank, world_size)
if rank == 0:
print(config)
print("config:")
# print(config)
if type(config) is dict:
try:
print(config)
Expand All @@ -136,7 +137,6 @@ def train_dgl(
# checkpoint_dir = os.path.join(config.output_dir)
# deterministic = False
classification = False
print("config:")
tmp = config.dict()
f = open(os.path.join(config.output_dir, "config.json"), "w")
f.write(json.dumps(tmp, indent=4))
Expand Down Expand Up @@ -195,7 +195,7 @@ def train_dgl(
standard_scalar_and_pca=config.standard_scalar_and_pca,
keep_data_order=config.keep_data_order,
output_dir=config.output_dir,
# use_ddp=use_ddp,
use_lmdb=config.use_lmdb,
)
else:
train_loader = train_val_test_loaders[0]
Expand Down Expand Up @@ -876,6 +876,10 @@ def get_batch_errors(dat=[]):
targets.append(ii)
predictions.append(jj)
f.close()
if config.use_lmdb:
train_loader.dataset.close()
val_loader.dataset.close()
test_loader.dataset.close()


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions alignn/train_alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,13 @@ def train_for_folder(
standard_scalar_and_pca=config.standard_scalar_and_pca,
keep_data_order=config.keep_data_order,
output_dir=config.output_dir,
use_lmdb=config.use_lmdb,
)
# print("dataset", dataset[0])
t1 = time.time()
# world_size = torch.cuda.device_count()
print("rank ht1", rank)
print("world_size ht1", world_size)
print("rank", rank)
print("world_size", world_size)
train_dgl(
config,
model=model,
Expand Down

0 comments on commit 4f6f871

Please sign in to comment.