Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noisy spiking neural nets implementations. #230

Merged
merged 11 commits into from
Aug 15, 2023
Merged
2 changes: 2 additions & 0 deletions snntorch/_neurons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"alpha",
"lapicque",
"leaky",
"noisyleaky",
"rleaky",
"rsynaptic",
"synaptic",
Expand All @@ -24,6 +25,7 @@
from .alpha import Alpha
from .lapicque import Lapicque
from .leaky import Leaky
from .noisyleaky import NoisyLeaky
from .synaptic import Synaptic

from .rleaky import RLeaky
Expand Down
244 changes: 244 additions & 0 deletions snntorch/_neurons/neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
import torch.nn as nn

import math


__all__ = [
"SpikingNeuron",
Expand Down Expand Up @@ -435,7 +437,249 @@ def init_alpha():
mem = _SpikeTensor(init_flag=False)

return syn_exc, syn_inh, mem


class NoisyLIF(SpikingNeuron):
"""Parent class for noisy leaky integrate and fire neuron models."""

def __init__(
self,
beta,
threshold=1.0,
noise_type='gaussian',
noise_scale=0.3,
init_hidden=False,
inhibition=False,
learn_beta=False,
learn_threshold=False,
reset_mechanism="subtract",
state_quant=False,
output=False,
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False,
):
super().__init__(
threshold,
None,
False,
init_hidden,
inhibition,
learn_threshold,
reset_mechanism,
state_quant,
output,
graded_spikes_factor,
learn_graded_spikes_factor,
)

self._lif_register_buffer(
beta,
learn_beta,
)
self._reset_mechanism = reset_mechanism
self._noise_scale = noise_scale

if noise_type == 'gaussian':
self.spike_grad = self.Gaussian.apply
elif noise_type == 'logistic':
self.spike_grad = self.Logistic.apply
elif noise_type == 'triangular':
self.spike_grad = self.Triangular.apply
elif noise_type == 'uniform':
self.spike_grad = self.Uniform.apply
elif noise_type == 'atan':
self.spike_grad = self.Arctangent.apply
else:
raise ValueError("Invalid noise type. Valid options: gaussian, logistic, triangular, \
uniform, atan")

if self.surrogate_disable:
self.spike_grad = self._surrogate_bypass

def _lif_register_buffer(
self,
beta,
learn_beta,
):
"""Set variables as learnable parameters else register them in the
buffer."""
self._beta_buffer(beta, learn_beta)

def _beta_buffer(self, beta, learn_beta):
if not isinstance(beta, torch.Tensor):
beta = torch.as_tensor(beta) # TODO: or .tensor() if no copy
if learn_beta:
self.beta = nn.Parameter(beta)
else:
self.register_buffer("beta", beta)

@staticmethod
def init_noisyleaky():
"""
Used to initialize mem as an empty SpikeTensor.
``init_flag`` is used as an attribute in the forward pass to convert
the hidden states to the same as the input.
"""
mem = _SpikeTensor(init_flag=False)

return mem

def mem_reset(self, mem):
"""Generates detached reset signal if mem > threshold.
Returns reset."""
mem_shift = mem - self.threshold
reset = self.spike_grad(mem_shift, 0, self._noise_scale).clone().detach()

return reset

@staticmethod
class Gaussian(torch.autograd.Function):
r"""
Gaussian noise. This is the original and default type because the iterative form is derived
from an Ito SDE. The noise is drawn from Gaus(mu, sigma**2).
Let us denote the cumulative distribution function of the noise by CDF,
its probability density function as PDF.

**Forward pass:** Probabilistic firing.

.. math::

S &~ \\text{Bernoulli}(P(\\text{spiking})) \\\\
P(\\text{firing}) = CDF$_{\\rm noise}$ (U-\\text{threshold})

**Backward pass:** Noise-driven learning corresponds to the specified membrane noise.

.. math::
\\frac{∂S}{∂U}&= PDF$_{\\rm noise}$ (U-\\text{threshold})

Refer to:

Ma et al. Exploiting Noise as a Resource for Computation and Learning in Spiking Neural
Networks. Patterns. Cell Press. 2023.
"""

@staticmethod
def forward(ctx, input_, mu=0, sigma=0.3):
input_ += -torch.normal(torch.ones_like(input_) * mu, sigma)
ctx.save_for_backward(input_)
ctx.mu = mu
ctx.sigma = sigma
p_spike = 1/2 * (
1 + torch.erf((input_ - mu) / (sigma * math.sqrt(2)))
)
return torch.bernoulli(p_spike)

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()

temp = (
1 / (math.sqrt(2*math.pi) * ctx.sigma)
) * torch.exp(
-0.5 * ((input_ - ctx.mu) / ctx.sigma).pow_(2)
)
return grad_input*temp, None, None

@staticmethod
class Logistic(torch.autograd.Function):
r"""
Logistic neuronal noise. The resulting noise-driven learning covers the sigmoidal surrogate
gradients in training conventional deterministic spiking models. The noise parameter mu
should be zero, and the scale denotes the noise scale.

Refer to:

Ma et al. Exploiting Noise as a Resource for Computation and Learning in Spiking Neural
Networks. Patterns. Cell Press. 2023.
"""
@staticmethod
def forward(ctx, input_, mu=0, scale=0.4):
noise = torch.zeros_like(input_).uniform_(0, 1)
noise = mu + scale * (torch.log((noise+1e-8) / (1-noise+1e-8)))
input_ += noise
ctx.save_for_backward(input_)
ctx.mu = mu
ctx.scale = scale

p_spike = torch.special.expit((input_ - ctx.mu + 1e-8) / (ctx.scale + 1e-8)).nan_to_num_()
return torch.bernoulli(p_spike)

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()

temp = torch.exp(
-(input_ - ctx.mu) / ctx.scale
) / ctx.scale / (1 + torch.exp(-(input_ - ctx.mu) / ctx.scale)).pow_(2)
return grad_input*temp, None, None


@staticmethod
class Triangular(torch.autograd.Function):
r"""
Triangular (symmetric) neuronal noise. The resulting noise-driven learning covers the
triangular surrogate gradients in training conventional deterministic spiking models.
The noise avg (mu) is zero.
"""
@staticmethod
def forward(ctx, input_, mu=0, a=0.3):
fc = 0.5
noise = torch.zeros_like(input_).uniform_(0, 1)
mask = (noise < fc).int()
noise = (-a * mask + (2 * a**2 * mask * noise).sqrt()) + \
((1-mask) * a - (2 * a**2 * (1-mask) * (1 - noise)).sqrt())
input_ += noise

ctx.save_for_backward(input_)
ctx.mu = mu
ctx.a = a
mask1 = (input_ < -a).int()
mask2 = (input_ >= a).int()
mask3 = ((input_ >= 0) & (input_ < a)).int()
p_spike = mask2 + \
(1-mask1)*(1-mask2)*(1-mask3) * (input_ + a)**2 / 2 / a**2 + \
mask3 * (1 - (input_ - a)**2 / 2 / a**2)
return torch.bernoulli(p_spike)

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()

mask1 = (input_ < -ctx.a).int()
mask2 = (input_ >= ctx.a).int()
temp = (1-mask1)*(1-mask2) * (ctx.a - input_.abs()) / ctx.a**2
return grad_input*temp, None, None

@staticmethod
class Uniform(torch.autograd.Function):
r"""
Uniform (continuous uniform distrib.) neuronal noise. The resulting noise-driven learning
covers the Gate (rectangular) surrogate gradients. The noise parameters a (left), b (right),
here we use a=(right-left)/2 to denote the noise scale, the noise avg (mu) should be zero.
"""
@staticmethod
def forward(ctx, input_, mu=0, a=0.5):
input_ += -torch.zeros_like(input_).uniform_(a, a)
ctx.save_for_backward(input_)
ctx.mu = mu
ctx.a = a

p_spike = ((input_ - -a) / (a - -a)).clamp(0, 1)
return torch.bernoulli(p_spike)

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()

temp = ((input_ >= -ctx.a).int() & (input_ <= ctx.a).int()) * (
1 / (ctx.a - -ctx.a)
)
return grad_input*temp, None, None


class _SpikeTensor(torch.Tensor):
"""Inherits from torch.Tensor with additional attributes.
Expand Down
Loading
Loading