Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update NIR <> snnTorch #246

Merged
merged 34 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a2d05dd
start import functionality for NIR
stevenabreu7 Oct 9, 2023
f592cc3
add support for RLeaky and RSynaptic
stevenabreu7 Oct 10, 2023
7ef2e5d
minor fix
stevenabreu7 Oct 10, 2023
7701cc0
minor fix #2
stevenabreu7 Oct 10, 2023
d3f1a8c
rename export + fix subgraph
stevenabreu7 Oct 12, 2023
d2755ff
remove hack to rename subgraph edges
stevenabreu7 Oct 12, 2023
e2e92c3
make import and export work for RNN
stevenabreu7 Oct 12, 2023
d3ee326
import & export using nirtorch (instead of manual)
stevenabreu7 Oct 12, 2023
14767e5
RNN nirtorch export/import works but still buggy
stevenabreu7 Oct 13, 2023
708faa2
minor
stevenabreu7 Oct 13, 2023
9004cda
update NIRTorch import/export (not done)
stevenabreu7 Oct 16, 2023
a9b7cee
version for braille-v2
stevenabreu7 Oct 17, 2023
0c7f692
update to latest support! (use init_hidden=False)
stevenabreu7 Oct 19, 2023
2ada175
rm dead code
stevenabreu7 Oct 19, 2023
0d45ad1
adapt NIR-standard thresholding for (r)synaptic
stevenabreu7 Oct 23, 2023
cb684c6
rename reset_after -> reset_delay (+ invert)
stevenabreu7 Oct 23, 2023
4352c60
add conv/if/pool to import
stevenabreu7 Oct 24, 2023
a204457
fix reset_delay (+ add for (r)leaky)
stevenabreu7 Oct 24, 2023
f5ef7f4
bias bug fix
stevenabreu7 Nov 14, 2023
8c7c786
move to using nirtorch
stevenabreu7 Jan 27, 2024
f962f3c
minor changes
stevenabreu7 Jan 27, 2024
208c87f
+ Leaky export
stevenabreu7 Jan 27, 2024
d831179
remove old files
stevenabreu7 Jan 27, 2024
38f19cd
add docstrings
stevenabreu7 Feb 3, 2024
04b1d02
rename files
stevenabreu7 Feb 3, 2024
bdb15a3
test suggestions
stevenabreu7 Feb 3, 2024
3710dfb
leaky syntax change
jeshraghian Feb 4, 2024
35e6da0
Merge remote-tracking branch 'upstream/master' into nir
jeshraghian Feb 4, 2024
9adbcef
solve conflicts for updated leaky neuron forward-pass
jeshraghian Feb 4, 2024
eec1045
membrane potential init bug fix
jeshraghian Feb 4, 2024
b6dc7d9
rename again (import.py is invalid)
stevenabreu7 Feb 5, 2024
11684cf
add lif nir graph
stevenabreu7 Feb 5, 2024
d9df668
tests
stevenabreu7 Feb 5, 2024
40ad136
missing features into docstring
stevenabreu7 Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions snntorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ._version import __version__
from ._neurons import *
from ._layers import *
from .export_nir import export_to_nir
from .import_nir import import_from_nir
31 changes: 27 additions & 4 deletions snntorch/_neurons/leaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ def forward(self, x, mem1, spk1, mem2):
returned when neuron is called. Defaults to False
:type output: bool, optional

:param graded_spikes_factor: output spikes are scaled this value, if specified. Defaults to 1.0
:type graded_spikes_factor: float or torch.tensor

:param learn_graded_spikes_factor: Option to enable learnable graded spikes. Defaults to False
:type learn_graded_spikes_factor: bool, optional

:param reset_delay: If `True`, a spike is returned with a one-step delay after the threshold is reached.
Defaults to True
:type reset_delay: bool, optional

