Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pd): add se_atten_v2 #4558

Open
wants to merge 3 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pd/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
DescrptBlockSeA,
DescrptSeA,
)
from .se_atten_v2 import (
DescrptSeAttenV2,
)
from .se_t_tebd import (
DescrptBlockSeTTebd,
DescrptSeTTebd,
Expand All @@ -37,6 +40,7 @@
"DescrptDPA1",
"DescrptDPA2",
"DescrptSeA",
"DescrptSeAttenV2",
"DescrptSeTTebd",
"prod_env_mat",
]
275 changes: 275 additions & 0 deletions deepmd/pd/model/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
Union,
)

import paddle

from deepmd.dpmodel.utils import EnvMat as DPEnvMat
from deepmd.pd.model.descriptor.dpa1 import (
DescrptDPA1,
)
from deepmd.pd.model.network.mlp import (
NetworkCollection,
)
from deepmd.pd.model.network.network import (
TypeEmbedNetConsistent,
)
from deepmd.pd.utils import (
env,
)
from deepmd.pd.utils.env import (
RESERVED_PRECISION_DICT,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .base_descriptor import (
BaseDescriptor,
)
from .se_atten import (
NeighborGatedAttention,
)


@BaseDescriptor.register("se_atten_v2")
class DescrptSeAttenV2(DescrptDPA1):
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: Union[list[int], int],
ntypes: int,
neuron: list = [25, 50, 100],
njzjz marked this conversation as resolved.
Show resolved Hide resolved
axis_neuron: int = 16,
tebd_dim: int = 8,
set_davg_zero: bool = True,
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
attn_mask: bool = False,
activation_function: str = "tanh",
precision: str = "float64",
resnet_dt: bool = False,
exclude_types: list[tuple[int, int]] = [],
njzjz marked this conversation as resolved.
Show resolved Hide resolved
env_protection: float = 0.0,
scaling_factor: int = 1.0,
normalize=True,
temperature=None,
concat_output_tebd: bool = True,
trainable: bool = True,
trainable_ln: bool = True,
ln_eps: Optional[float] = 1e-5,
type_one_side: bool = False,
stripped_type_embedding: Optional[bool] = None,
seed: Optional[Union[int, list[int]]] = None,
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
type_map: Optional[list[str]] = None,
# not implemented
spin=None,
type: Optional[str] = None,
) -> None:
r"""Construct smooth version of embedding net of type `se_atten_v2`.

Parameters
----------
rcut : float
The cut-off radius :math:`r_c`
rcut_smth : float
From where the environment matrix should be smoothed :math:`r_s`
sel : list[int], int
list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius
int: the total maxmum number of atoms in the cut-off radius
ntypes : int
Number of element types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}`
axis_neuron : int
Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix)
tebd_dim : int
Dimension of the type embedding
set_davg_zero : bool
Set the shift of embedding net input to zero.
attn : int
Hidden dimension of the attention vectors
attn_layer : int
Number of attention layers
attn_dotr : bool
If dot the angular gate to the attention weights
attn_mask : bool
(Only support False to keep consistent with other backend references.)
(Not used in this version.)
If mask the diagonal of attention weights
activation_function : str
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision : str
The precision of the embedding net parameters. Supported options are |PRECISION|
resnet_dt : bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
exclude_types : list[list[int]]
The excluded pairs of types which have no interaction with each other.
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
env_protection : float
Protection parameter to prevent division by zero errors during environment matrix calculations.
scaling_factor : float
The scaling factor of normalization in calculations of attention weights.
If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5
normalize : bool
Whether to normalize the hidden vectors in attention weights calculation.
temperature : float
If not None, the scaling of attention weights is `temperature` itself.
concat_output_tebd : bool
Whether to concat type embedding at the output of the descriptor.
trainable : bool
If the weights of this descriptors are trainable.
trainable_ln : bool
Whether to use trainable shift and scale weights in layer normalization.
ln_eps : float, Optional
The epsilon value for layer normalization.
type_one_side : bool
If 'False', type embeddings of both neighbor and central atoms are considered.
If 'True', only type embeddings of neighbor atoms are considered.
Default is 'False'.
stripped_type_embedding : bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separate embedding network.
Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'.
The default value is `None`, which means the `tebd_input_mode` setting will be used instead.
seed : int, Optional
Random seed for parameter initialization.
use_econf_tebd : bool, Optional
Whether to use electronic configuration type embedding.
use_tebd_bias : bool, Optional
Whether to use bias in the type embedding layer.
type_map : list[str], Optional
A list of strings. Give the name to each type of atoms.
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
The old implementation of deepspin.
"""
DescrptDPA1.__init__(
self,
rcut,
rcut_smth,
sel,
ntypes,
neuron=neuron,
axis_neuron=axis_neuron,
tebd_dim=tebd_dim,
tebd_input_mode="strip",
set_davg_zero=set_davg_zero,
attn=attn,
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=attn_mask,
activation_function=activation_function,
precision=precision,
resnet_dt=resnet_dt,
exclude_types=exclude_types,
env_protection=env_protection,
scaling_factor=scaling_factor,
normalize=normalize,
temperature=temperature,
concat_output_tebd=concat_output_tebd,
trainable=trainable,
trainable_ln=trainable_ln,
ln_eps=ln_eps,
smooth_type_embedding=True,
type_one_side=type_one_side,
stripped_type_embedding=stripped_type_embedding,
seed=seed,
use_econf_tebd=use_econf_tebd,
use_tebd_bias=use_tebd_bias,
type_map=type_map,
# not implemented
spin=spin,
type=type,
)

