From 615bf859b9f998b513c7bb13ffc61f4e8c230574 Mon Sep 17 00:00:00 2001 From: Yusong Wang Date: Mon, 6 Nov 2023 11:35:01 +0000 Subject: [PATCH] Add a new Equivariant GNN named ViSNet --- pahelix/model_zoo/visnet.py | 537 ++++++++++++++++++++ pahelix/networks/visnet_output_modules.py | 113 +++++ pahelix/networks/visnet_utils.py | 587 ++++++++++++++++++++++ 3 files changed, 1237 insertions(+) create mode 100755 pahelix/model_zoo/visnet.py create mode 100755 pahelix/networks/visnet_output_modules.py create mode 100755 pahelix/networks/visnet_utils.py diff --git a/pahelix/model_zoo/visnet.py b/pahelix/model_zoo/visnet.py new file mode 100755 index 00000000..5cce7290 --- /dev/null +++ b/pahelix/model_zoo/visnet.py @@ -0,0 +1,537 @@ +from typing import Optional, Tuple + +import paddle +import pgl +from paddle import Tensor, nn +from pgl.message import Message +from pgl.utils import op + +from pahelix.networks.visnet_output_modules import EquivariantScalar +from pahelix.networks.visnet_utils import (Atomref, CosineCutoff, Distance, + EdgeEmbedding, ExpNormalSmearing, + NeighborEmbedding, Sphere, + VecLayerNorm) + + +class ViS_Graph(pgl.Graph): + def recv(self, reduce_func, msg, recv_mode="dst"): + r"""Receives messages and reduces them.""" + if not self._is_tensor: + raise ValueError("You must call Graph.tensor()") + + if not isinstance(msg, dict): + raise TypeError( + "The input of msg should be a dict, but receives a %s" % (type(msg)) + ) + + if not callable(reduce_func): + raise TypeError("reduce_func should be callable") + + src, dst, eid = self.sorted_edges(sort_by=recv_mode) + msg = op.RowReader(msg, eid) + uniq_ind, segment_ids = self.get_segment_ids(src, dst, segment_by=recv_mode) + bucketed_msg = Message(msg, segment_ids) + output = reduce_func(bucketed_msg) + x, vec = output + + x_output_dim = x.shape[-1] + vec_output_dim1 = vec.shape[-1] + vec_output_dim2 = vec.shape[-2] + x_init_output = paddle.zeros( + shape=[self._num_nodes, x_output_dim], dtype=x.dtype + ) + x_final_output = paddle.scatter(x_init_output, uniq_ind, x) + + vec_init_output = paddle.zeros( + shape=[self._num_nodes, vec_output_dim2, vec_output_dim1], dtype=vec.dtype + ) + vec_final_output = paddle.scatter(vec_init_output, uniq_ind, vec) + + return x_final_output, vec_final_output + + +class ViS_MP(nn.Layer): + r"""The message passing module without vertex geometric features + of the equivariant vector-scalar interactive graph neural network (ViSNet) + from the `"Enhancing geometric representations for molecules + with equivariant vector-scalar interactive message passing" + `_ paper. + + Args: + num_heads (int): The number of attention heads. + hidden_channels (int): The number of hidden channels + in the node embeddings. + cutoff (float): The cutoff distance. + vecnorm_type (str): The type of normalization + to apply to the vectors. + trainable_vecnorm (bool): Whether the normalization weights + are trainable. + last_layer (bool): Whether this is the last layer + in the model. + """ + + def __init__( + self, + num_heads: int, + hidden_channels: int, + cutoff: float, + vecnorm_type: str, + trainable_vecnorm: bool, + last_layer: bool = False, + ): + super(ViS_MP, self).__init__() + assert hidden_channels % num_heads == 0, ( + f"The number of hidden channels ({hidden_channels}) " + f"must be evenly divisible by the number of " + f"attention heads ({num_heads})" + ) + + self.num_heads = num_heads + self.hidden_channels = hidden_channels + self.head_dim = hidden_channels // num_heads + self.last_layer = last_layer + + self.layernorm = nn.LayerNorm(hidden_channels) + self.vec_layernorm = VecLayerNorm( + hidden_channels, trainable=trainable_vecnorm, norm_type=vecnorm_type + ) + + self.act = nn.Silu() + self.attn_activation = nn.Silu() + + self.cutoff = CosineCutoff(cutoff) + + self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 3, bias_attr=False) + + self.q_proj = nn.Linear(hidden_channels, hidden_channels) + self.k_proj = nn.Linear(hidden_channels, hidden_channels) + self.v_proj = nn.Linear(hidden_channels, hidden_channels) + self.dk_proj = nn.Linear(hidden_channels, hidden_channels) + self.dv_proj = nn.Linear(hidden_channels, hidden_channels) + + self.s_proj = nn.Linear(hidden_channels, hidden_channels * 2) + if not self.last_layer: + self.f_proj = nn.Linear(hidden_channels, hidden_channels) + self.w_src_proj = nn.Linear( + hidden_channels, hidden_channels, bias_attr=False + ) + self.w_trg_proj = nn.Linear( + hidden_channels, hidden_channels, bias_attr=False + ) + + self.o_proj = nn.Linear(hidden_channels, hidden_channels * 3) + + @staticmethod + def vector_rejection(vec: Tensor, d_ij: Tensor): + r"""Computes the component of 'vec' orthogonal to 'd_ij'. + + Args: + vec (paddle.Tensor): The input vector. + d_ij (paddle.Tensor): The reference vector. + + Returns: + vec_rej (paddle.Tensor): The component of 'vec' + orthogonal to 'd_ij'. + """ + vec_proj = (vec * d_ij.unsqueeze(2)).sum(axis=1, keepdim=True) + return vec - vec_proj * d_ij.unsqueeze(2) + + def forward(self, graph: pgl.Graph) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Computes the residual scalar and vector features + of the nodes and scalar featues of the edges. + + Args: + graph (pgl.Graph): + - num_nodes, + - edges <--> edge_index , + - node_feat <--> x, vec, + - edge_feat <--> r_ij, f_ij, d_ij, + + Returns: + dx (paddle.Tensor): The residual scalar features + of the nodes. + dvec (paddle.Tensor): The residual vector features + of the nodes. + df_ij (paddle.Tensor, optional): The residual scalar features + of the edges, or None if this is the last layer. + """ + x, vec, r_ij, f_ij, d_ij = ( + graph.node_feat["x"], + graph.node_feat["vec"], + graph.edge_feat["r_ij"], + graph.edge_feat["f_ij"], + graph.edge_feat["d_ij"], + ) + x = self.layernorm(x) + vec = self.vec_layernorm(vec) + + q = self.q_proj(x).reshape([-1, self.num_heads, self.head_dim]) + k = self.k_proj(x).reshape([-1, self.num_heads, self.head_dim]) + v = self.v_proj(x).reshape([-1, self.num_heads, self.head_dim]) + dk = self.act(self.dk_proj(f_ij)).reshape([-1, self.num_heads, self.head_dim]) + dv = self.act(self.dv_proj(f_ij)).reshape([-1, self.num_heads, self.head_dim]) + + vec1, vec2, vec3 = paddle.split(self.vec_proj(vec), 3, axis=-1) + vec_dot = (vec1 * vec2).sum(axis=1) + + def _send_func(src_feat, dst_feat, edge_feat): + q_i = dst_feat["q"] + k_j, v_j, vec_j = (src_feat["k"], src_feat["v"], src_feat["vec"]) + dk, dv, r_ij, d_ij = ( + edge_feat["dk"], + edge_feat["dv"], + edge_feat["r_ij"], + edge_feat["d_ij"], + ) + + attn = (q_i * k_j * dk).sum(axis=-1) + attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1) + + v_j = v_j * dv + v_j = (v_j * attn.unsqueeze(2)).reshape([-1, self.hidden_channels]) + + s1, s2 = paddle.split(self.act(self.s_proj(v_j)), 2, axis=1) + vec_j = vec_j * s1.unsqueeze(1) + s2.unsqueeze(1) * d_ij.unsqueeze(2) + + return {"x": v_j, "vec": vec_j} + + def _recv_func(msg: pgl.Message): + x, vec = msg["x"], msg["vec"] + return msg.reduce(x, pool_type="sum"), msg.reduce(vec, pool_type="sum") + + msg = graph.send( + message_func=_send_func, + node_feat={"q": q, "k": k, "v": v, "vec": vec}, + edge_feat={"dk": dk, "dv": dv, "r_ij": r_ij, "d_ij": d_ij}, + ) + + x, vec_out = graph.recv(reduce_func=_recv_func, msg=msg) + + o1, o2, o3 = paddle.split(self.o_proj(x), 3, axis=1) + dx = vec_dot * o2 + o3 + dvec = vec3 * o1.unsqueeze(1) + vec_out + + def _send_func_dihedral(src_feat, dst_feat, edge_feat): + vec_i, vec_j = dst_feat["vec"], src_feat["vec"] + d_ij, f_ij = edge_feat["d_ij"], edge_feat["f_ij"] + + w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij) + w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij) + w_dot = (w1 * w2).sum(axis=1) + df_ij = self.act(self.f_proj(f_ij)) * w_dot + + return {"df_ij": df_ij} + + if not self.last_layer: + edge_msg = graph.send( + message_func=_send_func_dihedral, + node_feat={ + "vec": vec, + }, + edge_feat={"f_ij": f_ij, "d_ij": d_ij}, + ) + + df_ij = edge_msg["df_ij"] + + return dx, dvec, df_ij + + return dx, dvec, None + + +class ViSNetBlock(nn.Layer): + r"""The representation module of the equivariant vector-scalar + interactive graph neural network (ViSNet) from the + `"Enhancing geometric representations for molecules + with equivariant vector-scalar interactive message passing" + `_ paper. + + Args: + lmax (int): The maximum degree + of the spherical harmonics. + vecnorm_type (str): The type of normalization + to apply to the vectors. + trainable_vecnorm (bool): Whether the normalization weights + are trainable. + num_heads (int): The number of attention heads. + num_layers (int): The number of layers in the network. + hidden_channels (int): The number of hidden channels + in the node embeddings. + num_rbf (int): The number of radial basis functions. + trainable_rbf (bool):Whether the radial basis function + parameters are trainable. + max_z (int): The maximum atomic numbers. + cutoff (float): The cutoff distance. + max_num_neighbors (int):The maximum number of neighbors + considered for each atom. + vertex (bool): Whether to use vertex geometric features. + """ + + def __init__( + self, + lmax: int = 1, + vecnorm_type: str = "none", + trainable_vecnorm: bool = False, + num_heads: int = 8, + num_layers: int = 9, + hidden_channels: int = 256, + num_rbf: int = 32, + trainable_rbf: bool = False, + max_z: int = 100, + cutoff: float = 5.0, + max_num_neighbors: int = 32, + ): + super().__init__() + self.lmax = lmax + self.vecnorm_type = vecnorm_type + self.trainable_vecnorm = trainable_vecnorm + self.num_heads = num_heads + self.num_layers = num_layers + self.hidden_channels = hidden_channels + self.num_rbf = num_rbf + self.trainable_rbf = trainable_rbf + self.cutoff = cutoff + self.max_num_neighbors = max_num_neighbors + + self.embedding = nn.Embedding(max_z, hidden_channels) + self.distance = Distance(cutoff, max_num_neighbors, loop=True) + self.sphere = Sphere(lmax=lmax) + self.distance_expansion = ExpNormalSmearing(cutoff, num_rbf, trainable_rbf) + self.neighbor_embedding = NeighborEmbedding( + hidden_channels, num_rbf, cutoff, max_z + ) + self.edge_embedding = EdgeEmbedding(num_rbf, hidden_channels) + + self.vis_mp_layers = nn.LayerList() + vis_mp_kwargs = dict( + num_heads=num_heads, + hidden_channels=hidden_channels, + cutoff=cutoff, + vecnorm_type=vecnorm_type, + trainable_vecnorm=trainable_vecnorm, + ) + vis_mp_class = ViS_MP + for _ in range(num_layers - 1): + layer = vis_mp_class(last_layer=False, **vis_mp_kwargs) + self.vis_mp_layers.append(layer) + self.vis_mp_layers.append(vis_mp_class(last_layer=True, **vis_mp_kwargs)) + + self.out_norm = nn.LayerNorm(hidden_channels) + self.vec_out_norm = VecLayerNorm( + hidden_channels, trainable=trainable_vecnorm, norm_type=vecnorm_type + ) + + def forward(self, graph: pgl.Graph): + r"""Computes the scalar and vector features of the nodes. + + Args: + graph (pgl.Graph): + - num_nodes, + - node_feat <--> z, pos, + + Returns: + x (paddle.Tensor): The scalar features of the nodes. + vec (paddle.Tensor): The vector features of the nodes. + """ + z, pos = graph.node_feat["z"], graph.node_feat["pos"] + + x = self.embedding(z) + edge_index, edge_weight, edge_vec = self.distance(pos, graph.graph_node_id) + edge_attr = self.distance_expansion(edge_weight) + mask = edge_index[0] != edge_index[1] + edge_vec[mask] = edge_vec[mask] / paddle.norm(edge_vec[mask], axis=1).unsqueeze( + 1 + ) + edge_vec = self.sphere(edge_vec) + x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr) + vec = paddle.zeros((x.shape[0], ((self.lmax + 1) ** 2) - 1, x.shape[1])) + edge_attr = self.edge_embedding(edge_index, edge_attr, x) + + vis_graph = ViS_Graph( + num_nodes=x.shape[0], + edges=edge_index.T, + node_feat={"x": x, "vec": vec}, + edge_feat={"r_ij": edge_weight, "f_ij": edge_attr, "d_ij": edge_vec}, + ) + vis_graph._graph_node_index = graph._graph_node_index + + # ViS-MP Layers + for attn in self.vis_mp_layers[:-1]: + dx, dvec, dedge_attr = attn(vis_graph) + x = x + dx + vec = vec + dvec + edge_attr = edge_attr + dedge_attr + vis_graph.node_feat["x"] = x + vis_graph.node_feat["vec"] = vec + vis_graph.edge_feat["f_ij"] = edge_attr + + dx, dvec, _ = self.vis_mp_layers[-1](vis_graph) + x = x + dx + vec = vec + dvec + + x = self.out_norm(x) + vec = self.vec_out_norm(vec) + + return x, vec + + +class ViSNet(nn.Layer): + r"""A PyTorch module that implements + the equivariant vector-scalar interactive graph neural network (ViSNet) + from the `"Enhancing geometric representations for molecules + with equivariant vector-scalar interactive message passing" + `_ paper. + + Args: + lmax (int): The maximum degree + of the spherical harmonics. + vecnorm_type (str): The type of normalization + to apply to the vectors. + trainable_vecnorm (bool): Whether the normalization weights + are trainable. + num_heads (int): The number of attention heads. + num_layers (int): The number of layers in the network. + hidden_channels (int): The number of hidden channels + in the node embeddings. + num_rbf (int): The number of radial basis functions. + trainable_rbf (bool): Whether the radial basis function + parameters are trainable. + max_z (int): The maximum atomic numbers. + cutoff (float): The cutoff distance. + max_num_neighbors (int): The maximum number of neighbors + considered for each atom. + vertex (bool): Whether to use vertex geometric features. + atomref (paddle.Tensor, optional): A tensor of atom reference values, + or None if not provided. + reduce_op (str): The type of reduction operation to apply + ("sum", "mean"). + mean (float, optional): The mean of the output distribution, + or 0 if not provided. + std (float, optional): The standard deviation + of the output distribution, or 1 if not provided. + derivative (bool): Whether to compute the derivative of the output + with respect to the positions. + """ + + def __init__( + self, + lmax: int = 1, + vecnorm_type: str = "none", + trainable_vecnorm: bool = False, + num_heads: int = 8, + num_layers: int = 6, + hidden_channels: int = 128, + num_rbf: int = 32, + trainable_rbf: bool = False, + max_z: int = 100, + cutoff: float = 5.0, + max_num_neighbors: int = 32, + atomref: Optional[Tensor] = None, + reduce_op: str = "sum", + mean: Optional[float] = None, + std: Optional[float] = None, + derivative: bool = False, + ): + super(ViSNet, self).__init__() + self.representation_model = ViSNetBlock( + lmax=lmax, + vecnorm_type=vecnorm_type, + trainable_vecnorm=trainable_vecnorm, + num_heads=num_heads, + num_layers=num_layers, + hidden_channels=hidden_channels, + num_rbf=num_rbf, + trainable_rbf=trainable_rbf, + max_z=max_z, + cutoff=cutoff, + max_num_neighbors=max_num_neighbors, + ) + self.output_model = EquivariantScalar(hidden_channels=hidden_channels) + self.prior_model = Atomref(atomref=atomref, max_z=max_z) + self.reduce_op = reduce_op + self.derivative = derivative + + mean = paddle.to_tensor(0) if mean is None else paddle.to_tensor(mean) + self.register_buffer("mean", mean) + std = paddle.to_tensor(1) if std is None else paddle.to_tensor(std) + self.register_buffer("std", std) + + def forward(self, graph: pgl.Graph) -> Tuple[Tensor, Optional[Tensor]]: + r"""Computes the energies or properties (forces) + for a batch of molecules. + + Args: + graph (pgl.Graph): + - num_nodes, + - node_feat <--> z, pos, + + Returns: + y (paddle.Tensor): The energies or properties for each molecule. + dy (paddle.Tensor, optional): The negative derivative of energies. + """ + if self.derivative: + graph.node_feat["pos"].stop_gradient = False + + x, v = self.representation_model(graph) + x = self.output_model.pre_reduce(x, v) + x = x * self.std + + if self.prior_model is not None: + x = self.prior_model(x, z=graph.node_feat["z"]) + + y = pgl.math.segment_pool(x, graph.graph_node_id, pool_type=self.reduce_op) + y = y + self.mean + + if self.derivative: + try: + dy = paddle.grad( + [y], + [graph.node_feat["pos"]], + grad_outputs=[paddle.ones_like(y)], + create_graph=True, + retain_graph=True, + )[0] + return y, -dy + except RuntimeError: + print( + "Since the Op segment_pool_grad doesn't have any gradop. " + \ + "Can't compute the derivative of the energies with respect to the positions." + ) + print( + "The derivative of the energies with respect to the positions is None." + ) + return y, None + return y, None + +if __name__ == '__main__': + + graph = pgl.Graph( + num_nodes=5, + edges=paddle.to_tensor([]), + node_feat={ + "z": paddle.randint(0, 100, (5,)), + "pos": paddle.rand((5, 3)), + }, + ) + + model = ViSNet( + lmax=1, + vecnorm_type="none", + trainable_vecnorm=False, + num_heads=8, + num_layers=6, + hidden_channels=128, + num_rbf=32, + trainable_rbf=False, + max_z=100, + cutoff=5.0, + max_num_neighbors=32, + atomref=None, + reduce_op="sum", + mean=None, + std=None, + derivative=False, + ) + + out, _ = model(graph) + + assert out.shape == [1, 1] \ No newline at end of file diff --git a/pahelix/networks/visnet_output_modules.py b/pahelix/networks/visnet_output_modules.py new file mode 100755 index 00000000..d9d77fb5 --- /dev/null +++ b/pahelix/networks/visnet_output_modules.py @@ -0,0 +1,113 @@ +from typing import Optional, Tuple + +import paddle +from paddle import Tensor, nn + + +class GatedEquivariantBlock(nn.Layer): + r"""Applies a gated equivariant operation + to scalar features and vector features from the + `"Equivariant message passing for the prediction + of tensorial properties and molecular spectra" + `_ paper. + + Args: + hidden_channels (int): The number of hidden channels + in the node embeddings. + out_channels (int): The number of output channels. + intermediate_channels (int or None): The number of channels + in the intermediate layer, + or None to use the same number as 'hidden_channels'. + scalar_activation (bool): Whether to apply + a scalar activation function to the output node features. + """ + + def __init__( + self, + hidden_channels: int, + out_channels: int, + intermediate_channels: Optional[int] = None, + scalar_activation: bool = False, + ): + super().__init__() + self.out_channels = out_channels + + if intermediate_channels is None: + intermediate_channels = hidden_channels + + self.vec1_proj = nn.Linear(hidden_channels, hidden_channels, bias_attr=False) + self.vec2_proj = nn.Linear(hidden_channels, out_channels, bias_attr=False) + + self.update_net = nn.Sequential( + nn.Linear(hidden_channels * 2, intermediate_channels), + nn.Silu(), + nn.Linear(intermediate_channels, out_channels * 2), + ) + + self.act = nn.Silu() if scalar_activation else None + + + def forward(self, x: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]: + r"""Applies a gated equivariant operation + to node features and vector features. + + Args: + x (paddle.Tensor): The scalar features of the nodes. + v (paddle.Tensor): The vector features of the nodes. + + Returns: + x (paddle.Tensor): The updated scalar features of the nodes. + v (paddle.Tensor): The updated vector features of the nodes. + """ + vec1 = paddle.norm(self.vec1_proj(v), axis=-2) + vec2 = self.vec2_proj(v) + + x = paddle.concat([x, vec1], axis=-1) + x, v = paddle.split(self.update_net(x), 2, axis=-1) + v = v.unsqueeze(1) * vec2 + + if self.act is not None: + x = self.act(x) + return x, v + + +class EquivariantScalar(nn.Layer): + r"""Computes final scalar outputs based on + node features and vector features. + + Args: + hidden_channels (int): The number of hidden channels + in the node embeddings. + """ + + def __init__(self, hidden_channels: int): + super(EquivariantScalar, self).__init__() + self.output_network = nn.LayerList( + [ + GatedEquivariantBlock( + hidden_channels, + hidden_channels // 2, + scalar_activation=True, + ), + GatedEquivariantBlock( + hidden_channels // 2, + 1, + scalar_activation=False, + ), + ] + ) + + def pre_reduce(self, x: Tensor, v: Tensor) -> Tensor: + r"""Computes the final scalar outputs. + + Args: + x (paddle.Tensor): The scalar features of the nodes. + v (paddle.Tensor): The vector features of the nodes. + + Returns: + out (paddle.Tensor): The final scalar outputs of the nodes. + """ + for layer in self.output_network: + x, v = layer(x, v) + + return x + v.sum() * 0 diff --git a/pahelix/networks/visnet_utils.py b/pahelix/networks/visnet_utils.py new file mode 100755 index 00000000..e69a2216 --- /dev/null +++ b/pahelix/networks/visnet_utils.py @@ -0,0 +1,587 @@ +import math +from typing import Optional, Tuple + +import numpy as np +import paddle +import paddle.nn.functional as F +import pgl +from paddle import Tensor, nn + + +def radius(x, y, r, max_num_neighbors): + assert x.dim() == 2 and y.dim() == 2, "Input must be 2-D tensor" + assert x.shape[1] == y.shape[1], "Input dimensions must match" + + row = paddle.full((y.shape[0] * max_num_neighbors,), -1, dtype="int64") + col = paddle.full((y.shape[0] * max_num_neighbors,), -1, dtype="int64") + + count = 0 + for n_y in range(y.shape[0]): + for n_x in range(x.shape[0]): + dist = paddle.sum((x[n_x, :] - y[n_y, :]) ** 2) + if dist < r: + row[n_y * max_num_neighbors + count] = n_y + col[n_y * max_num_neighbors + count] = n_x + count += 1 + if count >= max_num_neighbors: + break + + mask = row != -1 + return paddle.stack([row[mask], col[mask]], axis=0) + + +def batch_radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32): + assert x.dim() == 2, "Input must be 2-D tensor" + + if batch is None: + batch = paddle.zeros([x.shape[0]], dtype="int64") + + unique_batches = paddle.unique(batch) + num_batches = unique_batches.shape[0] + + all_edges = [] + + for batch_id in range(num_batches): + batch_mask = batch == batch_id + batch_x = x[batch_mask] + + edge_index = radius( + batch_x, batch_x, r, max_num_neighbors if loop else max_num_neighbors + 1 + ) + if not loop: + mask = edge_index[0, :] != edge_index[1, :] + edge_index = edge_index[:, mask] + + # Adjust the node indices for the current batch + edge_index += batch_id * batch_x.shape[0] + + all_edges.append(edge_index) + + return paddle.concat(all_edges, axis=1) + + +class CosineCutoff(nn.Layer): + r"""Appies a cosine cutoff to the input distances. + + .. math:: + \text{cutoffs} = + \begin{cases} + 0.5 * (\cos(\frac{\text{distances} * \pi}{\text{cutoff}}) + 1.0), + & \text{if } \text{distances} < \text{cutoff} \\ + 0, & \text{otherwise} + \end{cases} + + Args: + cutoff (float): A scalar that determines the point + at which the cutoff is applied. + """ + + def __init__(self, cutoff: float): + super(CosineCutoff, self).__init__() + self.cutoff = cutoff + + def forward(self, distances: Tensor): + r"""Applies a cosine cutoff to the input distances. + Args: + distances (paddle.Tensor): A tensor of distances. + Returns: + cutoffs (paddle.Tensor): A tensor where the cosine function + has been applied to the distances, + but any values that exceed the cutoff are set to 0. + """ + cutoffs = 0.5 * (paddle.cos(distances * math.pi / self.cutoff) + 1.0) + cutoffs = cutoffs * (distances < self.cutoff).astype("float32") + return cutoffs + + +class ExpNormalSmearing(nn.Layer): + r"""Applies exponential normal smearing to the input distances. + + .. math:: + \text{smeared\_dist} = \text{CosineCutoff}(\text{dist}) + * e^{-\beta * (e^{\alpha * (-\text{dist})} - \text{means})^2} + + Args: + cutoff (float): A scalar that determines the point + at which the cutoff is applied. + num_rbf (int): The number of radial basis functions. + trainable (bool): If True, the means and betas of the RBFs + are trainable parameters. + """ + + def __init__(self, cutoff: float = 5.0, num_rbf: int = 128, trainable: bool = True): + super(ExpNormalSmearing, self).__init__() + self.cutoff = cutoff + self.num_rbf = num_rbf + self.trainable = trainable + + self.cutoff_fn = CosineCutoff(cutoff) + self.alpha = 5.0 / cutoff + + means, betas = self._initial_params() + if trainable: + self.add_parameter( + "means", + self.create_parameter( + shape=means.shape, default_initializer=nn.initializer.Assign(means) + ), + ) + self.add_parameter( + "betas", + self.create_parameter( + shape=means.shape, default_initializer=nn.initializer.Assign(betas) + ), + ) + else: + self.register_buffer("means", means) + self.register_buffer("betas", betas) + + def _initial_params(self) -> Tuple[Tensor, Tensor]: + r"""Initializes the means and betas + for the radial basis functions. + Returns: + means, betas (Tuple[paddle.Tensor, paddle.Tensor]): The + initialized means and betas. + """ + start_value = paddle.exp(paddle.to_tensor(-self.cutoff)) + means = paddle.linspace(start_value, 1, self.num_rbf) + betas = paddle.to_tensor( + [(2 / self.num_rbf * (1 - start_value)) ** -2] * self.num_rbf + ) + return means, betas + + def forward(self, dist: Tensor) -> Tensor: + r"""Applies the exponential normal smearing + to the input distance. + + Args: + dist (paddle.Tensor): A tensor of distances. + + Returns: + smeared_dist (paddle.Tensor): The smeared distances. + """ + dist = dist.unsqueeze(-1) + smeared_dist = self.cutoff_fn(dist) * paddle.exp( + -self.betas * (paddle.exp(self.alpha * (-dist)) - self.means) ** 2 + ) + return smeared_dist + + +class Sphere(nn.Layer): + r"""Computes spherical harmonics of the input data. + + This module computes the spherical harmonics up + to a given degree `lmax` for the input tensor of 3D vectors. + The vectors are assumed to be given in Cartesian coordinates. + See `Wikipedia + `_ + for mathematical details. + + Args: + lmax (int): The maximum degree of the spherical harmonics. + """ + + def __init__(self, lmax: int = 2): + super(Sphere, self).__init__() + self.lmax = lmax + + def forward(self, edge_vec: Tensor) -> Tensor: + r"""Computes the spherical harmonics of the input tensor. + + Args: + edge_vec (paddle.Tensor): A tensor of 3D vectors. + + Returns: + edge_sh (paddle.Tensor): The spherical harmonics + of the input tensor. + """ + edge_sh = self._spherical_harmonics( + self.lmax, edge_vec[..., 0], edge_vec[..., 1], edge_vec[..., 2] + ) + return edge_sh + + @staticmethod + def _spherical_harmonics(lmax: int, x: Tensor, y: Tensor, z: Tensor) -> Tensor: + r"""Computes the spherical harmonics + up to degree `lmax` of the input vectors. + + Args: + lmax (int): The maximum degree of the spherical harmonics. + x (paddle.Tensor): The x coordinates of the vectors. + y (paddle.Tensor): The y coordinates of the vectors. + z (paddle.Tensor): The z coordinates of the vectors. + + Returns: + sh (paddle.Tensor): The spherical harmonics of the input vectors. + """ + + sh_1_0, sh_1_1, sh_1_2 = x, y, z + + if lmax == 1: + return paddle.stack([sh_1_0, sh_1_1, sh_1_2], axis=-1) + + sh_2_0 = math.sqrt(3.0) * x * z + sh_2_1 = math.sqrt(3.0) * x * y + y2 = y.pow(2) + x2z2 = x.pow(2) + z.pow(2) + sh_2_2 = y2 - 0.5 * x2z2 + sh_2_3 = math.sqrt(3.0) * y * z + sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2)) + + if lmax == 2: + return paddle.stack( + [sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4], + axis=-1, + ) + + +class VecLayerNorm(nn.Layer): + r"""Applies layer normalization to the input data. + + This module applies a custom layer normalization to a tensor of vectors. + The normalization can either be "max_min" normalization, + or no normalization. + + Args: + hidden_channels (int): The number of hidden channels in the input. + trainable (bool): If True, the normalization weights + are trainable parameters. + norm_type (str): The type of normalization to apply. + Can be "max_min" or "none". + """ + + def __init__( + self, hidden_channels: int, trainable: bool, norm_type: str = "max_min" + ): + super(VecLayerNorm, self).__init__() + + self.hidden_channels = hidden_channels + self.eps = 1e-12 + + weight = paddle.ones(self.hidden_channels) + if trainable: + self.add_parameter( + "weight", + self.create_parameter( + shape=weight.shape, + default_initializer=nn.initializer.Assign(weight), + ), + ) + else: + self.register_buffer("weight", weight) + + if norm_type == "max_min": + self.norm = self.max_min_norm + else: + self.norm = self.none_norm + + def none_norm(self, vec: Tensor) -> Tensor: + r"""Applies no normalization to the input tensor. + + Args: + vec (paddle.Tensor): The input tensor. + + Returns: + vec (paddle.Tensor): The same input tensor. + """ + return vec + + def max_min_norm(self, vec: Tensor) -> Tensor: + r"""Applies max-min normalization to the input tensor. + + .. math:: + \text{dist} = ||\text{vec}||_2 + \text{direct} = \frac{\text{vec}}{\text{dist}} + \text{max\_val} = \max(\text{dist}) + \text{min\_val} = \min(\text{dist}) + \text{delta} = \text{max\_val} - \text{min\_val} + \text{dist} = \frac{\text{dist} - \text{min\_val}}{\text{delta}} + \text{normed\_vec} = \max(0, \text{dist}) \cdot \text{direct} + + Args: + vec (paddle.Tensor): The input tensor. + + Returns: + normed_vec (paddle.Tensor): The normalized tensor. + """ + dist = paddle.norm(vec, axis=1, keepdim=True) + + if (dist == 0).all(): + return paddle.zeros_like(vec) + + dist = paddle.clip(dist, min=self.eps) + direct = vec / dist + + max_val, _ = paddle.max(dist, axis=-1) + min_val, _ = paddle.min(dist, axis=-1) + delta = (max_val - min_val).view(-1) + delta = paddle.where(delta == 0, paddle.ones_like(delta), delta) + dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1) + + return F.relu(dist) * direct + + def forward(self, vec: Tensor) -> Tensor: + r"""Applies the layer normalization to the input tensor. + + Args: + vec (paddle.Tensor): The input tensor. + + Returns: + normed_vec (paddle.Tensor): The normalized tensor. + """ + if vec.shape[1] == 3: + vec = self.norm(vec) + return vec * self.weight.unsqueeze(0).unsqueeze(0) + elif vec.shape[1] == 8: + vec1, vec2 = paddle.split(vec, [3, 5], axis=1) + vec1 = self.norm(vec1) + vec2 = self.norm(vec2) + vec = paddle.concat([vec1, vec2], axis=1) + return vec * self.weight.unsqueeze(0).unsqueeze(0) + else: + raise ValueError("VecLayerNorm only support 3 or 8 channels") + + +class Distance(nn.Layer): + r"""Computes the pairwise distances between atoms in a molecule. + + This module computes the pairwise distances between atoms in a molecule, + represented by their positions `pos`. + The distances are computed only between points + that are within a certain cutoff radius. + + Args: + cutoff (float): The cutoff radius beyond + which distances are not computed. + max_num_neighbors (int): The maximum number of neighbors + considered for each point. + loop (bool): Whether self-loops are included. + """ + + def __init__(self, cutoff: float, max_num_neighbors: int = 32, loop: bool = True): + super(Distance, self).__init__() + self.cutoff = cutoff + self.max_num_neighbors = max_num_neighbors + self.loop = loop + + def forward(self, pos: Tensor, batch: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + r"""Computes the pairwise distances between atoms in the molecule. + + Args: + pos (paddle.Tensor): The positions of the atoms + in the molecule. + batch (paddle.Tensor): A batch vector, + which assigns each node to a specific example. + + Returns: + edge_index (paddle.Tensor): The indices of the edges + in the graph. + edge_weight (paddle.Tensor): The distances + between connected nodes. + edge_vec (paddle.Tensor): The vector differences + between connected nodes. + """ + edge_index = batch_radius_graph( + pos, + r=self.cutoff, + batch=batch, + loop=self.loop, + max_num_neighbors=self.max_num_neighbors, + ) + edge_vec = pos[edge_index[0]] - pos[edge_index[1]] + + if self.loop: + mask = edge_index[0] != edge_index[1] + edge_weight = paddle.zeros(edge_vec.shape[0]) + edge_weight[mask] = paddle.norm(edge_vec[mask], axis=-1) + else: + edge_weight = paddle.norm(edge_vec, axis=-1) + + return edge_index, edge_weight, edge_vec + + +class NeighborEmbedding(nn.Layer): + r"""The `NeighborEmbedding` module from the + `"Enhancing geometric representations for molecules + with equivariant vector-scalar interactive message passing" + `_ paper. + + Args: + hidden_channels (int): The number of hidden channels + in the node embeddings. + num_rbf (int): The number of radial basis functions. + cutoff (float): The cutoff distance. + max_z (int): The maximum atomic numbers. + """ + + def __init__( + self, hidden_channels: int, num_rbf: int, cutoff: float, max_z: int = 100 + ): + super(NeighborEmbedding, self).__init__() + self.embedding = nn.Embedding(max_z, hidden_channels) + self.distance_proj = nn.Linear(num_rbf, hidden_channels) + self.combine = nn.Linear(hidden_channels * 2, hidden_channels) + self.cutoff = CosineCutoff(cutoff) + + def forward( + self, + z: Tensor, + x: Tensor, + edge_index: Tensor, + edge_weight: Tensor, + edge_attr: Tensor, + ) -> Tensor: + r"""Computes the neighborhood embedding of the nodes in the graph. + + Args: + z (paddle.Tensor): The atomic numbers. + x (paddle.Tensor): The node features. + edge_index (paddle.Tensor): The indices of the edges. + edge_weight (paddle.Tensor): The weights of the edges. + edge_attr (paddle.Tensor): The edge features. + + Returns: + x_neighbors (paddle.Tensor): The neighborhood embeddings + of the nodes. + """ + mask = edge_index[0] != edge_index[1] + + if not mask.all(): + edge_index = edge_index.T[mask].T + edge_weight = edge_weight[mask] + edge_attr = edge_attr[mask] + + C = self.cutoff(edge_weight) + W = self.distance_proj(edge_attr) * C.unsqueeze(-1) + + x_neighbors = self.embedding(z) + + graph = pgl.Graph( + edges=edge_index.T, + node_feat={ + "x": x, + "z": z, + }, + edge_feat={ + "W": W, + }, + ) + + def _send_func(src_feat, dst_feat, edge_feat): + x_j = src_feat["x"] + W = edge_feat["W"] + + return {"x": x_j * W} + + def _recv_func(msg: pgl.Message): + x = msg["x"] + return msg.reduce(x, pool_type="sum") + + msg = graph.send( + message_func=_send_func, + node_feat={ + "x": x_neighbors, + }, + edge_feat={ + "W": W, + }, + ) + + x_neighbors = graph.recv(reduce_func=_recv_func, msg=msg) + x_neighbors = self.combine(paddle.concat([x, x_neighbors], axis=-1)) + + return x_neighbors + + +class EdgeEmbedding(nn.Layer): + r"""The `EdgeEmbedding` module + from the `"Enhancing geometric representations for molecules + with equivariant vector-scalar interactive message passing" + `_ paper. + + Args: + num_rbf (int): + The number of radial basis functions. + hidden_channels (int): + The number of hidden channels in the node embeddings. + """ + + def __init__(self, num_rbf: int, hidden_channels: int): + super(EdgeEmbedding, self).__init__() + self.edge_proj = nn.Linear(num_rbf, hidden_channels) + + def forward(self, edge_index: Tensor, edge_attr: Tensor, x: Tensor) -> Tensor: + r"""Computes the edge embeddings of the graph. + + Args: + edge_index (paddle.Tensor): The indices of the edges. + edge_attr (paddle.Tensor): The edge features. + x (paddle.Tensor): The node features. + + Returns: + out_edge_attr (paddle.Tensor): The edge embeddings. + """ + edges = edge_index.T + + graph = pgl.Graph( + edges=edges, + node_feat={ + "x": x, + }, + edge_feat={ + "edge_attr": edge_attr, + }, + ) + + def _send_func(src_feat, dst_feat, edge_feat): + edge_attr = edge_feat["edge_attr"] + x_i, x_j = src_feat["x"], dst_feat["x"] + edge_attr = (x_i + x_j) * self.edge_proj(edge_attr) + return {"edge_attr": edge_attr} + + msg = graph.send( + message_func=_send_func, + node_feat={ + "x": x, + }, + edge_feat={ + "edge_attr": edge_attr, + }, + ) + + return msg["edge_attr"] + + +class Atomref(nn.Layer): + r"""Adds atom reference values to atomic energies. + + Args: + atomref (paddle.Tensor, optional): A tensor of atom reference values, + or None if not provided. + max_z (int): The maximum atomic numbers. + """ + + def __init__(self, atomref: Optional[Tensor] = None, max_z: int = 100): + super(Atomref, self).__init__() + if atomref is None: + atomref = paddle.zeros((max_z, 1)) + else: + atomref = paddle.to_tensor(atomref) + + if atomref.ndim == 1: + atomref = atomref.reshape((-1, 1)) + self.register_buffer("initial_atomref", atomref) + self.atomref = nn.Embedding(len(atomref), 1) + paddle.assign(self.initial_atomref, self.atomref.weight) + + def forward(self, x: Tensor, z: Tensor) -> Tensor: + r"""Adds atom reference values to atomic energies. + + Args: + x (paddle.Tensor): The atomic energies. + z (paddle.Tensor): The atomic numbers. + + Returns: + x (paddle.Tensor): The updated atomic energies. + """ + return x + self.atomref(z)