Inputs: \\input_, mem_0
- **input_** of shape `(batch, input_size)`: tensor containing input
Expand Down Expand Up @@ -142,6 +151,7 @@ def __init__(
output=False,
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False,
reset_delay=True,
Copy link
Collaborator Author

@stevenabreu7 stevenabreu7 Jan 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a clarifying comment: @jeshraghian and I discussed this a while ago - this flag is needed to match up the neuron dynamics with NIR. (with the default value, snntorch will behave exactly as before)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest version of the Leaky neuron had some minor refactoring to get it compatible with torch.compile(), so there will be some conflicts. I'll iron these out today.

):
super().__init__(
beta,
Expand Down Expand Up @@ -169,6 +179,11 @@ def __init__(
elif self.reset_mechanism_val == 2: # no reset, pure integration
self.state_function = self._base_int

self.reset_delay = reset_delay

if not self.reset_delay and self.init_hidden:
raise NotImplementedError("`reset_delay=True` is only supported for `init_hidden=False`")


def _init_mem(self):
mem = torch.zeros(1)
Expand All @@ -178,17 +193,18 @@ def reset_mem(self):
self.mem = torch.zeros_like(self.mem, device=self.mem.device)

def init_leaky(self):
"""Deprecated, please use :class:`Leaky.reset_mem` instead"""
"""Deprecated, use :class:`Leaky.reset_mem` instead"""
self.reset_mem()
return self.mem

def forward(self, input_, mem=None):

if not mem == None:
self.mem = mem

if self.init_hidden and not mem == None:
raise TypeError(
"mem should not be passed as an argument while `init_hidden=True`"
"`mem` should not be passed as an argument while `init_hidden=True`"
)

if not self.mem.shape == input_.shape:
Expand All @@ -201,9 +217,16 @@ def forward(self, input_, mem=None):
self.mem = self.state_quant(self.mem)

if self.inhibition:
spk = self.fire_inhibition(self.mem.size(0), self.mem)
spk = self.fire_inhibition(self.mem.size(0), self.mem) # batch_size
else:
spk = self.fire(self.mem)

if not self.reset_delay:
do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
self.mem = self.mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
self.mem = self.mem - do_reset * self.mem

if self.output:
return spk, self.mem
Expand Down
13 changes: 13 additions & 0 deletions snntorch/_neurons/rleaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(
reset_mechanism="subtract",
state_quant=False,
output=False,
reset_delay=True,
):
super().__init__(
beta,
Expand Down Expand Up @@ -279,6 +280,11 @@ def __init__(
if not learn_recurrent:
self._disable_recurrent_grad()

self.reset_delay = reset_delay

if not self.reset_delay and self.init_hidden:
raise NotImplementedError('no reset_delay only supported for init_hidden=False')

if self.init_hidden:
self.spk, self.mem = self.init_rleaky()
# self.state_fn = self._build_state_function_hidden
Expand Down Expand Up @@ -312,6 +318,13 @@ def forward(self, input_, spk=False, mem=False):
else:
spk = self.fire(mem)

if not self.reset_delay:
do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
mem = mem - do_reset * mem

return spk, mem

# intended for truncated-BPTT where instance variables are hidden
Expand Down
15 changes: 15 additions & 0 deletions snntorch/_neurons/rsynaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __init__(
reset_mechanism="subtract",
state_quant=False,
output=False,
reset_delay=True,
):
super().__init__(
beta,
Expand Down Expand Up @@ -294,6 +295,11 @@ def __init__(

self._alpha_register_buffer(alpha, learn_alpha)

self.reset_delay = reset_delay

if not reset_delay and self.init_hidden:
raise NotImplementedError('no reset_delay only supported for init_hidden=False')

if self.init_hidden:
self.spk, self.syn, self.mem = self.init_rsynaptic()

Expand Down Expand Up @@ -324,6 +330,15 @@ def forward(self, input_, spk=False, syn=False, mem=False):
else:
spk = self.fire(mem)

if not self.reset_delay:
# reset membrane potential _right_ after spike
do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
# mem -= do_reset * mem
mem = mem - do_reset * mem

return spk, syn, mem

# intended for truncated-BPTT where instance variables are hidden
Expand Down
14 changes: 14 additions & 0 deletions snntorch/_neurons/synaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
reset_mechanism="subtract",
state_quant=False,
output=False,
reset_delay=True,
):
super().__init__(
beta,
Expand All @@ -185,6 +186,11 @@ def __init__(

self._alpha_register_buffer(alpha, learn_alpha)

self.reset_delay = reset_delay

if not reset_delay and self.init_hidden:
raise NotImplementedError('no reset_delay only supported for init_hidden=False')

if self.init_hidden:
self.syn, self.mem = self.init_synaptic()

Expand Down Expand Up @@ -214,6 +220,14 @@ def forward(self, input_, syn=False, mem=False):
else:
spk = self.fire(mem)

if not self.reset_delay:
# reset membrane potential _right_ after spike
do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
mem = mem - do_reset * mem

return spk, syn, mem

# intended for truncated-BPTT where instance variables are
Expand Down
86 changes: 0 additions & 86 deletions snntorch/export.py

This file was deleted.

Loading
Loading