Skip to content

Commit

Permalink
add custom surrogate gradient section to documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jeshraghian committed Sep 23, 2023
1 parent 239fb08 commit 7bea7c9
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion docs/snntorch.surrogate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The discrete nature of spikes makes it difficult for ``torch.autograd`` to calcu

Alternative gradients are also available in the :mod:`snntorch.surrogate` module.
These represent either approximations of the backward pass or probabilistic models of firing as a function of the membrane potential.

Custom, user-defined surrogate gradients can also be implemented.

At present, the surrogate gradient functions available include:

Expand Down Expand Up @@ -67,7 +67,53 @@ Example::

net = Net().to(device)

Custom Surrogate Gradients
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For flexibility, custom surrogate gradients can also be defined by the user
using `custom_surrogate`.


Example::

import snntorch as snn
from snntorch import surrogate
import torch
import torch.nn as nn

beta = 0.9

# Define custom surrogate gradient
def custom_fast_sigmoid(input_, grad_input, spikes):
## The hyperparameter slope is defined inside the function.
slope = 25
grad = grad_input / (slope * torch.abs(input_) + 1.0) ** 2
return grad

spike_grad = surrogate.custom_surrogate(custom_fast_sigmoid)

# Define Network
class Net(nn.Module):
def __init__(self):
super().__init__()

# Initialize layers
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
self.fc2 = nn.Linear(num_hidden, num_outputs)
self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)

def forward(self, x, mem1, spk1, mem2):
cur1 = self.fc1(x)
spk1, mem1 = self.lif1(cur1, mem1)
cur2 = self.fc2(spk1)
spk2, mem2 = self.lif2(cur2, mem2)

return mem1, spk1, mem2, spk2

net = Net().to(device)

.. automodule:: snntorch.surrogate
:members:
:undoc-members:
Expand Down

0 comments on commit 7bea7c9

Please sign in to comment.