diff --git a/docs/snntorch.surrogate.rst b/docs/snntorch.surrogate.rst index d7af5b78..ccb47502 100644 --- a/docs/snntorch.surrogate.rst +++ b/docs/snntorch.surrogate.rst @@ -7,7 +7,7 @@ The discrete nature of spikes makes it difficult for ``torch.autograd`` to calcu Alternative gradients are also available in the :mod:`snntorch.surrogate` module. These represent either approximations of the backward pass or probabilistic models of firing as a function of the membrane potential. - +Custom, user-defined surrogate gradients can also be implemented. At present, the surrogate gradient functions available include: @@ -17,6 +17,7 @@ At present, the surrogate gradient functions available include: * `Straight Through Estimator `_ * `Triangular `_ * `SpikeRateEscape `_ +* `Custom Surrogate Gradients `_ amongst several other options. @@ -66,7 +67,53 @@ Example:: net = Net().to(device) +Custom Surrogate Gradients +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For flexibility, custom surrogate gradients can also be defined by the user +using `custom_surrogate`. + + +Example:: + + import snntorch as snn + from snntorch import surrogate + import torch + import torch.nn as nn + + beta = 0.9 + + # Define custom surrogate gradient + def custom_fast_sigmoid(input_, grad_input, spikes): + ## The hyperparameter slope is defined inside the function. + slope = 25 + grad = grad_input / (slope * torch.abs(input_) + 1.0) ** 2 + return grad + + spike_grad = surrogate.custom_surrogate(custom_fast_sigmoid) + + # Define Network + class Net(nn.Module): + def __init__(self): + super().__init__() + + # Initialize layers + self.fc1 = nn.Linear(num_inputs, num_hidden) + self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad) + self.fc2 = nn.Linear(num_hidden, num_outputs) + self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad) + + def forward(self, x, mem1, spk1, mem2): + cur1 = self.fc1(x) + spk1, mem1 = self.lif1(cur1, mem1) + cur2 = self.fc2(spk1) + spk2, mem2 = self.lif2(cur2, mem2) + + return mem1, spk1, mem2, spk2 + + net = Net().to(device) + .. automodule:: snntorch.surrogate :members: :undoc-members: diff --git a/snntorch/surrogate.py b/snntorch/surrogate.py index d6188bac..8edb1512 100644 --- a/snntorch/surrogate.py +++ b/snntorch/surrogate.py @@ -580,6 +580,93 @@ def inner(x): return inner +class CustomSurrogate(torch.autograd.Function): + """ + Surrogate gradient of the Heaviside step function. + + **Forward pass:** Spike operator function. + + .. math:: + + S=\\begin{cases} \\frac{U(t)}{U} & \\text{if U ≥ U$_{\\rm thr}$} + \\\\ + 0 & \\text{if U < U$_{\\rm thr}$} + \\end{cases} + + **Backward pass:** User-defined custom surrogate gradient function. + + The user defines the custom surrogate gradient in a separate function. + It is passed in the forward static method and used in the backward + static method. + + The arguments of the custom surrogate gradient function are always + the input of the forward pass (input_), the gradient of the input + (grad_input) and the output of the forward pass (out). + + ** Important Note: The hyperparameters of the custom surrogate gradient + function have to be defined inside of the function itself. ** + + Example:: + + import torch + import torch.nn as nn + import snntorch as snn + from snntorch import surrogate + + def custom_fast_sigmoid(input_, grad_input, spikes): + ## The hyperparameter slope is defined inside the function. + slope = 25 + grad = grad_input / (slope * torch.abs(input_) + 1.0) ** 2 + return grad + + spike_grad = surrogate.custom_surrogate(custom_fast_sigmoid) + + net_seq = nn.Sequential(nn.Conv2d(1, 12, 5), + nn.MaxPool2d(2), + snn.Leaky(beta=beta, + spike_grad=spike_grad, + init_hidden=True), + nn.Conv2d(12, 64, 5), + nn.MaxPool2d(2), + snn.Leaky(beta=beta, + spike_grad=spike_grad, + init_hidden=True), + nn.Flatten(), + nn.Linear(64*4*4, 10), + snn.Leaky(beta=beta, + spike_grad=spike_grad, + init_hidden=True, + output=True) + ).to(device) + + """ + @staticmethod + def forward(ctx, input_, custom_surrogate_function): + out = (input_ > 0).float() + ctx.save_for_backward(input_, out) + ctx.custom_surrogate_function = custom_surrogate_function + return out + + @staticmethod + def backward(ctx, grad_output): + input_, out = ctx.saved_tensors + custom_surrogate_function = ctx.custom_surrogate_function + + grad_input = grad_output.clone() + grad = custom_surrogate_function(input_, grad_input, out) + return grad, None + + +def custom_surrogate(custom_surrogate_function): + """Custom surrogate gradient enclosed within a wrapper.""" + func = custom_surrogate_function + + def inner(data): + return CustomSurrogate.apply(data, func) + + return inner + + # class InverseSpikeOperator(torch.autograd.Function): # """ # Surrogate gradient of the Heaviside step function.