diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index f67bbc93a4..01bd60c777 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -105,7 +105,6 @@ def __init__( r_differentiable: bool = True, c_differentiable: bool = True, type_map: Optional[list[str]] = None, - old_impl=False, seed: Optional[Union[int, list[int]]] = None, ): if tot_ener_zero: @@ -141,7 +140,6 @@ def __init__( type_map=type_map, seed=seed, ) - self.old_impl = False def _net_out_dim(self): """Set the FittingNet output dim.""" @@ -151,7 +149,6 @@ def serialize(self) -> dict: data = super().serialize() data["type"] = "dipole" data["embedding_width"] = self.embedding_width - data["old_impl"] = self.old_impl data["r_differentiable"] = self.r_differentiable data["c_differentiable"] = self.c_differentiable return data diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 2ff5052a83..73a691f482 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -107,7 +107,6 @@ def __init__( spin: Any = None, mixed_types: bool = False, exclude_types: list[int] = [], - old_impl: bool = False, fit_diag: bool = True, scale: Optional[list[float]] = None, shift_diag: bool = True, @@ -165,7 +164,6 @@ def __init__( type_map=type_map, seed=seed, ) - self.old_impl = False def _net_out_dim(self): """Set the FittingNet output dim.""" @@ -192,7 +190,6 @@ def serialize(self) -> dict: data["type"] = "polar" data["@version"] = 3 data["embedding_width"] = self.embedding_width - data["old_impl"] = self.old_impl data["fit_diag"] = self.fit_diag data["shift_diag"] = self.shift_diag data["@variables"]["scale"] = self.scale diff --git a/deepmd/pt/model/backbone/__init__.py b/deepmd/pt/model/backbone/__init__.py deleted file mode 100644 index a76bdb2a2d..0000000000 --- a/deepmd/pt/model/backbone/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from .backbone import ( - BackBone, -) -from .evoformer2b import ( - Evoformer2bBackBone, -) - -__all__ = [ - "BackBone", - "Evoformer2bBackBone", -] diff --git a/deepmd/pt/model/backbone/backbone.py b/deepmd/pt/model/backbone/backbone.py deleted file mode 100644 index ddeedfeff5..0000000000 --- a/deepmd/pt/model/backbone/backbone.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import torch - - -class BackBone(torch.nn.Module): - def __init__(self, **kwargs): - """BackBone base method.""" - super().__init__() - - def forward(self, **kwargs): - """Calculate backBone.""" - raise NotImplementedError diff --git a/deepmd/pt/model/backbone/evoformer2b.py b/deepmd/pt/model/backbone/evoformer2b.py deleted file mode 100644 index 1146b3a298..0000000000 --- a/deepmd/pt/model/backbone/evoformer2b.py +++ /dev/null @@ -1,103 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from deepmd.pt.model.backbone import ( - BackBone, -) -from deepmd.pt.model.network.network import ( - Evoformer2bEncoder, -) - - -class Evoformer2bBackBone(BackBone): - def __init__( - self, - nnei, - layer_num=6, - attn_head=8, - atomic_dim=1024, - pair_dim=100, - feature_dim=1024, - ffn_dim=2048, - post_ln=False, - final_layer_norm=True, - final_head_layer_norm=False, - emb_layer_norm=False, - atomic_residual=False, - evo_residual=False, - residual_factor=1.0, - activation_function="gelu", - **kwargs, - ): - """Construct an evoformer backBone.""" - super().__init__() - self.nnei = nnei - self.layer_num = layer_num - self.attn_head = attn_head - self.atomic_dim = atomic_dim - self.pair_dim = pair_dim - self.feature_dim = feature_dim - self.head_dim = feature_dim // attn_head - assert ( - feature_dim % attn_head == 0 - ), f"feature_dim {feature_dim} must be divided by attn_head {attn_head}!" - self.ffn_dim = ffn_dim - self.post_ln = post_ln - self.final_layer_norm = final_layer_norm - self.final_head_layer_norm = final_head_layer_norm - self.emb_layer_norm = emb_layer_norm - self.activation_function = activation_function - self.atomic_residual = atomic_residual - self.evo_residual = evo_residual - self.residual_factor = float(residual_factor) - self.encoder = Evoformer2bEncoder( - nnei=self.nnei, - layer_num=self.layer_num, - attn_head=self.attn_head, - atomic_dim=self.atomic_dim, - pair_dim=self.pair_dim, - feature_dim=self.feature_dim, - ffn_dim=self.ffn_dim, - post_ln=self.post_ln, - final_layer_norm=self.final_layer_norm, - final_head_layer_norm=self.final_head_layer_norm, - emb_layer_norm=self.emb_layer_norm, - atomic_residual=self.atomic_residual, - evo_residual=self.evo_residual, - residual_factor=self.residual_factor, - activation_function=self.activation_function, - ) - - def forward(self, atomic_rep, pair_rep, nlist, nlist_type, nlist_mask): - """Encoder the atomic and pair representations. - - Args: - - atomic_rep: Atomic representation with shape [nframes, nloc, atomic_dim]. - - pair_rep: Pair representation with shape [nframes, nloc, nnei, pair_dim]. - - nlist: Neighbor list with shape [nframes, nloc, nnei]. - - nlist_type: Neighbor types with shape [nframes, nloc, nnei]. - - nlist_mask: Neighbor mask with shape [nframes, nloc, nnei], `False` if blank. - - Returns - ------- - - atomic_rep: Atomic representation after encoder with shape [nframes, nloc, feature_dim]. - - transformed_atomic_rep: Transformed atomic representation after encoder with shape [nframes, nloc, atomic_dim]. - - pair_rep: Pair representation after encoder with shape [nframes, nloc, nnei, attn_head]. - - delta_pair_rep: Delta pair representation after encoder with shape [nframes, nloc, nnei, attn_head]. - - norm_x: Normalization loss of atomic_rep. - - norm_delta_pair_rep: Normalization loss of delta_pair_rep. - """ - ( - atomic_rep, - transformed_atomic_rep, - pair_rep, - delta_pair_rep, - norm_x, - norm_delta_pair_rep, - ) = self.encoder(atomic_rep, pair_rep, nlist, nlist_type, nlist_mask) - return ( - atomic_rep, - transformed_atomic_rep, - pair_rep, - delta_pair_rep, - norm_x, - norm_delta_pair_rep, - ) diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 779e7a562c..4ffa937bcb 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -16,9 +16,6 @@ from .env_mat import ( prod_env_mat, ) -from .gaussian_lcc import ( - DescrptGaussianLcc, -) from .hybrid import ( DescrptHybrid, ) @@ -59,6 +56,5 @@ "DescrptDPA2", "DescrptHybrid", "prod_env_mat", - "DescrptGaussianLcc", "DescrptBlockRepformers", ] diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 617e8b49b6..322fa3a12d 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -245,7 +245,6 @@ def __init__( # not implemented spin=None, type: Optional[str] = None, - old_impl: bool = False, ): super().__init__() # Ensure compatibility with the deprecated stripped_type_embedding option. @@ -290,7 +289,6 @@ def __init__( trainable_ln=trainable_ln, ln_eps=ln_eps, seed=child_seed(seed, 1), - old_impl=old_impl, ) self.use_econf_tebd = use_econf_tebd self.use_tebd_bias = use_tebd_bias diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index f1ef200b09..632efe5dbf 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -92,7 +92,6 @@ def __init__( use_econf_tebd: bool = False, use_tebd_bias: bool = False, type_map: Optional[list[str]] = None, - old_impl: bool = False, ): r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. @@ -235,7 +234,6 @@ def init_subclass_params(sub_data, sub_class): g1_out_conv=self.repformer_args.g1_out_conv, g1_out_mlp=self.repformer_args.g1_out_mlp, seed=child_seed(seed, 1), - old_impl=old_impl, ) self.rcsl_list = [ (self.repformers.get_rcut(), self.repformers.get_nsel()), diff --git a/deepmd/pt/model/descriptor/gaussian_lcc.py b/deepmd/pt/model/descriptor/gaussian_lcc.py deleted file mode 100644 index 8ac52215c0..0000000000 --- a/deepmd/pt/model/descriptor/gaussian_lcc.py +++ /dev/null @@ -1,319 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Optional, -) - -import torch -import torch.nn as nn - -from deepmd.pt.model.descriptor.base_descriptor import ( - BaseDescriptor, -) -from deepmd.pt.model.network.network import ( - Evoformer3bEncoder, - GaussianEmbedding, - TypeEmbedNet, -) -from deepmd.pt.utils import ( - env, -) -from deepmd.utils.path import ( - DPPath, -) - - -class DescrptGaussianLcc(torch.nn.Module, BaseDescriptor): - def __init__( - self, - rcut, - rcut_smth, - sel: int, - ntypes: int, - num_pair: int, - embed_dim: int = 768, - kernel_num: int = 128, - pair_embed_dim: int = 64, - num_block: int = 1, - layer_num: int = 12, - attn_head: int = 48, - pair_hidden_dim: int = 16, - ffn_embedding_dim: int = 768, - dropout: float = 0.0, - droppath_prob: float = 0.1, - pair_dropout: float = 0.25, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - pre_ln: bool = True, - do_tag_embedding: bool = False, - tag_ener_pref: bool = False, - atomic_sum_gbf: bool = False, - pre_add_seq: bool = True, - tri_update: bool = True, - **kwargs, - ): - """Construct a descriptor of Gaussian Based Local Cluster. - - Args: - - rcut: Cut-off radius. - - rcut_smth: Smooth hyper-parameter for pair force & energy. **Not used in this descriptor**. - - sel: For each element type, how many atoms is selected as neighbors. - - ntypes: Number of atom types. - - num_pair: Number of atom type pairs. Default is 2 * ntypes. - - kernel_num: Number of gaussian kernels. - - embed_dim: Dimension of atomic representation. - - pair_embed_dim: Dimension of pair representation. - - num_block: Number of evoformer blocks. - - layer_num: Number of attention layers. - - attn_head: Number of attention heads. - - pair_hidden_dim: Hidden dimension of pair representation during attention process. - - ffn_embedding_dim: Dimension during feed forward network. - - dropout: Dropout probability of atomic representation. - - droppath_prob: If not zero, it will use drop paths (Stochastic Depth) per sample and ignore `dropout`. - - pair_dropout: Dropout probability of pair representation during triangular update. - - attention_dropout: Dropout probability during attetion process. - - activation_dropout: Dropout probability of pair feed forward network. - - pre_ln: Do previous layer norm or not. - - do_tag_embedding: Add tag embedding to atomic and pair representations. (`tags`, `tags2`, `tags3` must exist) - - atomic_sum_gbf: Add sum of gaussian outputs to atomic representation or not. - - pre_add_seq: Add output of other descriptor (if has) to the atomic representation before attention. - """ - super().__init__() - self.rcut = rcut - self.rcut_smth = rcut_smth - self.embed_dim = embed_dim - self.num_pair = num_pair - self.kernel_num = kernel_num - self.pair_embed_dim = pair_embed_dim - self.num_block = num_block - self.layer_num = layer_num - self.attention_heads = attn_head - self.pair_hidden_dim = pair_hidden_dim - self.ffn_embedding_dim = ffn_embedding_dim - self.dropout = dropout - self.droppath_prob = droppath_prob - self.pair_dropout = pair_dropout - self.attention_dropout = attention_dropout - self.activation_dropout = activation_dropout - self.pre_ln = pre_ln - self.do_tag_embedding = do_tag_embedding - self.tag_ener_pref = tag_ener_pref - self.atomic_sum_gbf = atomic_sum_gbf - self.local_cluster = True - self.pre_add_seq = pre_add_seq - self.tri_update = tri_update - - if isinstance(sel, int): - sel = [sel] - - self.ntypes = ntypes - self.sec = torch.tensor(sel) # pylint: disable=no-explicit-dtype,no-explicit-device - self.nnei = sum(sel) - - if self.do_tag_embedding: - self.tag_encoder = nn.Embedding(3, self.embed_dim) - self.tag_encoder2 = nn.Embedding(2, self.embed_dim) - self.tag_type_embedding = TypeEmbedNet(10, pair_embed_dim) - self.edge_type_embedding = nn.Embedding( - (ntypes + 1) * (ntypes + 1), - pair_embed_dim, - padding_idx=(ntypes + 1) * (ntypes + 1) - 1, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - ) - self.gaussian_encoder = GaussianEmbedding( - rcut, - kernel_num, - num_pair, - embed_dim, - pair_embed_dim, - sel, - ntypes, - atomic_sum_gbf, - ) - self.backbone = Evoformer3bEncoder( - self.nnei, - layer_num=self.layer_num, - attn_head=self.attention_heads, - atomic_dim=self.embed_dim, - pair_dim=self.pair_embed_dim, - pair_hidden_dim=self.pair_hidden_dim, - ffn_embedding_dim=self.ffn_embedding_dim, - dropout=self.dropout, - droppath_prob=self.droppath_prob, - pair_dropout=self.pair_dropout, - attention_dropout=self.attention_dropout, - activation_dropout=self.activation_dropout, - pre_ln=self.pre_ln, - tri_update=self.tri_update, - ) - - @property - def dim_out(self): - """Returns the output dimension of atomic representation.""" - return self.embed_dim - - @property - def dim_in(self): - """Returns the atomic input dimension of this descriptor.""" - return self.embed_dim - - @property - def dim_emb(self): - """Returns the output dimension of pair representation.""" - return self.pair_embed_dim - - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" - pass - - def forward( - self, - extended_coord, - nlist, - atype, - nlist_type, - nlist_loc=None, - atype_tebd=None, - nlist_tebd=None, - seq_input=None, - ): - """Calculate the atomic and pair representations of this descriptor. - - Args: - - extended_coord: Copied atom coordinates with shape [nframes, nall, 3]. - - nlist: Neighbor list with shape [nframes, nloc, nnei]. - - atype: Atom type with shape [nframes, nloc]. - - nlist_type: Atom type of neighbors with shape [nframes, nloc, nnei]. - - nlist_loc: Local index of neighbor list with shape [nframes, nloc, nnei]. - - atype_tebd: Atomic type embedding with shape [nframes, nloc, tebd_dim]. - - nlist_tebd: Type embeddings of neighbor with shape [nframes, nloc, nnei, tebd_dim]. - - seq_input: The sequential input from other descriptor with - shape [nframes, nloc, tebd_dim] or [nframes * nloc, 1 + nnei, tebd_dim] - - Returns - ------- - - result: descriptor with shape [nframes, nloc, self.filter_neuron[-1] * self.axis_neuron]. - - ret: environment matrix with shape [nframes, nloc, self.neei, out_size] - """ - nframes, nloc = nlist.shape[:2] - nall = extended_coord.shape[1] - nlist2 = torch.cat( - [ - torch.arange(0, nloc, device=nlist.device) # pylint: disable=no-explicit-dtype - .reshape(1, nloc, 1) - .expand(nframes, -1, -1), - nlist, - ], - dim=-1, - ) - nlist_loc2 = torch.cat( - [ - torch.arange(0, nloc, device=nlist_loc.device) # pylint: disable=no-explicit-dtype - .reshape(1, nloc, 1) - .expand(nframes, -1, -1), - nlist_loc, - ], - dim=-1, - ) - nlist_type2 = torch.cat([atype.reshape(nframes, nloc, 1), nlist_type], dim=-1) - nnei2_mask = nlist2 != -1 - padding_mask = nlist2 == -1 - nlist2 = nlist2 * nnei2_mask - nlist_loc2 = nlist_loc2 * nnei2_mask - - # nframes x nloc x (1 + nnei2) x (1 + nnei2) - pair_mask = nnei2_mask.unsqueeze(-1) * nnei2_mask.unsqueeze(-2) - # nframes x nloc x (1 + nnei2) x (1 + nnei2) x head - attn_mask = torch.zeros( - [nframes, nloc, 1 + self.nnei, 1 + self.nnei, self.attention_heads], - device=nlist.device, - dtype=extended_coord.dtype, - ) - attn_mask.masked_fill_(padding_mask.unsqueeze(2).unsqueeze(-1), float("-inf")) - # (nframes x nloc) x head x (1 + nnei2) x (1 + nnei2) - attn_mask = ( - attn_mask.reshape( - nframes * nloc, 1 + self.nnei, 1 + self.nnei, self.attention_heads - ) - .permute(0, 3, 1, 2) - .contiguous() - ) - - # Atomic feature - # [(nframes x nloc) x (1 + nnei2) x tebd_dim] - atom_feature = torch.gather( - atype_tebd, - dim=1, - index=nlist_loc2.reshape(nframes, -1) - .unsqueeze(-1) - .expand(-1, -1, self.embed_dim), - ).reshape(nframes * nloc, 1 + self.nnei, self.embed_dim) - if self.pre_add_seq and seq_input is not None: - first_dim = seq_input.shape[0] - if first_dim == nframes * nloc: - atom_feature += seq_input - elif first_dim == nframes: - atom_feature_seq = torch.gather( - seq_input, - dim=1, - index=nlist_loc2.reshape(nframes, -1) - .unsqueeze(-1) - .expand(-1, -1, self.embed_dim), - ).reshape(nframes * nloc, 1 + self.nnei, self.embed_dim) - atom_feature += atom_feature_seq - else: - raise RuntimeError - atom_feature = atom_feature * nnei2_mask.reshape( - nframes * nloc, 1 + self.nnei, 1 - ) - - # Pair feature - # [(nframes x nloc) x (1 + nnei2)] - nlist_type2_reshape = nlist_type2.reshape(nframes * nloc, 1 + self.nnei) - # [(nframes x nloc) x (1 + nnei2) x (1 + nnei2)] - edge_type = nlist_type2_reshape.unsqueeze(-1) * ( - self.ntypes + 1 - ) + nlist_type2_reshape.unsqueeze(-2) - # [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x pair_dim] - edge_feature = self.edge_type_embedding(edge_type) - - # [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x 2] - edge_type_2dim = torch.cat( - [ - nlist_type2_reshape.view(nframes * nloc, 1 + self.nnei, 1, 1).expand( - -1, -1, 1 + self.nnei, -1 - ), - nlist_type2_reshape.view(nframes * nloc, 1, 1 + self.nnei, 1).expand( - -1, 1 + self.nnei, -1, -1 - ) - + self.ntypes, - ], - dim=-1, - ) - # [(nframes x nloc) x (1 + nnei2) x 3] - coord_selected = torch.gather( - extended_coord.unsqueeze(1) - .expand(-1, nloc, -1, -1) - .reshape(nframes * nloc, nall, 3), - dim=1, - index=nlist2.reshape(nframes * nloc, 1 + self.nnei, 1).expand(-1, -1, 3), - ) - - # Update pair features (or and atomic features) with gbf features - # delta_pos: [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x 3]. - atomic_feature, pair_feature, delta_pos = self.gaussian_encoder( - coord_selected, atom_feature, edge_type_2dim, edge_feature - ) - # [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x pair_dim] - attn_bias = pair_feature - - # output: [(nframes x nloc) x (1 + nnei2) x tebd_dim] - # pair: [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x pair_dim] - output, pair = self.backbone( - atomic_feature, - pair=attn_bias, - attn_mask=attn_mask, - pair_mask=pair_mask, - atom_mask=nnei2_mask.reshape(nframes * nloc, 1 + self.nnei), - ) - - return output, pair, delta_pos, None diff --git a/deepmd/pt/model/descriptor/repformer_layer_old_impl.py b/deepmd/pt/model/descriptor/repformer_layer_old_impl.py deleted file mode 100644 index 47b20f7b03..0000000000 --- a/deepmd/pt/model/descriptor/repformer_layer_old_impl.py +++ /dev/null @@ -1,744 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Callable, -) - -import torch - -from deepmd.pt.model.network.network import ( - SimpleLinear, -) -from deepmd.pt.utils import ( - env, -) -from deepmd.pt.utils.utils import ( - ActivationFn, -) - - -def _make_nei_g1( - g1_ext: torch.Tensor, - nlist: torch.Tensor, -) -> torch.Tensor: - # nlist: nb x nloc x nnei - nb, nloc, nnei = nlist.shape - # g1_ext: nb x nall x ng1 - ng1 = g1_ext.shape[-1] - # index: nb x (nloc x nnei) x ng1 - index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) - # gg1 : nb x (nloc x nnei) x ng1 - gg1 = torch.gather(g1_ext, dim=1, index=index) - # gg1 : nb x nloc x nnei x ng1 - gg1 = gg1.view(nb, nloc, nnei, ng1) - return gg1 - - -def _apply_nlist_mask( - gg: torch.Tensor, - nlist_mask: torch.Tensor, -) -> torch.Tensor: - # gg: nf x nloc x nnei x ng - # msk: nf x nloc x nnei - return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) - - -def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: - # gg: nf x nloc x nnei x ng - # sw: nf x nloc x nnei - return gg * sw.unsqueeze(-1) - - -def _apply_h_norm( - hh: torch.Tensor, # nf x nloc x nnei x 3 -) -> torch.Tensor: - """Normalize h by the std of vector length. - do not have an idea if this is a good way. - """ - nf, nl, nnei, _ = hh.shape - # nf x nloc x nnei - normh = torch.linalg.norm(hh, dim=-1) - # nf x nloc - std = torch.std(normh, dim=-1) - # nf x nloc x nnei x 3 - hh = hh[:, :, :, :] / (1.0 + std[:, :, None, None]) - return hh - - -class Atten2Map(torch.nn.Module): - def __init__( - self, - ni: int, - nd: int, - nh: int, - has_gate: bool = False, # apply gate to attn map - smooth: bool = True, - attnw_shift: float = 20.0, - ): - super().__init__() - self.ni = ni - self.nd = nd - self.nh = nh - self.mapqk = SimpleLinear(ni, nd * 2 * nh, bias=False) # todo - self.has_gate = has_gate - self.smooth = smooth - self.attnw_shift = attnw_shift - - def forward( - self, - g2: torch.Tensor, # nb x nloc x nnei x ng2 - h2: torch.Tensor, # nb x nloc x nnei x 3 - nlist_mask: torch.Tensor, # nb x nloc x nnei - sw: torch.Tensor, # nb x nloc x nnei - ) -> torch.Tensor: - ( - nb, - nloc, - nnei, - _, - ) = g2.shape - nd, nh = self.nd, self.nh - # nb x nloc x nnei x nd x (nh x 2) - g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2) - # nb x nloc x (nh x 2) x nnei x nd - g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3)) - # nb x nloc x nh x nnei x nd - g2q, g2k = torch.split(g2qk, nh, dim=2) - # g2q = torch.nn.functional.normalize(g2q, dim=-1) - # g2k = torch.nn.functional.normalize(g2k, dim=-1) - # nb x nloc x nh x nnei x nnei - attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5 - if self.has_gate: - gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3) - attnw = attnw * gate - # mask the attenmap, nb x nloc x 1 x 1 x nnei - attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) - # mask the attenmap, nb x nloc x 1 x nnei x 1 - attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) - if self.smooth: - attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ - :, :, None, None, : - ] - self.attnw_shift - else: - attnw = attnw.masked_fill( - attnw_mask, - float("-inf"), - ) - attnw = torch.softmax(attnw, dim=-1) - attnw = attnw.masked_fill( - attnw_mask, - 0.0, - ) - # nb x nloc x nh x nnei x nnei - attnw = attnw.masked_fill( - attnw_mask_c, - 0.0, - ) - if self.smooth: - attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] - # nb x nloc x nnei x nnei - h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5 - # nb x nloc x nh x nnei x nnei - ret = attnw * h2h2t[:, :, None, :, :] - # ret = torch.softmax(g2qk, dim=-1) - # nb x nloc x nnei x nnei x nh - ret = torch.permute(ret, (0, 1, 3, 4, 2)) - return ret - - -class Atten2MultiHeadApply(torch.nn.Module): - def __init__( - self, - ni: int, - nh: int, - ): - super().__init__() - self.ni = ni - self.nh = nh - self.mapv = SimpleLinear(ni, ni * nh, bias=False) - self.head_map = SimpleLinear(ni * nh, ni) - - def forward( - self, - AA: torch.Tensor, # nf x nloc x nnei x nnei x nh - g2: torch.Tensor, # nf x nloc x nnei x ng2 - ) -> torch.Tensor: - nf, nloc, nnei, ng2 = g2.shape - nh = self.nh - # nf x nloc x nnei x ng2 x nh - g2v = self.mapv(g2).view(nf, nloc, nnei, ng2, nh) - # nf x nloc x nh x nnei x ng2 - g2v = torch.permute(g2v, (0, 1, 4, 2, 3)) - # g2v = torch.nn.functional.normalize(g2v, dim=-1) - # nf x nloc x nh x nnei x nnei - AA = torch.permute(AA, (0, 1, 4, 2, 3)) - # nf x nloc x nh x nnei x ng2 - ret = torch.matmul(AA, g2v) - # nf x nloc x nnei x ng2 x nh - ret = torch.permute(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) - # nf x nloc x nnei x ng2 - return self.head_map(ret) - - -class Atten2EquiVarApply(torch.nn.Module): - def __init__( - self, - ni: int, - nh: int, - ): - super().__init__() - self.ni = ni - self.nh = nh - self.head_map = SimpleLinear(nh, 1, bias=False) - - def forward( - self, - AA: torch.Tensor, # nf x nloc x nnei x nnei x nh - h2: torch.Tensor, # nf x nloc x nnei x 3 - ) -> torch.Tensor: - nf, nloc, nnei, _ = h2.shape - nh = self.nh - # nf x nloc x nh x nnei x nnei - AA = torch.permute(AA, (0, 1, 4, 2, 3)) - h2m = torch.unsqueeze(h2, dim=2) - # nf x nloc x nh x nnei x 3 - h2m = torch.tile(h2m, [1, 1, nh, 1, 1]) - # nf x nloc x nh x nnei x 3 - ret = torch.matmul(AA, h2m) - # nf x nloc x nnei x 3 x nh - ret = torch.permute(ret, (0, 1, 3, 4, 2)).view(nf, nloc, nnei, 3, nh) - # nf x nloc x nnei x 3 - return torch.squeeze(self.head_map(ret), dim=-1) - - -class LocalAtten(torch.nn.Module): - def __init__( - self, - ni: int, - nd: int, - nh: int, - smooth: bool = True, - attnw_shift: float = 20.0, - ): - super().__init__() - self.ni = ni - self.nd = nd - self.nh = nh - self.mapq = SimpleLinear(ni, nd * 1 * nh, bias=False) - self.mapkv = SimpleLinear(ni, (nd + ni) * nh, bias=False) - self.head_map = SimpleLinear(ni * nh, ni) - self.smooth = smooth - self.attnw_shift = attnw_shift - - def forward( - self, - g1: torch.Tensor, # nb x nloc x ng1 - gg1: torch.Tensor, # nb x nloc x nnei x ng1 - nlist_mask: torch.Tensor, # nb x nloc x nnei - sw: torch.Tensor, # nb x nloc x nnei - ) -> torch.Tensor: - nb, nloc, nnei = nlist_mask.shape - ni, nd, nh = self.ni, self.nd, self.nh - assert ni == g1.shape[-1] - assert ni == gg1.shape[-1] - # nb x nloc x nd x nh - g1q = self.mapq(g1).view(nb, nloc, nd, nh) - # nb x nloc x nh x nd - g1q = torch.permute(g1q, (0, 1, 3, 2)) - # nb x nloc x nnei x (nd+ni) x nh - gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh) - gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3)) - # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 - gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1) - - # nb x nloc x nh x 1 x nnei - attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5 - # nb x nloc x nh x nnei - attnw = attnw.squeeze(-2) - # mask the attenmap, nb x nloc x 1 x nnei - attnw_mask = ~nlist_mask.unsqueeze(-2) - # nb x nloc x nh x nnei - if self.smooth: - attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift - else: - attnw = attnw.masked_fill( - attnw_mask, - float("-inf"), - ) - attnw = torch.softmax(attnw, dim=-1) - attnw = attnw.masked_fill( - attnw_mask, - 0.0, - ) - if self.smooth: - attnw = attnw * sw.unsqueeze(-2) - - # nb x nloc x nh x ng1 - ret = ( - torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) - ) - # nb x nloc x ng1 - ret = self.head_map(ret) - return ret - - -class RepformerLayer(torch.nn.Module): - def __init__( - self, - rcut, - rcut_smth, - sel: int, - ntypes: int, - g1_dim=128, - g2_dim=16, - axis_neuron: int = 4, - update_chnnl_2: bool = True, - do_bn_mode: str = "no", - bn_momentum: float = 0.1, - update_g1_has_conv: bool = True, - update_g1_has_drrd: bool = True, - update_g1_has_grrg: bool = True, - update_g1_has_attn: bool = True, - update_g2_has_g1g1: bool = True, - update_g2_has_attn: bool = True, - update_h2: bool = False, - attn1_hidden: int = 64, - attn1_nhead: int = 4, - attn2_hidden: int = 16, - attn2_nhead: int = 4, - attn2_has_gate: bool = False, - activation_function: str = "tanh", - update_style: str = "res_avg", - set_davg_zero: bool = True, # TODO - smooth: bool = True, - ): - super().__init__() - self.epsilon = 1e-4 # protection of 1./nnei - self.rcut = rcut - self.rcut_smth = rcut_smth - self.ntypes = ntypes - sel = [sel] if isinstance(sel, int) else sel - self.nnei = sum(sel) - assert len(sel) == 1 - self.sel = torch.tensor(sel, device=env.DEVICE) # pylint: disable=no-explicit-dtype - self.sec = self.sel - self.axis_neuron = axis_neuron - self.set_davg_zero = set_davg_zero - self.do_bn_mode = do_bn_mode - self.bn_momentum = bn_momentum - self.act = ActivationFn(activation_function) - self.update_g1_has_grrg = update_g1_has_grrg - self.update_g1_has_drrd = update_g1_has_drrd - self.update_g1_has_conv = update_g1_has_conv - self.update_g1_has_attn = update_g1_has_attn - self.update_chnnl_2 = update_chnnl_2 - self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False - self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False - self.update_h2 = update_h2 if self.update_chnnl_2 else False - del update_g2_has_g1g1, update_g2_has_attn, update_h2 - self.update_style = update_style - self.smooth = smooth - self.g1_dim = g1_dim - self.g2_dim = g2_dim - - g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) - self.linear1 = SimpleLinear(g1_in_dim, g1_dim) - self.linear2 = None - self.proj_g1g2 = None - self.proj_g1g1g2 = None - self.attn2g_map = None - self.attn2_mh_apply = None - self.attn2_lm = None - self.attn2h_map = None - self.attn2_ev_apply = None - self.loc_attn = None - - if self.update_chnnl_2: - self.linear2 = SimpleLinear(g2_dim, g2_dim) - if self.update_g1_has_conv: - self.proj_g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) - if self.update_g2_has_g1g1: - self.proj_g1g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) - if self.update_g2_has_attn: - self.attn2g_map = Atten2Map( - g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth - ) - self.attn2_mh_apply = Atten2MultiHeadApply(g2_dim, attn2_nhead) - self.attn2_lm = torch.nn.LayerNorm( - g2_dim, - elementwise_affine=True, - device=env.DEVICE, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - ) - if self.update_h2: - self.attn2h_map = Atten2Map( - g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth - ) - self.attn2_ev_apply = Atten2EquiVarApply(g2_dim, attn2_nhead) - if self.update_g1_has_attn: - self.loc_attn = LocalAtten(g1_dim, attn1_hidden, attn1_nhead, self.smooth) - - if self.do_bn_mode == "uniform": - self.bn1 = self._bn_layer() - self.bn2 = self._bn_layer() - elif self.do_bn_mode == "component": - self.bn1 = self._bn_layer(nf=g1_dim) - self.bn2 = self._bn_layer(nf=g2_dim) - elif self.do_bn_mode == "no": - self.bn1, self.bn2 = None, None - else: - raise RuntimeError(f"unknown bn_mode {self.do_bn_mode}") - - def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: - ret = g1d - if self.update_g1_has_grrg: - ret += g2d * ax - if self.update_g1_has_drrd: - ret += g1d * ax - if self.update_g1_has_conv: - ret += g2d - return ret - - def _update_h2( - self, - g2: torch.Tensor, - h2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - ) -> torch.Tensor: - assert self.attn2h_map is not None - assert self.attn2_ev_apply is not None - nb, nloc, nnei, _ = g2.shape - # # nb x nloc x nnei x nh2 - # h2_1 = self.attn2_ev_apply(AA, h2) - # h2_update.append(h2_1) - # nb x nloc x nnei x nnei x nh - AAh = self.attn2h_map(g2, h2, nlist_mask, sw) - # nb x nloc x nnei x nh2 - h2_1 = self.attn2_ev_apply(AAh, h2) - return h2_1 - - def _update_g1_conv( - self, - gg1: torch.Tensor, - g2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - ) -> torch.Tensor: - assert self.proj_g1g2 is not None - nb, nloc, nnei, _ = g2.shape - ng1 = gg1.shape[-1] - ng2 = g2.shape[-1] - # gg1 : nb x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) - # nb x nloc x nnei x ng2 - gg1 = _apply_nlist_mask(gg1, nlist_mask) - if not self.smooth: - # normalized by number of neighbors, not smooth - # nb x nloc x 1 - invnnei = 1.0 / ( - self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) - ).unsqueeze(-1) - else: - gg1 = _apply_switch(gg1, sw) - invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=gg1.device - ) - # nb x nloc x ng2 - g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei - return g1_11 - - def _cal_h2g2( - self, - g2: torch.Tensor, - h2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - ) -> torch.Tensor: - # g2: nf x nloc x nnei x ng2 - # h2: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nb, nloc, nnei, _ = g2.shape - ng2 = g2.shape[-1] - # nb x nloc x nnei x ng2 - g2 = _apply_nlist_mask(g2, nlist_mask) - if not self.smooth: - # nb x nloc - invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) - # nb x nloc x 1 x 1 - invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) - else: - g2 = _apply_switch(g2, sw) - invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=g2.device - ) - # nb x nloc x 3 x ng2 - h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei - return h2g2 - - def _cal_grrg(self, h2g2: torch.Tensor) -> torch.Tensor: - # nb x nloc x 3 x ng2 - nb, nloc, _, ng2 = h2g2.shape - # nb x nloc x 3 x axis - h2g2m = torch.split(h2g2, self.axis_neuron, dim=-1)[0] - # nb x nloc x axis x ng2 - g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) - # nb x nloc x (axisxng2) - g1_13 = g1_13.view(nb, nloc, self.axis_neuron * ng2) - return g1_13 - - def _update_g1_grrg( - self, - g2: torch.Tensor, - h2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - ) -> torch.Tensor: - # g2: nf x nloc x nnei x ng2 - # h2: nf x nloc x nnei x 3 - # msk: nf x nloc x nnei - nb, nloc, nnei, _ = g2.shape - ng2 = g2.shape[-1] - # nb x nloc x 3 x ng2 - h2g2 = self._cal_h2g2(g2, h2, nlist_mask, sw) - # nb x nloc x (axisxng2) - g1_13 = self._cal_grrg(h2g2) - return g1_13 - - def _update_g2_g1g1( - self, - g1: torch.Tensor, # nb x nloc x ng1 - gg1: torch.Tensor, # nb x nloc x nnei x ng1 - nlist_mask: torch.Tensor, # nb x nloc x nnei - sw: torch.Tensor, # nb x nloc x nnei - ) -> torch.Tensor: - ret = g1.unsqueeze(-2) * gg1 - # nb x nloc x nnei x ng1 - ret = _apply_nlist_mask(ret, nlist_mask) - if self.smooth: - ret = _apply_switch(ret, sw) - return ret - - def _apply_bn( - self, - bn_number: int, - gg: torch.Tensor, - ): - if self.do_bn_mode == "uniform": - return self._apply_bn_uni(bn_number, gg) - elif self.do_bn_mode == "component": - return self._apply_bn_comp(bn_number, gg) - else: - return gg - - def _apply_nb_1(self, bn_number: int, gg: torch.Tensor) -> torch.Tensor: - nb, nl, nf = gg.shape - gg = gg.view([nb, 1, nl * nf]) - if bn_number == 1: - assert self.bn1 is not None - gg = self.bn1(gg) - else: - assert self.bn2 is not None - gg = self.bn2(gg) - return gg.view([nb, nl, nf]) - - def _apply_nb_2( - self, - bn_number: int, - gg: torch.Tensor, - ) -> torch.Tensor: - nb, nl, nnei, nf = gg.shape - gg = gg.view([nb, 1, nl * nnei * nf]) - if bn_number == 1: - assert self.bn1 is not None - gg = self.bn1(gg) - else: - assert self.bn2 is not None - gg = self.bn2(gg) - return gg.view([nb, nl, nnei, nf]) - - def _apply_bn_uni( - self, - bn_number: int, - gg: torch.Tensor, - mode: str = "1", - ) -> torch.Tensor: - if len(gg.shape) == 3: - return self._apply_nb_1(bn_number, gg) - elif len(gg.shape) == 4: - return self._apply_nb_2(bn_number, gg) - else: - raise RuntimeError(f"unsupported input shape {gg.shape}") - - def _apply_bn_comp( - self, - bn_number: int, - gg: torch.Tensor, - ) -> torch.Tensor: - ss = gg.shape - nf = ss[-1] - gg = gg.view([-1, nf]) - if bn_number == 1: - assert self.bn1 is not None - gg = self.bn1(gg).view(ss) - else: - assert self.bn2 is not None - gg = self.bn2(gg).view(ss) - return gg - - def forward( - self, - g1_ext: torch.Tensor, # nf x nall x ng1 - g2: torch.Tensor, # nf x nloc x nnei x ng2 - h2: torch.Tensor, # nf x nloc x nnei x 3 - nlist: torch.Tensor, # nf x nloc x nnei - nlist_mask: torch.Tensor, # nf x nloc x nnei - sw: torch.Tensor, # switch func, nf x nloc x nnei - ): - """ - Parameters - ---------- - g1_ext : nf x nall x ng1 extended single-atom chanel - g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant - h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant - nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) - nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 - sw : nf x nloc x nnei switch function - - Returns - ------- - g1: nf x nloc x ng1 updated single-atom chanel - g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant - h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant - """ - cal_gg1 = ( - self.update_g1_has_drrd - or self.update_g1_has_conv - or self.update_g1_has_attn - or self.update_g2_has_g1g1 - ) - - nb, nloc, nnei, _ = g2.shape - nall = g1_ext.shape[1] - g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) - assert (nb, nloc) == g1.shape[:2] - assert (nb, nloc, nnei) == h2.shape[:3] - ng1 = g1.shape[-1] - ng2 = g2.shape[-1] - nh2 = h2.shape[-1] - - if self.bn1 is not None: - g1 = self._apply_bn(1, g1) - if self.bn2 is not None: - g2 = self._apply_bn(2, g2) - if self.update_h2: - h2 = _apply_h_norm(h2) - - g2_update: list[torch.Tensor] = [g2] - h2_update: list[torch.Tensor] = [h2] - g1_update: list[torch.Tensor] = [g1] - g1_mlp: list[torch.Tensor] = [g1] - - if cal_gg1: - gg1 = _make_nei_g1(g1_ext, nlist) - else: - gg1 = None - - if self.update_chnnl_2: - # nb x nloc x nnei x ng2 - assert self.linear2 is not None - g2_1 = self.act(self.linear2(g2)) - g2_update.append(g2_1) - - if self.update_g2_has_g1g1: - assert gg1 is not None - assert self.proj_g1g1g2 is not None - g2_update.append( - self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) - ) - - if self.update_g2_has_attn: - assert self.attn2g_map is not None - assert self.attn2_mh_apply is not None - assert self.attn2_lm is not None - # nb x nloc x nnei x nnei x nh - AAg = self.attn2g_map(g2, h2, nlist_mask, sw) - # nb x nloc x nnei x ng2 - g2_2 = self.attn2_mh_apply(AAg, g2) - g2_2 = self.attn2_lm(g2_2) - g2_update.append(g2_2) - - if self.update_h2: - h2_update.append(self._update_h2(g2, h2, nlist_mask, sw)) - - if self.update_g1_has_conv: - assert gg1 is not None - g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) - - if self.update_g1_has_grrg: - g1_mlp.append(self._update_g1_grrg(g2, h2, nlist_mask, sw)) - - if self.update_g1_has_drrd: - assert gg1 is not None - g1_mlp.append(self._update_g1_grrg(gg1, h2, nlist_mask, sw)) - - # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] - # conv grrg drrd - g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1))) - g1_update.append(g1_1) - - if self.update_g1_has_attn: - assert gg1 is not None - assert self.loc_attn is not None - g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw)) - - # update - if self.update_chnnl_2: - g2_new = self.list_update(g2_update) - h2_new = self.list_update(h2_update) - else: - g2_new, h2_new = g2, h2 - g1_new = self.list_update(g1_update) - return g1_new, g2_new, h2_new - - @torch.jit.export - def list_update_res_avg( - self, - update_list: list[torch.Tensor], - ) -> torch.Tensor: - nitem = len(update_list) - uu = update_list[0] - for ii in range(1, nitem): - uu = uu + update_list[ii] - return uu / (float(nitem) ** 0.5) - - @torch.jit.export - def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor: - nitem = len(update_list) - uu = update_list[0] - scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 - for ii in range(1, nitem): - uu = uu + scale * update_list[ii] - return uu - - @torch.jit.export - def list_update(self, update_list: list[torch.Tensor]) -> torch.Tensor: - if self.update_style == "res_avg": - return self.list_update_res_avg(update_list) - elif self.update_style == "res_incr": - return self.list_update_res_incr(update_list) - else: - raise RuntimeError(f"unknown update style {self.update_style}") - - def _bn_layer( - self, - nf: int = 1, - ) -> Callable: - return torch.nn.BatchNorm1d( - nf, - eps=1e-5, - momentum=self.bn_momentum, - affine=False, - track_running_stats=True, - device=env.DEVICE, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - ) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 64965825a0..ad4ead4d74 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -41,7 +41,6 @@ from .repformer_layer import ( RepformerLayer, ) -from .repformer_layer_old_impl import RepformerLayer as RepformerLayerOld if not hasattr(torch.ops.deepmd, "border_op"): @@ -106,7 +105,6 @@ def __init__( use_sqrt_nnei: bool = True, g1_out_conv: bool = True, g1_out_mlp: bool = True, - old_impl: bool = False, ): r""" The repformer descriptor block. @@ -240,78 +238,48 @@ def __init__( self.ln_eps = ln_eps self.epsilon = 1e-4 self.seed = seed - self.old_impl = old_impl self.g2_embd = MLPLayer( 1, self.g2_dim, precision=precision, seed=child_seed(seed, 0) ) layers = [] for ii in range(nlayers): - if self.old_impl: - layers.append( - RepformerLayerOld( - self.rcut, - self.rcut_smth, - self.sel, - self.ntypes, - self.g1_dim, - self.g2_dim, - axis_neuron=self.axis_neuron, - update_chnnl_2=(ii != nlayers - 1), - update_g1_has_conv=self.update_g1_has_conv, - update_g1_has_drrd=self.update_g1_has_drrd, - update_g1_has_grrg=self.update_g1_has_grrg, - update_g1_has_attn=self.update_g1_has_attn, - update_g2_has_g1g1=self.update_g2_has_g1g1, - update_g2_has_attn=self.update_g2_has_attn, - update_h2=self.update_h2, - attn1_hidden=self.attn1_hidden, - attn1_nhead=self.attn1_nhead, - attn2_has_gate=self.attn2_has_gate, - attn2_hidden=self.attn2_hidden, - attn2_nhead=self.attn2_nhead, - activation_function=self.activation_function, - update_style=self.update_style, - smooth=self.smooth, - ) - ) - else: - layers.append( - RepformerLayer( - self.rcut, - self.rcut_smth, - self.sel, - self.ntypes, - self.g1_dim, - self.g2_dim, - axis_neuron=self.axis_neuron, - update_chnnl_2=(ii != nlayers - 1), - update_g1_has_conv=self.update_g1_has_conv, - update_g1_has_drrd=self.update_g1_has_drrd, - update_g1_has_grrg=self.update_g1_has_grrg, - update_g1_has_attn=self.update_g1_has_attn, - update_g2_has_g1g1=self.update_g2_has_g1g1, - update_g2_has_attn=self.update_g2_has_attn, - update_h2=self.update_h2, - attn1_hidden=self.attn1_hidden, - attn1_nhead=self.attn1_nhead, - attn2_has_gate=self.attn2_has_gate, - attn2_hidden=self.attn2_hidden, - attn2_nhead=self.attn2_nhead, - activation_function=self.activation_function, - update_style=self.update_style, - update_residual=self.update_residual, - update_residual_init=self.update_residual_init, - smooth=self.smooth, - trainable_ln=self.trainable_ln, - ln_eps=self.ln_eps, - precision=precision, - use_sqrt_nnei=self.use_sqrt_nnei, - g1_out_conv=self.g1_out_conv, - g1_out_mlp=self.g1_out_mlp, - seed=child_seed(child_seed(seed, 1), ii), - ) + layers.append( + RepformerLayer( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.g1_dim, + self.g2_dim, + axis_neuron=self.axis_neuron, + update_chnnl_2=(ii != nlayers - 1), + update_g1_has_conv=self.update_g1_has_conv, + update_g1_has_drrd=self.update_g1_has_drrd, + update_g1_has_grrg=self.update_g1_has_grrg, + update_g1_has_attn=self.update_g1_has_attn, + update_g2_has_g1g1=self.update_g2_has_g1g1, + update_g2_has_attn=self.update_g2_has_attn, + update_h2=self.update_h2, + attn1_hidden=self.attn1_hidden, + attn1_nhead=self.attn1_nhead, + attn2_has_gate=self.attn2_has_gate, + attn2_hidden=self.attn2_hidden, + attn2_nhead=self.attn2_nhead, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + smooth=self.smooth, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + precision=precision, + use_sqrt_nnei=self.use_sqrt_nnei, + g1_out_conv=self.g1_out_conv, + g1_out_mlp=self.g1_out_mlp, + seed=child_seed(child_seed(seed, 1), ii), ) + ) self.layers = torch.nn.ModuleList(layers) wanted_shape = (self.ntypes, self.nnei, 4) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 1b51acfa21..e939a2541b 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -55,9 +55,6 @@ EmbeddingNet, NetworkCollection, ) -from deepmd.pt.model.network.network import ( - TypeFilter, -) from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) @@ -83,7 +80,6 @@ def __init__( resnet_dt: bool = False, exclude_types: list[tuple[int, int]] = [], env_protection: float = 0.0, - old_impl: bool = False, type_one_side: bool = True, trainable: bool = True, seed: Optional[Union[int, list[int]]] = None, @@ -109,7 +105,6 @@ def __init__( resnet_dt=resnet_dt, exclude_types=exclude_types, env_protection=env_protection, - old_impl=old_impl, type_one_side=type_one_side, trainable=trainable, seed=seed, @@ -385,7 +380,6 @@ def __init__( resnet_dt: bool = False, exclude_types: list[tuple[int, int]] = [], env_protection: float = 0.0, - old_impl: bool = False, type_one_side: bool = True, trainable: bool = True, seed: Optional[Union[int, list[int]]] = None, @@ -411,7 +405,6 @@ def __init__( self.precision = precision self.prec = PRECISION_DICT[self.precision] self.resnet_dt = resnet_dt - self.old_impl = old_impl self.env_protection = env_protection self.ntypes = len(sel) self.type_one_side = type_one_side @@ -431,39 +424,23 @@ def __init__( stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) - self.filter_layers_old = None - self.filter_layers = None - - if self.old_impl: - if not self.type_one_side: - raise ValueError( - "The old implementation does not support type_one_side=False." - ) - filter_layers = [] - # TODO: remove - start_index = 0 - for type_i in range(self.ntypes): - one = TypeFilter(start_index, sel[type_i], self.filter_neuron) - filter_layers.append(one) - start_index += sel[type_i] - self.filter_layers_old = torch.nn.ModuleList(filter_layers) - else: - ndim = 1 if self.type_one_side else 2 - filter_layers = NetworkCollection( - ndim=ndim, ntypes=len(sel), network_type="embedding_network" + + ndim = 1 if self.type_one_side else 2 + filter_layers = NetworkCollection( + ndim=ndim, ntypes=len(sel), network_type="embedding_network" + ) + for ii, embedding_idx in enumerate( + itertools.product(range(self.ntypes), repeat=ndim) + ): + filter_layers[embedding_idx] = EmbeddingNet( + 1, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + seed=child_seed(self.seed, ii), ) - for ii, embedding_idx in enumerate( - itertools.product(range(self.ntypes), repeat=ndim) - ): - filter_layers[embedding_idx] = EmbeddingNet( - 1, - self.filter_neuron, - activation_function=self.activation_function, - precision=self.precision, - resnet_dt=self.resnet_dt, - seed=child_seed(self.seed, ii), - ) - self.filter_layers = filter_layers + self.filter_layers = filter_layers self.stats = None # set trainable for param in self.parameters(): @@ -632,66 +609,49 @@ def forward( protection=self.env_protection, ) - if self.old_impl: - assert self.filter_layers_old is not None - dmatrix = dmatrix.view( - -1, self.ndescrpt - ) # shape is [nframes*nall, self.ndescrpt] - xyz_scatter = torch.empty( # pylint: disable=no-explicit-dtype - 1, - device=env.DEVICE, - ) - ret = self.filter_layers_old[0](dmatrix) - xyz_scatter = ret - for ii, transform in enumerate(self.filter_layers_old[1:]): - # shape is [nframes*nall, 4, self.filter_neuron[-1]] - ret = transform.forward(dmatrix) - xyz_scatter = xyz_scatter + ret - else: - assert self.filter_layers is not None - dmatrix = dmatrix.view(-1, self.nnei, 4) - dmatrix = dmatrix.to(dtype=self.prec) - nfnl = dmatrix.shape[0] - # pre-allocate a shape to pass jit - xyz_scatter = torch.zeros( - [nfnl, 4, self.filter_neuron[-1]], - dtype=self.prec, - device=extended_coord.device, - ) - # nfnl x nnei - exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei) - for embedding_idx, ll in enumerate(self.filter_layers.networks): - if self.type_one_side: - ii = embedding_idx - # torch.jit is not happy with slice(None) - # ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device) - # applying a mask seems to cause performance degradation - ti_mask = None - else: - # ti: center atom type, ii: neighbor type... - ii = embedding_idx // self.ntypes - ti = embedding_idx % self.ntypes - ti_mask = atype.ravel().eq(ti) - # nfnl x nt - if ti_mask is not None: - mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]] - else: - mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] - # nfnl x nt x 4 - if ti_mask is not None: - rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :] - else: - rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] - rr = rr * mm[:, :, None] - ss = rr[:, :, :1] - # nfnl x nt x ng - gg = ll.forward(ss) - # nfnl x 4 x ng - gr = torch.matmul(rr.permute(0, 2, 1), gg) - if ti_mask is not None: - xyz_scatter[ti_mask] += gr - else: - xyz_scatter += gr + dmatrix = dmatrix.view(-1, self.nnei, 4) + dmatrix = dmatrix.to(dtype=self.prec) + nfnl = dmatrix.shape[0] + # pre-allocate a shape to pass jit + xyz_scatter = torch.zeros( + [nfnl, 4, self.filter_neuron[-1]], + dtype=self.prec, + device=extended_coord.device, + ) + # nfnl x nnei + exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei) + for embedding_idx, ll in enumerate(self.filter_layers.networks): + if self.type_one_side: + ii = embedding_idx + # torch.jit is not happy with slice(None) + # ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device) + # applying a mask seems to cause performance degradation + ti_mask = None + else: + # ti: center atom type, ii: neighbor type... + ii = embedding_idx // self.ntypes + ti = embedding_idx % self.ntypes + ti_mask = atype.ravel().eq(ti) + # nfnl x nt + if ti_mask is not None: + mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]] + else: + mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] + # nfnl x nt x 4 + if ti_mask is not None: + rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :] + else: + rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] + rr = rr * mm[:, :, None] + ss = rr[:, :, :1] + # nfnl x nt x ng + gg = ll.forward(ss) + # nfnl x 4 x ng + gr = torch.matmul(rr.permute(0, 2, 1), gg) + if ti_mask is not None: + xyz_scatter[ti_mask] += gr + else: + xyz_scatter += gr xyz_scatter /= self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index c760f7330b..c028230e9b 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -26,10 +26,6 @@ MLPLayer, NetworkCollection, ) -from deepmd.pt.model.network.network import ( - NeighborWiseAttention, - TypeFilter, -) from deepmd.pt.utils import ( env, ) @@ -85,7 +81,6 @@ def __init__( ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, list[int]]] = None, type: Optional[str] = None, - old_impl: bool = False, ): r"""Construct an embedding net of type `se_atten`. @@ -182,7 +177,6 @@ def __init__( if ln_eps is None: ln_eps = 1e-5 self.ln_eps = ln_eps - self.old_impl = old_impl if isinstance(sel, int): sel = [sel] @@ -195,40 +189,22 @@ def __init__( self.ndescrpt = self.nnei * 4 # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) - if self.old_impl: - assert self.tebd_input_mode in [ - "concat" - ], "Old implementation does not support tebd_input_mode != 'concat'." - self.dpa1_attention = NeighborWiseAttention( - self.attn_layer, - self.nnei, - self.filter_neuron[-1], - self.attn_dim, - dotr=self.attn_dotr, - do_mask=self.attn_mask, - activation=self.activation_function, - scaling_factor=self.scaling_factor, - normalize=self.normalize, - temperature=self.temperature, - smooth=self.smooth, - ) - else: - self.dpa1_attention = NeighborGatedAttention( - self.attn_layer, - self.nnei, - self.filter_neuron[-1], - self.attn_dim, - dotr=self.attn_dotr, - do_mask=self.attn_mask, - scaling_factor=self.scaling_factor, - normalize=self.normalize, - temperature=self.temperature, - trainable_ln=self.trainable_ln, - ln_eps=self.ln_eps, - smooth=self.smooth, - precision=self.precision, - seed=child_seed(self.seed, 0), - ) + self.dpa1_attention = NeighborGatedAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + smooth=self.smooth, + precision=self.precision, + seed=child_seed(self.seed, 0), + ) wanted_shape = (self.ntypes, self.nnei, 4) mean = torch.zeros( @@ -245,48 +221,32 @@ def __init__( else: self.embd_input_dim = 1 - self.filter_layers_old = None - self.filter_layers = None self.filter_layers_strip = None - if self.old_impl: - filter_layers = [] - one = TypeFilter( - 0, - self.nnei, - self.filter_neuron, - return_G=True, - tebd_dim=self.tebd_dim, - use_tebd=True, - tebd_mode=self.tebd_input_mode, - ) - filter_layers.append(one) - self.filter_layers_old = torch.nn.ModuleList(filter_layers) - else: - filter_layers = NetworkCollection( + filter_layers = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers[0] = EmbeddingNet( + self.embd_input_dim, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + seed=child_seed(self.seed, 1), + ) + self.filter_layers = filter_layers + if self.tebd_input_mode in ["strip"]: + filter_layers_strip = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network" ) - filter_layers[0] = EmbeddingNet( - self.embd_input_dim, + filter_layers_strip[0] = EmbeddingNet( + self.tebd_dim_input, self.filter_neuron, activation_function=self.activation_function, precision=self.precision, resnet_dt=self.resnet_dt, - seed=child_seed(self.seed, 1), + seed=child_seed(self.seed, 2), ) - self.filter_layers = filter_layers - if self.tebd_input_mode in ["strip"]: - filter_layers_strip = NetworkCollection( - ndim=0, ntypes=self.ntypes, network_type="embedding_network" - ) - filter_layers_strip[0] = EmbeddingNet( - self.tebd_dim_input, - self.filter_neuron, - activation_function=self.activation_function, - precision=self.precision, - resnet_dt=self.resnet_dt, - seed=child_seed(self.seed, 2), - ) - self.filter_layers_strip = filter_layers_strip + self.filter_layers_strip = filter_layers_strip self.stats = None def get_rcut(self) -> float: @@ -500,75 +460,51 @@ def forward( sw = sw.masked_fill(~nlist_mask, 0.0) # (nb x nloc) x nnei exclude_mask = exclude_mask.view(nb * nloc, nnei) - if self.old_impl: - assert self.filter_layers_old is not None - dmatrix = dmatrix.view( - -1, self.ndescrpt - ) # shape is [nframes*nall, self.ndescrpt] - gg = self.filter_layers_old[0]( - dmatrix, - atype_tebd=atype_tebd_nnei, - nlist_tebd=atype_tebd_nlist, - ) # shape is [nframes*nall, self.neei, out_size] - input_r = torch.nn.functional.normalize( - dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 - ) - gg = self.dpa1_attention( - gg, nlist_mask, input_r=input_r, sw=sw - ) # shape is [nframes*nloc, self.neei, out_size] - inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute( - 0, 2, 1 - ) # shape is [nframes*natoms[0], 4, self.neei] - xyz_scatter = torch.matmul( - inputs_reshape, gg - ) # shape is [nframes*natoms[0], 4, out_size] - else: - assert self.filter_layers is not None - # nfnl x nnei x 4 - dmatrix = dmatrix.view(-1, self.nnei, 4) - nfnl = dmatrix.shape[0] - # nfnl x nnei x 4 - rr = dmatrix - rr = rr * exclude_mask[:, :, None] - ss = rr[:, :, :1] - nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) - atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) - if self.tebd_input_mode in ["concat"]: - if not self.type_one_side: - # nfnl x nnei x (1 + tebd_dim * 2) - ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2) - else: - # nfnl x nnei x (1 + tebd_dim) - ss = torch.concat([ss, nlist_tebd], dim=2) - # nfnl x nnei x ng - gg = self.filter_layers.networks[0](ss) - elif self.tebd_input_mode in ["strip"]: - # nfnl x nnei x ng - gg_s = self.filter_layers.networks[0](ss) - assert self.filter_layers_strip is not None - if not self.type_one_side: - # nfnl x nnei x (tebd_dim * 2) - tt = torch.concat([nlist_tebd, atype_tebd], dim=2) - else: - # nfnl x nnei x tebd_dim - tt = nlist_tebd - # nfnl x nnei x ng - gg_t = self.filter_layers_strip.networks[0](tt) - if self.smooth: - gg_t = gg_t * sw.reshape(-1, self.nnei, 1) - # nfnl x nnei x ng - gg = gg_s * gg_t + gg_s + # nfnl x nnei x 4 + dmatrix = dmatrix.view(-1, self.nnei, 4) + nfnl = dmatrix.shape[0] + # nfnl x nnei x 4 + rr = dmatrix + rr = rr * exclude_mask[:, :, None] + ss = rr[:, :, :1] + nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) + atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + # nfnl x nnei x (1 + tebd_dim * 2) + ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2) + else: + # nfnl x nnei x (1 + tebd_dim) + ss = torch.concat([ss, nlist_tebd], dim=2) + # nfnl x nnei x ng + gg = self.filter_layers.networks[0](ss) + elif self.tebd_input_mode in ["strip"]: + # nfnl x nnei x ng + gg_s = self.filter_layers.networks[0](ss) + assert self.filter_layers_strip is not None + if not self.type_one_side: + # nfnl x nnei x (tebd_dim * 2) + tt = torch.concat([nlist_tebd, atype_tebd], dim=2) else: - raise NotImplementedError + # nfnl x nnei x tebd_dim + tt = nlist_tebd + # nfnl x nnei x ng + gg_t = self.filter_layers_strip.networks[0](tt) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + # nfnl x nnei x ng + gg = gg_s * gg_t + gg_s + else: + raise NotImplementedError - input_r = torch.nn.functional.normalize( - rr.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 - ) - gg = self.dpa1_attention( - gg, nlist_mask, input_r=input_r, sw=sw - ) # shape is [nframes*nloc, self.neei, out_size] - # nfnl x 4 x ng - xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) + input_r = torch.nn.functional.normalize( + rr.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) xyz_scatter = xyz_scatter / self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) rot_mat = xyz_scatter_1[:, :, 1:4] diff --git a/deepmd/pt/model/descriptor/se_atten_v2.py b/deepmd/pt/model/descriptor/se_atten_v2.py index f73ff255e6..11d783261e 100644 --- a/deepmd/pt/model/descriptor/se_atten_v2.py +++ b/deepmd/pt/model/descriptor/se_atten_v2.py @@ -71,7 +71,6 @@ def __init__( # not implemented spin=None, type: Optional[str] = None, - old_impl: bool = False, ) -> None: r"""Construct smooth version of embedding net of type `se_atten_v2`. @@ -191,7 +190,6 @@ def __init__( # not implemented spin=spin, type=type, - old_impl=old_impl, ) def serialize(self) -> dict: diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index b873ee20b8..e82bb23dac 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -68,7 +68,6 @@ def __init__( resnet_dt: bool = False, exclude_types: list[tuple[int, int]] = [], env_protection: float = 0.0, - old_impl: bool = False, trainable: bool = True, seed: Optional[Union[int, list[int]]] = None, type_map: Optional[list[str]] = None, @@ -84,7 +83,6 @@ def __init__( self.precision = precision self.prec = PRECISION_DICT[self.precision] self.resnet_dt = resnet_dt - self.old_impl = False # this does not support old implementation. self.exclude_types = exclude_types self.ntypes = len(sel) self.type_map = type_map diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index ef50274b03..12e1eabf22 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -26,10 +26,6 @@ except ImportError: from torch.jit import Final -from functools import ( - partial, -) - import torch.utils.checkpoint from deepmd.dpmodel.utils.type_embed import ( @@ -48,247 +44,6 @@ def Tensor(*shape): return torch.empty(shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE) -class Dropout(nn.Module): - def __init__(self, p): - super().__init__() - self.p = p - - def forward(self, x, inplace: bool = False): - if self.p > 0 and self.training: - return F.dropout(x, p=self.p, training=True, inplace=inplace) - else: - return x - - -class Identity(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x - - -class DropPath(torch.nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, prob=None): - super().__init__() - self.drop_prob = prob - - def forward(self, x): - if self.drop_prob == 0.0 or not self.training: - return x - keep_prob = 1 - self.drop_prob - shape = (x.shape[0],) + (1,) * ( - x.ndim - 1 - ) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - def extra_repr(self) -> str: - return f"prob={self.drop_prob}" - - -def softmax_dropout( - input_x, dropout_prob, is_training=True, mask=None, bias=None, inplace=True -): - input_x = input_x.contiguous() - if not inplace: - input_x = input_x.clone() - if mask is not None: - input_x += mask - if bias is not None: - input_x += bias - return F.dropout(F.softmax(input_x, dim=-1), p=dropout_prob, training=is_training) - - -def checkpoint_sequential( - functions, - input_x, - enabled=True, -): - def wrap_tuple(a): - return (a,) if type(a) is not tuple else a - - def exec(func, a): - return wrap_tuple(func(*a)) - - def get_wrap_exec(func): - def wrap_exec(*a): - return exec(func, a) - - return wrap_exec - - input_x = wrap_tuple(input_x) - - is_grad_enabled = torch.is_grad_enabled() - - if enabled and is_grad_enabled: - for func in functions: - input_x = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input_x) - else: - for func in functions: - input_x = exec(func, input_x) - return input_x - - -class ResidualLinear(nn.Module): - resnet: Final[int] - - def __init__(self, num_in, num_out, bavg=0.0, stddev=1.0, resnet_dt=False): - """Construct a residual linear layer. - - Args: - - num_in: Width of input tensor. - - num_out: Width of output tensor. - - resnet_dt: Using time-step in the ResNet construction. - """ - super().__init__() - self.num_in = num_in - self.num_out = num_out - self.resnet = resnet_dt - - self.matrix = nn.Parameter(data=Tensor(num_in, num_out)) - nn.init.normal_(self.matrix.data, std=stddev / np.sqrt(num_out + num_in)) - self.bias = nn.Parameter(data=Tensor(1, num_out)) - nn.init.normal_(self.bias.data, mean=bavg, std=stddev) - if self.resnet: - self.idt = nn.Parameter(data=Tensor(1, num_out)) - nn.init.normal_(self.idt.data, mean=1.0, std=0.001) - - def forward(self, inputs): - """Return X ?+ X*W+b.""" - xw_plus_b = torch.matmul(inputs, self.matrix) + self.bias - hidden = torch.tanh(xw_plus_b) - if self.resnet: - hidden = hidden * self.idt - if self.num_in == self.num_out: - return inputs + hidden - elif self.num_in * 2 == self.num_out: - return torch.cat([inputs, inputs], dim=1) + hidden - else: - return hidden - - -class TypeFilter(nn.Module): - use_tebd: Final[bool] - tebd_mode: Final[str] - - def __init__( - self, - offset, - length, - neuron, - return_G=False, - tebd_dim=0, - use_tebd=False, - tebd_mode="concat", - ): - """Construct a filter on the given element as neighbor. - - Args: - - offset: Element offset in the descriptor matrix. - - length: Atom count of this element. - - neuron: Number of neurons in each hidden layers of the embedding net. - """ - super().__init__() - self.offset = offset - self.length = length - self.tebd_dim = tebd_dim - self.use_tebd = use_tebd - self.tebd_mode = tebd_mode - supported_tebd_mode = ["concat", "dot", "dot_residual_s", "dot_residual_t"] - assert ( - tebd_mode in supported_tebd_mode - ), f"Unknown tebd_mode {tebd_mode}! Supported are {supported_tebd_mode}." - if use_tebd and tebd_mode == "concat": - self.neuron = [1 + tebd_dim * 2, *neuron] - else: - self.neuron = [1, *neuron] - - deep_layers = [] - for ii in range(1, len(self.neuron)): - one = ResidualLinear(self.neuron[ii - 1], self.neuron[ii]) - deep_layers.append(one) - self.deep_layers = nn.ModuleList(deep_layers) - - deep_layers_t = [] - if use_tebd and tebd_mode in ["dot", "dot_residual_s", "dot_residual_t"]: - self.neuron_t = [tebd_dim * 2, *neuron] - for ii in range(1, len(self.neuron_t)): - one = ResidualLinear(self.neuron_t[ii - 1], self.neuron_t[ii]) - deep_layers_t.append(one) - self.deep_layers_t = nn.ModuleList(deep_layers_t) - - self.return_G = return_G - - def forward( - self, - inputs, - atype_tebd: Optional[torch.Tensor] = None, - nlist_tebd: Optional[torch.Tensor] = None, - ): - """Calculate decoded embedding for each atom. - - Args: - - inputs: Descriptor matrix. Its shape is [nframes*natoms[0], len_descriptor]. - - Returns - ------- - - `torch.Tensor`: Embedding contributed by me. Its shape is [nframes*natoms[0], 4, self.neuron[-1]]. - """ - inputs_i = inputs[:, self.offset * 4 : (self.offset + self.length) * 4] - inputs_reshape = inputs_i.reshape( - -1, 4 - ) # shape is [nframes*natoms[0]*self.length, 4] - xyz_scatter = inputs_reshape[:, 0:1] - - # concat the tebd as input - if self.use_tebd and self.tebd_mode == "concat": - assert nlist_tebd is not None and atype_tebd is not None - nlist_tebd = nlist_tebd.reshape(-1, self.tebd_dim) - atype_tebd = atype_tebd.reshape(-1, self.tebd_dim) - # [nframes * nloc * nnei, 1 + tebd_dim * 2] - xyz_scatter = torch.concat([xyz_scatter, nlist_tebd, atype_tebd], dim=1) - - for linear in self.deep_layers: - xyz_scatter = linear(xyz_scatter) - # [nframes * nloc * nnei, out_size] - - # dot the tebd output - if self.use_tebd and self.tebd_mode in [ - "dot", - "dot_residual_s", - "dot_residual_t", - ]: - assert nlist_tebd is not None and atype_tebd is not None - nlist_tebd = nlist_tebd.reshape(-1, self.tebd_dim) - atype_tebd = atype_tebd.reshape(-1, self.tebd_dim) - # [nframes * nloc * nnei, tebd_dim * 2] - two_side_tebd = torch.concat([nlist_tebd, atype_tebd], dim=1) - for linear in self.deep_layers_t: - two_side_tebd = linear(two_side_tebd) - # [nframes * nloc * nnei, out_size] - if self.tebd_mode == "dot": - xyz_scatter = xyz_scatter * two_side_tebd - elif self.tebd_mode == "dot_residual_s": - xyz_scatter = xyz_scatter * two_side_tebd + xyz_scatter - elif self.tebd_mode == "dot_residual_t": - xyz_scatter = xyz_scatter * two_side_tebd + two_side_tebd - - xyz_scatter = xyz_scatter.view( - -1, self.length, self.neuron[-1] - ) # shape is [nframes*natoms[0], self.length, self.neuron[-1]] - if self.return_G: - return xyz_scatter - else: - # shape is [nframes*natoms[0], 4, self.length] - inputs_reshape = inputs_i.view(-1, self.length, 4).permute(0, 2, 1) - return torch.matmul(inputs_reshape, xyz_scatter) - - class SimpleLinear(nn.Module): use_timestep: Final[bool] @@ -396,53 +151,6 @@ def _normal_init(self): nn.init.kaiming_normal_(self.weight, nonlinearity="linear") -class Transition(nn.Module): - def __init__(self, d_in, n, dropout=0.0): - super().__init__() - - self.d_in = d_in - self.n = n - - self.linear_1 = Linear(self.d_in, self.n * self.d_in, init="relu") - self.act = nn.GELU() - self.linear_2 = Linear(self.n * self.d_in, d_in, init="final") - self.dropout = dropout - - def _transition(self, x): - x = self.linear_1(x) - x = self.act(x) - x = F.dropout(x, p=self.dropout, training=self.training) - x = self.linear_2(x) - return x - - def forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - x = self._transition(x=x) - return x - - -class Embedding(nn.Embedding): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - dtype=torch.float64, - ): - super().__init__( - num_embeddings, embedding_dim, padding_idx=padding_idx, dtype=dtype - ) - self._normal_init() - - if padding_idx is not None: - self.weight.data[self.padding_idx].zero_() - - def _normal_init(self, std=0.02): - nn.init.normal_(self.weight, mean=0.0, std=std) - - class NonLinearHead(nn.Module): def __init__(self, input_dim, out_dim, activation_fn, hidden=None): super().__init__() @@ -456,27 +164,6 @@ def forward(self, x): return x -class NonLinear(nn.Module): - def __init__(self, input, output_size, hidden=None): - super().__init__() - - if hidden is None: - hidden = input - self.layer1 = Linear(input, hidden, init="relu") - self.layer2 = Linear(hidden, output_size, init="final") - - def forward(self, x): - x = F.linear(x, self.layer1.weight) - # x = fused_ops.bias_torch_gelu(x, self.layer1.bias) - x = nn.GELU()(x) + self.layer1.bias - x = self.layer2(x) - return x - - def zero_init(self): - nn.init.zeros_(self.layer2.weight) - nn.init.zeros_(self.layer2.bias) - - class MaskLMHead(nn.Module): """Head for masked language modeling.""" @@ -844,1327 +531,3 @@ def serialize(self) -> dict: "type_map": self.type_map, "embedding": self.embedding_net.serialize(), } - - -@torch.jit.script -def gaussian(x, mean, std: float): - pi = 3.14159 - a = (2 * pi) ** 0.5 - return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) - - -class GaussianKernel(nn.Module): - def __init__(self, K=128, num_pair=512, std_width=1.0, start=0.0, stop=9.0): - super().__init__() - self.K = K - std_width = std_width - start = start - stop = stop - mean = torch.linspace(start, stop, K, dtype=env.GLOBAL_PT_FLOAT_PRECISION) # pylint: disable=no-explicit-device - self.std = (std_width * (mean[1] - mean[0])).item() - self.register_buffer("mean", mean) - self.mul = Embedding( - num_pair + 1, 1, padding_idx=num_pair, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.bias = Embedding( - num_pair + 1, 1, padding_idx=num_pair, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - nn.init.constant_(self.bias.weight, 0) - nn.init.constant_(self.mul.weight, 1.0) - - def forward(self, x, atom_pair): - mul = self.mul(atom_pair).abs().sum(dim=-2) - bias = self.bias(atom_pair).sum(dim=-2) - x = mul * x.unsqueeze(-1) + bias - # [nframes, nloc, nnei, K] - x = x.expand(-1, -1, -1, self.K) - mean = self.mean.view(-1) - return gaussian(x, mean, self.std) - - -class GaussianEmbedding(nn.Module): - def __init__( - self, - rcut, - kernel_num, - num_pair, - embed_dim, - pair_embed_dim, - sel, - ntypes, - atomic_sum_gbf, - ): - """Construct a gaussian kernel based embedding of pair representation. - - Args: - rcut: Radial cutoff. - kernel_num: Number of gaussian kernels. - num_pair: Number of different pairs. - embed_dim: Dimension of atomic representation. - pair_embed_dim: Dimension of pair representation. - sel: Number of neighbors. - ntypes: Number of atom types. - """ - super().__init__() - self.gbf = GaussianKernel(K=kernel_num, num_pair=num_pair, stop=rcut) - self.gbf_proj = NonLinear(kernel_num, pair_embed_dim) - self.embed_dim = embed_dim - self.pair_embed_dim = pair_embed_dim - self.atomic_sum_gbf = atomic_sum_gbf - if self.atomic_sum_gbf: - if kernel_num != self.embed_dim: - self.edge_proj = torch.nn.Linear( - kernel_num, self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - else: - self.edge_proj = None - self.ntypes = ntypes - self.nnei = sel - - def forward(self, coord_selected, atom_feature, edge_type_2dim, edge_feature): - ## local cluster forward - """Calculate decoded embedding for each atom. - Args: - coord_selected: Clustered atom coordinates with shape [nframes*nloc, natoms, 3]. - atom_feature: Previous calculated atomic features with shape [nframes*nloc, natoms, embed_dim]. - edge_type_2dim: Edge index for gbf calculation with shape [nframes*nloc, natoms, natoms, 2]. - edge_feature: Previous calculated edge features with shape [nframes*nloc, natoms, natoms, pair_dim]. - - Returns - ------- - atom_feature: Updated atomic features with shape [nframes*nloc, natoms, embed_dim]. - attn_bias: Updated edge features as attention bias with shape [nframes*nloc, natoms, natoms, pair_dim]. - delta_pos: Delta position for force/vector prediction with shape [nframes*nloc, natoms, natoms, 3]. - """ - ncluster, natoms, _ = coord_selected.shape - # ncluster x natoms x natoms x 3 - delta_pos = coord_selected.unsqueeze(1) - coord_selected.unsqueeze(2) - # (ncluster x natoms x natoms - dist = delta_pos.norm(dim=-1).view(-1, natoms, natoms) - # [ncluster, natoms, natoms, K] - gbf_feature = self.gbf(dist, edge_type_2dim) - if self.atomic_sum_gbf: - edge_features = gbf_feature - # [ncluster, natoms, K] - sum_edge_features = edge_features.sum(dim=-2) - if self.edge_proj is not None: - sum_edge_features = self.edge_proj(sum_edge_features) - # [ncluster, natoms, embed_dim] - atom_feature = atom_feature + sum_edge_features - - # [ncluster, natoms, natoms, pair_dim] - gbf_result = self.gbf_proj(gbf_feature) - - attn_bias = gbf_result + edge_feature - return atom_feature, attn_bias, delta_pos - - -class NeighborWiseAttention(nn.Module): - def __init__( - self, - layer_num, - nnei, - embed_dim, - hidden_dim, - dotr=False, - do_mask=False, - post_ln=True, - ffn=False, - ffn_embed_dim=1024, - activation="tanh", - scaling_factor=1.0, - head_num=1, - normalize=True, - temperature=None, - smooth=True, - ): - """Construct a neighbor-wise attention net.""" - super().__init__() - self.layer_num = layer_num - attention_layers = [] - for i in range(self.layer_num): - attention_layers.append( - NeighborWiseAttentionLayer( - nnei, - embed_dim, - hidden_dim, - dotr=dotr, - do_mask=do_mask, - post_ln=post_ln, - ffn=ffn, - ffn_embed_dim=ffn_embed_dim, - activation=activation, - scaling_factor=scaling_factor, - head_num=head_num, - normalize=normalize, - temperature=temperature, - smooth=smooth, - ) - ) - self.attention_layers = nn.ModuleList(attention_layers) - - def forward( - self, - input_G, - nei_mask, - input_r: Optional[torch.Tensor] = None, - sw: Optional[torch.Tensor] = None, - ): - """ - Args: - input_G: Input G, [nframes * nloc, nnei, embed_dim]. - nei_mask: neighbor mask, [nframes * nloc, nnei]. - input_r: normalized radial, [nframes, nloc, nei, 3]. - - Returns - ------- - out: Output G, [nframes * nloc, nnei, embed_dim] - - """ - out = input_G - # https://github.com/pytorch/pytorch/issues/39165#issuecomment-635472592 - for layer in self.attention_layers: - out = layer(out, nei_mask, input_r=input_r, sw=sw) - return out - - -class NeighborWiseAttentionLayer(nn.Module): - ffn: Final[bool] - - def __init__( - self, - nnei, - embed_dim, - hidden_dim, - dotr=False, - do_mask=False, - post_ln=True, - ffn=False, - ffn_embed_dim=1024, - activation="tanh", - scaling_factor=1.0, - head_num=1, - normalize=True, - temperature=None, - smooth=True, - ): - """Construct a neighbor-wise attention layer.""" - super().__init__() - self.nnei = nnei - self.embed_dim = embed_dim - self.hidden_dim = hidden_dim - self.dotr = dotr - self.do_mask = do_mask - self.post_ln = post_ln - self.ffn = ffn - self.smooth = smooth - self.attention_layer = GatedSelfAttetion( - nnei, - embed_dim, - hidden_dim, - dotr=dotr, - do_mask=do_mask, - scaling_factor=scaling_factor, - head_num=head_num, - normalize=normalize, - temperature=temperature, - smooth=smooth, - ) - self.attn_layer_norm = nn.LayerNorm( - self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - ) - if self.ffn: - self.ffn_embed_dim = ffn_embed_dim - self.fc1 = nn.Linear( - self.embed_dim, self.ffn_embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.activation_fn = ActivationFn(activation) - self.fc2 = nn.Linear( - self.ffn_embed_dim, self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.final_layer_norm = nn.LayerNorm( - self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - - def forward( - self, - x, - nei_mask, - input_r: Optional[torch.Tensor] = None, - sw: Optional[torch.Tensor] = None, - ): - residual = x - if not self.post_ln: - x = self.attn_layer_norm(x) - x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) - x = residual + x - if self.post_ln: - x = self.attn_layer_norm(x) - if self.ffn: - residual = x - if not self.post_ln: - x = self.final_layer_norm(x) - x = self.fc1(x) - x = self.activation_fn(x) - x = self.fc2(x) - x = residual + x - if self.post_ln: - x = self.final_layer_norm(x) - return x - - -class GatedSelfAttetion(nn.Module): - def __init__( - self, - nnei, - embed_dim, - hidden_dim, - dotr=False, - do_mask=False, - scaling_factor=1.0, - head_num=1, - normalize=True, - temperature=None, - bias=True, - smooth=True, - ): - """Construct a neighbor-wise attention net.""" - super().__init__() - self.nnei = nnei - self.embed_dim = embed_dim - self.hidden_dim = hidden_dim - self.head_num = head_num - self.dotr = dotr - self.do_mask = do_mask - if temperature is None: - self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 - else: - self.scaling = temperature - self.normalize = normalize - self.in_proj = SimpleLinear( - embed_dim, - hidden_dim * 3, - bavg=0.0, - stddev=1.0, - use_timestep=False, - bias=bias, - ) - self.out_proj = SimpleLinear( - hidden_dim, embed_dim, bavg=0.0, stddev=1.0, use_timestep=False, bias=bias - ) - self.smooth = smooth - - def forward( - self, - query, - nei_mask, - input_r: Optional[torch.Tensor] = None, - sw: Optional[torch.Tensor] = None, - attnw_shift: float = 20.0, - ): - """ - Args: - query: input G, [nframes * nloc, nnei, embed_dim]. - nei_mask: neighbor mask, [nframes * nloc, nnei]. - input_r: normalized radial, [nframes, nloc, nei, 3]. - - Returns - ------- - type_embedding: - - """ - q, k, v = self.in_proj(query).chunk(3, dim=-1) - # [nframes * nloc, nnei, hidden_dim] - q = q.view(-1, self.nnei, self.hidden_dim) - k = k.view(-1, self.nnei, self.hidden_dim) - v = v.view(-1, self.nnei, self.hidden_dim) - if self.normalize: - q = F.normalize(q, dim=-1) - k = F.normalize(k, dim=-1) - v = F.normalize(v, dim=-1) - q = q * self.scaling - k = k.transpose(1, 2) - # [nframes * nloc, nnei, nnei] - attn_weights = torch.bmm(q, k) - # [nframes * nloc, nnei] - nei_mask = nei_mask.view(-1, self.nnei) - if self.smooth: - # [nframes * nloc, nnei] - assert sw is not None - sw = sw.view([-1, self.nnei]) - attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[ - :, None, : - ] - attnw_shift - else: - attn_weights = attn_weights.masked_fill( - ~nei_mask.unsqueeze(1), float("-inf") - ) - attn_weights = F.softmax(attn_weights, dim=-1) - attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), 0.0) - if self.smooth: - assert sw is not None - attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] - if self.dotr: - assert input_r is not None, "input_r must be provided when dotr is True!" - angular_weight = torch.bmm(input_r, input_r.transpose(1, 2)) - attn_weights = attn_weights * angular_weight - o = torch.bmm(attn_weights, v) - output = self.out_proj(o) - return output - - -class LocalSelfMultiheadAttention(nn.Module): - def __init__(self, feature_dim, attn_head, scaling_factor=1.0): - super().__init__() - self.feature_dim = feature_dim - self.attn_head = attn_head - self.head_dim = feature_dim // attn_head - assert ( - feature_dim % attn_head == 0 - ), f"feature_dim {feature_dim} must be divided by attn_head {attn_head}!" - self.scaling = (self.head_dim * scaling_factor) ** -0.5 - self.in_proj = SimpleLinear(self.feature_dim, self.feature_dim * 3) - # TODO debug - # self.out_proj = SimpleLinear(self.feature_dim, self.feature_dim) - - def forward( - self, - query, - attn_bias: Optional[torch.Tensor] = None, - nlist_mask: Optional[torch.Tensor] = None, - nlist: Optional[torch.Tensor] = None, - return_attn=True, - ): - nframes, nloc, feature_dim = query.size() - _, _, nnei = nlist.size() - assert feature_dim == self.feature_dim - # [nframes, nloc, feature_dim] - q, k, v = self.in_proj(query).chunk(3, dim=-1) - # [nframes * attn_head * nloc, 1, head_dim] - q = ( - q.view(nframes, nloc, self.attn_head, self.head_dim) - .transpose(1, 2) - .contiguous() - .view(nframes * self.attn_head * nloc, 1, self.head_dim) - * self.scaling - ) - # [nframes, nloc, feature_dim] --> [nframes, nloc + 1, feature_dim] - # with nlist [nframes, nloc, nnei] --> [nframes, nloc, nnei, feature_dim] - # padding = torch.zeros(feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION).to(k.device) - # k = torch.concat([k, padding.unsqueeze(0).unsqueeze(1)], dim=1) - # v = torch.concat([v, padding.unsqueeze(0).unsqueeze(1)], dim=1) - - # [nframes, nloc * nnei, feature_dim] - index = nlist.view(nframes, -1).unsqueeze(-1).expand(-1, -1, feature_dim) - k = torch.gather(k, dim=1, index=index) - # [nframes, nloc * nnei, feature_dim] - v = torch.gather(v, dim=1, index=index) - # [nframes * attn_head * nloc, nnei, head_dim] - k = ( - k.view(nframes, nloc, nnei, self.attn_head, self.head_dim) - .permute(0, 3, 1, 2, 4) - .contiguous() - .view(nframes * self.attn_head * nloc, nnei, self.head_dim) - ) - v = ( - v.view(nframes, nloc, nnei, self.attn_head, self.head_dim) - .permute(0, 3, 1, 2, 4) - .contiguous() - .view(nframes * self.attn_head * nloc, nnei, self.head_dim) - ) - # [nframes * attn_head * nloc, 1, nnei] - attn_weights = torch.bmm(q, k.transpose(1, 2)) - # maskfill - # [nframes, attn_head, nloc, nnei] - attn_weights = attn_weights.view( - nframes, self.attn_head, nloc, nnei - ).masked_fill(~nlist_mask.unsqueeze(1), float("-inf")) - # add bias - if return_attn: - attn_weights = attn_weights + attn_bias - # softmax - # [nframes * attn_head * nloc, 1, nnei] - attn = F.softmax(attn_weights, dim=-1).view( - nframes * self.attn_head * nloc, 1, nnei - ) - # bmm - # [nframes * attn_head * nloc, 1, head_dim] - o = torch.bmm(attn, v) - assert list(o.size()) == [nframes * self.attn_head * nloc, 1, self.head_dim] - # [nframes, nloc, feature_dim] - o = ( - o.view(nframes, self.attn_head, nloc, self.head_dim) - .transpose(1, 2) - .contiguous() - .view(nframes, nloc, self.feature_dim) - ) - # out - ## TODO debug: - # o = self.out_proj(o) - if not return_attn: - return o - else: - return o, attn_weights, attn - - -class NodeTaskHead(nn.Module): - def __init__( - self, - embed_dim: int, - pair_dim: int, - num_head: int, - ): - super().__init__() - self.layer_norm = nn.LayerNorm(embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) - self.pair_norm = nn.LayerNorm(pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) - self.embed_dim = embed_dim - self.q_proj = Linear(embed_dim, embed_dim, bias=False, init="glorot") - self.k_proj = Linear(embed_dim, embed_dim, bias=False, init="glorot") - self.v_proj = Linear(embed_dim, embed_dim, bias=False, init="glorot") - self.num_heads = num_head - self.head_dim = embed_dim // num_head - self.scaling = self.head_dim**-0.5 - self.force_proj = Linear(embed_dim, 1, init="final", bias=False) - self.linear_bias = Linear(pair_dim, num_head) - self.dropout = 0.1 - - def zero_init(self): - nn.init.zeros_(self.force_proj.weight) - - def forward( - self, - query: Tensor, - pair: Tensor, - delta_pos: Tensor, - attn_mask: Tensor = None, - ) -> Tensor: - ncluster, natoms, _ = query.size() - query = self.layer_norm(query) - # [ncluster, natoms, natoms, pair_dim] - pair = self.pair_norm(pair) - - # [ncluster, attn_head, natoms, head_dim] - q = ( - self.q_proj(query) - .view(ncluster, natoms, self.num_heads, -1) - .transpose(1, 2) - * self.scaling - ) - # [ncluster, attn_head, natoms, head_dim] - k = ( - self.k_proj(query) - .view(ncluster, natoms, self.num_heads, -1) - .transpose(1, 2) - ) - v = ( - self.v_proj(query) - .view(ncluster, natoms, self.num_heads, -1) - .transpose(1, 2) - ) - # [ncluster, attn_head, natoms, natoms] - attn = q @ k.transpose(-1, -2) - del q, k - # [ncluster, attn_head, natoms, natoms] - bias = self.linear_bias(pair).permute(0, 3, 1, 2).contiguous() - - # [ncluster, attn_head, natoms, natoms] - attn_probs = softmax_dropout( - attn, - self.dropout, - self.training, - mask=attn_mask, - bias=bias.contiguous(), - ).view(ncluster, self.num_heads, natoms, natoms) - - # delta_pos: [ncluster, natoms, natoms, 3] - # [ncluster, attn_head, natoms, natoms, 3] - rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as( - attn_probs - ) - # [ncluster, attn_head, 3, natoms, natoms] - rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3) - # [ncluster, attn_head, 3, natoms, head_dim] - x = rot_attn_probs @ v.unsqueeze(2) - # [ncluster, natoms, 3, embed_dim] - x = x.permute(0, 3, 2, 1, 4).contiguous().view(ncluster, natoms, 3, -1) - cur_force = self.force_proj(x).view(ncluster, natoms, 3) - return cur_force - - -class EnergyHead(nn.Module): - def __init__( - self, - input_dim, - output_dim, - ): - super().__init__() - self.layer_norm = nn.LayerNorm(input_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) - self.linear_in = Linear(input_dim, input_dim, init="relu") - - self.linear_out = Linear(input_dim, output_dim, bias=True, init="final") - - def forward(self, x): - x = x.type(self.linear_in.weight.dtype) - x = F.gelu(self.layer_norm(self.linear_in(x))) - x = self.linear_out(x) - return x - - -class OuterProduct(nn.Module): - def __init__(self, d_atom, d_pair, d_hid=32): - super().__init__() - - self.d_atom = d_atom - self.d_pair = d_pair - self.d_hid = d_hid - - self.linear_in = nn.Linear( - d_atom, d_hid * 2, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.linear_out = nn.Linear( - d_hid**2, d_pair, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.act = nn.GELU() - - def _opm(self, a, b): - # [nframes, nloc, d] - nframes, nloc, d = a.shape - a = a.view(nframes, nloc, 1, d, 1) - b = b.view(nframes, 1, nloc, 1, d) - # [nframes, nloc, nloc, d, d] - outer = a * b - outer = outer.view(outer.shape[:-2] + (-1,)) - outer = self.linear_out(outer) - return outer - - def forward( - self, - m: torch.Tensor, - nlist: torch.Tensor, - op_mask: float, - op_norm: float, - ) -> torch.Tensor: - ab = self.linear_in(m) - ab = ab * op_mask - a, b = ab.chunk(2, dim=-1) - # [ncluster, natoms, natoms, d_pair] - z = self._opm(a, b) - z *= op_norm - return z - - -class Attention(nn.Module): - def __init__( - self, - q_dim: int, - k_dim: int, - v_dim: int, - head_dim: int, - num_heads: int, - gating: bool = False, - dropout: float = 0.0, - ): - super().__init__() - - self.num_heads = num_heads - self.head_dim = head_dim - total_dim = head_dim * self.num_heads - self.total_dim = total_dim - self.q_dim = q_dim - self.gating = gating - self.linear_q = Linear(q_dim, total_dim, bias=False, init="glorot") - self.linear_k = Linear(k_dim, total_dim, bias=False, init="glorot") - self.linear_v = Linear(v_dim, total_dim, bias=False, init="glorot") - self.linear_o = Linear(total_dim, q_dim, init="final") - self.linear_g = None - if self.gating: - self.linear_g = Linear(q_dim, total_dim, init="gating") - # precompute the 1/sqrt(head_dim) - self.norm = head_dim**-0.5 - self.dropout = dropout - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - bias: torch.Tensor, - mask: torch.Tensor = None, - ) -> torch.Tensor: - nframes, nloc, embed_dim = q.size() - g = None - if self.linear_g is not None: - # gating, use raw query input - # [nframes, nloc, total_dim] - g = self.linear_g(q) - # [nframes, nloc, total_dim] - q = self.linear_q(q) - q *= self.norm - # [nframes, nloc, total_dim] - k = self.linear_k(k) - # [nframes, nloc, total_dim] - v = self.linear_v(v) - # global - # q [nframes, h, nloc, d] - # k [nframes, h, nloc, d] - # v [nframes, h, nloc, d] - # attn [nframes, h, nloc, nloc] - # o [nframes, h, nloc, d] - - # [nframes, h, nloc, d] - q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() - k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() - v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) - # [nframes, h, nloc, nloc] - attn = torch.matmul(q, k.transpose(-1, -2)) - del q, k - # [nframes, h, nloc, nloc] - attn = softmax_dropout(attn, self.dropout, self.training, mask=mask, bias=bias) - # [nframes, h, nloc, d] - o = torch.matmul(attn, v) - del attn, v - - # local - # q [nframes, h, nloc, 1, d] - # k [nframes, h, nloc, nnei, d] - # v [nframes, h, nloc, nnei, d] - # attn [nframes, h, nloc, nnei] - # o [nframes, h, nloc, d] - - assert list(o.size()) == [nframes, self.num_heads, nloc, self.head_dim] - # [nframes, nloc, total_dim] - o = o.transpose(-2, -3).contiguous() - o = o.view(*o.shape[:-2], -1) - - if g is not None: - o = torch.sigmoid(g) * o - - # merge heads - o = self.linear_o(o) - return o - - -class AtomAttention(nn.Module): - def __init__( - self, - q_dim: int, - k_dim: int, - v_dim: int, - pair_dim: int, - head_dim: int, - num_heads: int, - gating: bool = False, - dropout: float = 0.0, - ): - super().__init__() - - self.mha = Attention( - q_dim, k_dim, v_dim, head_dim, num_heads, gating=gating, dropout=dropout - ) - self.layer_norm = nn.LayerNorm(pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) - self.linear_bias = Linear(pair_dim, num_heads) - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - nlist: torch.Tensor, - pair: torch.Tensor, - mask: torch.Tensor = None, - ) -> torch.Tensor: - pair = self.layer_norm(pair) - bias = self.linear_bias(pair).permute(0, 3, 1, 2).contiguous() - return self.mha(q, k, v, bias=bias, mask=mask) - - -class TriangleMultiplication(nn.Module): - def __init__(self, d_pair, d_hid): - super().__init__() - - self.linear_ab_p = Linear(d_pair, d_hid * 2) - self.linear_ab_g = Linear(d_pair, d_hid * 2, init="gating") - - self.linear_g = Linear(d_pair, d_pair, init="gating") - self.linear_z = Linear(d_hid, d_pair, init="final") - - self.layer_norm_out = nn.LayerNorm(d_hid, dtype=env.GLOBAL_PT_FLOAT_PRECISION) - - def forward( - self, - z: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # z : [nframes, nloc, nloc, pair_dim] - - # [nframes, nloc, nloc, pair_dim] - g = self.linear_g(z) - if self.training: - ab = self.linear_ab_p(z) * torch.sigmoid(self.linear_ab_g(z)) - else: - ab = self.linear_ab_p(z) - ab *= torch.sigmoid(self.linear_ab_g(z)) - # [nframes, nloc, nloc, d] - a, b = torch.chunk(ab, 2, dim=-1) - del z, ab - - # [nframes, d, nloc_i, nloc_k] row not trans - a1 = a.permute(0, 3, 1, 2) - # [nframes, d, nloc_k, nloc_j(i)] trans - b1 = b.transpose(-1, -3) - # [nframes, d, nloc_i, nloc_j] - x = torch.matmul(a1, b1) - del a1, b1 - - # [nframes, d, nloc_k, nloc_j(i)] not trans - b2 = b.permute(0, 3, 1, 2) - # [nframes, d, nloc_i, nloc_k] col trans # check TODO - a2 = a.transpose(-1, -3) - - # [nframes, d, nloc_i, nloc_j] - x = x + torch.matmul(a2, b2) - del a, b, a2, b2 - - # [nframes, nloc_i, nloc_j, d] - x = x.permute(0, 2, 3, 1) - - x = self.layer_norm_out(x) - x = self.linear_z(x) - return g * x - - -class EvoformerEncoderLayer(nn.Module): - def __init__( - self, - feature_dim: int = 768, - ffn_dim: int = 2048, - attn_head: int = 8, - activation_fn: str = "gelu", - post_ln: bool = False, - ): - super().__init__() - self.feature_dim = feature_dim - self.ffn_dim = ffn_dim - self.attn_head = attn_head - self.activation_fn = ( - ActivationFn(activation_fn) if activation_fn is not None else None - ) - self.post_ln = post_ln - self.self_attn_layer_norm = nn.LayerNorm( - self.feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - - self.self_attn = LocalSelfMultiheadAttention( - self.feature_dim, - self.attn_head, - ) - self.final_layer_norm = nn.LayerNorm( - self.feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.fc1 = SimpleLinear(self.feature_dim, self.ffn_dim) - self.fc2 = SimpleLinear(self.ffn_dim, self.feature_dim) - - def forward( - self, - x, - attn_bias: Optional[torch.Tensor] = None, - nlist_mask: Optional[torch.Tensor] = None, - nlist: Optional[torch.Tensor] = None, - return_attn=True, - ): - residual = x - if not self.post_ln: - x = self.self_attn_layer_norm(x) - x = self.self_attn( - query=x, - attn_bias=attn_bias, - nlist_mask=nlist_mask, - nlist=nlist, - return_attn=return_attn, - ) - if return_attn: - x, attn_weights, attn_probs = x - x = residual + x - if self.post_ln: - x = self.self_attn_layer_norm(x) - - residual = x - if not self.post_ln: - x = self.final_layer_norm(x) - x = self.fc1(x) - x = self.activation_fn(x) - x = self.fc2(x) - x = residual + x - if self.post_ln: - x = self.final_layer_norm(x) - if not return_attn: - return x - else: - return x, attn_weights, attn_probs - - -# output: atomic_rep, transformed_atomic_rep, pair_rep, delta_pair_rep, norm_x, norm_delta_pair_rep, -class Evoformer2bEncoder(nn.Module): - def __init__( - self, - nnei: int, - layer_num: int = 6, - attn_head: int = 8, - atomic_dim: int = 1024, - pair_dim: int = 100, - feature_dim: int = 1024, - ffn_dim: int = 2048, - post_ln: bool = False, - final_layer_norm: bool = True, - final_head_layer_norm: bool = False, - emb_layer_norm: bool = False, - atomic_residual: bool = False, - evo_residual: bool = False, - residual_factor: float = 1.0, - activation_function: str = "gelu", - ): - super().__init__() - self.nnei = nnei - self.layer_num = layer_num - self.attn_head = attn_head - self.atomic_dim = atomic_dim - self.pair_dim = pair_dim - self.feature_dim = feature_dim - self.ffn_dim = ffn_dim - self.post_ln = post_ln - self._final_layer_norm = final_layer_norm - self._final_head_layer_norm = final_head_layer_norm - self._emb_layer_norm = emb_layer_norm - self.activation_function = activation_function - self.evo_residual = evo_residual - self.residual_factor = residual_factor - if atomic_residual and atomic_dim == feature_dim: - self.atomic_residual = True - else: - self.atomic_residual = False - self.in_proj = SimpleLinear( - self.atomic_dim, - self.feature_dim, - bavg=0.0, - stddev=1.0, - use_timestep=False, - activate="tanh", - ) # TODO - self.out_proj = SimpleLinear( - self.feature_dim, - self.atomic_dim, - bavg=0.0, - stddev=1.0, - use_timestep=False, - activate="tanh", - ) - if self._emb_layer_norm: - self.emb_layer_norm = nn.LayerNorm( - self.feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - - ## TODO debug : self.in_proj_pair = NonLinearHead(self.pair_dim, self.attn_head, activation_fn=None) - self.in_proj_pair = SimpleLinear(self.pair_dim, self.attn_head, activate=None) - evoformer_encoder_layers = [] - for i in range(self.layer_num): - evoformer_encoder_layers.append( - EvoformerEncoderLayer( - feature_dim=self.feature_dim, - ffn_dim=self.ffn_dim, - attn_head=self.attn_head, - activation_fn=self.activation_function, - post_ln=self.post_ln, - ) - ) - self.evoformer_encoder_layers = nn.ModuleList(evoformer_encoder_layers) - if self._final_layer_norm: - self.final_layer_norm = nn.LayerNorm( - self.feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - if self._final_head_layer_norm: - self.final_head_layer_norm = nn.LayerNorm( - self.attn_head, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - - def forward(self, atomic_rep, pair_rep, nlist, nlist_type, nlist_mask): - """Encoder the atomic and pair representations. - - Args: - - atomic_rep: Atomic representation with shape [nframes, nloc, atomic_dim]. - - pair_rep: Pair representation with shape [nframes, nloc, nnei, pair_dim]. - - nlist: Neighbor list with shape [nframes, nloc, nnei]. - - nlist_type: Neighbor types with shape [nframes, nloc, nnei]. - - nlist_mask: Neighbor mask with shape [nframes, nloc, nnei], `False` if blank. - - Returns - ------- - - atomic_rep: Atomic representation after encoder with shape [nframes, nloc, feature_dim]. - - transformed_atomic_rep: Transformed atomic representation after encoder with shape [nframes, nloc, atomic_dim]. - - pair_rep: Pair representation after encoder with shape [nframes, nloc, nnei, attn_head]. - - delta_pair_rep: Delta pair representation after encoder with shape [nframes, nloc, nnei, attn_head]. - - norm_x: Normalization loss of atomic_rep. - - norm_delta_pair_rep: Normalization loss of delta_pair_rep. - """ - # Global branch - nframes, nloc, _ = atomic_rep.size() - nnei = pair_rep.shape[2] - input_atomic_rep = atomic_rep - # [nframes, nloc, feature_dim] - if self.atomic_residual: - atomic_rep = atomic_rep + self.in_proj(atomic_rep) - else: - atomic_rep = self.in_proj(atomic_rep) - - if self._emb_layer_norm: - atomic_rep = self.emb_layer_norm(atomic_rep) - - # Local branch - # [nframes, nloc, nnei, attn_head] - pair_rep = self.in_proj_pair(pair_rep) - # [nframes, attn_head, nloc, nnei] - pair_rep = pair_rep.permute(0, 3, 1, 2).contiguous() - input_pair_rep = pair_rep - pair_rep = pair_rep.masked_fill(~nlist_mask.unsqueeze(1), float("-inf")) - - for i in range(self.layer_num): - atomic_rep, pair_rep, _ = self.evoformer_encoder_layers[i]( - atomic_rep, - attn_bias=pair_rep, - nlist_mask=nlist_mask, - nlist=nlist, - return_attn=True, - ) - - def norm_loss(x, eps=1e-10, tolerance=1.0): - # x = x.float() - max_norm = x.shape[-1] ** 0.5 - norm = torch.sqrt(torch.sum(x**2, dim=-1) + eps) - error = F.relu((norm - max_norm).abs() - tolerance) - return error - - def masked_mean(mask, value, dim=-1, eps=1e-10): - return ( - torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) - ).mean() - - # atomic_rep shape: [nframes, nloc, feature_dim] - # pair_rep shape: [nframes, attn_head, nloc, nnei] - - norm_x = torch.mean(norm_loss(atomic_rep)) - if self._final_layer_norm: - atomic_rep = self.final_layer_norm(atomic_rep) - - delta_pair_rep = pair_rep - input_pair_rep - delta_pair_rep = delta_pair_rep.masked_fill(~nlist_mask.unsqueeze(1), 0) - # [nframes, nloc, nnei, attn_head] - delta_pair_rep = ( - delta_pair_rep.view(nframes, self.attn_head, nloc, nnei) - .permute(0, 2, 3, 1) - .contiguous() - ) - - # [nframes, nloc, nnei] - norm_delta_pair_rep = norm_loss(delta_pair_rep) - norm_delta_pair_rep = masked_mean(mask=nlist_mask, value=norm_delta_pair_rep) - if self._final_head_layer_norm: - delta_pair_rep = self.final_head_layer_norm(delta_pair_rep) - - if self.atomic_residual: - transformed_atomic_rep = atomic_rep + self.out_proj(atomic_rep) - else: - transformed_atomic_rep = self.out_proj(atomic_rep) - - if self.evo_residual: - transformed_atomic_rep = ( - self.residual_factor * transformed_atomic_rep + input_atomic_rep - ) * (1 / np.sqrt(2)) - - return ( - atomic_rep, - transformed_atomic_rep, - pair_rep, - delta_pair_rep, - norm_x, - norm_delta_pair_rep, - ) - - -class Evoformer3bEncoderLayer(nn.Module): - def __init__( - self, - nnei, - embedding_dim: int = 768, - pair_dim: int = 64, - pair_hidden_dim: int = 32, - ffn_embedding_dim: int = 3072, - num_attention_heads: int = 8, - dropout: float = 0.1, - droppath_prob: float = 0.0, - pair_dropout: float = 0.25, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - pre_ln: bool = True, - tri_update: bool = True, - ): - super().__init__() - # Initialize parameters - self.nnei = nnei - self.embedding_dim = embedding_dim - self.num_attention_heads = num_attention_heads - self.attention_dropout = attention_dropout - - # self.dropout = dropout - self.activation_dropout = activation_dropout - - if droppath_prob > 0.0: - self.dropout_module = DropPath(droppath_prob) - else: - self.dropout_module = Dropout(dropout) - - # self.self_attn = AtomAttentionLocal(embedding_dim, embedding_dim, embedding_dim, pair_dim, - # embedding_dim // num_attention_heads, num_attention_heads, - # gating=False, dropout=attention_dropout) - self.self_attn = AtomAttention( - embedding_dim, - embedding_dim, - embedding_dim, - pair_dim, - embedding_dim // num_attention_heads, - num_attention_heads, - gating=False, - dropout=attention_dropout, - ) - # layer norm associated with the self attention layer - self.pre_ln = pre_ln - self.self_attn_layer_norm = nn.LayerNorm( - self.embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.fc1 = nn.Linear( - self.embedding_dim, ffn_embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.fc2 = nn.Linear( - ffn_embedding_dim, self.embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.final_layer_norm = nn.LayerNorm( - self.embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - - self.x_layer_norm_opm = nn.LayerNorm( - self.embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - # self.opm = OuterProductLocal(self.embedding_dim, pair_dim, d_hid=pair_hidden_dim) - self.opm = OuterProduct(self.embedding_dim, pair_dim, d_hid=pair_hidden_dim) - # self.pair_layer_norm_opm = nn.LayerNorm(pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) - self.pair_layer_norm_ffn = nn.LayerNorm( - pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.pair_ffn = Transition( - pair_dim, - 1, - dropout=activation_dropout, - ) - self.pair_dropout = pair_dropout - self.tri_update = tri_update - if self.tri_update: - self.pair_layer_norm_trimul = nn.LayerNorm( - pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION - ) - self.pair_tri_mul = TriangleMultiplication(pair_dim, pair_hidden_dim) - - def update_pair( - self, - x, - pair, - nlist, - op_mask, - op_norm, - ): - # local: - # [nframes, nloc, nnei, pair_dim] - # global: - # [nframes, nloc, nloc, pair_dim] - pair = pair + self.dropout_module( - self.opm(self.x_layer_norm_opm(x), nlist, op_mask, op_norm) - ) - if not self.pre_ln: - pair = self.pair_layer_norm_opm(pair) - return x, pair - - def shared_dropout(self, x, shared_dim, dropout): - shape = list(x.shape) - shape[shared_dim] = 1 - with torch.no_grad(): - mask = x.new_ones(shape) - return F.dropout(mask, p=dropout, training=self.training) * x - - def forward( - self, - x: torch.Tensor, - pair: torch.Tensor, - nlist: torch.Tensor = None, - attn_mask: Optional[torch.Tensor] = None, - pair_mask: Optional[torch.Tensor] = None, - op_mask: float = 1.0, - op_norm: float = 1.0, - ): - """Encoder the atomic and pair representations. - - Args: - - x: Atomic representation with shape [ncluster, natoms, embed_dim]. - - pair: Pair representation with shape [ncluster, natoms, natoms, pair_dim]. - - attn_mask: Attention mask with shape [ncluster, head, natoms, natoms]. - - pair_mask: Neighbor mask with shape [ncluster, natoms, natoms]. - - """ - # [ncluster, natoms, embed_dim] - residual = x - if self.pre_ln: - x = self.self_attn_layer_norm(x) - x = self.self_attn( - x, - x, - x, - nlist=nlist, - pair=pair, - mask=attn_mask, - ) - # x = F.dropout(x, p=self.dropout, training=self.training) - x = self.dropout_module(x) - x = residual + x - if not self.pre_ln: - x = self.self_attn_layer_norm(x) - - residual = x - if self.pre_ln: - x = self.final_layer_norm(x) - x = F.linear(x, self.fc1.weight) - # x = fused_ops.bias_torch_gelu(x, self.fc1.bias) - x = nn.GELU()(x) + self.fc1.bias - x = F.dropout(x, p=self.activation_dropout, training=self.training) - x = self.fc2(x) - # x = F.dropout(x, p=self.dropout, training=self.training) - x = self.dropout_module(x) - - x = residual + x - if not self.pre_ln: - x = self.final_layer_norm(x) - - block = [ - partial( - self.update_pair, - nlist=nlist, - op_mask=op_mask, - op_norm=op_norm, - ) - ] - - x, pair = checkpoint_sequential( - block, - input_x=(x, pair), - ) - - if self.tri_update: - residual_pair = pair - if self.pre_ln: - pair = self.pair_layer_norm_trimul(pair) - - pair = self.shared_dropout( - self.pair_tri_mul(pair, pair_mask), -3, self.pair_dropout - ) - pair = residual_pair + pair - if not self.pre_ln: - pair = self.pair_layer_norm_trimul(pair) - - residual_pair = pair - if self.pre_ln: - pair = self.pair_layer_norm_ffn(pair) - pair = self.dropout_module(self.pair_ffn(pair)) - pair = residual_pair + pair - if not self.pre_ln: - pair = self.pair_layer_norm_ffn(pair) - return x, pair - - -class Evoformer3bEncoder(nn.Module): - def __init__( - self, - nnei, - layer_num=6, - attn_head=8, - atomic_dim=768, - pair_dim=64, - pair_hidden_dim=32, - ffn_embedding_dim=3072, - dropout: float = 0.1, - droppath_prob: float = 0.0, - pair_dropout: float = 0.25, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - pre_ln: bool = True, - tri_update: bool = True, - **kwargs, - ): - super().__init__() - self.nnei = nnei - if droppath_prob > 0: - droppath_probs = [ - x.item() - for x in torch.linspace(0, droppath_prob, layer_num) # pylint: disable=no-explicit-dtype,no-explicit-device - ] - else: - droppath_probs = None - - self.layers = nn.ModuleList( - [ - Evoformer3bEncoderLayer( - nnei, - atomic_dim, - pair_dim, - pair_hidden_dim, - ffn_embedding_dim, - num_attention_heads=attn_head, - dropout=dropout, - droppath_prob=droppath_probs[_], - pair_dropout=pair_dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - pre_ln=pre_ln, - tri_update=tri_update, - ) - for _ in range(layer_num) - ] - ) - - def forward(self, x, pair, attn_mask=None, pair_mask=None, atom_mask=None): - """Encoder the atomic and pair representations. - - Args: - x: Atomic representation with shape [ncluster, natoms, atomic_dim]. - pair: Pair representation with shape [ncluster, natoms, natoms, pair_dim]. - attn_mask: Attention mask (with -inf for softmax) with shape [ncluster, head, natoms, natoms]. - pair_mask: Pair mask (with 1 for real atom pair and 0 for padding) with shape [ncluster, natoms, natoms]. - atom_mask: Atom mask (with 1 for real atom and 0 for padding) with shape [ncluster, natoms]. - - Returns - ------- - x: Atomic representation with shape [ncluster, natoms, atomic_dim]. - pair: Pair representation with shape [ncluster, natoms, natoms, pair_dim]. - - """ - # [ncluster, natoms, 1] - op_mask = atom_mask.unsqueeze(-1) - op_mask = op_mask * (op_mask.size(-2) ** -0.5) - eps = 1e-3 - # [ncluster, natoms, natoms, 1] - op_norm = 1.0 / (eps + torch.einsum("...bc,...dc->...bdc", op_mask, op_mask)) - for layer in self.layers: - x, pair = layer( - x, - pair, - nlist=None, - attn_mask=attn_mask, - pair_mask=pair_mask, - op_mask=op_mask, - op_norm=op_norm, - ) - return x, pair diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 572dc60d56..02d852eab7 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -1,7 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from .atten_lcc import ( - FittingNetAttenLcc, -) from .base_fitting import ( BaseFitting, ) @@ -32,7 +29,6 @@ ) __all__ = [ - "FittingNetAttenLcc", "DenoiseNet", "DipoleFittingNet", "EnergyFittingNet", diff --git a/deepmd/pt/model/task/atten_lcc.py b/deepmd/pt/model/task/atten_lcc.py deleted file mode 100644 index 4f54038548..0000000000 --- a/deepmd/pt/model/task/atten_lcc.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import torch -import torch.nn as nn - -from deepmd.pt.model.network.network import ( - EnergyHead, - NodeTaskHead, -) -from deepmd.pt.model.task.fitting import ( - Fitting, -) -from deepmd.pt.utils import ( - env, -) - - -class FittingNetAttenLcc(Fitting): - def __init__( - self, embedding_width, bias_atom_e, pair_embed_dim, attention_heads, **kwargs - ): - super().__init__() - self.embedding_width = embedding_width - self.engergy_proj = EnergyHead(self.embedding_width, 1) - self.energe_agg_factor = nn.Embedding(4, 1, dtype=env.GLOBAL_PT_FLOAT_PRECISION) - nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01) - bias_atom_e = torch.tensor(bias_atom_e) # pylint: disable=no-explicit-dtype,no-explicit-device - self.register_buffer("bias_atom_e", bias_atom_e) - self.pair_embed_dim = pair_embed_dim - self.attention_heads = attention_heads - self.node_proc = NodeTaskHead( - self.embedding_width, self.pair_embed_dim, self.attention_heads - ) - self.node_proc.zero_init() - - def forward(self, output, pair, delta_pos, atype, nframes, nloc): - # [nframes x nloc x tebd_dim] - output_nloc = (output[:, 0, :]).reshape(nframes, nloc, self.embedding_width) - # Optional: GRRG or mean of gbf TODO - - # energy outut - # [nframes, nloc] - energy_out = self.engergy_proj(output_nloc).view(nframes, nloc) - # [nframes, nloc] - energy_factor = self.energe_agg_factor(torch.zeros_like(atype)).view( - nframes, nloc - ) - energy_out = (energy_out * energy_factor) + self.bias_atom_e[atype] - energy_out = energy_out.sum(dim=-1) - - # vector output - # predict_force: [(nframes x nloc) x (1 + nnei2) x 3] - predict_force = self.node_proc(output, pair, delta_pos=delta_pos) - # predict_force_nloc: [nframes x nloc x 3] - predict_force_nloc = (predict_force[:, 0, :]).reshape(nframes, nloc, 3) - return energy_out, predict_force_nloc diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 56b14677b9..79f9a0a86c 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -113,7 +113,6 @@ def __init__( type_map=type_map, **kwargs, ) - self.old_impl = False # this only supports the new implementation. def _net_out_dim(self): """Set the FittingNet output dim.""" @@ -123,7 +122,6 @@ def serialize(self) -> dict: data = super().serialize() data["type"] = "dipole" data["embedding_width"] = self.embedding_width - data["old_impl"] = self.old_impl data["r_differentiable"] = self.r_differentiable data["c_differentiable"] = self.c_differentiable return data diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 1827569a17..10f88519e1 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -19,9 +19,6 @@ FittingNet, NetworkCollection, ) -from deepmd.pt.model.network.network import ( - ResidualDeep, -) from deepmd.pt.model.task.base_fitting import ( BaseFitting, ) @@ -211,41 +208,24 @@ def __init__( in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam - self.old_impl = kwargs.get("old_impl", False) - if self.old_impl: - filter_layers = [] - for type_i in range(self.ntypes if not self.mixed_types else 1): - bias_type = 0.0 - one = ResidualDeep( - type_i, - self.dim_descrpt, + self.filter_layers = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + net_dim_out, self.neuron, - bias_type, - resnet_dt=self.resnet_dt, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii), ) - filter_layers.append(one) - self.filter_layers_old = torch.nn.ModuleList(filter_layers) - self.filter_layers = None - else: - self.filter_layers = NetworkCollection( - 1 if not self.mixed_types else 0, - self.ntypes, - network_type="fitting_network", - networks=[ - FittingNet( - in_dim, - net_dim_out, - self.neuron, - self.activation_function, - self.resnet_dt, - self.precision, - bias_out=True, - seed=child_seed(self.seed, ii), - ) - for ii in range(self.ntypes if not self.mixed_types else 1) - ], - ) - self.filter_layers_old = None + for ii in range(self.ntypes if not self.mixed_types else 1) + ], + ) # set trainable for param in self.parameters(): param.requires_grad = self.trainable @@ -488,47 +468,29 @@ def _forward_common( dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=descriptor.device, ) # jit assertion - if self.old_impl: - assert self.filter_layers_old is not None - assert xx_zeros is None - if self.mixed_types: - atom_property = self.filter_layers_old[0](xx) + self.bias_atom_e[atype] - outs = outs + atom_property # Shape is [nframes, natoms[0], 1] - else: - for type_i, filter_layer in enumerate(self.filter_layers_old): - mask = atype == type_i - atom_property = filter_layer(xx) - atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask.unsqueeze(-1) - outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + if self.mixed_types: + atom_property = self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] + if xx_zeros is not None: + atom_property -= self.filter_layers.networks[0](xx_zeros) + outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out] else: - if self.mixed_types: - atom_property = ( - self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] - ) + for type_i, ll in enumerate(self.filter_layers.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, net_dim_out)) + atom_property = ll(xx) if xx_zeros is not None: - atom_property -= self.filter_layers.networks[0](xx_zeros) + # must assert, otherwise jit is not happy + assert self.remove_vaccum_contribution is not None + if not ( + len(self.remove_vaccum_contribution) > type_i + and not self.remove_vaccum_contribution[type_i] + ): + atom_property -= ll(xx_zeros) + atom_property = atom_property + self.bias_atom_e[type_i] + atom_property = atom_property * mask outs = ( outs + atom_property ) # Shape is [nframes, natoms[0], net_dim_out] - else: - for type_i, ll in enumerate(self.filter_layers.networks): - mask = (atype == type_i).unsqueeze(-1) - mask = torch.tile(mask, (1, 1, net_dim_out)) - atom_property = ll(xx) - if xx_zeros is not None: - # must assert, otherwise jit is not happy - assert self.remove_vaccum_contribution is not None - if not ( - len(self.remove_vaccum_contribution) > type_i - and not self.remove_vaccum_contribution[type_i] - ): - atom_property -= ll(xx_zeros) - atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask - outs = ( - outs + atom_property - ) # Shape is [nframes, natoms[0], net_dim_out] # nf x nloc mask = self.emask(atype) # nf x nloc x nod diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index a16ab886d4..512044efbd 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -138,7 +138,6 @@ def __init__( type_map=type_map, **kwargs, ) - self.old_impl = False # this only supports the new implementation. def _net_out_dim(self): """Set the FittingNet output dim.""" @@ -195,7 +194,6 @@ def serialize(self) -> dict: data["type"] = "polar" data["@version"] = 3 data["embedding_width"] = self.embedding_width - data["old_impl"] = self.old_impl data["fit_diag"] = self.fit_diag data["shift_diag"] = self.shift_diag data["@variables"]["scale"] = to_numpy_array(self.scale) diff --git a/source/tests/pt/model/test_descriptor_hybrid.py b/source/tests/pt/model/test_descriptor_hybrid.py index 5d03b28399..074af4da4e 100644 --- a/source/tests/pt/model/test_descriptor_hybrid.py +++ b/source/tests/pt/model/test_descriptor_hybrid.py @@ -41,7 +41,6 @@ def test_jit( self.rcut, self.rcut_smth, self.sel, - old_impl=False, ) ddsub1 = DescrptSeR( self.rcut, diff --git a/source/tests/pt/model/test_descriptor_se_r.py b/source/tests/pt/model/test_descriptor_se_r.py index f3692101c5..e4aa405dd8 100644 --- a/source/tests/pt/model/test_descriptor_se_r.py +++ b/source/tests/pt/model/test_descriptor_se_r.py @@ -61,7 +61,6 @@ def test_consistency( self.sel, precision=prec, resnet_dt=idt, - old_impl=False, exclude_mask=em, seed=GLOBAL_SEED, ).to(env.DEVICE) @@ -130,7 +129,6 @@ def test_load_stat(self): self.sel, precision=prec, resnet_dt=idt, - old_impl=False, seed=GLOBAL_SEED, ) dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) @@ -181,7 +179,6 @@ def test_jit( self.sel, precision=prec, resnet_dt=idt, - old_impl=False, seed=GLOBAL_SEED, ) dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py index b825885311..d168ceb2ae 100644 --- a/source/tests/pt/model/test_dpa1.py +++ b/source/tests/pt/model/test_dpa1.py @@ -70,7 +70,6 @@ def test_consistency( tebd_input_mode=tm, use_econf_tebd=ect, type_map=["O", "H"] if ect else None, - old_impl=False, seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) @@ -108,69 +107,6 @@ def test_consistency( atol=atol, err_msg=err_msg, ) - # old impl - if ( - idt is False - and prec == "float64" - and to is False - and tm == "concat" - and ect is False - ): - dd3 = DescrptDPA1( - self.rcut, - self.rcut_smth, - self.sel_mix, - self.nt, - attn_layer=2, - precision=prec, - resnet_dt=idt, - smooth_type_embedding=sm, - old_impl=True, - seed=GLOBAL_SEED, - ).to(env.DEVICE) - dd0_state_dict = dd0.se_atten.state_dict() - dd3_state_dict = dd3.se_atten.state_dict() - - dd0_state_dict_attn = dd0.se_atten.dpa1_attention.state_dict() - dd3_state_dict_attn = dd3.se_atten.dpa1_attention.state_dict() - for i in dd3_state_dict: - dd3_state_dict[i] = ( - dd0_state_dict[ - i.replace(".deep_layers.", ".layers.") - .replace("filter_layers_old.", "filter_layers._networks.") - .replace( - ".attn_layer_norm.weight", ".attn_layer_norm.matrix" - ) - ] - .detach() - .clone() - ) - if ".bias" in i and "attn_layer_norm" not in i: - dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) - dd3.se_atten.load_state_dict(dd3_state_dict) - - dd0_state_dict_tebd = dd0.type_embedding.state_dict() - dd3_state_dict_tebd = dd3.type_embedding.state_dict() - for i in dd3_state_dict_tebd: - dd3_state_dict_tebd[i] = ( - dd0_state_dict_tebd[i.replace("embedding.weight", "matrix")] - .detach() - .clone() - ) - dd3.type_embedding.load_state_dict(dd3_state_dict_tebd) - - rd3, _, _, _, _ = dd3( - torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), - torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), - torch.tensor(self.nlist, dtype=int, device=env.DEVICE), - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd3.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) def test_jit( self, @@ -211,7 +147,6 @@ def test_jit( tebd_input_mode=tm, use_econf_tebd=ect, type_map=["O", "H"] if ect else None, - old_impl=False, seed=GLOBAL_SEED, ) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py index 0beb34c031..2eac49d573 100644 --- a/source/tests/pt/model/test_dpa2.py +++ b/source/tests/pt/model/test_dpa2.py @@ -154,7 +154,6 @@ def test_consistency( precision=prec, use_econf_tebd=ect, type_map=["O", "H"] if ect else None, - old_impl=False, seed=GLOBAL_SEED, ).to(env.DEVICE) @@ -193,45 +192,6 @@ def test_consistency( rtol=rtol, atol=atol, ) - # old impl - if prec == "float64" and rus == "res_avg" and ect is False and ns is False: - dd3 = DescrptDPA2( - self.nt, - repinit=repinit, - repformer=repformer, - # kwargs for descriptor - smooth=sm, - exclude_types=[], - add_tebd_to_repinit_out=False, - precision=prec, - old_impl=True, - seed=GLOBAL_SEED, - ).to(env.DEVICE) - dd0_state_dict = dd0.state_dict() - dd3_state_dict = dd3.state_dict() - for i in list(dd0_state_dict.keys()): - if ".bias" in i and ( - ".linear1." in i or ".linear2." in i or ".head_map." in i - ): - dd0_state_dict[i] = dd0_state_dict[i].unsqueeze(0) - if ".attn2_lm.matrix" in i: - dd0_state_dict[ - i.replace(".attn2_lm.matrix", ".attn2_lm.weight") - ] = dd0_state_dict.pop(i) - - dd3.load_state_dict(dd0_state_dict) - rd3, _, _, _, _ = dd3( - torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), - torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), - torch.tensor(self.nlist, dtype=int, device=env.DEVICE), - torch.tensor(self.mapping, dtype=int, device=env.DEVICE), - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd3.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - ) def test_jit( self, @@ -350,7 +310,6 @@ def test_jit( precision=prec, use_econf_tebd=ect, type_map=["O", "H"] if ect else None, - old_impl=False, seed=GLOBAL_SEED, ).to(env.DEVICE) diff --git a/source/tests/pt/model/test_embedding_net.py b/source/tests/pt/model/test_embedding_net.py index 3605316437..1566eb2416 100644 --- a/source/tests/pt/model/test_embedding_net.py +++ b/source/tests/pt/model/test_embedding_net.py @@ -167,20 +167,15 @@ def test_consistency(self): ) # Reproduced - old_impl = False descriptor = DescrptSeA( self.rcut, self.rcut_smth, self.sel, neuron=self.filter_neuron, axis_neuron=self.axis_neuron, - old_impl=old_impl, ).to(DEVICE) for name, param in descriptor.named_parameters(): - if old_impl: - ms = re.findall(r"(\d)\.deep_layers\.(\d)\.([a-z]+)", name) - else: - ms = re.findall(r"(\d)\.layers\.(\d)\.([a-z]+)", name) + ms = re.findall(r"(\d)\.layers\.(\d)\.([a-z]+)", name) if len(ms) == 1: m = ms[0] key = gen_key(worb=m[2], depth=int(m[1]) + 1, elemid=int(m[0])) diff --git a/source/tests/pt/model/test_ener_fitting.py b/source/tests/pt/model/test_ener_fitting.py index 3255db2784..5c55766455 100644 --- a/source/tests/pt/model/test_ener_fitting.py +++ b/source/tests/pt/model/test_ener_fitting.py @@ -10,7 +10,6 @@ DescrptSeA, ) from deepmd.pt.model.task.ener import ( - EnergyFittingNet, InvarFitting, ) from deepmd.pt.utils import ( @@ -103,53 +102,6 @@ def test_consistency( ) self.assertEqual(ft0.get_sel_type(), ft1.get_sel_type()) - def test_new_old( - self, - ): - nf, nloc, nnei = self.nlist.shape - dd = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE) - rd0, _, _, _, _ = dd( - torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), - torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), - torch.tensor(self.nlist, dtype=int, device=env.DEVICE), - ) - atype = torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE) - - od = 1 - for foo, mixed_types in itertools.product( - [True], - [True, False], - ): - ft0 = EnergyFittingNet( - self.nt, - dd.dim_out, - mixed_types=mixed_types, - ).to(env.DEVICE) - ft1 = EnergyFittingNet( - self.nt, - dd.dim_out, - mixed_types=mixed_types, - old_impl=True, - ).to(env.DEVICE) - dd0 = ft0.state_dict() - dd1 = ft1.state_dict() - for kk, vv in dd1.items(): - new_kk = kk - new_kk = new_kk.replace("filter_layers_old", "filter_layers.networks") - new_kk = new_kk.replace("deep_layers", "layers") - new_kk = new_kk.replace("final_layer", "layers.3") - dd1[kk] = dd0[new_kk] - if kk.split(".")[-1] in ["idt", "bias"]: - dd1[kk] = dd1[kk].unsqueeze(0) - dd1["bias_atom_e"] = dd0["bias_atom_e"] - ft1.load_state_dict(dd1) - ret0 = ft0(rd0, atype) - ret1 = ft1(rd0, atype) - np.testing.assert_allclose( - to_numpy_array(ret0["energy"]), - to_numpy_array(ret1["energy"]), - ) - def test_jit( self, ): diff --git a/source/tests/pt/model/test_se_atten_v2.py b/source/tests/pt/model/test_se_atten_v2.py index f9857fc728..462b2aca34 100644 --- a/source/tests/pt/model/test_se_atten_v2.py +++ b/source/tests/pt/model/test_se_atten_v2.py @@ -66,7 +66,6 @@ def test_consistency( type_one_side=to, use_econf_tebd=ect, type_map=["O", "H"] if ect else None, - old_impl=False, seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) @@ -138,7 +137,6 @@ def test_jit( type_one_side=to, use_econf_tebd=ect, type_map=["O", "H"] if ect else None, - old_impl=False, seed=GLOBAL_SEED, ) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/test_se_e2_a.py b/source/tests/pt/model/test_se_e2_a.py index abe13ce86e..da9e69243c 100644 --- a/source/tests/pt/model/test_se_e2_a.py +++ b/source/tests/pt/model/test_se_e2_a.py @@ -58,7 +58,6 @@ def test_consistency( self.sel, precision=prec, resnet_dt=idt, - old_impl=False, exclude_types=em, seed=GLOBAL_SEED, ).to(env.DEVICE) @@ -105,46 +104,6 @@ def test_consistency( atol=atol, err_msg=err_msg, ) - # old impl - if idt is False and prec == "float64" and em == []: - dd3 = DescrptSeA( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - old_impl=True, - seed=GLOBAL_SEED, - ).to(env.DEVICE) - dd0_state_dict = dd0.sea.state_dict() - dd3_state_dict = dd3.sea.state_dict() - for i in dd3_state_dict: - dd3_state_dict[i] = ( - dd0_state_dict[ - i.replace(".deep_layers.", ".layers.").replace( - "filter_layers_old.", "filter_layers.networks." - ) - ] - .detach() - .clone() - ) - if ".bias" in i: - dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) - dd3.sea.load_state_dict(dd3_state_dict) - - rd3, gr3, _, _, sw3 = dd3( - torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), - torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), - torch.tensor(self.nlist, dtype=int, device=env.DEVICE), - ) - for aa, bb in zip([rd1, gr1, sw1], [rd3, gr3, sw3]): - np.testing.assert_allclose( - aa.detach().cpu().numpy(), - bb.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) def test_jit( self, @@ -169,7 +128,6 @@ def test_jit( self.sel, precision=prec, resnet_dt=idt, - old_impl=False, seed=GLOBAL_SEED, ) dd0.sea.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)