Skip to content

Commit

Permalink
solve conf
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed Aug 28, 2024
2 parents 6870ecd + 8bb4a06 commit 5d8c863
Show file tree
Hide file tree
Showing 15 changed files with 566 additions and 165 deletions.
21 changes: 21 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def __init__(
activation_function="tanh",
resnet_dt: bool = False,
type_one_side: bool = False,
use_three_body: bool = False,
three_body_neuron: List[int] = [2, 4, 8],
three_body_sel: int = 40,
three_body_rcut: float = 4.0,
three_body_rcut_smth: float = 0.5,
):
r"""The constructor for the RepinitArgs class which defines the parameters of the repinit block in DPA2 descriptor.
Expand Down Expand Up @@ -116,6 +121,11 @@ def __init__(
self.activation_function = activation_function
self.resnet_dt = resnet_dt
self.type_one_side = type_one_side
self.use_three_body = use_three_body
self.three_body_neuron = three_body_neuron
self.three_body_sel = three_body_sel
self.three_body_rcut = three_body_rcut
self.three_body_rcut_smth = three_body_rcut_smth

def __getitem__(self, key):
if hasattr(self, key):
Expand All @@ -136,6 +146,11 @@ def serialize(self) -> dict:
"activation_function": self.activation_function,
"resnet_dt": self.resnet_dt,
"type_one_side": self.type_one_side,
"use_three_body": self.use_three_body,
"three_body_neuron": self.three_body_neuron,
"three_body_sel": self.three_body_sel,
"three_body_rcut": self.three_body_rcut,
"three_body_rcut_smth": self.three_body_rcut_smth,
}

@classmethod
Expand Down Expand Up @@ -172,6 +187,9 @@ def __init__(
update_residual_init: str = "norm",
set_davg_zero: bool = True,
trainable_ln: bool = True,
use_sqrt_nnei: bool = False,
g1_out_conv: bool = False,
g1_out_mlp: bool = False,
ln_eps: Optional[float] = 1e-5,
):
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -265,6 +283,9 @@ def __init__(
self.update_residual_init = update_residual_init
self.set_davg_zero = set_davg_zero
self.trainable_ln = trainable_ln
self.use_sqrt_nnei = use_sqrt_nnei
self.g1_out_conv = g1_out_conv
self.g1_out_mlp = g1_out_mlp
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def __init__(
resnet_dt: bool = False,
precision: str = DEFAULT_PRECISION,
seed: Optional[Union[int, List[int]]] = None,
bias: bool = True,
bias=True,
):
layers = []
i_in = in_dim
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/utils/type_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,14 @@ def get_econf_tebd(type_map, precision: str = "default"):
type_map is not None
), "When using electronic configuration type embedding, type_map must be provided!"

missing_types = [t for t in type_map if t not in periodic_table]
missing_types = [t for t in type_map if t.split("_")[0] not in periodic_table]
assert not missing_types, (
"When using electronic configuration type embedding, "
"all element in type_map should be in periodic table! "
f"Found these invalid elements: {missing_types}"
)
econf_tebd = np.array(
[electronic_configuration_embedding[kk] for kk in type_map],
[electronic_configuration_embedding[kk.split("_")[0]] for kk in type_map],
dtype=PRECISION_DICT[precision],
)
econf_tebd /= econf_tebd.sum(-1, keepdims=True) # do normalization
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,11 @@ def train(FLAGS):
)

# argcheck
wandb_config = config["training"].pop("wandb_config", None)
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)
if wandb_config is not None:
config["training"]["wandb_config"] = wandb_config

# do neighbor stat
min_nbor_dist = None
Expand Down Expand Up @@ -334,7 +337,7 @@ def train(FLAGS):
shared_links=shared_links,
finetune_links=finetune_links,
)
# save min_nbor_dist
# save min_nbor_dist(
if min_nbor_dist is not None:
if not multi_task:
trainer.model.min_nbor_dist = min_nbor_dist
Expand Down
107 changes: 103 additions & 4 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,43 @@
)


def custom_huber_loss(predictions, targets, delta=1.0):
error = targets - predictions
abs_error = torch.abs(error)
quadratic_loss = 0.5 * torch.pow(error, 2)
linear_loss = delta * (abs_error - 0.5 * delta)
loss = torch.where(abs_error <= delta, quadratic_loss, linear_loss)
return torch.mean(loss)


def custom_step_huber_loss(predictions, targets, delta=1.0):
error = targets - predictions
abs_error = torch.abs(error)
abs_targets = torch.abs(targets)

# Define the different delta values based on the absolute value of targets
delta1 = delta
delta2 = 0.7 * delta
delta3 = 0.4 * delta
delta4 = 0.1 * delta

# Determine which delta to use based on the absolute value of targets
delta_values = torch.where(
abs_targets < 100,
delta1,
torch.where(
abs_targets < 200, delta2, torch.where(abs_targets < 300, delta3, delta4)
),
)

# Compute the quadratic and linear loss based on the dynamically selected delta values
quadratic_loss = 0.5 * torch.pow(error, 2)
linear_loss = delta_values * (abs_error - 0.5 * delta_values)
# Select the appropriate loss based on whether abs_error is less than or greater than delta_values
loss = torch.where(abs_error <= delta_values, quadratic_loss, linear_loss)
return torch.mean(loss)


class EnergyStdLoss(TaskLoss):
def __init__(
self,
Expand All @@ -42,6 +79,9 @@ def __init__(
numb_generalized_coord: int = 0,
use_l1_all: bool = False,
inference=False,
use_huber=False,
huber_delta=0.01,
torch_huber=False,
**kwargs,
):
r"""Construct a layer to compute loss on energy, force and virial.
Expand Down Expand Up @@ -119,6 +159,10 @@ def __init__(
)
self.use_l1_all = use_l1_all
self.inference = inference
self.huber = use_huber
self.huber_delta = huber_delta
self.torch_huber = torch_huber
self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta)

