Skip to content

Commit

Permalink
Merge pull request #230 from genema/master
Browse files Browse the repository at this point in the history
Noisy spiking neural nets implementations.
  • Loading branch information
ahenkes1 authored Aug 15, 2023
2 parents 50042c9 + dd8d8b1 commit fb8f2f9
Show file tree
Hide file tree
Showing 4 changed files with 721 additions and 0 deletions.
2 changes: 2 additions & 0 deletions snntorch/_neurons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"alpha",
"lapicque",
"leaky",
"noisyleaky",
"rleaky",
"rsynaptic",
"synaptic",
Expand All @@ -24,6 +25,7 @@
from .alpha import Alpha
from .lapicque import Lapicque
from .leaky import Leaky
from .noisyleaky import NoisyLeaky
from .synaptic import Synaptic

from .rleaky import RLeaky
Expand Down
244 changes: 244 additions & 0 deletions snntorch/_neurons/neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
import torch.nn as nn

import math


__all__ = [
"SpikingNeuron",
Expand Down Expand Up @@ -435,7 +437,249 @@ def init_alpha():
mem = _SpikeTensor(init_flag=False)

return syn_exc, syn_inh, mem


class NoisyLIF(SpikingNeuron):
"""Parent class for noisy leaky integrate and fire neuron models."""

def __init__(
self,
beta,
threshold=1.0,
noise_type='gaussian',
noise_scale=0.3,
init_hidden=False,
inhibition=False,
learn_beta=False,
learn_threshold=False,
reset_mechanism="subtract",
state_quant=False,
output=False,
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False,
):
super().__init__(
threshold,
None,
False,
init_hidden,
inhibition,
learn_threshold,
reset_mechanism,
state_quant,
output,
graded_spikes_factor,
learn_graded_spikes_factor,
)

self._lif_register_buffer(
beta,
learn_beta,
)
self._reset_mechanism = reset_mechanism
self._noise_scale = noise_scale

if noise_type == 'gaussian':
self.spike_grad = self.Gaussian.apply
elif noise_type == 'logistic':
self.spike_grad = self.Logistic.apply
elif noise_type == 'triangular':
self.spike_grad = self.Triangular.apply
elif noise_type == 'uniform':
self.spike_grad = self.Uniform.apply
elif noise_type == 'atan':
self.spike_grad = self.Arctangent.apply
else:
raise ValueError("Invalid noise type. Valid options: gaussian, logistic, triangular, \
uniform, atan")

if self.surrogate_disable:
self.spike_grad = self._surrogate_bypass

def _lif_register_buffer(
self,
beta,
learn_beta,
):
"""Set variables as learnable parameters else register them in the
buffer."""
self._beta_buffer(beta, learn_beta)

def _beta_buffer(self, beta, learn_beta):
if not isinstance(beta, torch.Tensor):
beta = torch.as_tensor(beta) # TODO: or .tensor() if no copy
if learn_beta:
self.beta = nn.Parameter(beta)
else:
self.register_buffer("beta", beta)

@staticmethod
def init_noisyleaky():
"""
Used to initialize mem as an empty SpikeTensor.
``init_flag`` is used as an attribute in the forward pass to convert
the hidden states to the same as the input.
"""
mem = _SpikeTensor(init_flag=False)

return mem

def mem_reset(self, mem):
"""Generates detached reset signal if mem > threshold.
Returns reset."""
mem_shift = mem - self.threshold
reset = self.spike_grad(mem_shift, 0, self._noise_scale).clone().detach()

return reset

@staticmethod
class Gaussian(torch.autograd.Function):
r"""
Gaussian noise. This is the original and default type because the iterative form is derived
from an Ito SDE. The noise is drawn from Gaus(mu, sigma**2).
Let us denote the cumulative distribution function of the noise by CDF,
its probability density function as PDF.
**Forward pass:** Probabilistic firing.
.. math::
S &~ \\text{Bernoulli}(P(\\text{spiking})) \\\\
P(\\text{firing}) = CDF$_{\\rm noise}$ (U-\\text{threshold})
**Backward pass:** Noise-driven learning corresponds to the specified membrane noise.
.. math::
\\frac{∂S}{∂U}&= PDF$_{\\rm noise}$ (U-\\text{threshold})
Refer to:
Ma et al. Exploiting Noise as a Resource for Computation and Learning in Spiking Neural
Networks. Patterns. Cell Press. 2023.
"""

@staticmethod
def forward(ctx, input_, mu=0, sigma=0.3):
input_ += -torch.normal(torch.ones_like(input_) * mu, sigma)
ctx.save_for_backward(input_)
ctx.mu = mu
ctx.sigma = sigma
p_spike = 1/2 * (
1 + torch.erf((input_ - mu) / (sigma * math.sqrt(2)))
)
return torch.bernoulli(p_spike)

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()

temp = (
1 / (math.sqrt(2*math.pi) * ctx.sigma)
) * torch.exp(
-0.5 * ((input_ - ctx.mu) / ctx.sigma).pow_(2)
)
return grad_input*temp, None, None

@staticmethod
class Logistic(torch.autograd.Function):
r"""
Logistic neuronal noise. The resulting noise-driven learning covers the sigmoidal surrogate
gradients in training conventional deterministic spiking models. The noise parameter mu
should be zero, and the scale denotes the noise scale.
Refer to:
Ma et al. Exploiting Noise as a Resource for Computation and Learning in Spiking Neural
Networks. Patterns. Cell Press. 2023.
"""
@staticmethod
def forward(ctx, input_, mu=0, scale=0.4):
noise = torch.zeros_like(input_).uniform_(0, 1)
noise = mu + scale * (torch.log((noise+1e-8) / (1-noise+1e-8)))
input_ += noise
ctx.save_for_backward(input_)
ctx.mu = mu
ctx.scale = scale

p_spike = torch.special.expit((input_ - ctx.mu + 1e-8) / (ctx.scale + 1e-8)).nan_to_num_()
return torch.bernoulli(p_spike)

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()

temp = torch.exp(
-(input_ - ctx.mu) / ctx.scale
) / ctx.scale / (1 + torch.exp(-(input_ - ctx.mu) / ctx.scale)).pow_(2)
return grad_input*temp, None, None


@staticmethod
class Triangular(torch.autograd.Function):
r"""
Triangular (symmetric) neuronal noise. The resulting noise-driven learning covers the
triangular surrogate gradients in training conventional deterministic spiking models.
The noise avg (mu) is zero.
"""
@staticmethod
def forward(ctx, input_, mu=0, a=0.3):
fc = 0.5
noise = torch.zeros_like(input_).uniform_(0, 1)
mask = (noise < fc).int()
noise = (-a * mask + (2 * a**2 * mask * noise).sqrt()) + \
((1-mask) * a - (2 * a**2 * (1-mask) * (1 - noise)).sqrt())
input_ += noise

ctx.save_for_backward(input_)
ctx.mu = mu
ctx.a = a
mask1 = (input_ < -a).int()
mask2 = (input_ >= a).int()
mask3 = ((input_ >= 0) & (input_ < a)).int()
p_spike = mask2 + \
(1-mask1)*(1-mask2)*(1-mask3) * (input_ + a)**2 / 2 / a**2 + \
mask3 * (1 - (input_ - a)**2 / 2 / a**2)
return torch.bernoulli(p_spike)

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()

mask1 = (input_ < -ctx.a).int()
mask2 = (input_ >= ctx.a).int()
temp = (1-mask1)*(1-mask2) * (ctx.a - input_.abs()) / ctx.a**2
return grad_input*temp, None, None

@staticmethod
class Uniform(torch.autograd.Function):
r"""
Uniform (continuous uniform distrib.) neuronal noise. The resulting noise-driven learning
covers the Gate (rectangular) surrogate gradients. The noise parameters a (left), b (right),
here we use a=(right-left)/2 to denote the noise scale, the noise avg (mu) should be zero.
"""
@staticmethod
def forward(ctx, input_, mu=0, a=0.5):
input_ += -torch.zeros_like(input_).uniform_(a, a)
ctx.save_for_backward(input_)
ctx.mu = mu
ctx.a = a

p_spike = ((input_ - -a) / (a - -a)).clamp(0, 1)
return torch.bernoulli(p_spike)

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()

temp = ((input_ >= -ctx.a).int() & (input_ <= ctx.a).int()) * (
1 / (ctx.a - -ctx.a)
)
return grad_input*temp, None, None


class _SpikeTensor(torch.Tensor):
"""Inherits from torch.Tensor with additional attributes.
Expand Down
Loading

0 comments on commit fb8f2f9

Please sign in to comment.