Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenabreu7 committed Feb 5, 2024
1 parent 11684cf commit d9df668
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 73 deletions.
22 changes: 14 additions & 8 deletions snntorch/import_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,22 +322,28 @@ def _nir_to_snntorch_module(
reset_mechanism='zero',
init_hidden=init_hidden,
all_to_all=not diagonal,
linear_features=lif_size,
V=V,
linear_features=lif_size if not diagonal else None,
V=V if diagonal else None,
reset_delay=False,
)

rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight)
if isinstance(wrec_node, nir.Affine):
rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias)
if isinstance(rsynaptic.recurrent, torch.nn.Linear):
rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight)
if isinstance(wrec_node, nir.Affine):
rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias)
else:
rsynaptic.recurrent.bias.data = torch.zeros_like(rsynaptic.recurrent.bias)
else:
rsynaptic.recurrent.bias.data = torch.zeros_like(rsynaptic.recurrent.bias)
rsynaptic.recurrent.V.data = torch.diagonal(torch.Tensor(wrec_node.weight))

return rsynaptic

elif node is None:
return torch.nn.Identity()

else:
print('[WARNING] could not parse node of type:', node.__class__.__name__)

return None
return None


def import_from_nir(graph: nir.NIRGraph) -> torch.nn.Module:
Expand Down
139 changes: 74 additions & 65 deletions tests/test_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,99 +2,108 @@

"""Tests for NIR import and export."""

import nir
import pytest
import snntorch as snn
import torch


@pytest.fixture(scope="module")
def snntorch_sequential():
lif1 = snn.Leaky(beta=0.9, init_hidden=False)
lif2 = snn.Leaky(beta=0.9, init_hidden=False, output=True)

return torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(784, 500),
lif1,
torch.nn.Linear(500, 10),
lif2
)
def sample_data():
return torch.ones((4, 784))


@pytest.fixture(scope="module")
def snntorch_sequential_hidden():
def snntorch_sequential():
lif1 = snn.Leaky(beta=0.9, init_hidden=True)
lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True)

return torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(784, 500),
lif1,
torch.nn.Linear(500, 10),
lif2
lif2,
)


@pytest.fixture(scope="module")
def snntorch_recurrent():
lif1 = snn.RLeaky(beta=0.9, V=1, all_to_all=True, init_hidden=False)
lif2 = snn.Leaky(beta=0.9, init_hidden=False, output=True)

return torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(784, 500),
lif1,
torch.nn.Linear(500, 10),
lif2
v = torch.ones((500,))
lif1 = snn.RSynaptic(
alpha=0.5, beta=0.9, V=v, all_to_all=False, init_hidden=True
)


@pytest.fixture(scope="module")
def snntorch_recurrent_hidden():
lif1 = snn.RLeaky(beta=0.9, V=1, all_to_all=True, init_hidden=True)
lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True)

return torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(784, 500),
lif1,
torch.nn.Linear(500, 10),
lif2
lif2,
)


class NIRTestExport:
"""Test exporting snnTorch network to NIR."""
def test_export_sequential(snntorch_sequential):
pass

def test_export_sequential_hidden(snntorch_sequential_hidden):
pass

def test_export_recurrent(snntorch_recurrent):
pass

def test_export_recurrent_hidden(snntorch_recurrent_hidden):
pass


class NIRTestImport:
"""Test importing NIR graph to snnTorch."""
def test_import_nir():
# load a NIR graph from a file?
pass


class NIRTestCommute:
"""Test that snnTorch -> NIR -> snnTorch doesn't change the network."""
def test_commute_sequential(snntorch_sequential):
pass

def test_commute_sequential_hidden(snntorch_sequential_hidden):
pass

def test_commute_recurrent(snntorch_recurrent):
pass

def test_commute_recurrent_hidden(snntorch_recurrent_hidden):
pass
class TestNIR:
"""Test import and export from snnTorch to NIR."""

def test_export_sequential(self, snntorch_sequential, sample_data):
nir_graph = snn.export_to_nir(snntorch_sequential, sample_data)
assert nir_graph is not None
assert set(nir_graph.nodes.keys()) == set(
["input", "output"] + [str(i) for i in range(4)]
), nir_graph.nodes.keys()
assert set(nir_graph.edges) == set(
[
("3", "output"),
("input", "0"),
("2", "3"),
("1", "2"),
("0", "1"),
]
)
assert isinstance(nir_graph.nodes["input"], nir.Input)
assert isinstance(nir_graph.nodes["output"], nir.Output)
assert isinstance(nir_graph.nodes["0"], nir.Affine)
assert isinstance(nir_graph.nodes["1"], nir.LIF)
assert isinstance(nir_graph.nodes["2"], nir.Affine)
assert isinstance(nir_graph.nodes["3"], nir.LIF)

def test_export_recurrent(self, snntorch_recurrent, sample_data):
nir_graph = snn.export_to_nir(snntorch_recurrent, sample_data)
assert nir_graph is not None
assert set(nir_graph.nodes.keys()) == set(
["input", "output", "0", "1.lif", "1.w_rec", "2", "3"]
), nir_graph.nodes.keys()
assert isinstance(nir_graph.nodes["input"], nir.Input)
assert isinstance(nir_graph.nodes["output"], nir.Output)
assert isinstance(nir_graph.nodes["0"], nir.Affine)
assert isinstance(nir_graph.nodes["1.lif"], nir.CubaLIF)
assert isinstance(nir_graph.nodes["1.w_rec"], nir.Linear)
assert isinstance(nir_graph.nodes["2"], nir.Affine)
assert isinstance(nir_graph.nodes["3"], nir.LIF)
assert set(nir_graph.edges) == set(
[
("1.lif", "1.w_rec"),
("1.w_rec", "1.lif"),
("0", "1.lif"),
("3", "output"),
("2", "3"),
("input", "0"),
("1.lif", "2"),
]
)

def test_import_nir(self):
graph = nir.read("tests/lif.nir")
net = snn.import_from_nir(graph)
out, _ = net(torch.ones(1, 1))
assert out.shape == (1, 1), out.shape

def test_commute_sequential(self, snntorch_sequential, sample_data):
x = torch.rand((4, 784))
y_snn, state = snntorch_sequential(x)
assert y_snn.shape == (4, 10)
nir_graph = snn.export_to_nir(snntorch_sequential, sample_data)
net = snn.import_from_nir(nir_graph)
y_nir, state = net(x)
assert y_nir.shape == (4, 10), y_nir.shape
assert torch.allclose(y_snn, y_nir)

0 comments on commit d9df668

Please sign in to comment.