diff --git a/snntorch/__init__.py b/snntorch/__init__.py index ccea82c0..0482a113 100644 --- a/snntorch/__init__.py +++ b/snntorch/__init__.py @@ -1,3 +1,5 @@ from ._version import __version__ from ._neurons import * from ._layers import * +from .export_nir import export_to_nir +from .import_nir import import_from_nir \ No newline at end of file diff --git a/snntorch/_neurons/leaky.py b/snntorch/_neurons/leaky.py index 00380fcb..7406fa82 100644 --- a/snntorch/_neurons/leaky.py +++ b/snntorch/_neurons/leaky.py @@ -106,6 +106,15 @@ def forward(self, x, mem1, spk1, mem2): returned when neuron is called. Defaults to False :type output: bool, optional + :param graded_spikes_factor: output spikes are scaled this value, if specified. Defaults to 1.0 + :type graded_spikes_factor: float or torch.tensor + + :param learn_graded_spikes_factor: Option to enable learnable graded spikes. Defaults to False + :type learn_graded_spikes_factor: bool, optional + + :param reset_delay: If `True`, a spike is returned with a one-step delay after the threshold is reached. + Defaults to True + :type reset_delay: bool, optional Inputs: \\input_, mem_0 - **input_** of shape `(batch, input_size)`: tensor containing input @@ -142,6 +151,7 @@ def __init__( output=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False, + reset_delay=True, ): super().__init__( beta, @@ -169,6 +179,11 @@ def __init__( elif self.reset_mechanism_val == 2: # no reset, pure integration self.state_function = self._base_int + self.reset_delay = reset_delay + + if not self.reset_delay and self.init_hidden: + raise NotImplementedError("`reset_delay=True` is only supported for `init_hidden=False`") + def _init_mem(self): mem = torch.zeros(1) @@ -178,17 +193,18 @@ def reset_mem(self): self.mem = torch.zeros_like(self.mem, device=self.mem.device) def init_leaky(self): - """Deprecated, please use :class:`Leaky.reset_mem` instead""" + """Deprecated, use :class:`Leaky.reset_mem` instead""" self.reset_mem() return self.mem - + def forward(self, input_, mem=None): + if not mem == None: self.mem = mem if self.init_hidden and not mem == None: raise TypeError( - "mem should not be passed as an argument while `init_hidden=True`" + "`mem` should not be passed as an argument while `init_hidden=True`" ) if not self.mem.shape == input_.shape: @@ -201,9 +217,16 @@ def forward(self, input_, mem=None): self.mem = self.state_quant(self.mem) if self.inhibition: - spk = self.fire_inhibition(self.mem.size(0), self.mem) + spk = self.fire_inhibition(self.mem.size(0), self.mem) # batch_size else: spk = self.fire(self.mem) + + if not self.reset_delay: + do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + self.mem = self.mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + self.mem = self.mem - do_reset * self.mem if self.output: return spk, self.mem diff --git a/snntorch/_neurons/rleaky.py b/snntorch/_neurons/rleaky.py index 55de0788..48383ada 100644 --- a/snntorch/_neurons/rleaky.py +++ b/snntorch/_neurons/rleaky.py @@ -241,6 +241,7 @@ def __init__( reset_mechanism="subtract", state_quant=False, output=False, + reset_delay=True, ): super().__init__( beta, @@ -279,6 +280,11 @@ def __init__( if not learn_recurrent: self._disable_recurrent_grad() + self.reset_delay = reset_delay + + if not self.reset_delay and self.init_hidden: + raise NotImplementedError('no reset_delay only supported for init_hidden=False') + if self.init_hidden: self.spk, self.mem = self.init_rleaky() # self.state_fn = self._build_state_function_hidden @@ -312,6 +318,13 @@ def forward(self, input_, spk=False, mem=False): else: spk = self.fire(mem) + if not self.reset_delay: + do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem = mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + mem = mem - do_reset * mem + return spk, mem # intended for truncated-BPTT where instance variables are hidden diff --git a/snntorch/_neurons/rsynaptic.py b/snntorch/_neurons/rsynaptic.py index 31de3510..654fd3d2 100644 --- a/snntorch/_neurons/rsynaptic.py +++ b/snntorch/_neurons/rsynaptic.py @@ -254,6 +254,7 @@ def __init__( reset_mechanism="subtract", state_quant=False, output=False, + reset_delay=True, ): super().__init__( beta, @@ -294,6 +295,11 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) + self.reset_delay = reset_delay + + if not reset_delay and self.init_hidden: + raise NotImplementedError('no reset_delay only supported for init_hidden=False') + if self.init_hidden: self.spk, self.syn, self.mem = self.init_rsynaptic() @@ -324,6 +330,15 @@ def forward(self, input_, spk=False, syn=False, mem=False): else: spk = self.fire(mem) + if not self.reset_delay: + # reset membrane potential _right_ after spike + do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem = mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + # mem -= do_reset * mem + mem = mem - do_reset * mem + return spk, syn, mem # intended for truncated-BPTT where instance variables are hidden diff --git a/snntorch/_neurons/synaptic.py b/snntorch/_neurons/synaptic.py index be1d22fe..4e3be032 100644 --- a/snntorch/_neurons/synaptic.py +++ b/snntorch/_neurons/synaptic.py @@ -168,6 +168,7 @@ def __init__( reset_mechanism="subtract", state_quant=False, output=False, + reset_delay=True, ): super().__init__( beta, @@ -185,6 +186,11 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) + self.reset_delay = reset_delay + + if not reset_delay and self.init_hidden: + raise NotImplementedError('no reset_delay only supported for init_hidden=False') + if self.init_hidden: self.syn, self.mem = self.init_synaptic() @@ -214,6 +220,14 @@ def forward(self, input_, syn=False, mem=False): else: spk = self.fire(mem) + if not self.reset_delay: + # reset membrane potential _right_ after spike + do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem = mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + mem = mem - do_reset * mem + return spk, syn, mem # intended for truncated-BPTT where instance variables are diff --git a/snntorch/export.py b/snntorch/export.py deleted file mode 100644 index b9cb94ce..00000000 --- a/snntorch/export.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import Union, Optional -from numbers import Number - -import torch -import nir -from nirtorch import extract_nir_graph - -from snntorch import Leaky, Synaptic - -# eqn is assumed to be: v_t+1 = (1-1/tau)*v_t + 1/tau * v_leak + I_in / C -def _extract_snntorch_module(module:torch.nn.Module) -> Optional[nir.NIRNode]: - if isinstance(module, Leaky): - return nir.LIF( - tau = 1 / (1 - module.beta).detach(), - v_threshold = module.threshold.detach(), - v_leak = torch.zeros_like(module.beta), - r = module.beta.detach(), - ) - - if isinstance(module, Synaptic): - return nir.CubaLIF( - tau_syn = 1 / (1 - module.beta).detach(), - tau_mem = 1 / (1 - module.alpha).detach(), - v_threshold = module.threshold.detach(), - v_leak = torch.zeros_like(module.beta), - r = module.beta.detach(), - ) - - elif isinstance(module, torch.nn.Linear): - if module.bias is None: # Add zero bias if none is present - return nir.Affine( - module.weight.detach(), torch.zeros(*module.weight.shape[:-1]) - ) - else: - return nir.Affine(module.weight.detach(), module.bias.detach()) - - return None - - -def to_nir( - module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch" -) -> nir.NIRNode: - """Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR). - - Example:: - - import torch, torch.nn as nn - import snntorch as snn - from snntorch import export - - data_path = "untrained-snntorch.pt" - - net = nn.Sequential(nn.Linear(784, 128), - snn.Leaky(beta=0.8, init_hidden=True), - nn.Linear(128, 10), - snn.Leaky(beta=0.8, init_hidden=True, output=True)) - - # save model in pt format - torch.save(net.state_dict(), data_path) - - # load model (does nothing here, but shown for completeness) - net.load_state_dict(torch.load(data_path)) - - # generate input tensor to dynamically construct graph - x = torch.zeros(784) - - # generate NIR graph - nir_net = export.to_nir(net, x) - - - :param module: a saved snnTorch model as a parameter dictionary - :type module: torch.nn.Module - - :param sample_data: sample input data to the model - :type sample_data: torch.Tensor - - :param model_name: name of library used to train model, default: "snntorch" - :type model_name: str, optional - - :return: NIR computational graph where torch modules are represented as NIR nodes - :rtype: NIRGraph - - """ - return extract_nir_graph( - module, _extract_snntorch_module, sample_data, model_name=model_name - ) diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py new file mode 100644 index 00000000..4f0a1614 --- /dev/null +++ b/snntorch/export_nir.py @@ -0,0 +1,185 @@ +from typing import Optional +import torch +import nir +import numpy as np +import nirtorch +import snntorch as snn + + +def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: + """Convert a single snnTorch module to the equivalent object in the Neuromorphic + Intermediate Representation (NIR). This function is used internally by the export_to_nir + function to convert each submodule/layer of the network to the NIR. + + Currently supported snnTorch modules: Leaky, Linear, Synaptic, RLeaky, RSynaptic. + + Note that recurrent layers such as RLeaky and RSynaptic will be converted to a NIR graph, + which will then be embedded as a subgraph into the main NIR graph. + + :param module: snnTorch module + :type module: torch.nn.Module + + :return: return the NIR node + :rtype: Optional[nir.NIRNode] + """ + if isinstance(module, snn.Leaky): + dt = 1e-4 + + beta = module.beta.detach().numpy() + vthr = module.threshold.detach().numpy() + tau_mem = dt / (1 - beta) + r = tau_mem / dt + v_leak = np.zeros_like(beta) + + return nir.LIF( + tau=tau_mem, + v_threshold=vthr, + v_leak=v_leak, + r=r, + ) + + elif isinstance(module, torch.nn.Linear): + if module.bias is None: + return nir.Linear( + weight=module.weight.data.detach().numpy() + ) + else: + return nir.Affine( + weight=module.weight.data.detach().numpy(), + bias=module.bias.data.detach().numpy() + ) + + elif isinstance(module, snn.Synaptic): + dt = 1e-4 + + # TODO: assert that size of the current layer is correct + alpha = module.alpha.detach().numpy() + beta = module.beta.detach().numpy() + vthr = module.threshold.detach().numpy() + + tau_syn = dt / (1 - alpha) + tau_mem = dt / (1 - beta) + r = tau_mem / dt + v_leak = np.zeros_like(beta) + w_in = tau_syn / dt + + return nir.CubaLIF( + tau_syn=tau_syn, + tau_mem=tau_mem, + v_threshold=vthr, + v_leak=v_leak, + r=r, + w_in=w_in, + ) + + elif isinstance(module, snn.RLeaky): + # TODO(stevenabreu7): implement RLeaky + raise NotImplementedError('RLeaky not supported') + + elif isinstance(module, snn.RSynaptic): + if module.all_to_all: + w_rec = _extract_snntorch_module(module.recurrent) + n_neurons = w_rec.weight.shape[0] + else: + if len(module.recurrent.V.shape) == 0: + # TODO: handle this better - if V is a scalar, then the weight has wrong shape + raise ValueError('V must be a vector, cannot infer layer size for scalar V') + n_neurons = module.recurrent.V.shape[0] + w = np.diag(module.recurrent.V.data.detach().numpy()) + w_rec = nir.Linear(weight=w) + + dt = 1e-4 + + alpha = module.alpha.detach().numpy() + beta = module.beta.detach().numpy() + vthr = module.threshold.detach().numpy() + alpha = np.ones(n_neurons) * alpha + beta = np.ones(n_neurons) * beta + vthr = np.ones(n_neurons) * vthr + + tau_syn = dt / (1 - alpha) + tau_mem = dt / (1 - beta) + r = tau_mem / dt + v_leak = np.zeros_like(beta) + w_in = tau_syn / dt + + return nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=[n_neurons]), + 'lif': nir.CubaLIF( + v_threshold=vthr, + tau_mem=tau_mem, + tau_syn=tau_syn, + r=r, + v_leak=v_leak, + w_in=w_in, + ), + 'w_rec': w_rec, + 'output': nir.Output(output_type=[n_neurons]) + }, edges=[ + ('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output') + ]) + + else: + print(f'[WARNING] module not implemented: {module.__class__.__name__}') + return None + + +def export_to_nir( + module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch", + model_fwd_args=[], ignore_dims=[] +) -> nir.NIRNode: + """Convert an snnTorch module to the Neuromorphic Intermediate Representation (NIR). + This function uses nirtorch to extract the computational graph of the torch module, + and the _extract_snntorch_module method is used to convert each module in the graph + to the corresponding NIR module. + + The NIR is a graph-based representation of a spiking neural network, which can be used to + port the network to different neuromorphic hardware and software platforms. + + Missing features: + - RLeaky + + Example:: + + import snntorch as snn + import torch + from snntorch import export_to_nir + + lif1 = snn.Leaky(beta=0.9, init_hidden=True) + lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) + + net = torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2 + ) + + sample_data = torch.randn(1, 784) + nir_graph = export_to_nir(net, sample_data) + + :param module: Network model (either wrapped in Sequential container or as a class) + :type module: torch.nn.Module + + :param sample_data: Sample input data to the network + :type sample_data: torch.Tensor + + :param model_name: Name of the model + :type model_name: str, optional + + :param model_fwd_args: Arguments to pass to the forward function of the model + :type model_fwd_args: list, optional + + :param ignore_dims: List of dimensions to ignore when extracting the NIR + :type ignore_dims: list, optional + + :return: return the NIR graph + :rtype: nir.NIRNode + """ + nir_graph = nirtorch.extract_nir_graph( + module, _extract_snntorch_module, sample_data, model_name=model_name, + ignore_submodules_of=[snn.RLeaky, snn.RSynaptic], + model_fwd_args=model_fwd_args, ignore_dims=ignore_dims + ) + return nir_graph diff --git a/snntorch/import_nir.py b/snntorch/import_nir.py new file mode 100644 index 00000000..154a6a17 --- /dev/null +++ b/snntorch/import_nir.py @@ -0,0 +1,389 @@ +import numpy as np +import nir +import nirtorch +import torch +import snntorch as snn + + +def _create_rnn_subgraph(graph: nir.NIRGraph, lif_nk: str, w_nk: str) -> nir.NIRGraph: + """Take a NIRGraph plus the node keys for a LIF and a W_rec, and return a new NIRGraph + which has the RNN subgraph replaced with a subgraph (i.e., a single NIRGraph node). + + The subgraph will have the following structure: + ``` + LIF -> W_rec -> LIF + ^ | + | v + Input Output + ``` + + :param graph: NIRGraph + :type graph: nir.NIRGraph + + :param lif_nk: key for the LIF node + :type lif_nk: str + + :param w_nk: key for the W_rec node + :type w_nk: str + + :return: NIRGraph with the RNN subgraph replaced with a single NIRGraph node + :rtype: nir.NIRGraph + """ + # NOTE: assuming that the LIF and W_rec have keys of form xyz.abc + sg_key = lif_nk.split('.')[0] # TODO: make this more general? + + # create subgraph for RNN + sg_edges = [ + (lif_nk, w_nk), (w_nk, lif_nk), (lif_nk, f'{sg_key}.output'), (f'{sg_key}.input', w_nk) + ] + sg_nodes = { + lif_nk: graph.nodes[lif_nk], + w_nk: graph.nodes[w_nk], + f'{sg_key}.input': nir.Input(graph.nodes[lif_nk].input_type), + f'{sg_key}.output': nir.Output(graph.nodes[lif_nk].output_type), + } + sg = nir.NIRGraph(nodes=sg_nodes, edges=sg_edges) + + # remove subgraph edges from graph + graph.edges = [e for e in graph.edges if e not in [(lif_nk, w_nk), (w_nk, lif_nk)]] + # remove subgraph nodes from graph + graph.nodes = {k: v for k, v in graph.nodes.items() if k not in [lif_nk, w_nk]} + + # change edges of type (x, lif_nk) to (x, sg_key) + graph.edges = [(e[0], sg_key) if e[1] == lif_nk else e for e in graph.edges] + # change edges of type (lif_nk, x) to (sg_key, x) + graph.edges = [(sg_key, e[1]) if e[0] == lif_nk else e for e in graph.edges] + + # insert subgraph into graph and return + graph.nodes[sg_key] = sg + return graph + + +def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: + """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node. + Goes through the NIRGraph to find any RNN subgraphs, and replaces them with a single NIRGraph node, + using the _create_rnn_subgraph function. + + :param graph: NIRGraph + :type graph: nir.NIRGraph + + :return: NIRGraph with RNN subgraphs replaced with a single NIRGraph node + :rtype: nir.NIRGraph + """ + print('replace rnn subgraph with nirgraph') + + if len(set(graph.edges)) != len(graph.edges): + print('[WARNING] duplicate edges found, removing') + graph.edges = list(set(graph.edges)) + + # find cycle of LIF <> Dense nodes + for edge1 in graph.edges: + for edge2 in graph.edges: + if not edge1 == edge2: + if edge1[0] == edge2[1] and edge1[1] == edge2[0]: + lif_nk = edge1[0] + lif_n = graph.nodes[lif_nk] + w_nk = edge1[1] + w_n = graph.nodes[w_nk] + is_lif = isinstance(lif_n, (nir.LIF, nir.CubaLIF)) + is_dense = isinstance(w_n, (nir.Affine, nir.Linear)) + # check if the dense only connects to the LIF + w_out_nk = [e[1] for e in graph.edges if e[0] == w_nk] + w_in_nk = [e[0] for e in graph.edges if e[1] == w_nk] + is_rnn = len(w_out_nk) == 1 and len(w_in_nk) == 1 + # check if we found an RNN - if so, then parse it + if is_rnn and is_lif and is_dense: + graph = _create_rnn_subgraph(graph, edge1[0], edge1[1]) + return graph + + +def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): + """Try parsing the presented graph as a RNN subgraph. Assumes the graph is a valid RNN subgraph + with four nodes in the following structure: + + ``` + Input -> LIF | CubaLIF -> Output + ^ + | + v + Affine | Linear + ``` + + :param graph: NIRGraph + :type graph: nir.NIRGraph + + :return: LIF | CubaLIF node + :rtype: nir.NIRNode + + :return: Affine | Linear node + :rtype: nir.NIRNode + + :return: int, number of neurons in the RNN + :rtype: int + """ + sub_nodes = graph.nodes.values() + assert len(sub_nodes) == 4, 'only 4-node RNN allowed in subgraph' + try: + input_node = [n for n in sub_nodes if isinstance(n, nir.Input)][0] + output_node = [n for n in sub_nodes if isinstance(n, nir.Output)][0] + lif_node = [n for n in sub_nodes if isinstance(n, (nir.LIF, nir.CubaLIF))][0] + wrec_node = [n for n in sub_nodes if isinstance(n, (nir.Affine, nir.Linear))][0] + except IndexError: + raise ValueError('invalid RNN subgraph - could not find all required nodes') + lif_size = list(input_node.input_type.values())[0][0] + assert lif_size == list(output_node.output_type.values())[0][0], 'output size mismatch' + assert lif_size == lif_node.v_threshold.size, 'lif size mismatch (v_threshold)' + assert lif_size == wrec_node.weight.shape[0], 'w_rec shape mismatch' + assert lif_size == wrec_node.weight.shape[1], 'w_rec shape mismatch' + + return lif_node, wrec_node, lif_size + + +def _nir_to_snntorch_module( + node: nir.NIRNode, hack_w_scale=True, init_hidden=False +) -> torch.nn.Module: + """Convert a NIR node to a snnTorch module. This function is used by the import_from_nir function. + + :param node: NIR node + :type node: nir.NIRNode + + :param hack_w_scale: if True, then the function will attempt to scale the weights to avoid scaling the inputs + :type hack_w_scale: bool + + :param init_hidden: the init_hidden flag of the snnTorch neuron. + :type init_hidden: bool + + :return: snnTorch module + :rtype: torch.nn.Module + """ + if isinstance(node, nir.Input) or isinstance(node, nir.Output): + return None + + elif isinstance(node, nir.Affine): + assert node.bias is not None, 'bias must be specified for Affine layer' + + mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0]) + mod.weight.data = torch.Tensor(node.weight) + mod.bias.data = torch.Tensor(node.bias) + + return mod + + elif isinstance(node, nir.Linear): + mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=False) + mod.weight.data = torch.Tensor(node.weight) + + return mod + + elif isinstance(node, nir.Conv2d): + mod = torch.nn.Conv2d( + node.weight.shape[1], + node.weight.shape[0], + kernel_size=[*node.weight.shape[-2:]], + stride=node.stride, + padding=node.padding, + dilation=node.dilation, + groups=node.groups, + ) + mod.bias.data = torch.Tensor(node.bias) + mod.weight.data = torch.Tensor(node.weight) + return mod + + if isinstance(node, nir.Flatten): + return torch.nn.Flatten(node.start_dim, node.end_dim) + + if isinstance(node, nir.SumPool2d): + return torch.nn.AvgPool2d( + kernel_size=tuple(node.kernel_size), + stride=tuple(node.stride), + padding=tuple(node.padding), + divisor_override=1, # turn AvgPool into SumPool + ) + + elif isinstance(node, nir.IF): + assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' + assert np.unique(node.r).size == 1, 'r must be same for all neurons' + vthr = np.unique(node.v_threshold)[0] + r = np.unique(node.r)[0] + assert r == 1, 'r != 1 not supported' + mod = snn.Leaky( + beta=0.9, + threshold=vthr * r, + init_hidden=False, + reset_delay=False, + ) + return mod + + elif isinstance(node, nir.LIF): + dt = 1e-4 + + assert np.allclose(node.v_leak, 0.), 'v_leak not supported' + assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' + + beta = 1 - (dt / node.tau) + vthr = node.v_threshold + w_scale = node.r * dt / node.tau + + if not np.allclose(w_scale, 1.): + if hack_w_scale: + vthr = vthr / np.unique(w_scale)[0] + print('[warning] scaling weights to avoid scaling inputs') + print(f'w_scale: {w_scale}, r: {node.r}, dt: {dt}, tau: {node.tau}') + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') + + assert np.unique(vthr).size == 1, 'LIF v_thr must be same for all neurons' + + return snn.Leaky( + beta=beta, + threshold=np.unique(vthr)[0], + reset_mechanism='zero', + init_hidden=init_hidden, + reset_delay=False, + ) + + elif isinstance(node, nir.CubaLIF): + dt = 1e-4 + + assert np.allclose(node.v_leak, 0), 'v_leak not supported' + assert np.allclose(node.r, node.tau_mem / dt), 'r not supported in CubaLIF' + + alpha = 1 - (dt / node.tau_syn) + beta = 1 - (dt / node.tau_mem) + vthr = node.v_threshold + w_scale = node.w_in * (dt / node.tau_syn) + + if not np.allclose(w_scale, 1.): + if hack_w_scale: + vthr = vthr / w_scale + print('[warning] scaling weights to avoid scaling inputs') + print(f'w_scale: {w_scale}, w_in: {node.w_in}, dt: {dt}, tau_syn: {node.tau_syn}') + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') + + assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' + + if np.unique(alpha).size == 1: + alpha = float(np.unique(alpha)[0]) + if np.unique(beta).size == 1: + beta = float(np.unique(beta)[0]) + + return snn.Synaptic( + alpha=alpha, + beta=beta, + threshold=float(np.unique(vthr)[0]), + reset_mechanism='zero', + init_hidden=init_hidden, + reset_delay=False, + ) + + elif isinstance(node, nir.NIRGraph): + lif_node, wrec_node, lif_size = _parse_rnn_subgraph(node) + + if isinstance(lif_node, nir.LIF): + raise NotImplementedError('LIF in subgraph not supported') + + elif isinstance(lif_node, nir.CubaLIF): + dt = 1e-4 + + assert np.allclose(lif_node.v_leak, 0), 'v_leak not supported' + assert np.allclose(lif_node.r, lif_node.tau_mem / dt), 'r not supported in CubaLIF' + + alpha = 1 - (dt / lif_node.tau_syn) + beta = 1 - (dt / lif_node.tau_mem) + vthr = lif_node.v_threshold + w_scale = lif_node.w_in * (dt / lif_node.tau_syn) + + if not np.allclose(w_scale, 1.): + if hack_w_scale: + vthr = vthr / w_scale + print(f'[warning] scaling weights to avoid scaling inputs. w_scale: {w_scale}') + print(f'w_in: {lif_node.w_in}, dt: {dt}, tau_syn: {lif_node.tau_syn}') + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') + + assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' + + diagonal = np.array_equal(wrec_node.weight, np.diag(np.diag(wrec_node.weight))) + + if np.unique(alpha).size == 1: + alpha = float(np.unique(alpha)[0]) + if np.unique(beta).size == 1: + beta = float(np.unique(beta)[0]) + + if diagonal: + V = torch.from_numpy(np.diag(wrec_node.weight)).to(dtype=torch.float32) + else: + V = None + + rsynaptic = snn.RSynaptic( + alpha=alpha, + beta=beta, + threshold=float(np.unique(vthr)[0]), + reset_mechanism='zero', + init_hidden=init_hidden, + all_to_all=not diagonal, + linear_features=lif_size if not diagonal else None, + V=V if diagonal else None, + reset_delay=False, + ) + + if isinstance(rsynaptic.recurrent, torch.nn.Linear): + rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) + if isinstance(wrec_node, nir.Affine): + rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias) + else: + rsynaptic.recurrent.bias.data = torch.zeros_like(rsynaptic.recurrent.bias) + else: + rsynaptic.recurrent.V.data = torch.diagonal(torch.Tensor(wrec_node.weight)) + + return rsynaptic + + elif node is None: + return torch.nn.Identity() + + else: + print('[WARNING] could not parse node of type:', node.__class__.__name__) + return None + + +def import_from_nir(graph: nir.NIRGraph) -> torch.nn.Module: + """Convert a NIRGraph to a snnTorch module. This function is the inverse of export_to_nir. + It proceeds by wrapping any recurrent connections into NIR sub-graphs, then converts each + NIR module into the equivalent snnTorch module, and wraps them into a torch.nn.Module + using the generic GraphExecutor from NIRTorch to execute all modules in the right order. + + Missing features: + - RLeaky (LIF inside RNN) + + Example:: + + import snntorch as snn + import torch + from snntorch import export_to_nir, import_from_nir + + lif1 = snn.Leaky(beta=0.9, init_hidden=True) + lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) + + net = torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2 + ) + + sample_data = torch.randn(1, 784) + nir_graph = export_to_nir(net, sample_data, model_name="snntorch") + + net2 = import_from_nir(nir_graph) + + :param graph: NIR graph + :type graph: NIR.NIRGraph + + :return: snnTorch network + :rtype: torch.nn.Module + """ + # find valid RNN subgraphs, and replace them with a single NIRGraph node + graph = _replace_rnn_subgraph_with_nirgraph(graph) + # convert the NIR graph into a torch.nn.Module using snnTorch modules + return nirtorch.load(graph, _nir_to_snntorch_module) diff --git a/tests/lif.nir b/tests/lif.nir new file mode 100644 index 00000000..ce5f8e1c Binary files /dev/null and b/tests/lif.nir differ diff --git a/tests/test_nir.py b/tests/test_nir.py new file mode 100644 index 00000000..0fa1f635 --- /dev/null +++ b/tests/test_nir.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +"""Tests for NIR import and export.""" + +import nir +import pytest +import snntorch as snn +import torch + + +@pytest.fixture(scope="module") +def sample_data(): + return torch.ones((4, 784)) + + +@pytest.fixture(scope="module") +def snntorch_sequential(): + lif1 = snn.Leaky(beta=0.9, init_hidden=True) + lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) + + return torch.nn.Sequential( + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2, + ) + + +@pytest.fixture(scope="module") +def snntorch_recurrent(): + v = torch.ones((500,)) + lif1 = snn.RSynaptic( + alpha=0.5, beta=0.9, V=v, all_to_all=False, init_hidden=True + ) + lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) + + return torch.nn.Sequential( + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2, + ) + + +class TestNIR: + """Test import and export from snnTorch to NIR.""" + + def test_export_sequential(self, snntorch_sequential, sample_data): + nir_graph = snn.export_to_nir(snntorch_sequential, sample_data) + assert nir_graph is not None + assert set(nir_graph.nodes.keys()) == set( + ["input", "output"] + [str(i) for i in range(4)] + ), nir_graph.nodes.keys() + assert set(nir_graph.edges) == set( + [ + ("3", "output"), + ("input", "0"), + ("2", "3"), + ("1", "2"), + ("0", "1"), + ] + ) + assert isinstance(nir_graph.nodes["input"], nir.Input) + assert isinstance(nir_graph.nodes["output"], nir.Output) + assert isinstance(nir_graph.nodes["0"], nir.Affine) + assert isinstance(nir_graph.nodes["1"], nir.LIF) + assert isinstance(nir_graph.nodes["2"], nir.Affine) + assert isinstance(nir_graph.nodes["3"], nir.LIF) + + def test_export_recurrent(self, snntorch_recurrent, sample_data): + nir_graph = snn.export_to_nir(snntorch_recurrent, sample_data) + assert nir_graph is not None + assert set(nir_graph.nodes.keys()) == set( + ["input", "output", "0", "1.lif", "1.w_rec", "2", "3"] + ), nir_graph.nodes.keys() + assert isinstance(nir_graph.nodes["input"], nir.Input) + assert isinstance(nir_graph.nodes["output"], nir.Output) + assert isinstance(nir_graph.nodes["0"], nir.Affine) + assert isinstance(nir_graph.nodes["1.lif"], nir.CubaLIF) + assert isinstance(nir_graph.nodes["1.w_rec"], nir.Linear) + assert isinstance(nir_graph.nodes["2"], nir.Affine) + assert isinstance(nir_graph.nodes["3"], nir.LIF) + assert set(nir_graph.edges) == set( + [ + ("1.lif", "1.w_rec"), + ("1.w_rec", "1.lif"), + ("0", "1.lif"), + ("3", "output"), + ("2", "3"), + ("input", "0"), + ("1.lif", "2"), + ] + ) + + def test_import_nir(self): + graph = nir.read("tests/lif.nir") + net = snn.import_from_nir(graph) + out, _ = net(torch.ones(1, 1)) + assert out.shape == (1, 1), out.shape + + def test_commute_sequential(self, snntorch_sequential, sample_data): + x = torch.rand((4, 784)) + y_snn, state = snntorch_sequential(x) + assert y_snn.shape == (4, 10) + nir_graph = snn.export_to_nir(snntorch_sequential, sample_data) + net = snn.import_from_nir(nir_graph) + y_nir, state = net(x) + assert y_nir.shape == (4, 10), y_nir.shape + assert torch.allclose(y_snn, y_nir)