From 4f94380815f831605f4641b7193df2eccd5652a3 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 26 Nov 2021 17:53:57 +0300 Subject: [PATCH] Add clip\clamp activation (#518) --- segmentation_models_pytorch/base/modules.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index 4074b059..c3191bc6 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -73,6 +73,15 @@ def forward(self, x): return torch.argmax(x, dim=self.dim) +class Clamp(nn.Module): + def __init__(self, min=0, max=1): + super().__init__() + self.min, self.max = min, max + + def forward(self, x): + return torch.clamp(x, self.min, self.max) + + class Activation(nn.Module): def __init__(self, name, **params): @@ -95,6 +104,8 @@ def __init__(self, name, **params): self.activation = ArgMax(**params) elif name == 'argmax2d': self.activation = ArgMax(dim=1, **params) + elif name == 'clamp': + self.activation = Clamp(**params) elif callable(name): self.activation = name(**params) else: