Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/jeshraghian/snntorch
Browse files Browse the repository at this point in the history
  • Loading branch information
jeshraghian committed Aug 6, 2023
2 parents 570239c + 1aea473 commit efc7bb9
Show file tree
Hide file tree
Showing 15 changed files with 133 additions and 12 deletions.
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ snnTorch contains the following components:
- Description
* - `snntorch <https://snntorch.readthedocs.io/en/latest/snntorch.html>`_
- a spiking neuron library like torch.nn, deeply integrated with autograd
* - `snntorch.export <https://snntorch.readthedocs.io/en/latest/snntorch.export.html>`_
- enables cross-compatibility with other SNN libraries via `NIR <https://nnir.readthedocs.io/en/latest/>`_
* - `snntorch.functional <https://snntorch.readthedocs.io/en/latest/snntorch.functional.html>`_
- common arithmetic operations on spikes, e.g., loss, regularization etc.
* - `snntorch.spikegen <https://snntorch.readthedocs.io/en/latest/snntorch.spikegen.html>`_
Expand Down Expand Up @@ -104,6 +106,8 @@ The following packages need to be installed to use snnTorch:
* pandas
* matplotlib
* math
* nir
* nirtorch

They are automatically installed if snnTorch is installed using the pip command. Ensure the correct version of torch is installed for your system to enable CUDA compatibility.

Expand Down
2 changes: 1 addition & 1 deletion _version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# fmt: off
__version__ = '0.6.4'
__version__ = '0.7.0'
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# fmt: off
__version__ = '0.6.4'
__version__ = '0.7.0'
# fmt: on


Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ snnTorch Documentation
readme
installation
snntorch
snntorch.backprop
snntorch.export
snntorch.functional
snntorch.spikegen
snntorch.spikeplot
Expand Down
4 changes: 3 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ h5py>=3.0.0
matplotlib
celluloid
numpy>=1.17
tqdm
tqdm
nir
nirtorch
8 changes: 8 additions & 0 deletions docs/snntorch.export.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
snntorch.export
------------------------
:mod:`snntorch.export` is a module that enables cross-compatibility with other SNN libraries by converting snntorch models to a `Neuromorphic Intermediate Representation (NIR) <https://nnir.readthedocs.io/en/latest/>`_

