From 98d40999394ce6f947628d6ead38b07103b5908d Mon Sep 17 00:00:00 2001 From: visdauas Date: Fri, 9 Feb 2024 18:18:26 +0100 Subject: [PATCH] Remove default values that break torch.compile --- snntorch/surrogate.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/snntorch/surrogate.py b/snntorch/surrogate.py index 8edb1512..390eb1bd 100644 --- a/snntorch/surrogate.py +++ b/snntorch/surrogate.py @@ -75,7 +75,7 @@ class Triangular(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, threshold=1): + def forward(ctx, input_, threshold): ctx.save_for_backward(input_) ctx.threshold = threshold out = (input_ > 0).float() @@ -90,12 +90,13 @@ def backward(ctx, grad_output): return grad, None -def triangular(): +def triangular(threshold=1): """Triangular surrogate gradient enclosed with a parameterized threshold.""" + threshold = threshold def inner(x): - return Triangular.apply(x) + return Triangular.apply(x, threshold) return inner @@ -128,7 +129,7 @@ class FastSigmoid(torch.autograd.Function): Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.*""" @staticmethod - def forward(ctx, input_, slope=25): + def forward(ctx, input_, slope): ctx.save_for_backward(input_) ctx.slope = slope out = (input_ > 0).float() @@ -183,7 +184,7 @@ class ATan(torch.autograd.Function): Int. Conf. Computer Vision (ICCV), pp. 2661-2671.*""" @staticmethod - def forward(ctx, input_, alpha=2.0): + def forward(ctx, input_, alpha): ctx.save_for_backward(input_) ctx.alpha = alpha out = (input_ > 0).float() @@ -290,7 +291,7 @@ class Sigmoid(torch.autograd.Function): Neural Networks. Neural Computation, pp. 1514-1541.*""" @staticmethod - def forward(ctx, input_, slope=25): + def forward(ctx, input_, slope): ctx.save_for_backward(input_) ctx.slope = slope out = (input_ > 0).float() @@ -350,7 +351,7 @@ class SpikeRateEscape(torch.autograd.Function): Cambridge University Press, 2002.*""" @staticmethod - def forward(ctx, input_, beta=1, slope=25): + def forward(ctx, input_, beta, slope): ctx.save_for_backward(input_) ctx.beta = beta ctx.slope = slope @@ -375,7 +376,7 @@ def spike_rate_escape(beta=1, slope=25): slope = slope def inner(x): - return SpikeRateEscape.apply(x, slope) + return SpikeRateEscape.apply(x, beta, slope) return inner @@ -428,7 +429,7 @@ class StochasticSpikeOperator(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, mean=0, variance=0.2): + def forward(ctx, input_, mean, variance): out = (input_ > 0).float() ctx.save_for_backward(input_, out) ctx.mean = mean @@ -490,7 +491,7 @@ class LeakySpikeOperator(torch.autograd.Function): The gradient is identical to that of a threshold-shifted Leaky ReLU.""" @staticmethod - def forward(ctx, input_, slope=0.1): + def forward(ctx, input_, slope): out = (input_ > 0).float() ctx.save_for_backward(out) ctx.slope = slope @@ -549,7 +550,7 @@ class SparseFastSigmoid(torch.autograd.Function): Gradient Descent. https://arxiv.org/pdf/2105.08810.pdf.*""" @staticmethod - def forward(ctx, input_, slope=25, B=1): + def forward(ctx, input_, slope, B): ctx.save_for_backward(input_) ctx.slope = slope ctx.B = B