Skip to content

Commit

Permalink
Merge pull request #237 from mehranfaraji/custom_surrogate_function
Browse files Browse the repository at this point in the history
Custom Surrogate Gradient Function
  • Loading branch information
jeshraghian authored Sep 23, 2023
2 parents d8c11aa + 7bea7c9 commit b4a1d10
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 1 deletion.
49 changes: 48 additions & 1 deletion docs/snntorch.surrogate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The discrete nature of spikes makes it difficult for ``torch.autograd`` to calcu

Alternative gradients are also available in the :mod:`snntorch.surrogate` module.
These represent either approximations of the backward pass or probabilistic models of firing as a function of the membrane potential.

Custom, user-defined surrogate gradients can also be implemented.

At present, the surrogate gradient functions available include:

Expand All @@ -17,6 +17,7 @@ At present, the surrogate gradient functions available include:
* `Straight Through Estimator <https://snntorch.readthedocs.io/en/latest/snntorch.surrogate.html#snntorch.surrogate.StraightThroughEstimator>`_
* `Triangular <https://snntorch.readthedocs.io/en/latest/snntorch.surrogate.html#snntorch.surrogate.Triangular>`_
* `SpikeRateEscape <https://snntorch.readthedocs.io/en/latest/snntorch.surrogate.html#snntorch.surrogate.SpikeRateEscape>`_
* `Custom Surrogate Gradients <https://snntorch.readthedocs.io/en/latest/snntorch.surrogate.html#snntorch.surrogate.CustomSurrogate>`_

amongst several other options.

Expand Down Expand Up @@ -66,7 +67,53 @@ Example::

net = Net().to(device)

Custom Surrogate Gradients
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For flexibility, custom surrogate gradients can also be defined by the user
using `custom_surrogate`.


Example::

import snntorch as snn
from snntorch import surrogate
import torch
import torch.nn as nn

beta = 0.9

# Define custom surrogate gradient
def custom_fast_sigmoid(input_, grad_input, spikes):
## The hyperparameter slope is defined inside the function.
slope = 25
grad = grad_input / (slope * torch.abs(input_) + 1.0) ** 2
return grad

spike_grad = surrogate.custom_surrogate(custom_fast_sigmoid)

# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()

# Initialize layers
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
self.fc2 = nn.Linear(num_hidden, num_outputs)
self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)

def forward(self, x, mem1, spk1, mem2):
cur1 = self.fc1(x)
spk1, mem1 = self.lif1(cur1, mem1)
cur2 = self.fc2(spk1)
spk2, mem2 = self.lif2(cur2, mem2)

return mem1, spk1, mem2, spk2

net = Net().to(device)

.. automodule:: snntorch.surrogate
:members:
:undoc-members:
Expand Down
87 changes: 87 additions & 0 deletions snntorch/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,93 @@ def inner(x):
return inner


class CustomSurrogate(torch.autograd.Function):
"""
Surrogate gradient of the Heaviside step function.
**Forward pass:** Spike operator function.
.. math::
S=\\begin{cases} \\frac{U(t)}{U} & \\text{if U ≥ U$_{\\rm thr}$}
\\\\
0 & \\text{if U < U$_{\\rm thr}$}
\\end{cases}
**Backward pass:** User-defined custom surrogate gradient function.
The user defines the custom surrogate gradient in a separate function.
It is passed in the forward static method and used in the backward
static method.
The arguments of the custom surrogate gradient function are always
the input of the forward pass (input_), the gradient of the input
(grad_input) and the output of the forward pass (out).
** Important Note: The hyperparameters of the custom surrogate gradient
function have to be defined inside of the function itself. **
Example::
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
def custom_fast_sigmoid(input_, grad_input, spikes):
## The hyperparameter slope is defined inside the function.
slope = 25
grad = grad_input / (slope * torch.abs(input_) + 1.0) ** 2
return grad
spike_grad = surrogate.custom_surrogate(custom_fast_sigmoid)
net_seq = nn.Sequential(nn.Conv2d(1, 12, 5),
nn.MaxPool2d(2),
snn.Leaky(beta=beta,
spike_grad=spike_grad,
init_hidden=True),
nn.Conv2d(12, 64, 5),
nn.MaxPool2d(2),
snn.Leaky(beta=beta,
spike_grad=spike_grad,
init_hidden=True),
nn.Flatten(),
nn.Linear(64*4*4, 10),
snn.Leaky(beta=beta,
spike_grad=spike_grad,
init_hidden=True,
output=True)
).to(device)
"""
@staticmethod
def forward(ctx, input_, custom_surrogate_function):
out = (input_ > 0).float()
ctx.save_for_backward(input_, out)
ctx.custom_surrogate_function = custom_surrogate_function
return out

@staticmethod
def backward(ctx, grad_output):
input_, out = ctx.saved_tensors
custom_surrogate_function = ctx.custom_surrogate_function

grad_input = grad_output.clone()
grad = custom_surrogate_function(input_, grad_input, out)
return grad, None


def custom_surrogate(custom_surrogate_function):
"""Custom surrogate gradient enclosed within a wrapper."""
func = custom_surrogate_function

def inner(data):
return CustomSurrogate.apply(data, func)

return inner


# class InverseSpikeOperator(torch.autograd.Function):
# """
# Surrogate gradient of the Heaviside step function.
Expand Down

0 comments on commit b4a1d10

Please sign in to comment.