Skip to content

Commit

Permalink
Merge pull request #246 from stevenabreu7/nir
Browse files Browse the repository at this point in the history
Update NIR <> snnTorch
  • Loading branch information
jeshraghian authored Feb 5, 2024
2 parents 2e8828e + 40ad136 commit 70b1c65
Show file tree
Hide file tree
Showing 10 changed files with 754 additions and 90 deletions.
2 changes: 2 additions & 0 deletions snntorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ._version import __version__
from ._neurons import *
from ._layers import *
from .export_nir import export_to_nir
from .import_nir import import_from_nir
31 changes: 27 additions & 4 deletions snntorch/_neurons/leaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ def forward(self, x, mem1, spk1, mem2):
returned when neuron is called. Defaults to False
:type output: bool, optional
:param graded_spikes_factor: output spikes are scaled this value, if specified. Defaults to 1.0
:type graded_spikes_factor: float or torch.tensor
:param learn_graded_spikes_factor: Option to enable learnable graded spikes. Defaults to False
:type learn_graded_spikes_factor: bool, optional
:param reset_delay: If `True`, a spike is returned with a one-step delay after the threshold is reached.
Defaults to True
:type reset_delay: bool, optional
Inputs: \\input_, mem_0
- **input_** of shape `(batch, input_size)`: tensor containing input
Expand Down Expand Up @@ -142,6 +151,7 @@ def __init__(
output=False,
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False,
reset_delay=True,
):
super().__init__(
beta,
Expand Down Expand Up @@ -169,6 +179,11 @@ def __init__(
elif self.reset_mechanism_val == 2: # no reset, pure integration
self.state_function = self._base_int

self.reset_delay = reset_delay

if not self.reset_delay and self.init_hidden:
raise NotImplementedError("`reset_delay=True` is only supported for `init_hidden=False`")


def _init_mem(self):
mem = torch.zeros(1)
Expand All @@ -178,17 +193,18 @@ def reset_mem(self):
self.mem = torch.zeros_like(self.mem, device=self.mem.device)

def init_leaky(self):
"""Deprecated, please use :class:`Leaky.reset_mem` instead"""
"""Deprecated, use :class:`Leaky.reset_mem` instead"""
self.reset_mem()
return self.mem

def forward(self, input_, mem=None):

if not mem == None:
self.mem = mem

if self.init_hidden and not mem == None:
raise TypeError(
"mem should not be passed as an argument while `init_hidden=True`"
"`mem` should not be passed as an argument while `init_hidden=True`"
)

if not self.mem.shape == input_.shape:
Expand All @@ -201,9 +217,16 @@ def forward(self, input_, mem=None):
self.mem = self.state_quant(self.mem)

if self.inhibition:
spk = self.fire_inhibition(self.mem.size(0), self.mem)
spk = self.fire_inhibition(self.mem.size(0), self.mem) # batch_size
else:
spk = self.fire(self.mem)

if not self.reset_delay:
do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
self.mem = self.mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
self.mem = self.mem - do_reset * self.mem

if self.output:
return spk, self.mem
Expand Down
13 changes: 13 additions & 0 deletions snntorch/_neurons/rleaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(
reset_mechanism="subtract",
state_quant=False,
output=False,
reset_delay=True,
):
super().__init__(
beta,
Expand Down Expand Up @@ -279,6 +280,11 @@ def __init__(
if not learn_recurrent:
self._disable_recurrent_grad()

self.reset_delay = reset_delay

if not self.reset_delay and self.init_hidden:
raise NotImplementedError('no reset_delay only supported for init_hidden=False')

if self.init_hidden:
self.spk, self.mem = self.init_rleaky()
# self.state_fn = self._build_state_function_hidden
Expand Down Expand Up @@ -312,6 +318,13 @@ def forward(self, input_, spk=False, mem=False):
else:
spk = self.fire(mem)

if not self.reset_delay:
do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
mem = mem - do_reset * mem

return spk, mem

# intended for truncated-BPTT where instance variables are hidden
Expand Down
15 changes: 15 additions & 0 deletions snntorch/_neurons/rsynaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __init__(
reset_mechanism="subtract",
state_quant=False,
output=False,
reset_delay=True,
):
super().__init__(
beta,
Expand Down Expand Up @@ -294,6 +295,11 @@ def __init__(

self._alpha_register_buffer(alpha, learn_alpha)

self.reset_delay = reset_delay

if not reset_delay and self.init_hidden:
raise NotImplementedError('no reset_delay only supported for init_hidden=False')

if self.init_hidden:
self.spk, self.syn, self.mem = self.init_rsynaptic()

Expand Down Expand Up @@ -324,6 +330,15 @@ def forward(self, input_, spk=False, syn=False, mem=False):
else:
spk = self.fire(mem)

if not self.reset_delay:
# reset membrane potential _right_ after spike
do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
# mem -= do_reset * mem
mem = mem - do_reset * mem

return spk, syn, mem

# intended for truncated-BPTT where instance variables are hidden
Expand Down
14 changes: 14 additions & 0 deletions snntorch/_neurons/synaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
reset_mechanism="subtract",
state_quant=False,
output=False,
reset_delay=True,
):
super().__init__(
beta,
Expand All @@ -185,6 +186,11 @@ def __init__(

self._alpha_register_buffer(alpha, learn_alpha)

self.reset_delay = reset_delay

if not reset_delay and self.init_hidden:
raise NotImplementedError('no reset_delay only supported for init_hidden=False')

if self.init_hidden:
self.syn, self.mem = self.init_synaptic()

Expand Down Expand Up @@ -214,6 +220,14 @@ def forward(self, input_, syn=False, mem=False):
else:
spk = self.fire(mem)

if not self.reset_delay:
# reset membrane potential _right_ after spike
do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
mem = mem - do_reset * mem

return spk, syn, mem

# intended for truncated-BPTT where instance variables are
Expand Down
86 changes: 0 additions & 86 deletions snntorch/export.py

This file was deleted.

Loading

0 comments on commit 70b1c65

Please sign in to comment.