Skip to content

Commit

Permalink
Add synaptic fullgraph support
Browse files Browse the repository at this point in the history
  • Loading branch information
gekkom committed Feb 14, 2024
1 parent 7b5fdca commit 583a860
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 118 deletions.
215 changes: 97 additions & 118 deletions snntorch/_neurons/synaptic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from .neurons import _SpikeTensor, _SpikeTorchConv, LIF
from .neurons import LIF


class Synaptic(LIF):
Expand Down Expand Up @@ -186,132 +186,111 @@ def __init__(

self._alpha_register_buffer(alpha, learn_alpha)

self._init_mem()

if self.reset_mechanism_val == 0: # reset by subtraction
self.state_function = self._base_sub
elif self.reset_mechanism_val == 1: # reset to zero
self.state_function = self._base_zero
elif self.reset_mechanism_val == 2: # no reset, pure integration
self.state_function = self._base_int

self.reset_delay = reset_delay

if not reset_delay and self.init_hidden:
raise NotImplementedError('no reset_delay only supported for init_hidden=False')

if self.init_hidden:
self.syn, self.mem = self.init_synaptic()

def forward(self, input_, syn=False, mem=False):

if hasattr(syn, "init_flag") or hasattr(
mem, "init_flag"
): # only triggered on first-pass
syn, mem = _SpikeTorchConv(syn, mem, input_=input_)
elif mem is False and hasattr(
self.mem, "init_flag"
): # init_hidden case
self.syn, self.mem = _SpikeTorchConv(
self.syn, self.mem, input_=input_
raise NotImplementedError(
"no reset_delay only supported for init_hidden=False"
)

if not self.init_hidden:
self.reset = self.mem_reset(mem)
syn, mem = self._build_state_function(input_, syn, mem)

if self.state_quant:
syn = self.state_quant(syn)
mem = self.state_quant(mem)

if self.inhibition:
spk = self.fire_inhibition(mem.size(0), mem)
else:
spk = self.fire(mem)

if not self.reset_delay:
# reset membrane potential _right_ after spike
do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
mem = mem - do_reset * mem

return spk, syn, mem

# intended for truncated-BPTT where instance variables are
# hidden states
if self.init_hidden:
self._synaptic_forward_cases(mem, syn)
self.reset = self.mem_reset(self.mem)
self.syn, self.mem = self._build_state_function_hidden(input_)

if self.state_quant:
self.syn = self.state_quant(self.syn)
self.mem = self.state_quant(self.mem)

if self.inhibition:
self.spk = self.fire_inhibition(self.mem.size(0), self.mem)
else:
self.spk = self.fire(self.mem)

if self.output:
return self.spk, self.syn, self.mem
else:
return self.spk

def _base_state_function(self, input_, syn, mem):
base_fn_syn = self.alpha.clamp(0, 1) * syn + input_
base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn
return base_fn_syn, base_fn_mem
def _init_mem(self):
syn = torch.zeros(1)
mem = torch.zeros(1)
self.register_buffer("syn", syn)
self.register_buffer("mem", mem)

def _base_state_reset_zero(self, input_, syn, mem):
base_fn_syn = self.alpha.clamp(0, 1) * syn + input_
base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn
return 0, base_fn_mem
def reset_mem(self):
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)

def _build_state_function(self, input_, syn, mem):
if self.reset_mechanism_val == 0: # reset by subtraction
state_fn = tuple(
map(
lambda x, y: x - y,
self._base_state_function(input_, syn, mem),
(0, self.reset * self.threshold),
)
)
elif self.reset_mechanism_val == 1: # reset to zero
state_fn = tuple(
map(
lambda x, y: x - self.reset * y,
self._base_state_function(input_, syn, mem),
self._base_state_reset_zero(input_, syn, mem),
)
def init_synaptic(self):
"""Deprecated, use :class:`Synaptic.reset_mem` instead"""
self.reset_mem()
return self.syn, self.mem

def forward(self, input_, syn=None, mem=None):

if not syn == None:
self.syn = mem

if not mem == None:
self.mem = mem

if self.init_hidden and (not mem == None or not syn == None):
raise TypeError(
"`mem` or `syn` should not be passed as an argument while `init_hidden=True`"
)
elif self.reset_mechanism_val == 2: # no reset, pure integration
state_fn = self._base_state_function(input_, syn, mem)
return state_fn

def _base_state_function_hidden(self, input_):
if not self.syn.shape == input_.shape:
self.syn = torch.zeros_like(input_, device=self.syn.device)

if not self.mem.shape == input_.shape:
self.mem = torch.zeros_like(input_, device=self.mem.device)

self.reset = self.mem_reset(self.mem)
self.syn, self.mem = self.state_function(input_)

if self.state_quant:
self.mem = self.state_quant(self.mem)
self.syn = self.state_quant(self.syn)

if self.inhibition:
spk = self.fire_inhibition(
self.mem.size(0), self.mem
) # batch_size
else:
spk = self.fire(self.mem)

if not self.reset_delay:
# reset membrane potential _right_ after spike
do_reset = (
spk / self.graded_spikes_factor - self.reset
) # avoid double reset
if self.reset_mechanism_val == 0: # reset by subtraction
mem = mem - do_reset * self.threshold
elif self.reset_mechanism_val == 1: # reset to zero
mem = mem - do_reset * mem

if self.output:
return spk, self.syn, self.mem
elif self.init_hidden:
return spk
else:
return spk, self.syn, self.mem

def _base_state_function(self, input_):
base_fn_syn = self.alpha.clamp(0, 1) * self.syn + input_
base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn
return base_fn_syn, base_fn_mem

def _base_state_reset_zero_hidden(self, input_):
def _base_state_reset_zero(self, input_):
base_fn_syn = self.alpha.clamp(0, 1) * self.syn + input_
base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn
return 0, base_fn_mem

def _build_state_function_hidden(self, input_):
if self.reset_mechanism_val == 0: # reset by subtraction
state_fn = tuple(
map(
lambda x, y: x - y,
self._base_state_function_hidden(input_),
(0, self.reset * self.threshold),
)
)
elif self.reset_mechanism_val == 1: # reset to zero
state_fn = tuple(
map(
lambda x, y: x - self.reset * y,
self._base_state_function_hidden(input_),
self._base_state_reset_zero_hidden(input_),
)
)
elif self.reset_mechanism_val == 2: # no reset, pure integration
state_fn = self._base_state_function_hidden(input_)
return state_fn
def _base_sub(self, input_):
syn, mem = self._base_state_function(input_)
mem = mem - self.reset * self.threshold
return syn, mem

def _base_zero(self, input_):
syn, mem = self._base_state_function(input_)
syn2, mem2 = self._base_state_reset_zero(input_)
syn -= syn2 * self.reset
mem -= mem2 * self.reset
return syn, mem

def _base_int(self, input_):
return self._base_state_function(input_)

def _alpha_register_buffer(self, alpha, learn_alpha):
if not isinstance(alpha, torch.Tensor):
Expand All @@ -321,12 +300,6 @@ def _alpha_register_buffer(self, alpha, learn_alpha):
else:
self.register_buffer("alpha", alpha)

def _synaptic_forward_cases(self, mem, syn):
if mem is not False or syn is not False:
raise TypeError(
"When `init_hidden=True`, Synaptic expects 1 input argument."
)

@classmethod
def detach_hidden(cls):
"""Returns the hidden states, detached from the current graph.
Expand All @@ -346,5 +319,11 @@ def reset_hidden(cls):

for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], Synaptic):
cls.instances[layer].syn = _SpikeTensor(init_flag=False)
cls.instances[layer].mem = _SpikeTensor(init_flag=False)
cls.instances[layer].syn = torch.zeros_like(
cls.instances[layer].syn,
device=cls.instances[layer].syn.device,
)
cls.instances[layer].mem = torch.zeros_like(
cls.instances[layer].mem,
device=cls.instances[layer].mem.device,
)
10 changes: 10 additions & 0 deletions tests/test_snntorch/test_synaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import snntorch as snn
import torch
import torch._dynamo as dynamo


@pytest.fixture(scope="module")
Expand All @@ -16,6 +17,10 @@ def input_():
def synaptic_instance():
return snn.Synaptic(alpha=0.5, beta=0.5)

@pytest.fixture(scope="module")
def synaptic_instance_surrogate():
return snn.Synaptic(alpha=0.5, beta=0.5, surrogate_disable=True)


@pytest.fixture(scope="module")
def synaptic_reset_zero_instance():
Expand Down Expand Up @@ -123,3 +128,8 @@ def test_synaptic_init_hidden_reset_none(
def test_synaptic_cases(self, synaptic_hidden_instance, input_):
with pytest.raises(TypeError):
synaptic_hidden_instance(input_, input_)

def test_synaptic_compile_fullgraph(self, synaptic_instance_surrogate, input_):
explanation = dynamo.explain(synaptic_instance_surrogate)(input_[0])

assert explanation.graph_break_count == 0

0 comments on commit 583a860

Please sign in to comment.