Skip to content

Commit

Permalink
Merge pull request #334 from RaulPPelaez/swiglu
Browse files Browse the repository at this point in the history
Add GLU, Swish, Mish and SwiGLU
  • Loading branch information
RaulPPelaez authored Aug 16, 2024
2 parents 6dea4b6 + 352da8a commit 26206eb
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions torchmdnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,36 @@ def forward(self, dist):
)


class GLU(nn.Module):
r"""Applies the gated linear unit (GLU) function:
.. math::
\text{GLU}(x) = \text{Linear}_1(x) \otimes \sigma(\text{Linear}_2(x))
where :math:`\otimes` is the element-wise multiplication operator and
:math:`\sigma` is an activation function.
Args:
in_channels (int): Number of input features.
hidden_channels (int, optional): Number of hidden features. Defaults to None, meaning hidden_channels=in_channels.
activation (nn.Module, optional): Activation function to use. Defaults to Sigmoid.
"""

def __init__(
self, in_channels, hidden_channels=None, activation: Optional[nn.Module] = None
):
super(GLU, self).__init__()
self.act = nn.Sigmoid() if activation is None else activation
hidden_channels = hidden_channels or in_channels
self.W = nn.Linear(in_channels, hidden_channels)
self.V = nn.Linear(in_channels, hidden_channels)

def forward(self, x):
return self.W(x) * self.act(self.V(x))


class ShiftedSoftplus(nn.Module):
r"""Applies the ShiftedSoftplus function :math:`\text{ShiftedSoftplus}(x) = \frac{1}{\beta} *
\log(1 + \exp(\beta * x))-\log(2)` element-wise.
Expand All @@ -404,6 +434,50 @@ def forward(self, x):
return F.softplus(x) - self.shift


class Swish(nn.Module):
"""Swish activation function as defined in https://arxiv.org/pdf/1710.05941 :
.. math::
\text{Swish}(x) = x \cdot \sigma(\beta x)
Args:
beta (float, optional): Scaling factor for Swish activation. Defaults to 1.
"""

def __init__(self, beta=1.0):
super(Swish, self).__init__()
self.beta = beta

def forward(self, x):
return x * torch.sigmoid(self.beta * x)


class SwiGLU(nn.Module):
"""SwiGLU activation function as defined in https://arxiv.org/pdf/2002.05202 :
.. math::
\text{SwiGLU}(x) = \text{Linear}_1(x) \otimes \text{Swish}(\text{Linear}_2(x))
W1, V have shape (in_features, hidden_features)
Args:
in_features (int): Number of input features.
hidden_features (int, optional): Number of hidden features. Defaults to None, meaning hidden_features=in_features.
beta (float, optional): Scaling factor for Swish activation. Defaults to 1.0.
"""

def __init__(self, in_features, hidden_features=None, beta=1.0):
super().__init__()
hidden_features = hidden_features or in_features
act = Swish(beta)
self.glu = GLU(in_features, hidden_features, activation=act)

def forward(self, x):
return self.glu(x)


class CosineCutoff(nn.Module):
def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0):
super(CosineCutoff, self).__init__()
Expand Down Expand Up @@ -615,6 +689,8 @@ def scatter(
"silu": nn.SiLU,
"tanh": nn.Tanh,
"sigmoid": nn.Sigmoid,
"swish": Swish,
"mish": nn.Mish,
}

dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64}

0 comments on commit 26206eb

Please sign in to comment.