diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index f34aa341..8309a136 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -388,6 +388,36 @@ def forward(self, dist): ) +class GLU(nn.Module): + r"""Applies the gated linear unit (GLU) function: + + .. math:: + + \text{GLU}(x) = \text{Linear}_1(x) \otimes \sigma(\text{Linear}_2(x)) + + + where :math:`\otimes` is the element-wise multiplication operator and + :math:`\sigma` is an activation function. + + Args: + in_channels (int): Number of input features. + hidden_channels (int, optional): Number of hidden features. Defaults to None, meaning hidden_channels=in_channels. + activation (nn.Module, optional): Activation function to use. Defaults to Sigmoid. + """ + + def __init__( + self, in_channels, hidden_channels=None, activation: Optional[nn.Module] = None + ): + super(GLU, self).__init__() + self.act = nn.Sigmoid() if activation is None else activation + hidden_channels = hidden_channels or in_channels + self.W = nn.Linear(in_channels, hidden_channels) + self.V = nn.Linear(in_channels, hidden_channels) + + def forward(self, x): + return self.W(x) * self.act(self.V(x)) + + class ShiftedSoftplus(nn.Module): r"""Applies the ShiftedSoftplus function :math:`\text{ShiftedSoftplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))-\log(2)` element-wise. @@ -404,6 +434,50 @@ def forward(self, x): return F.softplus(x) - self.shift +class Swish(nn.Module): + """Swish activation function as defined in https://arxiv.org/pdf/1710.05941 : + + .. math:: + + \text{Swish}(x) = x \cdot \sigma(\beta x) + + Args: + beta (float, optional): Scaling factor for Swish activation. Defaults to 1. + + """ + + def __init__(self, beta=1.0): + super(Swish, self).__init__() + self.beta = beta + + def forward(self, x): + return x * torch.sigmoid(self.beta * x) + + +class SwiGLU(nn.Module): + """SwiGLU activation function as defined in https://arxiv.org/pdf/2002.05202 : + + .. math:: + + \text{SwiGLU}(x) = \text{Linear}_1(x) \otimes \text{Swish}(\text{Linear}_2(x)) + + W1, V have shape (in_features, hidden_features) + Args: + in_features (int): Number of input features. + hidden_features (int, optional): Number of hidden features. Defaults to None, meaning hidden_features=in_features. + beta (float, optional): Scaling factor for Swish activation. Defaults to 1.0. + """ + + def __init__(self, in_features, hidden_features=None, beta=1.0): + super().__init__() + hidden_features = hidden_features or in_features + act = Swish(beta) + self.glu = GLU(in_features, hidden_features, activation=act) + + def forward(self, x): + return self.glu(x) + + class CosineCutoff(nn.Module): def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0): super(CosineCutoff, self).__init__() @@ -615,6 +689,8 @@ def scatter( "silu": nn.SiLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, + "swish": Swish, + "mish": nn.Mish, } dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64}