def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return loss on energy and force.
Expand Down Expand Up @@ -153,7 +197,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
pref_gf = self.limit_pref_gf + (self.start_pref_gf - self.limit_pref_gf) * coef

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
more_loss = {"ax_loss": 0.0} # for stop prefactor loss
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
Expand Down Expand Up @@ -181,7 +225,27 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
loss += atom_norm * (pref_e * l2_ener_loss)
if not self.huber:
loss += atom_norm * (pref_e * l2_ener_loss)
more_loss["ax_loss"] += atom_norm * (
self.limit_pref_e * find_energy * l2_ener_loss
)
else:
if self.torch_huber:
l_huber_loss = self.huber_loss(
atom_norm * model_pred["energy"],
atom_norm * label["energy"],
)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["energy"],
atom_norm * label["energy"],
delta=self.huber_delta,
)
loss += pref_e * l_huber_loss
more_loss["ax_loss"] += (
self.limit_pref_e * find_energy * l_huber_loss
)
rmse_e = l2_ener_loss.sqrt() * atom_norm
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
Expand Down Expand Up @@ -236,7 +300,24 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss.detach(), find_force
)
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if not self.huber:
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
more_loss["ax_loss"] += (
self.limit_pref_f * find_force * l2_force_loss
).to(GLOBAL_PT_FLOAT_PRECISION)
else:
if self.torch_huber:
l_huber_loss = self.huber_loss(
model_pred["force"], label["force"]
)
else:
l_huber_loss = custom_huber_loss(
force_pred.reshape(-1),
force_label.reshape(-1),
delta=self.huber_delta,
)
loss += self.limit_pref_f * find_force * l_huber_loss
more_loss["ax_loss"] += pref_f * l_huber_loss
rmse_f = l2_force_loss.sqrt()
more_loss["rmse_f"] = self.display_if_exist(
rmse_f.detach(), find_force
Expand Down Expand Up @@ -304,7 +385,25 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
if not self.huber:
loss += atom_norm * (pref_v * l2_virial_loss)
more_loss["ax_loss"] += atom_norm * (
self.limit_pref_v * find_virial * l2_virial_loss
)
else:
if self.torch_huber:
l_huber_loss = self.huber_loss(
atom_norm * model_pred["virial"],
atom_norm * label["virial"],
)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["virial"].reshape(-1),
atom_norm * label["virial"].reshape(-1),
delta=self.huber_delta,
)
loss += pref_v * l_huber_loss
more_loss["ax_loss"] += self.limit_pref_v * find_virial * l_huber_loss
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
type_map: Optional[List[str]] = None,
use_tebd_bias=True,
# not implemented
spin=None,
type: Optional[str] = None,
Expand Down Expand Up @@ -306,6 +307,7 @@ def __init__(
use_econf_tebd=use_econf_tebd,
use_tebd_bias=use_tebd_bias,
type_map=type_map,
bias=use_tebd_bias,
)
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
Expand Down
Loading

0 comments on commit 5d8c863

Please sign in to comment.