Skip to content

Commit

Permalink
Save current model.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Oct 2, 2023
1 parent aa9c36b commit 60cd9bc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 2 additions & 2 deletions alignn/ff/ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def get_figshare_model_ff(model_name="alignnff_fmult", dir_path=None):

def default_path():
"""Get default model path."""
# dpath = get_figshare_model_ff(model_name="alignnff_wt10")
dpath = get_figshare_model_ff(model_name="alignnff_fmult")
dpath = get_figshare_model_ff(model_name="alignnff_wt10")
# dpath = get_figshare_model_ff(model_name="alignnff_fmult")
print("model_path", dpath)
return dpath

Expand Down
5 changes: 5 additions & 0 deletions alignn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,11 @@ def get_batch_errors(dat=[]):
mean_out, mean_atom, mean_grad, mean_stress = get_batch_errors(
val_result
)
current_model_name = "current_model.pt"
torch.save(
net.state_dict(),
os.path.join(config.output_dir, current_model_name),
)
if val_loss < best_loss:
best_loss = val_loss
best_model_name = "best_model.pt"
Expand Down

0 comments on commit 60cd9bc

Please sign in to comment.