def serialize(self) -> dict:
obj = self.se_atten
data = {
"@class": "Descriptor",
"type": "se_atten_v2",
"@version": 2,
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"axis_neuron": obj.axis_neuron,
"tebd_dim": obj.tebd_dim,
"set_davg_zero": obj.set_davg_zero,
"attn": obj.attn_dim,
"attn_layer": obj.attn_layer,
"attn_dotr": obj.attn_dotr,
"attn_mask": False,
"activation_function": obj.activation_function,
"resnet_dt": obj.resnet_dt,
"scaling_factor": obj.scaling_factor,
"normalize": obj.normalize,
"temperature": obj.temperature,
"trainable_ln": obj.trainable_ln,
"ln_eps": obj.ln_eps,
"type_one_side": obj.type_one_side,
"concat_output_tebd": self.concat_output_tebd,
"use_econf_tebd": self.use_econf_tebd,
"use_tebd_bias": self.use_tebd_bias,
"type_map": self.type_map,
# make deterministic
"precision": RESERVED_PRECISION_DICT[obj.prec],
"embeddings": obj.filter_layers.serialize(),
"embeddings_strip": obj.filter_layers_strip.serialize(),
"attention_layers": obj.dpa1_attention.serialize(),
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
"type_embedding": self.type_embedding.embedding.serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": obj["davg"].detach().cpu().numpy(),
"dstd": obj["dstd"].detach().cpu().numpy(),
},
"trainable": self.trainable,
"spin": None,
}
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptSeAttenV2":
data = data.copy()
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
data.pop("env_mat")
embeddings_strip = data.pop("embeddings_strip")
# compat with version 1
if "use_tebd_bias" not in data:
data["use_tebd_bias"] = True
obj = cls(**data)

def t_cvt(xx):
return paddle.to_tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE)

obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(
type_embedding
)
obj.se_atten["davg"] = t_cvt(variables["davg"])
obj.se_atten["dstd"] = t_cvt(variables["dstd"])
obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings)
obj.se_atten.filter_layers_strip = NetworkCollection.deserialize(
embeddings_strip
)
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
attention_layers
)
return obj
46 changes: 46 additions & 0 deletions source/tests/consistent/descriptor/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PD,
INSTALLED_PT,
CommonTest,
parameterized,
Expand All @@ -44,6 +45,12 @@
)
else:
DescrptSeAttenV2Strict = None
if INSTALLED_PD:
from deepmd.pd.model.descriptor.se_atten_v2 import (
DescrptSeAttenV2 as DescrptSeAttenV2PD,
)
else:
DescrptSeAttenV2PD = None
DescrptSeAttenV2TF = None
from deepmd.utils.argcheck import (
descrpt_se_atten_args,
Expand Down Expand Up @@ -248,11 +255,40 @@ def skip_array_api_strict(self) -> bool:
)
)

@property
def skip_pd(self) -> bool:
njzjz marked this conversation as resolved.
Show resolved Hide resolved
(
tebd_dim,
resnet_dt,
type_one_side,
attn,
attn_layer,
attn_dotr,
excluded_types,
env_protection,
set_davg_zero,
scaling_factor,
normalize,
temperature,
ln_eps,
concat_output_tebd,
precision,
use_econf_tebd,
use_tebd_bias,
) = self.param
return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests(
njzjz marked this conversation as resolved.
Show resolved Hide resolved
attn_layer,
attn_dotr,
normalize,
temperature,
)

njzjz marked this conversation as resolved.
Show resolved Hide resolved
tf_class = DescrptSeAttenV2TF
dp_class = DescrptSeAttenV2DP
pt_class = DescrptSeAttenV2PT
jax_class = DescrptSeAttenV2JAX
array_api_strict_class = DescrptSeAttenV2Strict
pd_class = DescrptSeAttenV2PD
args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False))

def setUp(self) -> None:
Expand Down Expand Up @@ -339,6 +375,16 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
mixed_types=True,
)

def eval_pd(self, pd_obj: Any) -> Any:
return self.eval_pd_descriptor(
pd_obj,
self.natoms,
self.coords,
self.atype,
self.box,
mixed_types=True,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0], ret[1])

Expand Down
Loading