Skip to content

Commit

Permalink
Merge pull request #291 from visdauas/atan-fix
Browse files Browse the repository at this point in the history
Remove default values in autograd functions that break torch.compile
  • Loading branch information
jeshraghian authored Feb 11, 2024
2 parents 70b1c65 + 98d4099 commit e9080bd
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions snntorch/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Triangular(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, threshold=1):
def forward(ctx, input_, threshold):
ctx.save_for_backward(input_)
ctx.threshold = threshold
out = (input_ > 0).float()
Expand All @@ -90,12 +90,13 @@ def backward(ctx, grad_output):
return grad, None


def triangular():
def triangular(threshold=1):
"""Triangular surrogate gradient enclosed with
a parameterized threshold."""
threshold = threshold

def inner(x):
return Triangular.apply(x)
return Triangular.apply(x, threshold)

return inner

Expand Down Expand Up @@ -128,7 +129,7 @@ class FastSigmoid(torch.autograd.Function):
Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.*"""

@staticmethod
def forward(ctx, input_, slope=25):
def forward(ctx, input_, slope):
ctx.save_for_backward(input_)
ctx.slope = slope
out = (input_ > 0).float()
Expand Down Expand Up @@ -183,7 +184,7 @@ class ATan(torch.autograd.Function):
Int. Conf. Computer Vision (ICCV), pp. 2661-2671.*"""

@staticmethod
def forward(ctx, input_, alpha=2.0):
def forward(ctx, input_, alpha):
ctx.save_for_backward(input_)
ctx.alpha = alpha
out = (input_ > 0).float()
Expand Down Expand Up @@ -290,7 +291,7 @@ class Sigmoid(torch.autograd.Function):
Neural Networks. Neural Computation, pp. 1514-1541.*"""

@staticmethod
def forward(ctx, input_, slope=25):
def forward(ctx, input_, slope):
ctx.save_for_backward(input_)
ctx.slope = slope
out = (input_ > 0).float()
Expand Down Expand Up @@ -350,7 +351,7 @@ class SpikeRateEscape(torch.autograd.Function):
Cambridge University Press, 2002.*"""

@staticmethod
def forward(ctx, input_, beta=1, slope=25):
def forward(ctx, input_, beta, slope):
ctx.save_for_backward(input_)
ctx.beta = beta
ctx.slope = slope
Expand All @@ -375,7 +376,7 @@ def spike_rate_escape(beta=1, slope=25):
slope = slope

def inner(x):
return SpikeRateEscape.apply(x, slope)
return SpikeRateEscape.apply(x, beta, slope)

return inner

Expand Down Expand Up @@ -428,7 +429,7 @@ class StochasticSpikeOperator(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, mean=0, variance=0.2):
def forward(ctx, input_, mean, variance):
out = (input_ > 0).float()
ctx.save_for_backward(input_, out)
ctx.mean = mean
Expand Down Expand Up @@ -490,7 +491,7 @@ class LeakySpikeOperator(torch.autograd.Function):
The gradient is identical to that of a threshold-shifted Leaky ReLU."""

@staticmethod
def forward(ctx, input_, slope=0.1):
def forward(ctx, input_, slope):
out = (input_ > 0).float()
ctx.save_for_backward(out)
ctx.slope = slope
Expand Down Expand Up @@ -549,7 +550,7 @@ class SparseFastSigmoid(torch.autograd.Function):
Gradient Descent. https://arxiv.org/pdf/2105.08810.pdf.*"""

@staticmethod
def forward(ctx, input_, slope=25, B=1):
def forward(ctx, input_, slope, B):
ctx.save_for_backward(input_)
ctx.slope = slope
ctx.B = B
Expand Down

0 comments on commit e9080bd

Please sign in to comment.