From afac824865bbe4687e597e287e75b7a563f1c009 Mon Sep 17 00:00:00 2001 From: visdauas Date: Wed, 14 Feb 2024 13:30:05 +0100 Subject: [PATCH 01/17] Move from math to torch in ATan --- snntorch/surrogate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snntorch/surrogate.py b/snntorch/surrogate.py index 390eb1bd..31914e93 100644 --- a/snntorch/surrogate.py +++ b/snntorch/surrogate.py @@ -197,7 +197,7 @@ def backward(ctx, grad_output): grad = ( ctx.alpha / 2 - / (1 + (math.pi / 2 * ctx.alpha * input_).pow_(2)) + / (1 + (torch.pi / 2 * ctx.alpha * input_).pow_(2)) * grad_input ) return grad, None From 73c347adf9ad32e3252f5fb7af944b948b19dd4d Mon Sep 17 00:00:00 2001 From: visdauas Date: Wed, 14 Feb 2024 13:31:26 +0100 Subject: [PATCH 02/17] Finalize leaky fullgraph support --- snntorch/_neurons/leaky.py | 23 ++++++++++++++--------- tests/test_snntorch/test_leaky.py | 16 +++++++++++++++- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/snntorch/_neurons/leaky.py b/snntorch/_neurons/leaky.py index 7406fa82..41166462 100644 --- a/snntorch/_neurons/leaky.py +++ b/snntorch/_neurons/leaky.py @@ -2,6 +2,7 @@ import torch from torch import nn + class Leaky(LIF): """ First-order leaky integrate-and-fire neuron model. @@ -170,7 +171,6 @@ def __init__( ) self._init_mem() - self.init_hidden = init_hidden if self.reset_mechanism_val == 0: # reset by subtraction self.state_function = self._base_sub @@ -178,12 +178,13 @@ def __init__( self.state_function = self._base_zero elif self.reset_mechanism_val == 2: # no reset, pure integration self.state_function = self._base_int - + self.reset_delay = reset_delay if not self.reset_delay and self.init_hidden: - raise NotImplementedError("`reset_delay=True` is only supported for `init_hidden=False`") - + raise NotImplementedError( + "`reset_delay=True` is only supported for `init_hidden=False`" + ) def _init_mem(self): mem = torch.zeros(1) @@ -196,12 +197,12 @@ def init_leaky(self): """Deprecated, use :class:`Leaky.reset_mem` instead""" self.reset_mem() return self.mem - + def forward(self, input_, mem=None): if not mem == None: self.mem = mem - + if self.init_hidden and not mem == None: raise TypeError( "`mem` should not be passed as an argument while `init_hidden=True`" @@ -217,12 +218,16 @@ def forward(self, input_, mem=None): self.mem = self.state_quant(self.mem) if self.inhibition: - spk = self.fire_inhibition(self.mem.size(0), self.mem) # batch_size + spk = self.fire_inhibition( + self.mem.size(0), self.mem + ) # batch_size else: spk = self.fire(self.mem) - + if not self.reset_delay: - do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + do_reset = ( + spk / self.graded_spikes_factor - self.reset + ) # avoid double reset if self.reset_mechanism_val == 0: # reset by subtraction self.mem = self.mem - do_reset * self.threshold elif self.reset_mechanism_val == 1: # reset to zero diff --git a/tests/test_snntorch/test_leaky.py b/tests/test_snntorch/test_leaky.py index a04ca744..30622bce 100644 --- a/tests/test_snntorch/test_leaky.py +++ b/tests/test_snntorch/test_leaky.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -17,6 +18,11 @@ def leaky_instance(): return snn.Leaky(beta=0.5) +@pytest.fixture(scope="module") +def leaky_instance_surrogate(): + return snn.Leaky(beta=0.5, surrogate_disable=True) + + @pytest.fixture(scope="module") def leaky_reset_zero_instance(): return snn.Leaky(beta=0.5, reset_mechanism="zero") @@ -126,8 +132,16 @@ def test_leaky_cases(self, leaky_hidden_instance, input_): leaky_hidden_instance(input_, input_) def test_leaky_hidden_learn_graded_instance( - self, leaky_hidden_learn_graded_instance + self, leaky_hidden_learn_graded_instance ): factor = leaky_hidden_learn_graded_instance.graded_spikes_factor assert factor.requires_grad + + def test_leaky_compile_fullgraph(self, leaky_instance_surrogate, input_): + # net = nn.Sequential( + # snn.Leaky(beta=0.5, init_hidden=True, surrogate_disable=True), + # ) + + explanation = dynamo.explain(leaky_instance_surrogate)(input_[0]) + assert explanation.graph_break_count == 0 From cce3134bb4c221869aa0952551f2099f30507f33 Mon Sep 17 00:00:00 2001 From: visdauas Date: Wed, 14 Feb 2024 13:31:55 +0100 Subject: [PATCH 03/17] Add alpha fullgraph support --- snntorch/_neurons/alpha.py | 223 +++++++++++++----------------- tests/test_snntorch/test_alpha.py | 11 ++ 2 files changed, 109 insertions(+), 125 deletions(-) diff --git a/snntorch/_neurons/alpha.py b/snntorch/_neurons/alpha.py index 2d339b87..4b0b0c91 100644 --- a/snntorch/_neurons/alpha.py +++ b/snntorch/_neurons/alpha.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class Alpha(LIF): @@ -119,118 +119,86 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) self._alpha_cases() - if self.init_hidden: - self.syn_exc, self.syn_inh, self.mem = self.init_alpha() + self._init_mem() - # if reset_mechanism == "subtract": - # self.mem_residual = False + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int - def forward(self, input_, syn_exc=False, syn_inh=False, mem=False): + def _init_mem(self): + syn_exc = torch.zeros(1) + syn_inh = torch.zeros(1) + mem = torch.zeros(1) - if ( - hasattr(syn_exc, "init_flag") - or hasattr(syn_inh, "init_flag") - or hasattr(mem, "init_flag") - ): # only triggered on first-pass - syn_exc, syn_inh, mem = _SpikeTorchConv( - syn_exc, syn_inh, mem, input_=input_ - ) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.syn_exc, self.syn_inh, self.mem = _SpikeTorchConv( - self.syn_exc, self.syn_inh, self.mem, input_=input_ - ) + self.register_buffer("syn_exc", syn_exc) + self.register_buffer("syn_inh", syn_inh) + self.register_buffer("mem", mem) - # if hidden states are passed externally - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn_exc, syn_inh, mem = self._build_state_function( - input_, syn_exc, syn_inh, mem - ) + def reset_mem(self): + self.syn_exc = torch.zeros_like( + self.syn_exc, device=self.syn_exc.device + ) + self.syn_inh = torch.zeros_like( + self.syn_inh, device=self.syn_inh.device + ) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - if self.state_quant: - syn_exc = self.state_quant(syn_exc) - syn_inh = self.state_quant(syn_inh) - mem = self.state_quant(mem) + def init_alpha(self): + """Deprecated, use :class:`Alpha.reset_mem` instead""" + self.reset_mem() + return self.syn_exc, self.syn_inh, self.mem - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) + def forward(self, input_, syn_exc=None, syn_inh=None, mem=None): - else: - spk = self.fire(mem) + if not syn_exc == None: + self.syn_exc = syn_exc - return spk, syn_exc, syn_inh, mem + if not syn_inh == None: + self.syn_inh = syn_inh - # if hidden states and outputs are instance variables - if self.init_hidden: - self._alpha_forward_cases(mem, syn_exc, syn_inh) + if not mem == None: + self.mem = mem - self.reset = self.mem_reset(self.mem) - ( - self.syn_exc, - self.syn_inh, - self.mem, - ) = self._build_state_function_hidden(input_) + if self.init_hidden and ( + not mem == None or not syn_exc == None or not syn_inh == None + ): + raise TypeError( + "When `init_hidden=True`, Alpha expects 1 input argument." + ) - if self.state_quant: - self.syn_exc = self.state_quant(self.syn_exc) - self.syn_inh = self.state_quant(self.syn_inh) - self.mem = self.state_quant(self.mem) + if not self.syn_exc.shape == input_.shape: + self.syn_exc = torch.zeros_like(input_, device=self.syn_exc.device) - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) + if not self.syn_inh.shape == input_.shape: + self.syn_inh = torch.zeros_like(input_, device=self.syn_inh.device) - else: - self.spk = self.fire(self.mem) + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) - if self.output: - return self.spk, self.syn_exc, self.syn_inh, self.mem - else: - return self.spk + self.reset = self.mem_reset(self.mem) + self.syn_exc, self.syn_inh, self.mem = self.state_function(input_) - def _base_state_function(self, input_, syn_exc, syn_inh, mem): - base_fn_syn_exc = self.alpha.clamp(0, 1) * syn_exc + input_ - base_fn_syn_inh = self.beta.clamp(0, 1) * syn_inh - input_ - tau_alpha = ( - torch.log(self.alpha.clamp(0, 1)) - / ( - torch.log(self.beta.clamp(0, 1)) - - torch.log(self.alpha.clamp(0, 1)) - ) - + 1 - ) - base_fn_mem = tau_alpha * (base_fn_syn_exc + base_fn_syn_inh) - return base_fn_syn_exc, base_fn_syn_inh, base_fn_mem + if self.state_quant: + self.syn_exc = self.state_quant(self.syn_exc) + self.syn_inh = self.state_quant(self.syn_inh) + self.mem = self.state_quant(self.mem) - def _base_state_reset_sub_function(self, input_, syn_inh): - syn_exc_reset = self.threshold - syn_inh_reset = self.beta.clamp(0, 1) * syn_inh - input_ - mem_reset = 0 - return syn_exc_reset, syn_inh_reset, mem_reset + if self.inhibition: + spk = self.fire_inhibition(self.mem.size(0), self.mem) + else: + spk = self.fire(self.mem) - def _build_state_function(self, input_, syn_exc, syn_inh, mem): - if self.reset_mechanism_val == 0: - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn_exc, syn_inh, mem), - self._base_state_reset_sub_function(input_, syn_inh), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn_exc, syn_inh, mem), - self._base_state_function(input_, syn_exc, syn_inh, mem), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, syn_exc, syn_inh, mem) - return state_fn + if self.output: + return spk, self.syn_exc, self.syn_inh, self.mem + elif self.init_hidden: + return spk + else: + return spk, self.syn_exc, self.syn_inh, self.mem - def _base_state_function_hidden(self, input_): + def _base_state_function(self, input_): base_fn_syn_exc = self.alpha.clamp(0, 1) * self.syn_exc + input_ base_fn_syn_inh = self.beta.clamp(0, 1) * self.syn_inh - input_ tau_alpha = ( @@ -244,32 +212,34 @@ def _base_state_function_hidden(self, input_): base_fn_mem = tau_alpha * (base_fn_syn_exc + base_fn_syn_inh) return base_fn_syn_exc, base_fn_syn_inh, base_fn_mem - def _base_state_reset_sub_function_hidden(self, input_): + def _base_state_reset_sub_function(self, input_): syn_exc_reset = self.threshold syn_inh_reset = self.beta.clamp(0, 1) * self.syn_inh - input_ mem_reset = -self.syn_inh return syn_exc_reset, syn_inh_reset, mem_reset - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_reset_sub_function_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_function_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + syn_exec, syn_inh, mem = self._base_state_function(input_) + syn_exec2, syn_inh2, mem2 = self._base_state_reset_sub_function(input_) + + syn_exec -= syn_exec2 * self.reset + syn_inh -= syn_inh2 * self.reset + mem -= mem2 * self.reset + + return syn_exec, syn_inh, mem + + def _base_zero(self, input_): + syn_exec, syn_inh, mem = self._base_state_function(input_) + syn_exec2, syn_inh2, mem2 = self._base_state_function(input_) + + syn_exec -= syn_exec2 * self.reset + syn_inh -= syn_inh2 * self.reset + mem -= mem2 * self.reset + + return syn_exec, syn_inh, mem + + def _base_int(self, input_): + return self._base_state_function(input_) def _alpha_register_buffer(self, alpha, learn_alpha): if not isinstance(alpha, torch.Tensor): @@ -291,12 +261,6 @@ def _alpha_cases(self): "tau_alpha = log(alpha)/log(beta) - log(alpha) + 1" ) - def _alpha_forward_cases(self, mem, syn_exc, syn_inh): - if mem is not False or syn_exc is not False or syn_inh is not False: - raise TypeError( - "When `init_hidden=True`, Alpha expects 1 input argument." - ) - @classmethod def detach_hidden(cls): """Used to detach hidden states from the current graph. @@ -315,6 +279,15 @@ def reset_hidden(cls): variables.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Alpha): - cls.instances[layer].syn_exc = _SpikeTensor(init_flag=False) - cls.instances[layer].syn_inh = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn_exc = torch.zeros_like( + cls.instances[layer].syn_exc, + device=cls.instances[layer].syn_exc.device, + ) + cls.instances[layer].syn_inh = torch.zeros_like( + cls.instances[layer].syn_inh, + device=cls.instances[layer].syn_inh.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/tests/test_snntorch/test_alpha.py b/tests/test_snntorch/test_alpha.py index e96f8c4c..29e59935 100644 --- a/tests/test_snntorch/test_alpha.py +++ b/tests/test_snntorch/test_alpha.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -16,6 +17,10 @@ def input_(): def alpha_instance(): return snn.Alpha(alpha=0.6, beta=0.5, reset_mechanism="subtract") +@pytest.fixture(scope="module") +def alpha_instance_surrogate(): + return snn.Alpha(alpha=0.6, beta=0.5, reset_mechanism="subtract", surrogate_disable=True) + @pytest.fixture(scope="module") def alpha_reset_zero_instance(): @@ -136,3 +141,9 @@ def test_alpha_init_hidden_reset_none( def test_alpha_cases(self, alpha_hidden_instance, input_): with pytest.raises(TypeError): alpha_hidden_instance(input_, input_) + + + def test_alpha_compile_fullgraph(self, alpha_instance_surrogate, input_): + explanation = dynamo.explain(alpha_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 \ No newline at end of file From 7b5fdcab325641f36a0a962308d68bad96d733bc Mon Sep 17 00:00:00 2001 From: visdauas Date: Wed, 14 Feb 2024 13:32:19 +0100 Subject: [PATCH 04/17] Add lapicque fullgraph support --- snntorch/_neurons/lapicque.py | 133 ++++++++++++--------------- tests/test_snntorch/test_lapicque.py | 11 +++ 2 files changed, 68 insertions(+), 76 deletions(-) diff --git a/snntorch/_neurons/lapicque.py b/snntorch/_neurons/lapicque.py index bacbfc52..cacec197 100644 --- a/snntorch/_neurons/lapicque.py +++ b/snntorch/_neurons/lapicque.py @@ -1,5 +1,5 @@ import torch -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class Lapicque(LIF): @@ -216,93 +216,77 @@ def __init__( self._lapicque_cases(time_step, beta, R, C) - if self.init_hidden: - self.mem = self.init_lapicque() + self._init_mem() - def forward(self, input_, mem=False): + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int - if hasattr(mem, "init_flag"): # only triggered on first-pass - mem = _SpikeTorchConv(mem, input_=input_) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.mem = _SpikeTorchConv(self.mem, input_=input_) + def _init_mem(self): + mem = torch.zeros(1) + self.register_buffer("mem", mem) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - mem = self._build_state_function(input_, mem) + def reset_mem(self): + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - if self.state_quant: - mem = self.state_quant(mem) + def init_lapicque(self): + """Deprecated, use :class:`Lapicque.reset_mem` instead""" + self.reset_mem() + return self.mem - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) - else: - spk = self.fire(mem) + def forward(self, input_, mem=None): - return spk, mem + if not mem == None: + self.mem = mem - # intended for truncated-BPTT where instance variables are hidden - # states - if self.init_hidden: - self._lapicque_forward_cases(mem) - self.reset = self.mem_reset(self.mem) - self.mem = self._build_state_function_hidden(input_) + if self.init_hidden and not mem == None: + raise TypeError( + "`mem` should not be passed as an argument while `init_hidden=True`" + ) - if self.state_quant: - self.mem = self.state_quant(self.mem) + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) - else: - self.spk = self.fire(self.mem) + self.reset = self.mem_reset(self.mem) + self.mem = self.state_function(input_) - if self.output: - return self.spk, self.mem - else: - return self.spk + if self.state_quant: + self.mem = self.state_quant(self.mem) - def _base_state_function(self, input_, mem): - base_fn = ( - input_ * self.R * (1 / (self.R * self.C)) * self.time_step - + (1 - (self.time_step / (self.R * self.C))) * mem - ) - return base_fn + if self.inhibition: + spk = self.fire_inhibition( + self.mem.size(0), self.mem + ) # batch_size + else: + spk = self.fire(self.mem) - def _build_state_function(self, input_, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = ( - self._base_state_function(input_, mem) - - self.reset * self.threshold - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = self._base_state_function( - input_, mem - ) - self.reset * self._base_state_function(input_, mem) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, mem) - return state_fn + if self.output: + return spk, self.mem + elif self.init_hidden: + return spk + else: + return spk, self.mem - def _base_state_function_hidden(self, input_): + def _base_state_function(self, input_): base_fn = ( input_ * self.R * (1 / (self.R * self.C)) * self.time_step + (1 - (self.time_step / (self.R * self.C))) * self.mem ) return base_fn - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = ( - self._base_state_function_hidden(input_) - - self.reset * self.threshold - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = self._base_state_function_hidden( - input_ - ) - self.reset * self._base_state_function_hidden(input_) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + return self._base_state_function(input_) - self.reset * self.threshold + + def _base_zero(self, input_): + return self._base_state_function( + input_ + ) - self.reset * self._base_state_function(input_) + + def _base_int(self, input_): + return self._base_state_function(input_) def _lapicque_cases(self, time_step, beta, R, C): if not isinstance(time_step, torch.Tensor): @@ -357,12 +341,6 @@ def _lapicque_cases(self, time_step, beta, R, C): R = self.time_step / (C * torch.log(1 / self.beta)) self.register_buffer("R", R) - def _lapicque_forward_cases(self, mem): - if mem is not False: - raise TypeError( - "When `init_hidden=True`, Lapicque expects 1 input argument." - ) - @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. @@ -381,4 +359,7 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Lapicque): - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/tests/test_snntorch/test_lapicque.py b/tests/test_snntorch/test_lapicque.py index 36345975..f76dd232 100644 --- a/tests/test_snntorch/test_lapicque.py +++ b/tests/test_snntorch/test_lapicque.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -17,6 +18,11 @@ def lapicque_instance(): return snn.Lapicque(beta=0.5) +@pytest.fixture(scope="module") +def lapicque_instance_surrogate(): + return snn.Lapicque(beta=0.5, surrogate_disable=True) + + @pytest.fixture(scope="module") def lapicque_reset_zero_instance(): return snn.Lapicque(beta=0.5, reset_mechanism="zero") @@ -128,3 +134,8 @@ def test_lapicque_init_hidden_reset_none( def test_lapicque_cases(self, lapicque_hidden_instance, input_): with pytest.raises(TypeError): lapicque_hidden_instance(input_, input_) + + def test_lapicque_compile_fullgraph(self, lapicque_instance_surrogate, input_): + explanation = dynamo.explain(lapicque_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 From 583a860c15eea33ddd1fc4a46a6835a8341ee8cb Mon Sep 17 00:00:00 2001 From: visdauas Date: Wed, 14 Feb 2024 13:32:36 +0100 Subject: [PATCH 05/17] Add synaptic fullgraph support --- snntorch/_neurons/synaptic.py | 215 ++++++++++++--------------- tests/test_snntorch/test_synaptic.py | 10 ++ 2 files changed, 107 insertions(+), 118 deletions(-) diff --git a/snntorch/_neurons/synaptic.py b/snntorch/_neurons/synaptic.py index 4e3be032..2bf3ba07 100644 --- a/snntorch/_neurons/synaptic.py +++ b/snntorch/_neurons/synaptic.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class Synaptic(LIF): @@ -186,132 +186,111 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int + self.reset_delay = reset_delay if not reset_delay and self.init_hidden: - raise NotImplementedError('no reset_delay only supported for init_hidden=False') - - if self.init_hidden: - self.syn, self.mem = self.init_synaptic() - - def forward(self, input_, syn=False, mem=False): - - if hasattr(syn, "init_flag") or hasattr( - mem, "init_flag" - ): # only triggered on first-pass - syn, mem = _SpikeTorchConv(syn, mem, input_=input_) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.syn, self.mem = _SpikeTorchConv( - self.syn, self.mem, input_=input_ + raise NotImplementedError( + "no reset_delay only supported for init_hidden=False" ) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn, mem = self._build_state_function(input_, syn, mem) - - if self.state_quant: - syn = self.state_quant(syn) - mem = self.state_quant(mem) - - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) - else: - spk = self.fire(mem) - - if not self.reset_delay: - # reset membrane potential _right_ after spike - do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset - if self.reset_mechanism_val == 0: # reset by subtraction - mem = mem - do_reset * self.threshold - elif self.reset_mechanism_val == 1: # reset to zero - mem = mem - do_reset * mem - - return spk, syn, mem - - # intended for truncated-BPTT where instance variables are - # hidden states - if self.init_hidden: - self._synaptic_forward_cases(mem, syn) - self.reset = self.mem_reset(self.mem) - self.syn, self.mem = self._build_state_function_hidden(input_) - - if self.state_quant: - self.syn = self.state_quant(self.syn) - self.mem = self.state_quant(self.mem) - - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) - else: - self.spk = self.fire(self.mem) - - if self.output: - return self.spk, self.syn, self.mem - else: - return self.spk - - def _base_state_function(self, input_, syn, mem): - base_fn_syn = self.alpha.clamp(0, 1) * syn + input_ - base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn - return base_fn_syn, base_fn_mem + def _init_mem(self): + syn = torch.zeros(1) + mem = torch.zeros(1) + self.register_buffer("syn", syn) + self.register_buffer("mem", mem) - def _base_state_reset_zero(self, input_, syn, mem): - base_fn_syn = self.alpha.clamp(0, 1) * syn + input_ - base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn - return 0, base_fn_mem + def reset_mem(self): + self.syn = torch.zeros_like(self.syn, device=self.syn.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - def _build_state_function(self, input_, syn, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function(input_, syn, mem), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn, mem), - self._base_state_reset_zero(input_, syn, mem), - ) + def init_synaptic(self): + """Deprecated, use :class:`Synaptic.reset_mem` instead""" + self.reset_mem() + return self.syn, self.mem + + def forward(self, input_, syn=None, mem=None): + + if not syn == None: + self.syn = mem + + if not mem == None: + self.mem = mem + + if self.init_hidden and (not mem == None or not syn == None): + raise TypeError( + "`mem` or `syn` should not be passed as an argument while `init_hidden=True`" ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, syn, mem) - return state_fn - def _base_state_function_hidden(self, input_): + if not self.syn.shape == input_.shape: + self.syn = torch.zeros_like(input_, device=self.syn.device) + + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) + + self.reset = self.mem_reset(self.mem) + self.syn, self.mem = self.state_function(input_) + + if self.state_quant: + self.mem = self.state_quant(self.mem) + self.syn = self.state_quant(self.syn) + + if self.inhibition: + spk = self.fire_inhibition( + self.mem.size(0), self.mem + ) # batch_size + else: + spk = self.fire(self.mem) + + if not self.reset_delay: + # reset membrane potential _right_ after spike + do_reset = ( + spk / self.graded_spikes_factor - self.reset + ) # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem = mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + mem = mem - do_reset * mem + + if self.output: + return spk, self.syn, self.mem + elif self.init_hidden: + return spk + else: + return spk, self.syn, self.mem + + def _base_state_function(self, input_): base_fn_syn = self.alpha.clamp(0, 1) * self.syn + input_ base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn return base_fn_syn, base_fn_mem - def _base_state_reset_zero_hidden(self, input_): + def _base_state_reset_zero(self, input_): base_fn_syn = self.alpha.clamp(0, 1) * self.syn + input_ base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn return 0, base_fn_mem - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function_hidden(input_), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_reset_zero_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + syn, mem = self._base_state_function(input_) + mem = mem - self.reset * self.threshold + return syn, mem + + def _base_zero(self, input_): + syn, mem = self._base_state_function(input_) + syn2, mem2 = self._base_state_reset_zero(input_) + syn -= syn2 * self.reset + mem -= mem2 * self.reset + return syn, mem + + def _base_int(self, input_): + return self._base_state_function(input_) def _alpha_register_buffer(self, alpha, learn_alpha): if not isinstance(alpha, torch.Tensor): @@ -321,12 +300,6 @@ def _alpha_register_buffer(self, alpha, learn_alpha): else: self.register_buffer("alpha", alpha) - def _synaptic_forward_cases(self, mem, syn): - if mem is not False or syn is not False: - raise TypeError( - "When `init_hidden=True`, Synaptic expects 1 input argument." - ) - @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. @@ -346,5 +319,11 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Synaptic): - cls.instances[layer].syn = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn = torch.zeros_like( + cls.instances[layer].syn, + device=cls.instances[layer].syn.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/tests/test_snntorch/test_synaptic.py b/tests/test_snntorch/test_synaptic.py index 6ece23cc..262f5ca8 100644 --- a/tests/test_snntorch/test_synaptic.py +++ b/tests/test_snntorch/test_synaptic.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -16,6 +17,10 @@ def input_(): def synaptic_instance(): return snn.Synaptic(alpha=0.5, beta=0.5) +@pytest.fixture(scope="module") +def synaptic_instance_surrogate(): + return snn.Synaptic(alpha=0.5, beta=0.5, surrogate_disable=True) + @pytest.fixture(scope="module") def synaptic_reset_zero_instance(): @@ -123,3 +128,8 @@ def test_synaptic_init_hidden_reset_none( def test_synaptic_cases(self, synaptic_hidden_instance, input_): with pytest.raises(TypeError): synaptic_hidden_instance(input_, input_) + + def test_synaptic_compile_fullgraph(self, synaptic_instance_surrogate, input_): + explanation = dynamo.explain(synaptic_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 \ No newline at end of file From d2133d2ea0a6bf5ab74784ae5d24831e88b1f56e Mon Sep 17 00:00:00 2001 From: visdauas Date: Wed, 14 Feb 2024 13:33:10 +0100 Subject: [PATCH 06/17] Remove old init functions --- snntorch/_neurons/neurons.py | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/snntorch/_neurons/neurons.py b/snntorch/_neurons/neurons.py index ee58079c..1ebeeac3 100644 --- a/snntorch/_neurons/neurons.py +++ b/snntorch/_neurons/neurons.py @@ -311,18 +311,6 @@ def init_rleaky(): return spk, mem - @staticmethod - def init_synaptic(): - """Used to initialize syn and mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - - syn = _SpikeTensor(init_flag=False) - mem = _SpikeTensor(init_flag=False) - - return syn, mem - @staticmethod def init_rsynaptic(): """ @@ -336,30 +324,6 @@ def init_rsynaptic(): return spk, syn, mem - @staticmethod - def init_lapicque(): - """ - Used to initialize mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - - mem = _SpikeTensor(init_flag=False) - - return mem - - @staticmethod - def init_alpha(): - """Used to initialize syn_exc, syn_inh and mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - syn_exc = _SpikeTensor(init_flag=False) - syn_inh = _SpikeTensor(init_flag=False) - mem = _SpikeTensor(init_flag=False) - - return syn_exc, syn_inh, mem - class _SpikeTensor(torch.Tensor): """Inherits from torch.Tensor with additional attributes. From d7d8bb11c73bc38d5bd89747bbf2239454fa21ef Mon Sep 17 00:00:00 2001 From: visdauas Date: Fri, 16 Feb 2024 11:11:55 +0100 Subject: [PATCH 07/17] Add RLeaky fullgraph support --- snntorch/_neurons/neurons.py | 12 --- snntorch/_neurons/rleaky.py | 166 +++++++++++++---------------- tests/test_snntorch/test_rleaky.py | 15 +++ 3 files changed, 92 insertions(+), 101 deletions(-) diff --git a/snntorch/_neurons/neurons.py b/snntorch/_neurons/neurons.py index 1ebeeac3..894ca36f 100644 --- a/snntorch/_neurons/neurons.py +++ b/snntorch/_neurons/neurons.py @@ -299,18 +299,6 @@ def _V_register_buffer(self, V, learn_V): else: self.register_buffer("V", V) - @staticmethod - def init_rleaky(): - """ - Used to initialize spk and mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - spk = _SpikeTensor(init_flag=False) - mem = _SpikeTensor(init_flag=False) - - return spk, mem - @staticmethod def init_rsynaptic(): """ diff --git a/snntorch/_neurons/rleaky.py b/snntorch/_neurons/rleaky.py index 48383ada..7cd188ec 100644 --- a/snntorch/_neurons/rleaky.py +++ b/snntorch/_neurons/rleaky.py @@ -2,7 +2,7 @@ import torch.nn as nn # from torch import functional as F -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class RLeaky(LIF): @@ -280,72 +280,86 @@ def __init__( if not learn_recurrent: self._disable_recurrent_grad() + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int + self.reset_delay = reset_delay if not self.reset_delay and self.init_hidden: - raise NotImplementedError('no reset_delay only supported for init_hidden=False') - - if self.init_hidden: - self.spk, self.mem = self.init_rleaky() - # self.state_fn = self._build_state_function_hidden - # else: - # self.state_fn = self._build_state_function - - def forward(self, input_, spk=False, mem=False): - if hasattr(spk, "init_flag") or hasattr( - mem, "init_flag" - ): # only triggered on first-pass - spk, mem = _SpikeTorchConv(spk, mem, input_=input_) - # init_hidden case - elif mem is False and hasattr(self.mem, "init_flag"): - self.spk, self.mem = _SpikeTorchConv( - self.spk, self.mem, input_=input_ + raise NotImplementedError( + "no reset_delay only supported for init_hidden=False" ) - # TO-DO: alternatively, we could do torch.exp(-1 / - # self.beta.clamp_min(0)), giving actual time constants instead of - # values in [0, 1] as initial beta beta = self.beta.clamp(0, 1) + def _init_mem(self): + spk = torch.zeros(1) + mem = torch.zeros(1) + self.register_buffer("spk", spk) + self.register_buffer("mem", mem) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - mem = self._build_state_function(input_, spk, mem) + def reset_mem(self): + self.spk = torch.zeros_like(self.spk, device=self.spk.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - if self.state_quant: - mem = self.state_quant(mem) + def init_rleaky(self): + """Deprecated, use :class:`RLeaky.reset_mem` instead""" + self.reset_mem() + return self.spk, self.mem - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) # batch_size - else: - spk = self.fire(mem) + def forward(self, input_, spk=None, mem=None): - if not self.reset_delay: - do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset - if self.reset_mechanism_val == 0: # reset by subtraction - mem = mem - do_reset * self.threshold - elif self.reset_mechanism_val == 1: # reset to zero - mem = mem - do_reset * mem + if not spk == None: + self.spk = spk - return spk, mem + if not mem == None: + self.mem = mem - # intended for truncated-BPTT where instance variables are hidden - # states - if self.init_hidden: - self._rleaky_forward_cases(spk, mem) - self.reset = self.mem_reset(self.mem) - self.mem = self._build_state_function_hidden(input_) + if self.init_hidden and (not mem == None or not spk == None): + raise TypeError( + "When `init_hidden=True`," "RLeaky expects 1 input argument." + ) + + if not self.spk.shape == input_.shape: + self.spk = torch.zeros_like(input_, device=self.spk.device) + + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) + + # TO-DO: alternatively, we could do torch.exp(-1 / + # self.beta.clamp_min(0)), giving actual time constants instead of + # values in [0, 1] as initial beta beta = self.beta.clamp(0, 1) - if self.state_quant: - self.mem = self.state_quant(self.mem) + self.reset = self.mem_reset(self.mem) + self.mem = self.state_function(input_) - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) - else: - self.spk = self.fire(self.mem) + if self.state_quant: + self.mem = self.state_quant(self.mem) - if self.output: # read-out layer returns output+states - return self.spk, self.mem - else: # hidden layer e.g., in nn.Sequential, only returns output - return self.spk + if self.inhibition: + self.spk = self.fire_inhibition(self.mem.size(0), self.mem) + else: + self.spk = self.fire(self.mem) + + if not self.reset_delay: + do_reset = ( + self.spk / self.graded_spikes_factor - self.reset + ) # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + self.mem = self.mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + self.mem = self.mem - do_reset * self.mem + + if self.output: + return self.spk, self.mem + elif self.init_hidden: + return self.spk + else: + return self.spk, self.mem def _init_recurrent_net(self): if self.all_to_all: @@ -381,24 +395,7 @@ def _disable_recurrent_grad(self): for param in self.recurrent.parameters(): param.requires_grad = False - def _base_state_function(self, input_, spk, mem): - base_fn = self.beta.clamp(0, 1) * mem + input_ + self.recurrent(spk) - return base_fn - - def _build_state_function(self, input_, spk, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = self._base_state_function( - input_, spk, mem - self.reset * self.threshold - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = self._base_state_function( - input_, spk, mem - ) - self.reset * self._base_state_function(input_, spk, mem) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, spk, mem) - return state_fn - - def _base_state_function_hidden(self, input_): + def _base_state_function(self, input_): base_fn = ( self.beta.clamp(0, 1) * self.mem + input_ @@ -406,25 +403,16 @@ def _base_state_function_hidden(self, input_): ) return base_fn - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = ( - self._base_state_function_hidden(input_) - - self.reset * self.threshold - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = self._base_state_function_hidden( - input_ - ) - self.reset * self._base_state_function_hidden(input_) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + return self._base_state_function(input_) - self.reset * self.threshold - def _rleaky_forward_cases(self, spk, mem): - if mem is not False or spk is not False: - raise TypeError( - "When `init_hidden=True`," "RLeaky expects 1 input argument." - ) + def _base_zero(self, input_): + return self._base_state_function( + input_ + ) - self.reset * self._base_state_function(input_) + + def _base_int(self, input_): + return self._base_state_function(input_) def _rleaky_init_cases(self): all_to_all_bool = bool(self.all_to_all) diff --git a/tests/test_snntorch/test_rleaky.py b/tests/test_snntorch/test_rleaky.py index 5e336372..617488d4 100644 --- a/tests/test_snntorch/test_rleaky.py +++ b/tests/test_snntorch/test_rleaky.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -17,6 +18,13 @@ def rleaky_instance(): return snn.RLeaky(beta=0.5, V=0.5, all_to_all=False) +@pytest.fixture(scope="module") +def rleaky_instance_surrogate(): + return snn.RLeaky( + beta=0.5, V=0.5, all_to_all=False, surrogate_disable=True + ) + + @pytest.fixture(scope="module") def rleaky_reset_zero_instance(): return snn.RLeaky( @@ -133,3 +141,10 @@ def test_rleaky_init_hidden_reset_none( def test_lreaky_cases(self, rleaky_hidden_instance, input_): with pytest.raises(TypeError): rleaky_hidden_instance(input_, input_, input_) + + def test_rleaky_compile_fullgraph( + self, rleaky_instance_surrogate, input_ + ): + explanation = dynamo.explain(rleaky_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 From 4db294e51007021ae8b282c73b735c237d8c8422 Mon Sep 17 00:00:00 2001 From: visdauas Date: Fri, 16 Feb 2024 18:57:07 +0100 Subject: [PATCH 08/17] Add RSynaptic fullgraph support --- snntorch/_neurons/rsynaptic.py | 237 ++++++++++++-------------- tests/test_snntorch/test_rsynaptic.py | 15 ++ 2 files changed, 125 insertions(+), 127 deletions(-) diff --git a/snntorch/_neurons/rsynaptic.py b/snntorch/_neurons/rsynaptic.py index 654fd3d2..141bf805 100644 --- a/snntorch/_neurons/rsynaptic.py +++ b/snntorch/_neurons/rsynaptic.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class RSynaptic(LIF): @@ -295,72 +295,94 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int + self.reset_delay = reset_delay if not reset_delay and self.init_hidden: - raise NotImplementedError('no reset_delay only supported for init_hidden=False') - - if self.init_hidden: - self.spk, self.syn, self.mem = self.init_rsynaptic() - - def forward(self, input_, spk=False, syn=False, mem=False): - if ( - hasattr(spk, "init_flag") - or hasattr(syn, "init_flag") - or hasattr(mem, "init_flag") - ): # only triggered on first-pass - spk, syn, mem = _SpikeTorchConv(spk, syn, mem, input_=input_) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.spk, self.syn, self.mem = _SpikeTorchConv( - self.spk, self.syn, self.mem, input_=input_ + raise NotImplementedError( + "no reset_delay only supported for init_hidden=False" ) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn, mem = self._build_state_function(input_, spk, syn, mem) - - if self.state_quant: - syn = self.state_quant(syn) - mem = self.state_quant(mem) - - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) - else: - spk = self.fire(mem) - - if not self.reset_delay: - # reset membrane potential _right_ after spike - do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset - if self.reset_mechanism_val == 0: # reset by subtraction - mem = mem - do_reset * self.threshold - elif self.reset_mechanism_val == 1: # reset to zero - # mem -= do_reset * mem - mem = mem - do_reset * mem - - return spk, syn, mem - - # intended for truncated-BPTT where instance variables are hidden - # states - if self.init_hidden: - self._rsynaptic_forward_cases(spk, mem, syn) - self.reset = self.mem_reset(self.mem) - self.syn, self.mem = self._build_state_function_hidden(input_) - - if self.state_quant: - self.syn = self.state_quant(self.syn) - self.mem = self.state_quant(self.mem) - - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) - else: - self.spk = self.fire(self.mem) - - if self.output: - return self.spk, self.syn, self.mem - else: - return self.spk + def _init_mem(self): + spk = torch.zeros(1) + syn = torch.zeros(1) + mem = torch.zeros(1) + + self.register_buffer("spk", spk) + self.register_buffer("syn", syn) + self.register_buffer("mem", mem) + + def reset_mem(self): + self.spk = torch.zeros_like(self.spk, device=self.spk.device) + self.syn = torch.zeros_like(self.syn, device=self.syn.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) + + def init_rsynaptic(self): + """Deprecated, use :class:`RSynaptic.reset_mem` instead""" + self.reset_mem() + return self.spk, self.syn, self.mem + + def forward(self, input_, spk=None, syn=None, mem=None): + if not spk == None: + self.spk = spk + + if not syn == None: + self.syn = syn + + if not mem == None: + self.mem = mem + + if self.init_hidden and (not spk == None or not syn == None or not mem == None): + raise TypeError( + "When `init_hidden=True`, RSynaptic expects 1 input argument." + ) + + if not self.spk.shape == input_.shape: + self.spk = torch.zeros_like(input_, device=self.spk.device) + + if not self.syn.shape == input_.shape: + self.syn = torch.zeros_like(input_, device=self.syn.device) + + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) + + self.reset = self.mem_reset(self.mem) + self.syn, self.mem = self.state_function(input_) + + if self.state_quant: + self.syn = self.state_quant(self.syn) + self.mem = self.state_quant(self.mem) + + if self.inhibition: + self.spk = self.fire_inhibition(self.mem.size(0), self.mem) + else: + self.spk = self.fire(self.mem) + + if not self.reset_delay: + # reset membrane potential _right_ after spike + do_reset = ( + spk / self.graded_spikes_factor - self.reset + ) # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem = mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + # mem -= do_reset * mem + mem = mem - do_reset * mem + + if self.output: + return self.spk, self.syn, self.mem + elif self.init_hidden: + return self.spk + else: + return self.spk, self.syn, self.mem def _init_recurrent_net(self): if self.all_to_all: @@ -396,42 +418,7 @@ def _disable_recurrent_grad(self): for param in self.recurrent.parameters(): param.requires_grad = False - def _base_state_function(self, input_, spk, syn, mem): - base_fn_syn = ( - self.alpha.clamp(0, 1) * syn + input_ + self.recurrent(spk) - ) - base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn - return base_fn_syn, base_fn_mem - - def _base_state_reset_zero(self, input_, spk, syn, mem): - base_fn_syn = ( - self.alpha.clamp(0, 1) * syn + input_ + self.recurrent(spk) - ) - base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn - return 0, base_fn_mem - - def _build_state_function(self, input_, spk, syn, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function(input_, spk, syn, mem), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, spk, syn, mem), - self._base_state_reset_zero(input_, spk, syn, mem), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, spk, syn, mem) - return state_fn - - def _base_state_function_hidden(self, input_): + def _base_state_function(self, input_): base_fn_syn = ( self.alpha.clamp(0, 1) * self.syn + input_ @@ -440,7 +427,7 @@ def _base_state_function_hidden(self, input_): base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn return base_fn_syn, base_fn_mem - def _base_state_reset_zero_hidden(self, input_): + def _base_state_reset_zero(self, input_): base_fn_syn = ( self.alpha.clamp(0, 1) * self.syn + input_ @@ -449,26 +436,22 @@ def _base_state_reset_zero_hidden(self, input_): base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn return 0, base_fn_mem - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function_hidden(input_), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_function_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + syn, mem = self._base_state_function(input_) + mem -= self.reset * self.threshold + return syn, mem + + def _base_zero(self, input_): + syn, mem = self._base_state_function(input_) + syn2, mem2 = self._base_state_reset_zero(input_) + syn2 *= self.reset + mem2 *= self.reset + syn -= syn2 + mem -= mem2 + return syn, mem + + def _base_int(self, input_): + return self._base_state_function(input_) def _alpha_register_buffer(self, alpha, learn_alpha): if not isinstance(alpha, torch.Tensor): @@ -478,12 +461,6 @@ def _alpha_register_buffer(self, alpha, learn_alpha): else: self.register_buffer("alpha", alpha) - def _rsynaptic_forward_cases(self, spk, mem, syn): - if mem is not False or syn is not False or spk is not False: - raise TypeError( - "When `init_hidden=True`, RSynaptic expects 1 input argument." - ) - def _rsynaptic_init_cases(self): all_to_all_bool = bool(self.all_to_all) linear_features_bool = self.linear_features @@ -545,8 +522,14 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], RSynaptic): - cls.instances[layer].syn = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn = torch.zeros_like( + cls.instances[layer].syn, + device=cls.instances[layer].syn.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) class RecurrentOneToOne(nn.Module): diff --git a/tests/test_snntorch/test_rsynaptic.py b/tests/test_snntorch/test_rsynaptic.py index 54e981d6..9bd386b6 100644 --- a/tests/test_snntorch/test_rsynaptic.py +++ b/tests/test_snntorch/test_rsynaptic.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -22,6 +23,13 @@ def rsynaptic_instance(): ) +@pytest.fixture(scope="module") +def rsynaptic_instance_surrogate(): + return snn.RSynaptic( + alpha=0.5, beta=0.5, V=0.5, all_to_all=False, surrogate_disable=True + ) + + @pytest.fixture(scope="module") def rsynaptic_reset_zero_instance(): return snn.RSynaptic( @@ -144,3 +152,10 @@ def test_rsynaptic_init_hidden_reset_none( def test_rsynaptic_cases(self, rsynaptic_hidden_instance, input_): with pytest.raises(TypeError): rsynaptic_hidden_instance(input_, input_) + + def test_rsynaptic_compile_fullgraph( + self, rsynaptic_instance_surrogate, input_ + ): + explanation = dynamo.explain(rsynaptic_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 From 0110f9fc6bcf4896e9f498897dcb18c7f3797748 Mon Sep 17 00:00:00 2001 From: visdauas Date: Fri, 16 Feb 2024 19:01:23 +0100 Subject: [PATCH 09/17] Enable reset_delay for init_hidden=True --- snntorch/_neurons/leaky.py | 5 ----- snntorch/_neurons/rleaky.py | 5 ----- snntorch/_neurons/rsynaptic.py | 5 ----- snntorch/_neurons/synaptic.py | 5 ----- 4 files changed, 20 deletions(-) diff --git a/snntorch/_neurons/leaky.py b/snntorch/_neurons/leaky.py index 41166462..75fe4fd3 100644 --- a/snntorch/_neurons/leaky.py +++ b/snntorch/_neurons/leaky.py @@ -181,11 +181,6 @@ def __init__( self.reset_delay = reset_delay - if not self.reset_delay and self.init_hidden: - raise NotImplementedError( - "`reset_delay=True` is only supported for `init_hidden=False`" - ) - def _init_mem(self): mem = torch.zeros(1) self.register_buffer("mem", mem) diff --git a/snntorch/_neurons/rleaky.py b/snntorch/_neurons/rleaky.py index 7cd188ec..66046bc0 100644 --- a/snntorch/_neurons/rleaky.py +++ b/snntorch/_neurons/rleaky.py @@ -291,11 +291,6 @@ def __init__( self.reset_delay = reset_delay - if not self.reset_delay and self.init_hidden: - raise NotImplementedError( - "no reset_delay only supported for init_hidden=False" - ) - def _init_mem(self): spk = torch.zeros(1) mem = torch.zeros(1) diff --git a/snntorch/_neurons/rsynaptic.py b/snntorch/_neurons/rsynaptic.py index 141bf805..0020a3ff 100644 --- a/snntorch/_neurons/rsynaptic.py +++ b/snntorch/_neurons/rsynaptic.py @@ -306,11 +306,6 @@ def __init__( self.reset_delay = reset_delay - if not reset_delay and self.init_hidden: - raise NotImplementedError( - "no reset_delay only supported for init_hidden=False" - ) - def _init_mem(self): spk = torch.zeros(1) syn = torch.zeros(1) diff --git a/snntorch/_neurons/synaptic.py b/snntorch/_neurons/synaptic.py index 2bf3ba07..6209ca32 100644 --- a/snntorch/_neurons/synaptic.py +++ b/snntorch/_neurons/synaptic.py @@ -197,11 +197,6 @@ def __init__( self.reset_delay = reset_delay - if not reset_delay and self.init_hidden: - raise NotImplementedError( - "no reset_delay only supported for init_hidden=False" - ) - def _init_mem(self): syn = torch.zeros(1) mem = torch.zeros(1) From 2dc9ba56439382363e254a6c91bd0aab31fd6907 Mon Sep 17 00:00:00 2001 From: visdauas Date: Thu, 22 Feb 2024 11:08:16 +0100 Subject: [PATCH 10/17] Add slstm fullgraph support --- snntorch/_neurons/slstm.py | 181 +++++++++++++----------------- tests/test_snntorch/test_slstm.py | 13 +++ 2 files changed, 88 insertions(+), 106 deletions(-) diff --git a/snntorch/_neurons/slstm.py b/snntorch/_neurons/slstm.py index dd29ba6b..dceff42f 100644 --- a/snntorch/_neurons/slstm.py +++ b/snntorch/_neurons/slstm.py @@ -1,12 +1,9 @@ import torch -from torch._C import Value import torch.nn as nn -import torch.nn.functional as F -from .neurons import _SpikeTensor, _SpikeTorchConv, SpikingNeuron +from .neurons import SpikingNeuron class SLSTM(SpikingNeuron): - """ A spiking long short-term memory cell. Hidden states are membrane potential and synaptic current @@ -188,8 +185,14 @@ def __init__( output, ) - if self.init_hidden: - self.syn, self.mem = self.init_slstm() + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int self.input_size = input_size self.hidden_size = hidden_size @@ -199,122 +202,82 @@ def __init__( self.input_size, self.hidden_size, bias=self.bias ) - def forward(self, input_, syn=False, mem=False): - if hasattr(mem, "init_flag") or hasattr( - syn, "init_flag" - ): # only triggered on first-pass + def _init_mem(self): + syn = torch.zeros(1) + mem = torch.zeros(1) + self.register_buffer("syn", syn) + self.register_buffer("mem", mem) - syn, mem = _SpikeTorchConv( - syn, mem, input_=self._reshape_input(input_) - ) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.syn, self.mem = _SpikeTorchConv( - self.syn, self.mem, input_=self._reshape_input(input_) - ) + def reset_mem(self): + self.syn = torch.zeros_like(self.syn, device=self.syn.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn, mem = self._build_state_function(input_, syn, mem) + def init_slstm(self): + """Deprecated, use :class:`SLSTM.reset_mem` instead""" + self.reset_mem() + return self.syn, self.mem - if self.state_quant: - syn = self.state_quant(syn) - mem = self.state_quant(mem) + def forward(self, input_, syn=None, mem=None): + if not syn == None: + self.syn = syn - spk = self.fire(mem) - return spk, syn, mem + if not mem == None: + self.mem = mem - if self.init_hidden: - # self._slstm_forward_cases(mem, syn) - self.reset = self.mem_reset(self.mem) - self.syn, self.mem = self._build_state_function_hidden(input_) + if self.init_hidden and (not mem == None or not syn == None): + raise TypeError( + "`mem` or `syn` should not be passed as an argument while `init_hidden=True`" + ) - if self.state_quant: - self.syn = self.state_quant(self.syn) - self.mem = self.state_quant(self.mem) + size = input_.size() + correct_shape = (size[0], self.hidden_size) - self.spk = self.fire(self.mem) + if not self.syn.shape == input_.shape: + self.syn = torch.zeros(correct_shape, device=self.syn.device) - if self.output: - return self.spk, self.syn, self.mem - else: - return self.spk + if not self.mem.shape == input_.shape: + self.mem = torch.zeros(correct_shape, device=self.mem.device) - def _base_state_function(self, input_, syn, mem): - base_fn_mem, base_fn_syn = self.lstm_cell(input_, (mem, syn)) - return base_fn_syn, base_fn_mem + self.reset = self.mem_reset(self.mem) + self.syn, self.mem = self.state_function(input_) - def _base_state_reset_zero(self, input_, syn, mem): - base_fn_mem, _ = self.lstm_cell(input_, (mem, syn)) - return 0, base_fn_mem + if self.state_quant: + self.syn = self.state_quant(self.syn) + self.mem = self.state_quant(self.mem) - def _build_state_function(self, input_, syn, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function(input_, syn, mem), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn, mem), - self._base_state_reset_zero(input_, syn, mem), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, syn, mem) - return state_fn + self.spk = self.fire(self.mem) - def _base_state_function_hidden(self, input_): + if self.output: + return self.spk, self.syn, self.mem + elif self.init_hidden: + return self.spk + else: + return self.spk, self.syn, self.mem + + def _base_state_function(self, input_): base_fn_mem, base_fn_syn = self.lstm_cell(input_, (self.mem, self.syn)) return base_fn_syn, base_fn_mem - def _base_state_reset_zero_hidden(self, input_): + def _base_state_reset_zero(self, input_): base_fn_mem, _ = self.lstm_cell(input_, (self.mem, self.syn)) return 0, base_fn_mem - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function_hidden(input_), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_reset_zero_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn - - def _reshape_input(self, input_): - device = input_.device - b, _ = input_.size() - return torch.zeros(b, self.hidden_size).to(device) - - @staticmethod - def init_slstm(): - """ - Used to initialize mem and syn as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - mem = _SpikeTensor(init_flag=False) - syn = _SpikeTensor(init_flag=False) - - return mem, syn + def _base_sub(self, input_): + syn, mem = self._base_state_function(input_) + mem -= self.reset * self.threshold + return syn, mem + + def _base_zero(self, input_): + syn, mem = self._base_state_function(input_) + syn2, mem2 = self._base_state_reset_zero(input_) + syn2 *= self.reset + mem2 *= self.reset + syn -= syn2 + mem -= mem2 + return syn, mem + + def _base_int(self, input_): + return self._base_state_function(input_) @classmethod def detach_hidden(cls): @@ -335,5 +298,11 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], SLSTM): - cls.instances[layer].syn = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn = torch.zeros_like( + cls.instances[layer].syn, + device=cls.instances[layer].syn.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/tests/test_snntorch/test_slstm.py b/tests/test_snntorch/test_slstm.py index 32007172..7f31f096 100644 --- a/tests/test_snntorch/test_slstm.py +++ b/tests/test_snntorch/test_slstm.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo # TO-DO: add avg/max-pooling tests @@ -20,6 +21,11 @@ def slstm_instance(): return snn.SLSTM(1, 2) +@pytest.fixture(scope="module") +def slstm_instance_surrogate(): + return snn.SLSTM(1, 2, surrogate_disable=True) + + @pytest.fixture(scope="module") def slstm_reset_zero_instance(): return snn.SLSTM(1, 2, reset_mechanism="zero") @@ -120,3 +126,10 @@ def test_sconv2dlstm_init_hidden_reset_subtract( spk_rec.append(spk) assert spk_rec[0].size() == (1, 2) + + def test_slstm_compile_fullgraph( + self, slstm_instance_surrogate, input_ + ): + explanation = dynamo.explain(slstm_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 From 90ad720959bdf5ddb486aaeb7046c32231092218 Mon Sep 17 00:00:00 2001 From: visdauas Date: Thu, 22 Feb 2024 11:08:35 +0100 Subject: [PATCH 11/17] Add conv2dlstm fullgraph support --- snntorch/_neurons/sconv2dlstm.py | 233 +++++++++--------------- tests/test_snntorch/test_sconv2dlstm.py | 13 ++ 2 files changed, 97 insertions(+), 149 deletions(-) diff --git a/snntorch/_neurons/sconv2dlstm.py b/snntorch/_neurons/sconv2dlstm.py index 73d91a7d..a38e4b03 100644 --- a/snntorch/_neurons/sconv2dlstm.py +++ b/snntorch/_neurons/sconv2dlstm.py @@ -1,12 +1,10 @@ import torch -from torch._C import Value import torch.nn as nn import torch.nn.functional as F -from .neurons import _SpikeTensor, _SpikeTorchConv, SpikingNeuron +from .neurons import SpikingNeuron class SConv2dLSTM(SpikingNeuron): - """ A spiking 2d convolutional long short-term memory cell. Hidden states are membrane potential and synaptic current @@ -240,8 +238,14 @@ def __init__( output, ) - if self.init_hidden: - self.syn, self.mem = self.init_sconv2dlstm() + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int self.in_channels = in_channels self.out_channels = out_channels @@ -268,117 +272,63 @@ def __init__( bias=self.bias, ) - def forward(self, input_, syn=False, mem=False): - if hasattr(mem, "init_flag") or hasattr( - syn, "init_flag" - ): # only triggered on first-pass - - syn, mem = _SpikeTorchConv( - syn, mem, input_=self._reshape_input(input_) - ) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.syn, self.mem = _SpikeTorchConv( - self.syn, self.mem, input_=self._reshape_input(input_) - ) + def _init_mem(self): + syn = torch.zeros(1) + mem = torch.zeros(1) + self.register_buffer("syn", syn) + self.register_buffer("mem", mem) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn, mem = self._build_state_function(input_, syn, mem) - - if self.state_quant: - syn = self.state_quant(syn) - mem = self.state_quant(mem) - - if self.max_pool: - spk = self.fire(F.max_pool2d(mem, self.max_pool)) - elif self.avg_pool: - spk = self.fire(F.avg_pool2d(mem, self.avg_pool)) - else: - spk = self.fire(mem) - return spk, syn, mem - - if self.init_hidden: - # self._sconv2dlstm_forward_cases(mem, c) - self.reset = self.mem_reset(self.mem) - self.syn, self.mem = self._build_state_function_hidden(input_) - - if self.state_quant: - self.syn = self.state_quant(self.syn) - self.mem = self.state_quant(self.mem) - - if self.max_pool: - self.spk = self.fire(F.max_pool2d(self.mem, self.max_pool)) - elif self.avg_pool: - self.spk = self.fire(F.avg_pool2d(self.mem, self.avg_pool)) - else: - self.spk = self.fire(self.mem) - - if self.output: - return self.spk, self.syn, self.mem - else: - return self.spk - - def _base_state_function(self, input_, syn, mem): + def reset_mem(self): + self.syn = torch.zeros_like(self.syn, device=self.syn.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - combined = torch.cat( - [input_, mem], dim=1 - ) # concatenate along channel axis (BxCxHxW) - combined_conv = self.conv(combined) - cc_i, cc_f, cc_o, cc_g = torch.split( - combined_conv, self.out_channels, dim=1 - ) - i = torch.sigmoid(cc_i) - f = torch.sigmoid(cc_f) - o = torch.sigmoid(cc_o) - g = torch.tanh(cc_g) - - base_fn_syn = f * syn + i * g - base_fn_mem = o * torch.tanh(base_fn_syn) - - return base_fn_syn, base_fn_mem - - def _base_state_reset_zero(self, input_, syn, mem): - combined = torch.cat( - [input_, mem], dim=1 - ) # concatenate along channel axis - combined_conv = self.conv(combined) - cc_i, cc_f, cc_o, cc_g = torch.split( - combined_conv, self.out_channels, dim=1 - ) - i = torch.sigmoid(cc_i) - f = torch.sigmoid(cc_f) - o = torch.sigmoid(cc_o) - g = torch.tanh(cc_g) + def init_sconv2dlstm(self): + """Deprecated, use :class:`SConv2dLSTM.reset_mem` instead""" + self.reset_mem() + return self.syn, self.mem - base_fn_syn = f * syn + i * g - base_fn_mem = o * torch.tanh(base_fn_syn) + def forward(self, input_, syn=None, mem=None): + if not syn == None: + self.syn = syn - return 0, base_fn_mem + if not mem == None: + self.mem = mem - def _build_state_function(self, input_, syn, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function(input_, syn, mem), - (0, self.reset * self.threshold), - ) + if self.init_hidden and (not mem == None or not syn == None): + raise TypeError( + "`mem` or `syn` should not be passed as an argument while `init_hidden=True`" ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn, mem), - self._base_state_reset_zero(input_, syn, mem), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, syn, mem) - return state_fn + + size = input_.size() + correct_shape = (size[0], self.out_channels, size[2], size[3]) + if not self.syn.shape == correct_shape: + self.syn = torch.zeros(correct_shape, device=self.syn.device) + + if not self.mem.shape == correct_shape: + self.mem = torch.zeros(correct_shape, device=self.mem.device) + + self.reset = self.mem_reset(self.mem) + self.syn, self.mem = self.state_function(input_) + + if self.state_quant: + self.syn = self.state_quant(self.syn) + self.mem = self.state_quant(self.mem) + + if self.max_pool: + self.spk = self.fire(F.max_pool2d(self.mem, self.max_pool)) + elif self.avg_pool: + self.spk = self.fire(F.avg_pool2d(self.mem, self.avg_pool)) + else: + self.spk = self.fire(self.mem) - def _base_state_function_hidden(self, input_): + if self.output: + return self.spk, self.syn, self.mem + elif self.init_hidden: + return self.spk + else: + return self.spk, self.syn, self.mem + + def _base_state_function(self, input_): combined = torch.cat( [input_, self.mem], dim=1 ) # concatenate along channel axis @@ -396,7 +346,7 @@ def _base_state_function_hidden(self, input_): return base_fn_syn, base_fn_mem - def _base_state_reset_zero_hidden(self, input_): + def _base_state_reset_zero(self, input_): combined = torch.cat( [input_, self.mem], dim=1 ) # concatenate along channel axis @@ -414,43 +364,22 @@ def _base_state_reset_zero_hidden(self, input_): return 0, base_fn_mem - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function_hidden(input_), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_reset_zero_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn - - @staticmethod - def init_sconv2dlstm(): - """ - Used to initialize h and c as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - mem = _SpikeTensor(init_flag=False) - syn = _SpikeTensor(init_flag=False) - - return mem, syn - - def _reshape_input(self, input_): - device = input_.device - b, _, h, w = input_.size() - return torch.zeros(b, self.out_channels, h, w).to(device) + def _base_sub(self, input_): + syn, mem = self._base_state_function(input_) + mem -= self.reset * self.threshold + return syn, mem + + def _base_zero(self, input_): + syn, mem = self._base_state_function(input_) + syn2, mem2 = self._base_state_reset_zero(input_) + syn2 *= self.reset + mem2 *= self.reset + syn -= syn2 + mem -= mem2 + return syn, mem + + def _base_int(self, input_): + return self._base_state_function(input_) def _sconv2dlstm_cases(self): if self.max_pool and self.avg_pool: @@ -478,5 +407,11 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], SConv2dLSTM): - cls.instances[layer].syn = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn = torch.zeros_like( + cls.instances[layer].syn, + device=cls.instances[layer].syn.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/tests/test_snntorch/test_sconv2dlstm.py b/tests/test_snntorch/test_sconv2dlstm.py index 8a504923..7c8b96f4 100644 --- a/tests/test_snntorch/test_sconv2dlstm.py +++ b/tests/test_snntorch/test_sconv2dlstm.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo # TO-DO: add avg/max-pooling tests @@ -20,6 +21,11 @@ def sconv2dlstm_instance(): return snn.SConv2dLSTM(1, 8, 3) +@pytest.fixture(scope="module") +def sconv2dlstm_instance_surrogate(): + return snn.SConv2dLSTM(1, 8, 3, surrogate_disable=True) + + @pytest.fixture(scope="module") def sconv2dlstm_reset_zero_instance(): return snn.SConv2dLSTM(1, 8, 3, reset_mechanism="zero") @@ -124,3 +130,10 @@ def test_sconv2dlstm_init_hidden_reset_subtract( spk_rec.append(spk) assert spk_rec[0].size() == (1, 8, 4, 4) + + def test_sconv2dlstm_compile_fullgraph( + self, sconv2dlstm_instance_surrogate, input_ + ): + explanation = dynamo.explain(sconv2dlstm_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 From 1d600aa1c5115d654edc0a3bc1dd097209496e0f Mon Sep 17 00:00:00 2001 From: visdauas Date: Thu, 22 Feb 2024 11:09:45 +0100 Subject: [PATCH 12/17] Misc fixes & cleanup --- snntorch/_neurons/neurons.py | 62 +------------------------------ tests/test_snntorch/test_leaky.py | 5 +-- tests/test_snntorch/test_slstm.py | 10 ++--- 3 files changed, 8 insertions(+), 69 deletions(-) diff --git a/snntorch/_neurons/neurons.py b/snntorch/_neurons/neurons.py index 894ca36f..18fa4216 100644 --- a/snntorch/_neurons/neurons.py +++ b/snntorch/_neurons/neurons.py @@ -1,6 +1,5 @@ -import inspect from warnings import warn -from snntorch.surrogate import StraightThroughEstimator, atan, straight_through_estimator +from snntorch.surrogate import atan import torch import torch.nn as nn @@ -8,8 +7,6 @@ __all__ = [ "SpikingNeuron", "LIF", - "_SpikeTensor", - "_SpikeTorchConv", ] dtype = torch.float @@ -234,6 +231,7 @@ def zeros(*args): def _surrogate_bypass(input_): return (input_ > 0).float() + class LIF(SpikingNeuron): """Parent class for leaky integrate and fire neuron models.""" @@ -298,59 +296,3 @@ def _V_register_buffer(self, V, learn_V): self.V = nn.Parameter(V) else: self.register_buffer("V", V) - - @staticmethod - def init_rsynaptic(): - """ - Used to initialize spk, syn and mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - spk = _SpikeTensor(init_flag=False) - syn = _SpikeTensor(init_flag=False) - mem = _SpikeTensor(init_flag=False) - - return spk, syn, mem - - -class _SpikeTensor(torch.Tensor): - """Inherits from torch.Tensor with additional attributes. - ``init_flag`` is set at the time of initialization. - When called in the forward function of any neuron, they are parsed and - replaced with a torch.Tensor variable. - """ - - @staticmethod - def __new__(cls, *args, init_flag=False, **kwargs): - return super().__new__(cls, *args, **kwargs) - - def __init__( - self, - *args, - init_flag=True, - ): - # super().__init__() # optional - self.init_flag = init_flag - - -def _SpikeTorchConv(*args, input_): - """Convert SpikeTensor to torch.Tensor of the same size as ``input_``.""" - - states = [] - # if len(input_.size()) == 0: - # _batch_size = 1 # assume batch_size=1 if 1D input - # else: - # _batch_size = input_.size(0) - if ( - len(args) == 1 and type(args) is not tuple - ): # if only one hidden state, make it iterable - args = (args,) - for arg in args: - arg = arg.to("cpu") - arg = torch.Tensor(arg) # wash away the SpikeTensor class - arg = torch.zeros_like(input_, requires_grad=True) - states.append(arg) - if len(states) == 1: # otherwise, list isn't unpacked - return states[0] - - return states diff --git a/tests/test_snntorch/test_leaky.py b/tests/test_snntorch/test_leaky.py index 30622bce..2bfc26ca 100644 --- a/tests/test_snntorch/test_leaky.py +++ b/tests/test_snntorch/test_leaky.py @@ -139,9 +139,6 @@ def test_leaky_hidden_learn_graded_instance( assert factor.requires_grad def test_leaky_compile_fullgraph(self, leaky_instance_surrogate, input_): - # net = nn.Sequential( - # snn.Leaky(beta=0.5, init_hidden=True, surrogate_disable=True), - # ) - explanation = dynamo.explain(leaky_instance_surrogate)(input_[0]) + assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_slstm.py b/tests/test_snntorch/test_slstm.py index 7f31f096..8aa46b79 100644 --- a/tests/test_snntorch/test_slstm.py +++ b/tests/test_snntorch/test_slstm.py @@ -52,7 +52,7 @@ def slstm_hidden_reset_subtract_instance(): class TestSLSTM: - def test_sconv2dlstm(self, slstm_instance, input_): + def test_slstm(self, slstm_instance, input_): c, h = slstm_instance.init_slstm() h_rec = [] @@ -70,7 +70,7 @@ def test_sconv2dlstm(self, slstm_instance, input_): assert h.size() == (1, 2) assert spk.size() == (1, 2) - def test_sconv2dlstm_reset( + def test_slstm_reset( self, slstm_instance, slstm_reset_zero_instance, @@ -93,7 +93,7 @@ def test_sconv2dlstm_reset( assert lif2.reset_mechanism_val == 1 assert lif3.reset_mechanism_val == 0 - def test_sconv2dlstm_init_hidden(self, slstm_hidden_instance, input_): + def test_slstm_init_hidden(self, slstm_hidden_instance, input_): spk_rec = [] @@ -103,7 +103,7 @@ def test_sconv2dlstm_init_hidden(self, slstm_hidden_instance, input_): assert spk_rec[0].size() == (1, 2) - def test_sconv2dlstm_init_hidden_reset_zero( + def test_slstm_init_hidden_reset_zero( self, slstm_hidden_reset_zero_instance, input_ ): @@ -115,7 +115,7 @@ def test_sconv2dlstm_init_hidden_reset_zero( assert spk_rec[0].size() == (1, 2) - def test_sconv2dlstm_init_hidden_reset_subtract( + def test_slstm_init_hidden_reset_subtract( self, slstm_hidden_reset_subtract_instance, input_ ): From 75a0e17e89a0f666abf8ddf4224e2502c5ce88a2 Mon Sep 17 00:00:00 2001 From: jeshraghian Date: Tue, 12 Mar 2024 13:59:14 -0700 Subject: [PATCH 13/17] =?UTF-8?q?Bump=20version:=200.7.0=20=E2=86=92=200.8?= =?UTF-8?q?.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- _version.py | 2 +- docs/conf.py | 2 +- setup.cfg | 2 +- setup.py | 2 +- snntorch/_version.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/_version.py b/_version.py index 46d9bfe6..409e6b2d 100644 --- a/_version.py +++ b/_version.py @@ -1,2 +1,2 @@ # fmt: off -__version__ = '0.7.0' +__version__ = '0.8.0' diff --git a/docs/conf.py b/docs/conf.py index a404c254..0e508082 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -6,7 +6,7 @@ # fmt: off -__version__ = '0.7.0' +__version__ = '0.8.0' # fmt: on diff --git a/setup.cfg b/setup.cfg index 912e11e7..d5968e88 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.7.0 +current_version = 0.8.0 commit = True tag = True diff --git a/setup.py b/setup.py index e78d41dd..801000b6 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ # history = history_file.read() # fmt: off -__version__ = '0.7.0' +__version__ = '0.8.0' # fmt: on requirements = [ diff --git a/snntorch/_version.py b/snntorch/_version.py index 46d9bfe6..409e6b2d 100644 --- a/snntorch/_version.py +++ b/snntorch/_version.py @@ -1,2 +1,2 @@ # fmt: off -__version__ = '0.7.0' +__version__ = '0.8.0' From 71afaa0b27636abfa85e9623abf3c8c4c894dabc Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 16 Mar 2024 19:14:11 -0700 Subject: [PATCH 14/17] remove python 3.7 from test matrix --- .github/workflows/build-tag.yml | 2 +- {examples_private => examples}/rnn_test.py | 0 {examples_private => examples}/rnn_test_train.py | 0 setup.py | 3 +-- 4 files changed, 2 insertions(+), 3 deletions(-) rename {examples_private => examples}/rnn_test.py (100%) rename {examples_private => examples}/rnn_test_train.py (100%) diff --git a/.github/workflows/build-tag.yml b/.github/workflows/build-tag.yml index 5f69e260..2c9d91c4 100644 --- a/.github/workflows/build-tag.yml +++ b/.github/workflows/build-tag.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.8', '3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v2 diff --git a/examples_private/rnn_test.py b/examples/rnn_test.py similarity index 100% rename from examples_private/rnn_test.py rename to examples/rnn_test.py diff --git a/examples_private/rnn_test_train.py b/examples/rnn_test_train.py similarity index 100% rename from examples_private/rnn_test_train.py rename to examples/rnn_test_train.py diff --git a/setup.py b/setup.py index 801000b6..da410c51 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setup( author="Jason K. Eshraghian", author_email="jeshragh@ucsc.edu", - python_requires=">=3.7", + python_requires=">=3.8", classifiers=[ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", @@ -42,7 +42,6 @@ "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", From 94fc028960f3222c23ff88ea59cc05907a576d3a Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 16 Mar 2024 19:14:29 -0700 Subject: [PATCH 15/17] =?UTF-8?q?Bump=20version:=200.8.0=20=E2=86=92=200.8?= =?UTF-8?q?.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- _version.py | 2 +- docs/conf.py | 2 +- setup.cfg | 3 ++- setup.py | 2 +- snntorch/_version.py | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/_version.py b/_version.py index 409e6b2d..edb9e336 100644 --- a/_version.py +++ b/_version.py @@ -1,2 +1,2 @@ # fmt: off -__version__ = '0.8.0' +__version__ = '0.8.1' diff --git a/docs/conf.py b/docs/conf.py index 0e508082..932e83cb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -6,7 +6,7 @@ # fmt: off -__version__ = '0.8.0' +__version__ = '0.8.1' # fmt: on diff --git a/setup.cfg b/setup.cfg index d5968e88..db7ed9ac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.8.0 +current_version = 0.8.1 commit = True tag = True @@ -31,3 +31,4 @@ test = pytest [tool:pytest] testpaths = tests addopts = --ignore=setup.py + diff --git a/setup.py b/setup.py index da410c51..8467aa89 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ # history = history_file.read() # fmt: off -__version__ = '0.8.0' +__version__ = '0.8.1' # fmt: on requirements = [ diff --git a/snntorch/_version.py b/snntorch/_version.py index 409e6b2d..edb9e336 100644 --- a/snntorch/_version.py +++ b/snntorch/_version.py @@ -1,2 +1,2 @@ # fmt: off -__version__ = '0.8.0' +__version__ = '0.8.1' From 5b9874c16418942ba8bbd81417e0839f50761a0b Mon Sep 17 00:00:00 2001 From: Ray Rubens <148541376+peppermintbird@users.noreply.github.com> Date: Mon, 18 Mar 2024 22:21:51 -0300 Subject: [PATCH 16/17] Update tutorial_exoplanet_hunter.rst Correcting minor typos. --- docs/tutorials/tutorial_exoplanet_hunter.rst | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/tutorials/tutorial_exoplanet_hunter.rst b/docs/tutorials/tutorial_exoplanet_hunter.rst index 8f4545a9..92927d8b 100644 --- a/docs/tutorials/tutorial_exoplanet_hunter.rst +++ b/docs/tutorials/tutorial_exoplanet_hunter.rst @@ -83,20 +83,20 @@ Before diving into the code, let's gain an understanding of what Exoplanet Detec The transit method is a widely used and successful technique for detecting exoplanets. When an exoplanet transits its host star, it causes a temporary reduction in the star's light flux (brightness). -Compared to other techniques, the transmit method has has discovered +Compared to other techniques, the transit method has discovered the largest number of planets. Astronomers use telescopes equipped with photometers or spectrophotometers to continuously monitor the brightness of a star over -time. Repeated observations of multiple transits allows astronomers to +time. Repeated observations of multiple transits allow astronomers to gather more detailed information about the exoplanet, such as its atmosphere and the presence of moons. Space telescopes like NASA's Kepler and TESS (Transiting Exoplanet Survey Satellite) have been instrumental in discovering thousands of -exoplanets using the transit method. Without the Earth's atmosphere in the way, -there is less interference and more precise measurements are possible. -The transit method continues to be a key tool in advancing our understanding of +exoplanets using the transit method. Without the Earth's atmosphere to hinder observations, +there is minimal interference, allowing for more precise measurements. +The transit method remains a key tool in furthering our comprehension of exoplanetary systems. For more information about transit method, you can visit `NASA Exoplanet Exploration Page `__. @@ -107,7 +107,7 @@ Page `__. The drawback of this method is that the angle between the planet's orbital plane and the direction of the observer's line of sight must be sufficiently small. Therefore, the chance of this phenomenon occurring is not -high. Thus more time and resources must be spent to detect and confirm +high. Thus, more time and resources must be allocated to detect and confirm the existence of an exoplanet. These resources include the Kepler telescope and ESA's CoRoT when they were still operational. @@ -202,8 +202,8 @@ datasets `__ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Given the low chance of detecting exoplanets, this dataset is very imbalanced. -Most samples are negative, i.e., there are very few exoplanets from the observed -light intensity data. If your model was to simply predict 'no exoplanet' for every sample, +Most samples are negative, meaning there are very few exoplanets from the observed +light intensity data. If your model were to simply predict 'no exoplanet' for every sample, then it would achieve very high accuracy. This indicates that accuracy is a poor metric for success. Let's first probe our data to gain insight into how imbalanced it is. @@ -245,7 +245,7 @@ To deal with the imbalance of our dataset, let's Synthetic Minority Over-Sampling Technique (SMOTE). SMOTE works by generating synthetic samples from the minority class to balance the distribution (typically implemented using the nearest neighbors -strategy). By implementing SMOTE, we attempt to reduce bias towards +strategy). By implementing SMOTE, we attempt to reduce bias toward stars without exoplanets (the majority class). .. code:: python @@ -368,9 +368,9 @@ After loading the data, let's see what our data looks like. The code block below follows the same syntax as with the `official snnTorch tutorial `__. -In contrast to other tutorials however, this model passes data across the entire sequence in parallel. -In that sense, it is more akin to how attention-based mechanisms take data. -Turning this into a more 'online' method would likely involve pre-processing to downsample the exceedingly long sequence length. +In contrast to other tutorials, however, this model concurrently processes data across the entire sequence. +In that sense, it is more akin to how attention-based mechanisms handle data. +Turning this into a more 'online' method would likely involve preprocessing to downsample the exceedingly long sequence length. .. code:: python From 25f43892b07971e506f2ff6eaaf2a5cd15feda00 Mon Sep 17 00:00:00 2001 From: Gabriel Date: Mon, 1 Apr 2024 22:31:22 -0300 Subject: [PATCH 17/17] fix tutorial 2 Add the batch dimension to the input current as discused in #306 Seems that the tutorial broke after a change in Torch, now the input current must have a batch dimension even when the dimension is 1, so I added a batch dimension to the input current in the tutorial 2 notebook and the tutorial 2 rst file. Fix typo in the tutorial 2 rst file, said "spk_in" instead of "cur_in" --- docs/tutorials/tutorial_2.rst | 24 ++++++++++++------------ examples/tutorial_2_lif_neuron.ipynb | 24 ++++++++++++------------ 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/docs/tutorials/tutorial_2.rst b/docs/tutorials/tutorial_2.rst index 36efc7f8..9ec3b7f2 100644 --- a/docs/tutorials/tutorial_2.rst +++ b/docs/tutorials/tutorial_2.rst @@ -307,7 +307,7 @@ The neuron model is now stored in ``lif1``. To use this neuron: **Inputs** -* ``spk_in``: each element of :math:`I_{\rm in}` is sequentially passed as an input (0 for now) +* ``cur_in``: each element of :math:`I_{\rm in}` is sequentially passed as an input (0 for now) * ``mem``: the membrane potential, previously :math:`U[t]`, is also passed as input. Initialize it arbitrarily as :math:`U[0] = 0.9~V`. **Outputs** @@ -321,7 +321,7 @@ These all need to be of type ``torch.Tensor``. # Initialize membrane, input, and output mem = torch.ones(1) * 0.9 # U=0.9 at t=0 - cur_in = torch.zeros(num_steps) # I=0 for all t + cur_in = torch.zeros(num_steps, 1) # I=0 for all t spk_out = torch.zeros(1) # initialize output spikes These values are only for the initial time step :math:`t=0`. @@ -382,7 +382,7 @@ Let’s visualize what this looks like by triggering a current pulse of :: # Initialize input current pulse - cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.1), 0) # input current turns on at t=10 + cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.1), 0) # input current turns on at t=10 # Initialize membrane, output and recordings mem = torch.zeros(1) # membrane potential of 0 at t=0 @@ -430,7 +430,7 @@ Now what if the step input was clipped at :math:`t=30ms`? :: # Initialize current pulse, membrane and outputs - cur_in1 = torch.cat((torch.zeros(10), torch.ones(20)*(0.1), torch.zeros(170)), 0) # input turns on at t=10, off at t=30 + cur_in1 = torch.cat((torch.zeros(10, 1), torch.ones(20, 1)*(0.1), torch.zeros(170, 1)), 0) # input turns on at t=10, off at t=30 mem = torch.zeros(1) spk_out = torch.zeros(1) mem_rec1 = [mem] @@ -462,7 +462,7 @@ time window must be decreased. :: # Increase amplitude of current pulse; half the time. - cur_in2 = torch.cat((torch.zeros(10), torch.ones(10)*0.111, torch.zeros(180)), 0) # input turns on at t=10, off at t=20 + cur_in2 = torch.cat((torch.zeros(10, 1), torch.ones(10, 1)*0.111, torch.zeros(180, 1)), 0) # input turns on at t=10, off at t=20 mem = torch.zeros(1) spk_out = torch.zeros(1) mem_rec2 = [mem] @@ -487,7 +487,7 @@ amplitude: :: # Increase amplitude of current pulse; quarter the time. - cur_in3 = torch.cat((torch.zeros(10), torch.ones(5)*0.147, torch.zeros(185)), 0) # input turns on at t=10, off at t=15 + cur_in3 = torch.cat((torch.zeros(10, 1), torch.ones(5, 1)*0.147, torch.zeros(185, 1)), 0) # input turns on at t=10, off at t=15 mem = torch.zeros(1) spk_out = torch.zeros(1) mem_rec3 = [mem] @@ -526,7 +526,7 @@ membrane potential will jump straight up in virtually zero rise time: :: # Current spike input - cur_in4 = torch.cat((torch.zeros(10), torch.ones(1)*0.5, torch.zeros(189)), 0) # input only on for 1 time step + cur_in4 = torch.cat((torch.zeros(10, 1), torch.ones(1, 1)*0.5, torch.zeros(189, 1)), 0) # input only on for 1 time step mem = torch.zeros(1) spk_out = torch.zeros(1) mem_rec4 = [mem] @@ -685,7 +685,7 @@ As before, all of that code is condensed by calling the built-in Lapicque neuron :: # Initialize inputs and outputs - cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.2), 0) + cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.2), 0) mem = torch.zeros(1) spk_out = torch.zeros(1) mem_rec = [mem] @@ -732,7 +732,7 @@ approaches the threshold :math:`U_{\rm thr}` faster: :: # Initialize inputs and outputs - cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) # increased current + cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) # increased current mem = torch.zeros(1) spk_out = torch.zeros(1) mem_rec = [mem] @@ -766,7 +766,7 @@ rest of the code block is the exact same as above: lif3 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5) # Initialize inputs and outputs - cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) + cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) mem = torch.zeros(1) spk_out = torch.zeros(1) mem_rec = [mem] @@ -806,7 +806,7 @@ generated input spikes. :: # Create a 1-D random spike train. Each element has a probability of 40% of firing. - spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40) + spk_in = spikegen.rate_conv(torch.ones((num_steps,1)) * 0.40) Run the following code block to see how many spikes have been generated. @@ -889,7 +889,7 @@ This can be explicitly overridden by passing the argument lif4 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5, reset_mechanism="zero") # Initialize inputs and outputs - spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40) + spk_in = spikegen.rate_conv(torch.ones((num_steps, 1)) * 0.40) mem = torch.ones(1)*0.5 spk_out = torch.zeros(1) mem_rec0 = [mem] diff --git a/examples/tutorial_2_lif_neuron.ipynb b/examples/tutorial_2_lif_neuron.ipynb index 00c46ace..0fb5250f 100644 --- a/examples/tutorial_2_lif_neuron.ipynb +++ b/examples/tutorial_2_lif_neuron.ipynb @@ -591,7 +591,7 @@ "source": [ "# Initialize membrane, input, and output\n", "mem = torch.ones(1) * 0.9 # U=0.9 at t=0\n", - "cur_in = torch.zeros(num_steps) # I=0 for all t \n", + "cur_in = torch.zeros(num_steps, 1) # I=0 for all t \n", "spk_out = torch.zeros(1) # initialize output spikes" ] }, @@ -688,7 +688,7 @@ "outputs": [], "source": [ "# Initialize input current pulse\n", - "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.1), 0) # input current turns on at t=10\n", + "cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.1), 0) # input current turns on at t=10\n", "\n", "# Initialize membrane, output and recordings\n", "mem = torch.zeros(1) # membrane potential of 0 at t=0\n", @@ -776,7 +776,7 @@ "outputs": [], "source": [ "# Initialize current pulse, membrane and outputs\n", - "cur_in1 = torch.cat((torch.zeros(10), torch.ones(20)*(0.1), torch.zeros(170)), 0) # input turns on at t=10, off at t=30\n", + "cur_in1 = torch.cat((torch.zeros(10, 1), torch.ones(20, 1)*(0.1), torch.zeros(170, 1)), 0) # input turns on at t=10, off at t=30\n", "mem = torch.zeros(1)\n", "spk_out = torch.zeros(1)\n", "mem_rec1 = [mem]" @@ -820,7 +820,7 @@ "outputs": [], "source": [ "# Increase amplitude of current pulse; half the time.\n", - "cur_in2 = torch.cat((torch.zeros(10), torch.ones(10)*0.111, torch.zeros(180)), 0) # input turns on at t=10, off at t=20\n", + "cur_in2 = torch.cat((torch.zeros(10, 1), torch.ones(10, 1)*0.111, torch.zeros(180, 1)), 0) # input turns on at t=10, off at t=20\n", "mem = torch.zeros(1)\n", "spk_out = torch.zeros(1)\n", "mem_rec2 = [mem]\n", @@ -853,7 +853,7 @@ "outputs": [], "source": [ "# Increase amplitude of current pulse; quarter the time.\n", - "cur_in3 = torch.cat((torch.zeros(10), torch.ones(5)*0.147, torch.zeros(185)), 0) # input turns on at t=10, off at t=15\n", + "cur_in3 = torch.cat((torch.zeros(10, 1), torch.ones(5, 1)*0.147, torch.zeros(185, 1)), 0) # input turns on at t=10, off at t=15\n", "mem = torch.zeros(1)\n", "spk_out = torch.zeros(1)\n", "mem_rec3 = [mem]\n", @@ -907,7 +907,7 @@ "outputs": [], "source": [ "# Current spike input\n", - "cur_in4 = torch.cat((torch.zeros(10), torch.ones(1)*0.5, torch.zeros(189)), 0) # input only on for 1 time step\n", + "cur_in4 = torch.cat((torch.zeros(10, 1), torch.ones(1, 1)*0.5, torch.zeros(189, 1)), 0) # input only on for 1 time step\n", "mem = torch.zeros(1) \n", "spk_out = torch.zeros(1)\n", "mem_rec4 = [mem]\n", @@ -1120,7 +1120,7 @@ "outputs": [], "source": [ "# Initialize inputs and outputs\n", - "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.2), 0)\n", + "cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.2), 0)\n", "mem = torch.zeros(1)\n", "spk_out = torch.zeros(1) \n", "mem_rec = [mem]\n", @@ -1180,7 +1180,7 @@ "outputs": [], "source": [ "# Initialize inputs and outputs\n", - "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) # increased current\n", + "cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) # increased current\n", "mem = torch.zeros(1)\n", "spk_out = torch.zeros(1) \n", "mem_rec = [mem]\n", @@ -1222,7 +1222,7 @@ "lif3 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5)\n", "\n", "# Initialize inputs and outputs\n", - "cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) \n", + "cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) \n", "mem = torch.zeros(1)\n", "spk_out = torch.zeros(1) \n", "mem_rec = [mem]\n", @@ -1278,7 +1278,7 @@ "outputs": [], "source": [ "# Create a 1-D random spike train. Each element has a probability of 40% of firing.\n", - "spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)" + "spk_in = spikegen.rate_conv(torch.ones((num_steps,1)) * 0.40)" ] }, { @@ -1372,7 +1372,7 @@ "lif4 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5, reset_mechanism=\"zero\")\n", "\n", "# Initialize inputs and outputs\n", - "spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)\n", + "spk_in = spikegen.rate_conv(torch.ones((num_steps, 1)) * 0.40)\n", "mem = torch.ones(1)*0.5\n", "spk_out = torch.zeros(1)\n", "mem_rec0 = [mem]\n", @@ -1466,7 +1466,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.0" + "version": "3.11.8" }, "vscode": { "interpreter": {