.. automodule:: snntorch.export
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion examples/tutorial_regression_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@
"outputs": [],
"source": [
"batch_size = 128\n",
"data_path='/data/mnist'\n",
"data_path='data/mnist'\n",
"\n",
"# Define a transform\n",
"transform = transforms.Compose([\n",
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.6.4
current_version = 0.7.0
commit = True
tag = True

Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
# history = history_file.read()

# fmt: off
__version__ = '0.6.4'
__version__ = '0.7.0'
# fmt: on

requirements = [
"torch>=1.1.0",
"pandas",
"matplotlib",
"numpy>=1.17",
"nir",
"nirtorch",
]


Expand Down
10 changes: 7 additions & 3 deletions snntorch/_neurons/leaky.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .neurons import _SpikeTensor, _SpikeTorchConv, LIF

import torch

class Leaky(LIF):
"""
Expand Down Expand Up @@ -73,8 +73,8 @@ def forward(self, x, mem1, spk1, mem2):
optional
:param surrogate_disable: Disables surrogate gradients regardless of
`spike_grad` argument. Useful for ONNX compatibility. Defaults
to False
`spike_grad` argument. Useful for ONNX compatibility. Defaults
to False
:type surrogate_disable: bool, Optional
:param init_hidden: Instantiates state variables as instance variables.
Expand Down Expand Up @@ -139,6 +139,8 @@ def __init__(
reset_mechanism="subtract",
state_quant=False,
output=False,
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False,
):
super(Leaky, self).__init__(
beta,
Expand All @@ -152,6 +154,8 @@ def __init__(
reset_mechanism,
state_quant,
output,
graded_spikes_factor,
learn_graded_spikes_factor,
)

if self.init_hidden:
Expand Down
2 changes: 1 addition & 1 deletion snntorch/_neurons/synaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Synaptic(LIF):
# Define Network
class Net(nn.Module):
def __init__(self):
def __init__(self, num_inputs, num_hidden, num_outputs):
super().__init__()
# initialize layers
Expand Down
2 changes: 1 addition & 1 deletion snntorch/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# fmt: off
__version__ = '0.6.4'
__version__ = '0.7.0'
86 changes: 86 additions & 0 deletions snntorch/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Union, Optional
from numbers import Number

import torch
import nir
from nirtorch import extract_nir_graph

from snntorch import Leaky, Synaptic

# eqn is assumed to be: v_t+1 = (1-1/tau)*v_t + 1/tau * v_leak + I_in / C
def _extract_snntorch_module(module:torch.nn.Module) -> Optional[nir.NIRNode]:
if isinstance(module, Leaky):
return nir.LIF(
tau = 1 / (1 - module.beta).detach(),
v_threshold = module.threshold.detach(),
v_leak = torch.zeros_like(module.beta),
r = module.beta.detach(),
)

if isinstance(module, Synaptic):
return nir.CubaLIF(
tau_syn = 1 / (1 - module.beta).detach(),
tau_mem = 1 / (1 - module.alpha).detach(),
v_threshold = module.threshold.detach(),
v_leak = torch.zeros_like(module.beta),
r = module.beta.detach(),
)

elif isinstance(module, torch.nn.Linear):
if module.bias is None: # Add zero bias if none is present
return nir.Affine(
module.weight.detach(), torch.zeros(*module.weight.shape[:-1])
)
else:
return nir.Affine(module.weight.detach(), module.bias.detach())

return None


def to_nir(
module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch"
) -> nir.NIRNode:
"""Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR).
Example::
import torch, torch.nn as nn
import snntorch as snn
from snntorch import export
data_path = "untrained-snntorch.pt"
net = nn.Sequential(nn.Linear(784, 128),
snn.Leaky(beta=0.8, init_hidden=True),
nn.Linear(128, 10),
snn.Leaky(beta=0.8, init_hidden=True, output=True))
# save model in pt format
torch.save(net.state_dict(), data_path)
# load model (does nothing here, but shown for completeness)
net.load_state_dict(torch.load(data_path))
# generate input tensor to dynamically construct graph
x = torch.zeros(784)
# generate NIR graph
nir_net = export.to_nir(net, x)
:param module: a saved snnTorch model as a parameter dictionary
:type module: torch.nn.Module
:param sample_data: sample input data to the model
:type sample_data: torch.Tensor
:param model_name: name of library used to train model, default: "snntorch"
:type model_name: str, optional
:return: NIR computational graph where torch modules are represented as NIR nodes
:rtype: NIRGraph
"""
return extract_nir_graph(
module, _extract_snntorch_module, sample_data, model_name=model_name
)
1 change: 1 addition & 0 deletions snntorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,4 @@ def _final_layer_check(net):
return 4
else: # if not from snn, assume from nn with 1 return
return 1

14 changes: 14 additions & 0 deletions tests/test_snntorch/test_leaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def leaky_hidden_reset_none_instance():
return snn.Leaky(beta=0.5, init_hidden=True, reset_mechanism="none")


@pytest.fixture(scope="module")
def leaky_hidden_learn_graded_instance():
return snn.Leaky(
beta=0.5, init_hidden=True, learn_graded_spikes_factor=True
)


class TestLeaky:
def test_leaky(self, leaky_instance, input_):
mem = leaky_instance.init_leaky()
Expand Down Expand Up @@ -117,3 +124,10 @@ def test_leaky_init_hidden_reset_none(
def test_leaky_cases(self, leaky_hidden_instance, input_):
with pytest.raises(TypeError):
leaky_hidden_instance(input_, input_)

def test_leaky_hidden_learn_graded_instance(
self, leaky_hidden_learn_graded_instance
):
factor = leaky_hidden_learn_graded_instance.graded_spikes_factor

assert factor.requires_grad

0 comments on commit efc7bb9

Please sign in to comment.