Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix activation lookup with Python 3.12.3 #375

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 10 additions & 41 deletions curated_transformers/layers/activations.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,13 @@
import math
from enum import Enum, EnumMeta
from enum import Enum
from typing import Type

import torch
from torch import Tensor
from torch.nn import Module


class _ActivationMeta(EnumMeta):
"""
``Enum`` metaclass to override the class ``__call__`` method with a more
fine-grained exception for unknown activation functions.
"""

def __call__(
cls,
value,
names=None,
*,
module=None,
qualname=None,
type=None,
start=1,
):
# Wrap superclass __call__ to give a nicer error message when
# an unknown activation is used.
if names is None:
try:
return EnumMeta.__call__(
cls,
value,
names,
module=module,
qualname=qualname,
type=type,
start=start,
)
except ValueError:
supported_activations = ", ".join(sorted(v.value for v in cls))
raise ValueError(
f"Invalid activation function `{value}`. "
f"Supported functions: {supported_activations}"
)
else:
return EnumMeta.__call__(cls, value, names, module, qualname, type, start)


class Activation(Enum, metaclass=_ActivationMeta):
class Activation(Enum):
"""
Activation functions.
Expand All @@ -71,6 +32,14 @@ class Activation(Enum, metaclass=_ActivationMeta):
#: Sigmoid Linear Unit (`Hendrycks et al., 2016`_).
SiLU = "silu"

@classmethod
def _missing_(cls, value):
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
supported_activations = ", ".join(sorted(v.value for v in cls))
raise ValueError(
f"Invalid activation function `{value}`. "
f"Supported functions: {supported_activations}"
)

@property
def module(self) -> Type[torch.nn.Module]:
"""
Expand Down
Loading