-
Notifications
You must be signed in to change notification settings - Fork 223
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' of https://github.com/jeshraghian/snntorch
- Loading branch information
Showing
15 changed files
with
133 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
# fmt: off | ||
__version__ = '0.6.4' | ||
__version__ = '0.7.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
|
||
# fmt: off | ||
__version__ = '0.6.4' | ||
__version__ = '0.7.0' | ||
# fmt: on | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,6 @@ h5py>=3.0.0 | |
matplotlib | ||
celluloid | ||
numpy>=1.17 | ||
tqdm | ||
tqdm | ||
nir | ||
nirtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
snntorch.export | ||
------------------------ | ||
:mod:`snntorch.export` is a module that enables cross-compatibility with other SNN libraries by converting snntorch models to a `Neuromorphic Intermediate Representation (NIR) <https://nnir.readthedocs.io/en/latest/>`_ | ||
|
||
.. automodule:: snntorch.export | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
[bumpversion] | ||
current_version = 0.6.4 | ||
current_version = 0.7.0 | ||
commit = True | ||
tag = True | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
# fmt: off | ||
__version__ = '0.6.4' | ||
__version__ = '0.7.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import Union, Optional | ||
from numbers import Number | ||
|
||
import torch | ||
import nir | ||
from nirtorch import extract_nir_graph | ||
|
||
from snntorch import Leaky, Synaptic | ||
|
||
# eqn is assumed to be: v_t+1 = (1-1/tau)*v_t + 1/tau * v_leak + I_in / C | ||
def _extract_snntorch_module(module:torch.nn.Module) -> Optional[nir.NIRNode]: | ||
if isinstance(module, Leaky): | ||
return nir.LIF( | ||
tau = 1 / (1 - module.beta).detach(), | ||
v_threshold = module.threshold.detach(), | ||
v_leak = torch.zeros_like(module.beta), | ||
r = module.beta.detach(), | ||
) | ||
|
||
if isinstance(module, Synaptic): | ||
return nir.CubaLIF( | ||
tau_syn = 1 / (1 - module.beta).detach(), | ||
tau_mem = 1 / (1 - module.alpha).detach(), | ||
v_threshold = module.threshold.detach(), | ||
v_leak = torch.zeros_like(module.beta), | ||
r = module.beta.detach(), | ||
) | ||
|
||
elif isinstance(module, torch.nn.Linear): | ||
if module.bias is None: # Add zero bias if none is present | ||
return nir.Affine( | ||
module.weight.detach(), torch.zeros(*module.weight.shape[:-1]) | ||
) | ||
else: | ||
return nir.Affine(module.weight.detach(), module.bias.detach()) | ||
|
||
return None | ||
|
||
|
||
def to_nir( | ||
module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch" | ||
) -> nir.NIRNode: | ||
"""Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR). | ||
Example:: | ||
import torch, torch.nn as nn | ||
import snntorch as snn | ||
from snntorch import export | ||
data_path = "untrained-snntorch.pt" | ||
net = nn.Sequential(nn.Linear(784, 128), | ||
snn.Leaky(beta=0.8, init_hidden=True), | ||
nn.Linear(128, 10), | ||
snn.Leaky(beta=0.8, init_hidden=True, output=True)) | ||
# save model in pt format | ||
torch.save(net.state_dict(), data_path) | ||
# load model (does nothing here, but shown for completeness) | ||
net.load_state_dict(torch.load(data_path)) | ||
# generate input tensor to dynamically construct graph | ||
x = torch.zeros(784) | ||
# generate NIR graph | ||
nir_net = export.to_nir(net, x) | ||
:param module: a saved snnTorch model as a parameter dictionary | ||
:type module: torch.nn.Module | ||
:param sample_data: sample input data to the model | ||
:type sample_data: torch.Tensor | ||
:param model_name: name of library used to train model, default: "snntorch" | ||
:type model_name: str, optional | ||
:return: NIR computational graph where torch modules are represented as NIR nodes | ||
:rtype: NIRGraph | ||
""" | ||
return extract_nir_graph( | ||
module, _extract_snntorch_module, sample_data, model_name=model_name | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -249,3 +249,4 @@ def _final_layer_check(net): | |
return 4 | ||
else: # if not from snn, assume from nn with 1 return | ||
return 1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters