From a2d05dd0bf9ff42822538315187b153568fa8f06 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Mon, 9 Oct 2023 16:21:25 +0200 Subject: [PATCH 01/33] start import functionality for NIR --- snntorch/import.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 snntorch/import.py diff --git a/snntorch/import.py b/snntorch/import.py new file mode 100644 index 00000000..efe9b320 --- /dev/null +++ b/snntorch/import.py @@ -0,0 +1,81 @@ +import snntorch as snn +import numpy as np +import torch +import nir + + +class ImportedNetwork(torch.nn.Module): + def __init__(self, module_list): + super().__init__() + self.module_list = module_list + + def forward(self, x): + for module in self.module_list: + # TODO: this must be implemented in snnTorch (timestep) + x = module(x) + return x + + +def _to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: + """Convert a NIR node to a snnTorch module. + + Supported NIR nodes: Affine. + """ + if isinstance(node, nir.LIF): + return snn.Leaky() + + elif isinstance(node, nir.Affine): + if len(node.weight.shape) != 2: + raise NotImplementedError('only 2D weight matrices are supported') + has_bias = node.bias is not None and not np.alltrue(node.bias == 0) + linear = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=has_bias) + linear.weight.data = torch.Tensor(node.weight) + if has_bias: + linear.bias.data = torch.Tensor(node.bias) + return linear + + else: + raise NotImplementedError(f'node type {type(node).__name__} not supported') + + +def _get_next_node_key(node_key: str, graph: nir.ir.NIRGraph): + """Get the next node key in the NIR graph.""" + possible_next_node_keys = [edge[1] for edge in graph.edges if edge[0] == node_key] + assert len(possible_next_node_keys) <= 1, 'branching networks are not supported' + if len(possible_next_node_keys) == 0: + return None + else: + return possible_next_node_keys[0] + + +def from_nir(graph: nir.ir.NIRGraph) -> torch.nn.Module: + """Convert NIR graph to snnTorch module. + + :param graph: a saved snnTorch model as a parameter dictionary + :type graph: nir.ir.NIRGraph + + :return: snnTorch module + :rtype: torch.nn.Module + """ + node_key = 'input' + visited_node_keys = [node_key] + module_list = [] + + while _get_next_node_key(node_key, graph.edges) is not None: + node_key = _get_next_node_key(node_key, graph.edges) + node = graph.nodes[node_key] + + if node_key in visited_node_keys: + raise NotImplementedError('cyclic NIR graphs are not supported') + + visited_node_keys.append(node_key) + print(f'node {node_key}: {type(node).__name__}') + if node_key == 'output': + continue + module = _to_snntorch_module(node) + module_list.append(module) + + if len(visited_node_keys) != len(graph.nodes): + raise ValueError('not all nodes visited') + + return ImportedNetwork(module_list) From f592cc3209faebc7aba5e395b54e1bf2f271bb56 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Tue, 10 Oct 2023 12:49:26 +0200 Subject: [PATCH 02/33] add support for RLeaky and RSynaptic --- snntorch/export.py | 97 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 81 insertions(+), 16 deletions(-) diff --git a/snntorch/export.py b/snntorch/export.py index b9cb94ce..de7c3e96 100644 --- a/snntorch/export.py +++ b/snntorch/export.py @@ -3,38 +3,87 @@ import torch import nir +import numpy as np from nirtorch import extract_nir_graph -from snntorch import Leaky, Synaptic +from snntorch import Leaky, Synaptic, RLeaky, RSynaptic + + +def _create_rnn_subgraph(module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLIF]) -> nir.NIRGraph: + """Create NIR Graph for RNN, from the snnTorch module and the extracted LIF/CubaLIF node.""" + if module.all_to_all: + lif_shape = module.recurrent.weight.shape[0] + weight_rec = module.recurrent.weight.data.detach().numpy() + if module.recurrent.bias is not None: + assert np.allclose(module.recurrent.bias.detach().numpy(), 0), 'bias not supported' + else: + if len(module.recurrent.V.shape) == 0: + lif_shape = None + weight_rec = np.eye(1) * module.recurrent.V.data.detach().numpy() + else: + lif_shape = module.recurrent.V.shape[0] + weight_rec = np.diag(module.recurrent.V.data.detach().numpy()) + + return nir.NIRGraph( + nodes={ + 'input': nir.Input(input_type=[lif_shape]), + 'lif': lif, + 'w_rec': nir.Linear(weight=weight_rec), + 'output': nir.Output(output_type=[lif_shape]) + }, + edges=[('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output')] + ) + # eqn is assumed to be: v_t+1 = (1-1/tau)*v_t + 1/tau * v_leak + I_in / C -def _extract_snntorch_module(module:torch.nn.Module) -> Optional[nir.NIRNode]: +def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: if isinstance(module, Leaky): return nir.LIF( - tau = 1 / (1 - module.beta).detach(), - v_threshold = module.threshold.detach(), - v_leak = torch.zeros_like(module.beta), - r = module.beta.detach(), + tau=1 / (1 - module.beta).detach(), + v_threshold=module.threshold.detach(), + v_leak=torch.zeros_like(module.beta), + r=module.beta.detach(), ) - if isinstance(module, Synaptic): + elif isinstance(module, RSynaptic): + lif = nir.CubaLIF( + tau_syn=1 / (1 - module.beta).detach(), + tau_mem=1 / (1 - module.alpha).detach(), + v_threshold=module.threshold.detach(), + v_leak=torch.zeros_like(module.beta), + r=module.beta.detach(), + ) + return _create_rnn_subgraph(module, lif) + + elif isinstance(module, RLeaky): + lif = nir.LIF( + tau=1 / (1 - module.beta).detach(), + v_threshold=module.threshold.detach(), + v_leak=torch.zeros_like(module.beta), + r=module.beta.detach(), + ) + return _create_rnn_subgraph(module, lif) + + elif isinstance(module, Synaptic): return nir.CubaLIF( - tau_syn = 1 / (1 - module.beta).detach(), - tau_mem = 1 / (1 - module.alpha).detach(), - v_threshold = module.threshold.detach(), - v_leak = torch.zeros_like(module.beta), - r = module.beta.detach(), + tau_syn=1 / (1 - module.beta).detach(), + tau_mem=1 / (1 - module.alpha).detach(), + v_threshold=module.threshold.detach(), + v_leak=torch.zeros_like(module.beta), + r=module.beta.detach(), ) elif isinstance(module, torch.nn.Linear): - if module.bias is None: # Add zero bias if none is present + if module.bias is None: # Add zero bias if none is present return nir.Affine( module.weight.detach(), torch.zeros(*module.weight.shape[:-1]) ) else: return nir.Affine(module.weight.detach(), module.bias.detach()) - return None + else: + print(f'[WARNING] unknown module type: {type(module).__name__} (ignored)') + return None def to_nir( @@ -81,6 +130,22 @@ def to_nir( :rtype: NIRGraph """ - return extract_nir_graph( - module, _extract_snntorch_module, sample_data, model_name=model_name + nir_graph = extract_nir_graph( + module, _extract_snntorch_module, sample_data, model_name=model_name, + ignore_submodules_of=[RLeaky, RSynaptic] ) + + # NOTE: this is a hack to make sure all input and output types are set correctly + for node_key, node in nir_graph.nodes.items(): + input_undef = node.input_type.get('input', [None]) == [None] + if isinstance(node, nir.Input) and input_undef and '.' in node_key: + print('WARNING: subgraph input type not set, inferring from previous node') + key = '.'.join(node_key.split('.')[:-1]) + prev_keys = [edge[0] for edge in nir_graph.edges if edge[1] == key] + assert len(prev_keys) == 1, 'multiple previous nodes not supported' + prev_node = nir_graph.nodes[prev_keys[0]] + cur_type = prev_node.output_type['output'] + node.input_type['input'] = cur_type + nir_graph.nodes[f'{key}.output'].output_type['output'] = cur_type + + return nir_graph From 7ef2e5d4e43d26f4777c807169ec5f82a7d0ed21 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Tue, 10 Oct 2023 13:07:24 +0200 Subject: [PATCH 03/33] minor fix --- snntorch/export.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/snntorch/export.py b/snntorch/export.py index de7c3e96..7c018c7c 100644 --- a/snntorch/export.py +++ b/snntorch/export.py @@ -13,22 +13,21 @@ def _create_rnn_subgraph(module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLI """Create NIR Graph for RNN, from the snnTorch module and the extracted LIF/CubaLIF node.""" if module.all_to_all: lif_shape = module.recurrent.weight.shape[0] - weight_rec = module.recurrent.weight.data.detach().numpy() - if module.recurrent.bias is not None: - assert np.allclose(module.recurrent.bias.detach().numpy(), 0), 'bias not supported' + w_rec = module.recurrent.weight.data.detach().numpy() + b = None if module.recurrent.bias is None else module.recurrent.bias.data.detach().numpy() else: if len(module.recurrent.V.shape) == 0: lif_shape = None - weight_rec = np.eye(1) * module.recurrent.V.data.detach().numpy() + w_rec = np.eye(1) * module.recurrent.V.data.detach().numpy() else: lif_shape = module.recurrent.V.shape[0] - weight_rec = np.diag(module.recurrent.V.data.detach().numpy()) + w_rec = np.diag(module.recurrent.V.data.detach().numpy()) return nir.NIRGraph( nodes={ 'input': nir.Input(input_type=[lif_shape]), 'lif': lif, - 'w_rec': nir.Linear(weight=weight_rec), + 'w_rec': nir.Linear(weight=w_rec) if b is None else nir.Affine(weight=w_rec, bias=b), 'output': nir.Output(output_type=[lif_shape]) }, edges=[('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output')] @@ -137,7 +136,8 @@ def to_nir( # NOTE: this is a hack to make sure all input and output types are set correctly for node_key, node in nir_graph.nodes.items(): - input_undef = node.input_type.get('input', [None]) == [None] + inp_type = node.input_type.get('input', [None]) + input_undef = len(inp_type) == 0 or inp_type[0] is None if isinstance(node, nir.Input) and input_undef and '.' in node_key: print('WARNING: subgraph input type not set, inferring from previous node') key = '.'.join(node_key.split('.')[:-1]) From 7701cc0fff1e56fcc06928b3a8a7ff4adbb7e16b Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Tue, 10 Oct 2023 13:09:41 +0200 Subject: [PATCH 04/33] minor fix #2 --- snntorch/export.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/snntorch/export.py b/snntorch/export.py index 7c018c7c..ce88cc1b 100644 --- a/snntorch/export.py +++ b/snntorch/export.py @@ -11,10 +11,12 @@ def _create_rnn_subgraph(module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLIF]) -> nir.NIRGraph: """Create NIR Graph for RNN, from the snnTorch module and the extracted LIF/CubaLIF node.""" + b = None if module.all_to_all: lif_shape = module.recurrent.weight.shape[0] w_rec = module.recurrent.weight.data.detach().numpy() - b = None if module.recurrent.bias is None else module.recurrent.bias.data.detach().numpy() + if module.recurrent.bias is not None: + b = module.recurrent.bias.data.detach().numpy() else: if len(module.recurrent.V.shape) == 0: lif_shape = None From d3f1a8c620f6d5c3a6846c140f1bfa37d1ed97bb Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Thu, 12 Oct 2023 12:16:45 +0200 Subject: [PATCH 05/33] rename export + fix subgraph --- snntorch/{export.py => export_nir.py} | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) rename snntorch/{export.py => export_nir.py} (86%) diff --git a/snntorch/export.py b/snntorch/export_nir.py similarity index 86% rename from snntorch/export.py rename to snntorch/export_nir.py index ce88cc1b..cb565431 100644 --- a/snntorch/export.py +++ b/snntorch/export_nir.py @@ -77,10 +77,10 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: elif isinstance(module, torch.nn.Linear): if module.bias is None: # Add zero bias if none is present return nir.Affine( - module.weight.detach(), torch.zeros(*module.weight.shape[:-1]) + module.weight.detach().numpy(), np.zeros(*module.weight.shape[:-1]) ) else: - return nir.Affine(module.weight.detach(), module.bias.detach()) + return nir.Affine(module.weight.detach().numpy(), module.bias.detach().numpy()) else: print(f'[WARNING] unknown module type: {type(module).__name__} (ignored)') @@ -150,4 +150,19 @@ def to_nir( node.input_type['input'] = cur_type nir_graph.nodes[f'{key}.output'].output_type['output'] = cur_type + # NOTE: hack to remove recurrent connections of subgraph to itself + for edge in nir_graph.edges: + if edge[0] not in nir_graph.nodes and edge[1] not in nir_graph.nodes: + nir_graph.edges.remove(edge) + + # NOTE: hack to rename input and output nodes of subgraphs + for edge in nir_graph.edges: + if edge[1] not in nir_graph.nodes: + nir_graph.edges.remove(edge) + nir_graph.edges.append((edge[0], f'{edge[1]}.input')) + for edge in nir_graph.edges: + if edge[0] not in nir_graph.nodes: + nir_graph.edges.remove(edge) + nir_graph.edges.append((f'{edge[0]}.output', edge[1])) + return nir_graph From d2755ff7e7c8bf2ad407b6d9c64505ca5c31cb06 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Thu, 12 Oct 2023 12:23:14 +0200 Subject: [PATCH 06/33] remove hack to rename subgraph edges --- snntorch/export_nir.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index cb565431..0b1e8e0c 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -155,14 +155,14 @@ def to_nir( if edge[0] not in nir_graph.nodes and edge[1] not in nir_graph.nodes: nir_graph.edges.remove(edge) - # NOTE: hack to rename input and output nodes of subgraphs - for edge in nir_graph.edges: - if edge[1] not in nir_graph.nodes: - nir_graph.edges.remove(edge) - nir_graph.edges.append((edge[0], f'{edge[1]}.input')) - for edge in nir_graph.edges: - if edge[0] not in nir_graph.nodes: - nir_graph.edges.remove(edge) - nir_graph.edges.append((f'{edge[0]}.output', edge[1])) + # # NOTE: hack to rename input and output nodes of subgraphs (not needed) + # for edge in nir_graph.edges: + # if edge[1] not in nir_graph.nodes: + # nir_graph.edges.remove(edge) + # nir_graph.edges.append((edge[0], f'{edge[1]}.input')) + # for edge in nir_graph.edges: + # if edge[0] not in nir_graph.nodes: + # nir_graph.edges.remove(edge) + # nir_graph.edges.append((f'{edge[0]}.output', edge[1])) return nir_graph From e2e92c3aaf9452b8e5dfc1e12876f9be952095dd Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Thu, 12 Oct 2023 15:56:48 +0200 Subject: [PATCH 07/33] make import and export work for RNN --- snntorch/export_nir.py | 103 ++++++++++++++++--------- snntorch/import.py | 81 -------------------- snntorch/import_nir.py | 166 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+), 115 deletions(-) delete mode 100644 snntorch/import.py create mode 100644 snntorch/import_nir.py diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index 0b1e8e0c..ce3baec5 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -1,5 +1,4 @@ from typing import Union, Optional -from numbers import Number import torch import nir @@ -9,7 +8,9 @@ from snntorch import Leaky, Synaptic, RLeaky, RSynaptic -def _create_rnn_subgraph(module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLIF]) -> nir.NIRGraph: +def _create_rnn_subgraph( + module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLIF], n_neurons=-1 +) -> nir.NIRGraph: """Create NIR Graph for RNN, from the snnTorch module and the extracted LIF/CubaLIF node.""" b = None if module.all_to_all: @@ -18,9 +19,12 @@ def _create_rnn_subgraph(module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLI if module.recurrent.bias is not None: b = module.recurrent.bias.data.detach().numpy() else: - if len(module.recurrent.V.shape) == 0: + if len(module.recurrent.V.shape) == 0 and n_neurons == -1: lif_shape = None w_rec = np.eye(1) * module.recurrent.V.data.detach().numpy() + elif n_neurons != -1: + lif_shape = n_neurons + w_rec = np.eye(n_neurons) * module.recurrent.V.data.detach().numpy() else: lif_shape = module.recurrent.V.shape[0] w_rec = np.diag(module.recurrent.V.data.detach().numpy()) @@ -36,42 +40,83 @@ def _create_rnn_subgraph(module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLI ) +def _get_neuron_count(module: torch.nn.Module) -> int: + if isinstance(module, RLeaky) or isinstance(module, RSynaptic): + if module.all_to_all: + return module.linear_features + elif isinstance(module.recurrent.V, torch.Tensor) and len(module.recurrent.V.shape) > 0: + return module.recurrent.V.shape[0] + elif module.init_hidden is True: + return module.mem.shape[0] + else: + # not implemented + return -1 + else: + # not implemented + return -1 + + # eqn is assumed to be: v_t+1 = (1-1/tau)*v_t + 1/tau * v_leak + I_in / C def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: + """ + NOTE: it might leave the NIR node of neurons with an incompatible shape. This must be fixed. + """ if isinstance(module, Leaky): + # TODO + tau = 1 / (1 - module.beta.detach().numpy()) + r = module.beta.detach().numpy() + threshold = module.threshold.detach().numpy() return nir.LIF( - tau=1 / (1 - module.beta).detach(), - v_threshold=module.threshold.detach(), - v_leak=torch.zeros_like(module.beta), - r=module.beta.detach(), + tau=tau, + v_threshold=threshold, + v_leak=torch.zeros_like(tau), + r=r, ) elif isinstance(module, RSynaptic): + alpha = module.alpha.detach().numpy() + beta = module.beta.detach().numpy() + threshold = module.threshold.detach().numpy() + n_neurons = _get_neuron_count(module) + if len(alpha.shape) == 0 and n_neurons != -1: + alpha = np.ones(n_neurons) * alpha + if len(beta.shape) == 0 and n_neurons != -1: + beta = np.ones(n_neurons) * beta + if len(threshold.shape) == 0 and n_neurons != -1: + threshold = np.ones(n_neurons) * threshold lif = nir.CubaLIF( - tau_syn=1 / (1 - module.beta).detach(), - tau_mem=1 / (1 - module.alpha).detach(), - v_threshold=module.threshold.detach(), - v_leak=torch.zeros_like(module.beta), - r=module.beta.detach(), + tau_syn=1 / (1 - beta), + tau_mem=1 / (1 - alpha), + v_threshold=threshold, + v_leak=np.zeros_like(beta), + r=beta, ) - return _create_rnn_subgraph(module, lif) + return _create_rnn_subgraph(module, lif, n_neurons=n_neurons) elif isinstance(module, RLeaky): + beta = module.beta.detach().numpy() + threshold = module.threshold.detach().numpy() + n_neurons = _get_neuron_count(module) + if len(beta.shape) == 0 and n_neurons != -1: + beta = np.ones(n_neurons) * beta + if len(threshold.shape) == 0 and n_neurons != -1: + threshold = np.ones(n_neurons) * threshold lif = nir.LIF( - tau=1 / (1 - module.beta).detach(), - v_threshold=module.threshold.detach(), - v_leak=torch.zeros_like(module.beta), - r=module.beta.detach(), + tau=1 / (1 - beta), + v_threshold=threshold, + v_leak=np.zeros_like(beta), + r=beta, ) - return _create_rnn_subgraph(module, lif) + return _create_rnn_subgraph(module, lif, n_neurons=n_neurons) elif isinstance(module, Synaptic): + # TODO return nir.CubaLIF( - tau_syn=1 / (1 - module.beta).detach(), - tau_mem=1 / (1 - module.alpha).detach(), - v_threshold=module.threshold.detach(), - v_leak=torch.zeros_like(module.beta), - r=module.beta.detach(), + tau_syn=1 / (1 - module.alpha).detach().numpy(), + tau_mem=1 / (1 - module.beta).detach().numpy(), + v_threshold=module.threshold.detach().numpy(), + v_leak=np.zeros_like(module.beta), + r=module.beta.detach().numpy(), # NOTE: is this right? ) elif isinstance(module, torch.nn.Linear): @@ -136,7 +181,7 @@ def to_nir( ignore_submodules_of=[RLeaky, RSynaptic] ) - # NOTE: this is a hack to make sure all input and output types are set correctly + # NOTE: hack to define subgraph I/O types for node_key, node in nir_graph.nodes.items(): inp_type = node.input_type.get('input', [None]) input_undef = len(inp_type) == 0 or inp_type[0] is None @@ -155,14 +200,4 @@ def to_nir( if edge[0] not in nir_graph.nodes and edge[1] not in nir_graph.nodes: nir_graph.edges.remove(edge) - # # NOTE: hack to rename input and output nodes of subgraphs (not needed) - # for edge in nir_graph.edges: - # if edge[1] not in nir_graph.nodes: - # nir_graph.edges.remove(edge) - # nir_graph.edges.append((edge[0], f'{edge[1]}.input')) - # for edge in nir_graph.edges: - # if edge[0] not in nir_graph.nodes: - # nir_graph.edges.remove(edge) - # nir_graph.edges.append((f'{edge[0]}.output', edge[1])) - return nir_graph diff --git a/snntorch/import.py b/snntorch/import.py deleted file mode 100644 index efe9b320..00000000 --- a/snntorch/import.py +++ /dev/null @@ -1,81 +0,0 @@ -import snntorch as snn -import numpy as np -import torch -import nir - - -class ImportedNetwork(torch.nn.Module): - def __init__(self, module_list): - super().__init__() - self.module_list = module_list - - def forward(self, x): - for module in self.module_list: - # TODO: this must be implemented in snnTorch (timestep) - x = module(x) - return x - - -def _to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: - """Convert a NIR node to a snnTorch module. - - Supported NIR nodes: Affine. - """ - if isinstance(node, nir.LIF): - return snn.Leaky() - - elif isinstance(node, nir.Affine): - if len(node.weight.shape) != 2: - raise NotImplementedError('only 2D weight matrices are supported') - has_bias = node.bias is not None and not np.alltrue(node.bias == 0) - linear = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=has_bias) - linear.weight.data = torch.Tensor(node.weight) - if has_bias: - linear.bias.data = torch.Tensor(node.bias) - return linear - - else: - raise NotImplementedError(f'node type {type(node).__name__} not supported') - - -def _get_next_node_key(node_key: str, graph: nir.ir.NIRGraph): - """Get the next node key in the NIR graph.""" - possible_next_node_keys = [edge[1] for edge in graph.edges if edge[0] == node_key] - assert len(possible_next_node_keys) <= 1, 'branching networks are not supported' - if len(possible_next_node_keys) == 0: - return None - else: - return possible_next_node_keys[0] - - -def from_nir(graph: nir.ir.NIRGraph) -> torch.nn.Module: - """Convert NIR graph to snnTorch module. - - :param graph: a saved snnTorch model as a parameter dictionary - :type graph: nir.ir.NIRGraph - - :return: snnTorch module - :rtype: torch.nn.Module - """ - node_key = 'input' - visited_node_keys = [node_key] - module_list = [] - - while _get_next_node_key(node_key, graph.edges) is not None: - node_key = _get_next_node_key(node_key, graph.edges) - node = graph.nodes[node_key] - - if node_key in visited_node_keys: - raise NotImplementedError('cyclic NIR graphs are not supported') - - visited_node_keys.append(node_key) - print(f'node {node_key}: {type(node).__name__}') - if node_key == 'output': - continue - module = _to_snntorch_module(node) - module_list.append(module) - - if len(visited_node_keys) != len(graph.nodes): - raise ValueError('not all nodes visited') - - return ImportedNetwork(module_list) diff --git a/snntorch/import_nir.py b/snntorch/import_nir.py new file mode 100644 index 00000000..a18a3a42 --- /dev/null +++ b/snntorch/import_nir.py @@ -0,0 +1,166 @@ +import snntorch as snn +import numpy as np +import torch +import nir +import typing + + +# TODO: implement this? +class ImportedNetwork(torch.nn.Module): + """Wrapper for a snnTorch network. NOTE: not working atm.""" + def __init__(self, module_list): + super().__init__() + self.module_list = module_list + + def forward(self, x): + for module in self.module_list: + x = module(x) + return x + + +def create_snntorch_network(module_list): + return torch.nn.Sequential(*module_list) + + +def _lif_to_snntorch_module( + lif: typing.Union[nir.LIF, nir.CubaLIF] +) -> torch.nn.Module: + """Parse a LIF node into snnTorch.""" + if isinstance(lif, nir.LIF): + assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' + assert np.alltrue(lif.r == 1. - 1. / lif.tau), 'r not supported' + assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons' + threshold = lif.v_threshold[0] + mod = snn.RLeaky( + beta=1. - 1. / lif.tau, + threshold=threshold, + all_to_all=True, + reset_mechanism='zero', + linear_features=lif.tau.shape[0] if len(lif.tau.shape) == 1 else None, + init_hidden=True, + ) + return mod + + elif isinstance(lif, nir.CubaLIF): + assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' + assert np.alltrue(lif.r == 1. - 1. / lif.tau_mem), 'r not supported' # NOTE: is this right? + assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons' + threshold = lif.v_threshold[0] + mod = snn.RSynaptic( + alpha=1. - 1. / lif.tau_syn, + beta=1. - 1. / lif.tau_mem, + threshold=threshold, + all_to_all=True, + reset_mechanism='zero', + linear_features=lif.tau_mem.shape[0] if len(lif.tau_mem.shape) == 1 else None, + init_hidden=True, + ) + return mod + + else: + raise ValueError('called _lif_to_snntorch_module on non-LIF node') + + +def _to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: + """Convert a NIR node to a snnTorch module. + + Supported NIR nodes: Affine. + """ + if isinstance(node, (nir.LIF, nir.CubaLIF)): + return _lif_to_snntorch_module(node) + + elif isinstance(node, nir.Affine): + if len(node.weight.shape) != 2: + raise NotImplementedError('only 2D weight matrices are supported') + has_bias = node.bias is not None and not np.alltrue(node.bias == 0) + linear = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=has_bias) + linear.weight.data = torch.Tensor(node.weight) + if has_bias: + linear.bias.data = torch.Tensor(node.bias) + return linear + + else: + raise NotImplementedError(f'node type {type(node).__name__} not supported') + + +def _rnn_subgraph_to_snntorch_module( + lif: typing.Union[nir.LIF, nir.CubaLIF], w_rec: typing.Union[nir.Affine, nir.Linear] +) -> torch.nn.Module: + """Parse an RNN subgraph consisting of a LIF node and a recurrent weight matrix into snnTorch. + + NOTE: for now always set it as a recurrent linear layer (not RecurrentOneToOne) + """ + assert isinstance(lif, (nir.LIF, nir.CubaLIF)), 'only LIF or CubaLIF nodes supported as RNNs' + mod = _lif_to_snntorch_module(lif) + mod.recurrent.weight.data = torch.Tensor(w_rec.weight) + if isinstance(w_rec, nir.Linear): + mod.recurrent.register_parameter('bias', None) + mod.recurrent.reset_parameters() + else: + mod.recurrent.bias.data = torch.Tensor(w_rec.bias) + return mod + + +def _get_next_node_key(node_key: str, graph: nir.ir.NIRGraph): + """Get the next node key in the NIR graph.""" + possible_next_node_keys = [edge[1] for edge in graph.edges if edge[0] == node_key] + # possible_next_node_keys += [edge[1] + '.input' for edge in graph.edges if edge[0] == node_key] + assert len(possible_next_node_keys) <= 1, 'branching networks are not supported' + if len(possible_next_node_keys) == 0: + return None + else: + return possible_next_node_keys[0] + + +def from_nir(graph: nir.ir.NIRGraph) -> torch.nn.Module: + """Convert NIR graph to snnTorch module. + + :param graph: a saved snnTorch model as a parameter dictionary + :type graph: nir.ir.NIRGraph + + :return: snnTorch module + :rtype: torch.nn.Module + """ + node_key = 'input' + visited_node_keys = [node_key] + module_list = [] + + while _get_next_node_key(node_key, graph) is not None: + node_key = _get_next_node_key(node_key, graph) + + assert node_key not in visited_node_keys, 'cyclic NIR graphs not supported' + + if node_key == 'output': + visited_node_keys.append(node_key) + continue + + if node_key in graph.nodes: + visited_node_keys.append(node_key) + node = graph.nodes[node_key] + print(f'simple node {node_key}: {type(node).__name__}') + module = _to_snntorch_module(node) + else: + # check if it's a nested node + print(f'potential subgraph node: {node_key}') + sub_node_keys = [n for n in graph.nodes if n.startswith(f'{node_key}.')] + assert len(sub_node_keys) > 0, f'no nodes found for subgraph {node_key}' + + # parse subgraph + # NOTE: for now only looking for RNN subgraphs + rnn_sub_node_keys = [f'{node_key}.{n}' for n in ['input', 'output', 'lif', 'w_rec']] + if set(sub_node_keys) != set(rnn_sub_node_keys): + raise NotImplementedError('only RNN subgraphs are supported') + print('found RNN subgraph') + module = _rnn_subgraph_to_snntorch_module( + graph.nodes[f'{node_key}.lif'], graph.nodes[f'{node_key}.w_rec'] + ) + for nk in sub_node_keys: + visited_node_keys.append(nk) + + module_list.append(module) + + if len(visited_node_keys) != len(graph.nodes): + print(graph.nodes.keys(), visited_node_keys) + raise ValueError('not all nodes visited') + + return create_snntorch_network(module_list) From d3ee326235248ff570962cfa30f59467d2d7d824 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Thu, 12 Oct 2023 18:38:07 +0200 Subject: [PATCH 08/33] import & export using nirtorch (instead of manual) --- snntorch/export_nirtorch.py | 73 +++++++++++++++++++++++++++++++++++++ snntorch/import_nirtorch.py | 53 +++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 snntorch/export_nirtorch.py create mode 100644 snntorch/import_nirtorch.py diff --git a/snntorch/export_nirtorch.py b/snntorch/export_nirtorch.py new file mode 100644 index 00000000..b405b427 --- /dev/null +++ b/snntorch/export_nirtorch.py @@ -0,0 +1,73 @@ +from typing import Union, Optional + +import torch +import nir +import numpy as np +from nirtorch import extract_nir_graph +import snntorch as snn + + +def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: + if isinstance(module, snn.Leaky): + print('leaky') + return None + + elif isinstance(module, torch.nn.Linear): + if module.bias is None: + return nir.Linear( + weight=module.weight.data.detach().numpy() + ) + else: + return nir.Affine( + weight=module.weight.data.detach().numpy(), + bias=module.bias.data.detach().numpy() + ) + + elif isinstance(module, snn.RLeaky): + if module.all_to_all: + w = module.recurrent.weight.data.detach().numpy() + n_neurons = w.shape[0] + print(module.linear_features, w.shape) + if module.recurrent.bias is None: + w_rec = nir.Linear(weight=w) + else: + b = module.recurrent.bias.data.detach().numpy() + w_rec = nir.Affine(weight=w, bias=b) + else: + # TODO: handle this better - if V is a scalar, then the weight has wrong shape + assert len(module.recurrent.V.shape) == 1, 'V must be a vector' + n_neurons = module.recurrent.V.shape[0] + w = np.diag(module.recurrent.V.data.detach().numpy()) + w_rec = nir.Linear(weight=w) + + # TODO: set the parameters correctly + v_thr = np.ones(n_neurons) * module.threshold.detach().numpy() + beta = np.ones(n_neurons) * module.beta.detach().numpy() + tau = 1 / (1 - beta) + r = beta + v_leak = beta + + return nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=[n_neurons]), + 'lif': nir.LIF(v_threshold=v_thr, tau=tau, r=r, v_leak=v_leak), + 'w_rec': w_rec, + 'output': nir.Output(output_type=[n_neurons]) + }, edges=[ + ('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output') + ]) + + else: + print(f'[WARNING] module not implemented: {module.__class__.__name__}') + return None + + +def to_nir( + module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch" +) -> nir.NIRNode: + """Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR). + """ + nir_graph = extract_nir_graph( + module, _extract_snntorch_module, sample_data, model_name=model_name, + ignore_submodules_of=[snn.RLeaky, snn.RSynaptic] + ) + return nir_graph diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py new file mode 100644 index 00000000..f3b82096 --- /dev/null +++ b/snntorch/import_nirtorch.py @@ -0,0 +1,53 @@ +import numpy as np +import nir +import nirtorch +import torch +import snntorch as snn + + +def _nir_to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: + if isinstance(node, nir.Input) or isinstance(node, nir.Output): + return None + + elif isinstance(node, nir.Affine): + mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0]) + mod.weight.data = torch.Tensor(node.weight) + if node.bias is not None: + mod.bias.data = torch.Tensor(node.bias) + return mod + + elif isinstance(node, nir.Linear): + mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=False) + mod.weight.data = torch.Tensor(node.weight) + return mod + + elif isinstance(node, nir.LIF): + # NOTE: assuming that parameters are arrays of correct size + dt = 1e-4 + assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' + vthr = node.v_threshold + beta = 1 - (dt / node.tau) + w_scale = node.r * dt / node.tau + breakpoint() + if np.alltrue(w_scale == 1.) or np.unique(w_scale).size == 1: + # HACK to avoid scaling the inputs + print('[warning] scaling weights to avoid scaling inputs') + vthr = vthr / np.unique(w_scale)[0] + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') + return snn.Leaky( + beta=beta, + threshold=vthr, + reset_mechanism='zero', + init_hidden=True, + # init_hidden=False, + ) + + else: + print(node.__class__.__name__, node) + + return None + + +def from_nir(graph: nir.NIRGraph) -> torch.nn.Module: + return nirtorch.load(graph, _nir_to_snntorch_module) From 14767e5568229fad1d2bd60da15ad7b6d19edc43 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Fri, 13 Oct 2023 10:54:39 +0200 Subject: [PATCH 09/33] RNN nirtorch export/import works but still buggy --- snntorch/export_nirtorch.py | 18 +++---- snntorch/import_nirtorch.py | 96 +++++++++++++++++++++++++++++++++++-- 2 files changed, 97 insertions(+), 17 deletions(-) diff --git a/snntorch/export_nirtorch.py b/snntorch/export_nirtorch.py index b405b427..30ae95ec 100644 --- a/snntorch/export_nirtorch.py +++ b/snntorch/export_nirtorch.py @@ -1,5 +1,4 @@ -from typing import Union, Optional - +from typing import Optional import torch import nir import numpy as np @@ -25,17 +24,12 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: elif isinstance(module, snn.RLeaky): if module.all_to_all: - w = module.recurrent.weight.data.detach().numpy() - n_neurons = w.shape[0] - print(module.linear_features, w.shape) - if module.recurrent.bias is None: - w_rec = nir.Linear(weight=w) - else: - b = module.recurrent.bias.data.detach().numpy() - w_rec = nir.Affine(weight=w, bias=b) + w_rec = _extract_snntorch_module(module.recurrent) + n_neurons = w_rec.weight.shape[0] else: - # TODO: handle this better - if V is a scalar, then the weight has wrong shape - assert len(module.recurrent.V.shape) == 1, 'V must be a vector' + if len(module.recurrent.V.shape) == 0: + # TODO: handle this better - if V is a scalar, then the weight has wrong shape + raise ValueError('V must be a vector, cannot infer layer size for scalar V') n_neurons = module.recurrent.V.shape[0] w = np.diag(module.recurrent.V.data.detach().numpy()) w_rec = nir.Linear(weight=w) diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py index f3b82096..5d257b8a 100644 --- a/snntorch/import_nirtorch.py +++ b/snntorch/import_nirtorch.py @@ -5,15 +5,66 @@ import snntorch as snn +def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: + """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node.""" + if len([e for e in graph.nodes.values() if isinstance(e, nir.Input)]) > 1: + cand_sg_nk = [e[1] for e in graph.edges if e[1] not in graph.nodes] + print('detected subgraph! candidates:', cand_sg_nk) + assert len(cand_sg_nk) == 1, 'only one subgraph allowed' + nk = cand_sg_nk[0] + nodes = {k: v for k, v in graph.nodes.items() if k.startswith(f'{nk}.')} + edges = [e for e in graph.edges if e[0].startswith(f'{nk}.') or e[1].startswith(f'{nk}.')] + valid_edges = all([e[0].startswith(f'{nk}.') for e in edges]) + valid_edges = valid_edges and all([e[1].startswith(f'{nk}.') for e in edges]) + assert valid_edges, 'subgraph edges must start with subgraph key' + sg_graph = nir.NIRGraph(nodes=nodes, edges=edges) + for k in nodes.keys(): + graph.nodes.pop(k) + for e in edges: + graph.edges.remove(e) + graph.nodes[nk] = sg_graph + return graph + + +def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): + """Try parsing the graph as a RNN subgraph. + + Assumes four nodes: Input, Output, LIF | CubaLIF, Affine | Linear + Checks that all nodes have consistent shapes. + Will throw an error if either not all nodes are found or consistent shapes are found. + + Returns: + lif_node: LIF | CubaLIF node + wrec_node: Affine | Linear node + lif_size: int, number of neurons in the RNN + """ + sub_nodes = graph.nodes.values() + assert len(sub_nodes) == 4, 'only 4-node RNN allowed in subgraph' + try: + input_node = [n for n in sub_nodes if isinstance(n, nir.Input)][0] + output_node = [n for n in sub_nodes if isinstance(n, nir.Output)][0] + lif_node = [n for n in sub_nodes if isinstance(n, (nir.LIF, nir.CubaLIF))][0] + wrec_node = [n for n in sub_nodes if isinstance(n, (nir.Affine, nir.Linear))][0] + except IndexError: + raise ValueError('invalid RNN subgraph - could not find all required nodes') + lif_size = list(input_node.input_type.values())[0].size + assert lif_size == list(output_node.output_type.values())[0].size, 'output size mismatch' + assert lif_size == lif_node.v_threshold.size, 'lif size mismatch (v_threshold)' + assert lif_size == wrec_node.weight.shape[0], 'w_rec shape mismatch' + assert lif_size == wrec_node.weight.shape[1], 'w_rec shape mismatch' + + return lif_node, wrec_node, lif_size + + def _nir_to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: if isinstance(node, nir.Input) or isinstance(node, nir.Output): return None elif isinstance(node, nir.Affine): + assert node.bias is not None, 'bias must be specified for Affine layer' mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0]) mod.weight.data = torch.Tensor(node.weight) - if node.bias is not None: - mod.bias.data = torch.Tensor(node.bias) + mod.bias.data = torch.Tensor(node.bias) return mod elif isinstance(node, nir.Linear): @@ -23,8 +74,8 @@ def _nir_to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: elif isinstance(node, nir.LIF): # NOTE: assuming that parameters are arrays of correct size - dt = 1e-4 assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' + dt = 1e-4 vthr = node.v_threshold beta = 1 - (dt / node.tau) w_scale = node.r * dt / node.tau @@ -40,14 +91,49 @@ def _nir_to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: threshold=vthr, reset_mechanism='zero', init_hidden=True, - # init_hidden=False, ) + elif isinstance(node, nir.NIRGraph): + lif_node, wrec_node, lif_size = _parse_rnn_subgraph(node) + + if isinstance(lif_node, nir.LIF): + # TODO: fix neuron parameters + rleaky = snn.RLeaky( + beta=1 - (1 / lif_node.tau), + threshold=lif_node.v_threshold, + reset_mechanism='zero', + init_hidden=True, + all_to_all=True, + linear_features=lif_size, + ) + rleaky.recurrent.weight.data = torch.Tensor(wrec_node.weight) + if isinstance(wrec_node, nir.Affine): + rleaky.recurrent.bias.data = torch.Tensor(wrec_node.bias) + return rleaky + + elif isinstance(lif_node, nir.CubaLIF): + # TODO: fix neuron parameters + rsynaptic = snn.RSynaptic( + alpha=1 - (1 / lif_node.tau_syn), + beta=1 - (1 / lif_node.tau_mem), + init_hidden=True, + reset_mechanism='zero', + all_to_all=True, + linear_features=lif_size, + ) + rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) + if isinstance(wrec_node, nir.Affine): + rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias) + return rsynaptic + else: - print(node.__class__.__name__, node) + print('[WARNING] could not parse node of type:', node.__class__.__name__) return None def from_nir(graph: nir.NIRGraph) -> torch.nn.Module: + # find valid RNN subgraphs, and replace them with a single NIRGraph node + graph = _replace_rnn_subgraph_with_nirgraph(graph) + # TODO: right now, the subgraph edges seem to not be parsed correctly - fix this return nirtorch.load(graph, _nir_to_snntorch_module) From 708faa2a49f59bda9561395afa00f5dd278345da Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Fri, 13 Oct 2023 11:27:02 +0200 Subject: [PATCH 10/33] minor --- snntorch/import_nirtorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py index 5d257b8a..0b600f75 100644 --- a/snntorch/import_nirtorch.py +++ b/snntorch/import_nirtorch.py @@ -8,6 +8,7 @@ def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node.""" if len([e for e in graph.nodes.values() if isinstance(e, nir.Input)]) > 1: + print('found RNN subgraph, trying to parse') cand_sg_nk = [e[1] for e in graph.edges if e[1] not in graph.nodes] print('detected subgraph! candidates:', cand_sg_nk) assert len(cand_sg_nk) == 1, 'only one subgraph allowed' @@ -79,7 +80,6 @@ def _nir_to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: vthr = node.v_threshold beta = 1 - (dt / node.tau) w_scale = node.r * dt / node.tau - breakpoint() if np.alltrue(w_scale == 1.) or np.unique(w_scale).size == 1: # HACK to avoid scaling the inputs print('[warning] scaling weights to avoid scaling inputs') From 9004cdadb4186aea45e6e6fae5a5b913e5eb8347 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Tue, 17 Oct 2023 01:35:24 +0200 Subject: [PATCH 11/33] update NIRTorch import/export (not done) --- snntorch/export_nir.py | 3 +- snntorch/export_nirtorch.py | 83 ++++++++++++++++++++++++++++++++++++- snntorch/import_nirtorch.py | 67 +++++++++++++++++++++++------- 3 files changed, 136 insertions(+), 17 deletions(-) diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index ce3baec5..71a7190a 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -111,10 +111,11 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: elif isinstance(module, Synaptic): # TODO + n_neurons = module.alpha.shape[0] return nir.CubaLIF( tau_syn=1 / (1 - module.alpha).detach().numpy(), tau_mem=1 / (1 - module.beta).detach().numpy(), - v_threshold=module.threshold.detach().numpy(), + v_threshold=np.ones_like(module.alpha) * module.threshold.detach().numpy(), v_leak=np.zeros_like(module.beta), r=module.beta.detach().numpy(), # NOTE: is this right? ) diff --git a/snntorch/export_nirtorch.py b/snntorch/export_nirtorch.py index 30ae95ec..2c0a2c9a 100644 --- a/snntorch/export_nirtorch.py +++ b/snntorch/export_nirtorch.py @@ -8,8 +8,20 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: if isinstance(module, snn.Leaky): - print('leaky') - return None + # TODO + raise NotImplementedError('Leaky not supported') + beta = module.beta.detach().numpy() + vthr = module.threshold.detach().numpy() + + tau = 1 / (1 - module.beta.detach().numpy()) + r = module.beta.detach().numpy() + threshold = module.threshold.detach().numpy() + return nir.LIF( + tau=tau, + v_threshold=threshold, + v_leak=torch.zeros_like(tau), + r=r, + ) elif isinstance(module, torch.nn.Linear): if module.bias is None: @@ -22,7 +34,33 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: bias=module.bias.data.detach().numpy() ) + elif isinstance(module, snn.Synaptic): + dt = 1e-4 + + alpha = module.alpha.detach().numpy() + beta = module.beta.detach().numpy() + vthr = module.threshold.detach().numpy() + + # TODO: make sure alpha, beta, vthr are tensors of same size + alpha = np.ones(7) * alpha + beta = np.ones(7) * beta + vthr = np.ones(7) * vthr + + tau_syn = dt / (1 - alpha) + tau_mem = dt / (1 - beta) + r = tau_mem / dt + v_leak = np.zeros_like(beta) + + return nir.CubaLIF( + tau_syn=tau_syn, + tau_mem=tau_mem, + v_threshold=vthr, + v_leak=v_leak, + r=r, + ) + elif isinstance(module, snn.RLeaky): + raise NotImplementedError('RLeaky not supported') if module.all_to_all: w_rec = _extract_snntorch_module(module.recurrent) n_neurons = w_rec.weight.shape[0] @@ -50,6 +88,47 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: ('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output') ]) + elif isinstance(module, snn.RSynaptic): + if module.all_to_all: + w_rec = _extract_snntorch_module(module.recurrent) + n_neurons = w_rec.weight.shape[0] + else: + if len(module.recurrent.V.shape) == 0: + # TODO: handle this better - if V is a scalar, then the weight has wrong shape + raise ValueError('V must be a vector, cannot infer layer size for scalar V') + n_neurons = module.recurrent.V.shape[0] + w = np.diag(module.recurrent.V.data.detach().numpy()) + w_rec = nir.Linear(weight=w) + + dt = 1e-4 + + alpha = module.alpha.detach().numpy() + beta = module.beta.detach().numpy() + vthr = module.threshold.detach().numpy() + alpha = np.ones(n_neurons) * alpha + beta = np.ones(n_neurons) * beta + vthr = np.ones(n_neurons) * vthr + + tau_syn = dt / (1 - alpha) + tau_mem = dt / (1 - beta) + r = tau_mem / dt + v_leak = np.zeros_like(beta) + + return nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=[n_neurons]), + 'lif': nir.CubaLIF( + v_threshold=vthr, + tau_mem=tau_mem, + tau_syn=tau_syn, + r=r, + v_leak=v_leak, + ), + 'w_rec': w_rec, + 'output': nir.Output(output_type=[n_neurons]) + }, edges=[ + ('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output') + ]) + else: print(f'[WARNING] module not implemented: {module.__class__.__name__}') return None diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py index 0b600f75..ab38ee98 100644 --- a/snntorch/import_nirtorch.py +++ b/snntorch/import_nirtorch.py @@ -9,7 +9,7 @@ def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node.""" if len([e for e in graph.nodes.values() if isinstance(e, nir.Input)]) > 1: print('found RNN subgraph, trying to parse') - cand_sg_nk = [e[1] for e in graph.edges if e[1] not in graph.nodes] + cand_sg_nk = list(set([e[1] for e in graph.edges if e[1] not in graph.nodes])) print('detected subgraph! candidates:', cand_sg_nk) assert len(cand_sg_nk) == 1, 'only one subgraph allowed' nk = cand_sg_nk[0] @@ -48,8 +48,8 @@ def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): wrec_node = [n for n in sub_nodes if isinstance(n, (nir.Affine, nir.Linear))][0] except IndexError: raise ValueError('invalid RNN subgraph - could not find all required nodes') - lif_size = list(input_node.input_type.values())[0].size - assert lif_size == list(output_node.output_type.values())[0].size, 'output size mismatch' + lif_size = list(input_node.input_type.values())[0][0] + assert lif_size == list(output_node.output_type.values())[0][0], 'output size mismatch' assert lif_size == lif_node.v_threshold.size, 'lif size mismatch (v_threshold)' assert lif_size == wrec_node.weight.shape[0], 'w_rec shape mismatch' assert lif_size == wrec_node.weight.shape[1], 'w_rec shape mismatch' @@ -57,38 +57,77 @@ def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): return lif_node, wrec_node, lif_size -def _nir_to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: +def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Module: if isinstance(node, nir.Input) or isinstance(node, nir.Output): return None elif isinstance(node, nir.Affine): assert node.bias is not None, 'bias must be specified for Affine layer' + mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0]) mod.weight.data = torch.Tensor(node.weight) mod.bias.data = torch.Tensor(node.bias) + return mod elif isinstance(node, nir.Linear): mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=False) mod.weight.data = torch.Tensor(node.weight) + return mod elif isinstance(node, nir.LIF): - # NOTE: assuming that parameters are arrays of correct size - assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' dt = 1e-4 - vthr = node.v_threshold + + assert np.allclose(node.v_leak, 0.), 'v_leak not supported' + assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' + beta = 1 - (dt / node.tau) + vthr = node.v_threshold w_scale = node.r * dt / node.tau - if np.alltrue(w_scale == 1.) or np.unique(w_scale).size == 1: - # HACK to avoid scaling the inputs - print('[warning] scaling weights to avoid scaling inputs') - vthr = vthr / np.unique(w_scale)[0] - else: - raise NotImplementedError('w_scale must be 1, or the same for all neurons') + + if not np.alltrue(w_scale == 1.): + if hack_w_scale: + vthr = vthr / np.unique(w_scale)[0] + print('[warning] scaling weights to avoid scaling inputs') + print(f'w_scale: {w_scale}, r: {node.r}, dt: {dt}, tau: {node.tau}') + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') + + assert np.unique(vthr).size == 1, 'LIF v_thr must be same for all neurons' + return snn.Leaky( beta=beta, - threshold=vthr, + threshold=np.unique(vthr)[0], + reset_mechanism='zero', + init_hidden=True, + ) + + elif isinstance(node, nir.CubaLIF): + dt = 1e-4 + + assert np.allclose(node.v_leak, 0), 'v_leak not supported' + assert np.allclose(node.r * dt / node.tau_mem, 1.), 'r not supported in CubaLIF' + + alpha = 1 - (1 / node.tau_syn) + beta = 1 - (1 / node.tau_mem) + vthr = node.v_threshold + w_scale = node.w_in * (dt / node.tau_syn) + + if not np.alltrue(w_scale == 1.): + if hack_w_scale: + vthr = vthr / w_scale + print('[warning] scaling weights to avoid scaling inputs') + print(f'w_scale: {w_scale}, w_in: {node.w_in}, dt: {dt}, tau_syn: {node.tau_syn}') + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') + + assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' + + return snn.Synaptic( + alpha=alpha, + beta=beta, + threshold=np.unique(vthr)[0], reset_mechanism='zero', init_hidden=True, ) From a9b7cee9187d335dff3cc804336ff6c9c7bad91e Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Tue, 17 Oct 2023 02:09:37 +0200 Subject: [PATCH 12/33] version for braille-v2 --- snntorch/export_nirtorch.py | 4 ++++ snntorch/import_nirtorch.py | 45 ++++++++++++++++++++++++++++--------- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/snntorch/export_nirtorch.py b/snntorch/export_nirtorch.py index 2c0a2c9a..33336343 100644 --- a/snntorch/export_nirtorch.py +++ b/snntorch/export_nirtorch.py @@ -50,6 +50,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: tau_mem = dt / (1 - beta) r = tau_mem / dt v_leak = np.zeros_like(beta) + w_in = tau_syn / dt return nir.CubaLIF( tau_syn=tau_syn, @@ -57,6 +58,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: v_threshold=vthr, v_leak=v_leak, r=r, + w_in=w_in, ) elif isinstance(module, snn.RLeaky): @@ -113,6 +115,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: tau_mem = dt / (1 - beta) r = tau_mem / dt v_leak = np.zeros_like(beta) + w_in = tau_syn / dt return nir.NIRGraph(nodes={ 'input': nir.Input(input_type=[n_neurons]), @@ -122,6 +125,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: tau_syn=tau_syn, r=r, v_leak=v_leak, + w_in=w_in, ), 'w_rec': w_rec, 'output': nir.Output(output_type=[n_neurons]) diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py index ab38ee98..74837fcd 100644 --- a/snntorch/import_nirtorch.py +++ b/snntorch/import_nirtorch.py @@ -86,7 +86,7 @@ def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Mo vthr = node.v_threshold w_scale = node.r * dt / node.tau - if not np.alltrue(w_scale == 1.): + if not np.allclose(w_scale, 1.): if hack_w_scale: vthr = vthr / np.unique(w_scale)[0] print('[warning] scaling weights to avoid scaling inputs') @@ -107,14 +107,14 @@ def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Mo dt = 1e-4 assert np.allclose(node.v_leak, 0), 'v_leak not supported' - assert np.allclose(node.r * dt / node.tau_mem, 1.), 'r not supported in CubaLIF' + assert np.allclose(node.r, node.tau_mem / dt), 'r not supported in CubaLIF' - alpha = 1 - (1 / node.tau_syn) - beta = 1 - (1 / node.tau_mem) + alpha = 1 - (dt / node.tau_syn) + beta = 1 - (dt / node.tau_mem) vthr = node.v_threshold w_scale = node.w_in * (dt / node.tau_syn) - if not np.alltrue(w_scale == 1.): + if not np.allclose(w_scale, 1.): if hack_w_scale: vthr = vthr / w_scale print('[warning] scaling weights to avoid scaling inputs') @@ -136,6 +136,7 @@ def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Mo lif_node, wrec_node, lif_size = _parse_rnn_subgraph(node) if isinstance(lif_node, nir.LIF): + raise NotImplementedError('LIF in subgraph not supported') # TODO: fix neuron parameters rleaky = snn.RLeaky( beta=1 - (1 / lif_node.tau), @@ -151,15 +152,39 @@ def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Mo return rleaky elif isinstance(lif_node, nir.CubaLIF): - # TODO: fix neuron parameters + dt = 1e-4 + + assert np.allclose(lif_node.v_leak, 0), 'v_leak not supported' + assert np.allclose(lif_node.r, lif_node.tau_mem / dt), 'r not supported in CubaLIF' + + alpha = 1 - (dt / lif_node.tau_syn) + beta = 1 - (dt / lif_node.tau_mem) + vthr = lif_node.v_threshold + w_scale = lif_node.w_in * (dt / lif_node.tau_syn) + + if not np.allclose(w_scale, 1.): + if hack_w_scale: + vthr = vthr / w_scale + print(f'[warning] scaling weights to avoid scaling inputs. w_scale: {w_scale}') + print(f'w_in: {lif_node.w_in}, dt: {dt}, tau_syn: {lif_node.tau_syn}') + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') + + assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' + + diagonal = np.array_equal(wrec_node.weight, np.diag(np.diag(wrec_node.weight))) + rsynaptic = snn.RSynaptic( - alpha=1 - (1 / lif_node.tau_syn), - beta=1 - (1 / lif_node.tau_mem), - init_hidden=True, + alpha=alpha, + beta=beta, + threshold=np.unique(vthr)[0], reset_mechanism='zero', - all_to_all=True, + init_hidden=True, + all_to_all=not diagonal, linear_features=lif_size, + V=np.diag(wrec_node.weight) if diagonal else None, ) + rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) if isinstance(wrec_node, nir.Affine): rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias) From 0c7f69297ec5359bf26d18b9580339c45c20881a Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Thu, 19 Oct 2023 17:21:02 +0200 Subject: [PATCH 13/33] update to latest support! (use init_hidden=False) --- snntorch/export_nirtorch.py | 10 ++-- snntorch/import_nirtorch.py | 108 +++++++++++++++++++++++++++--------- 2 files changed, 89 insertions(+), 29 deletions(-) diff --git a/snntorch/export_nirtorch.py b/snntorch/export_nirtorch.py index 33336343..a8e3051c 100644 --- a/snntorch/export_nirtorch.py +++ b/snntorch/export_nirtorch.py @@ -2,7 +2,7 @@ import torch import nir import numpy as np -from nirtorch import extract_nir_graph +import nirtorch import snntorch as snn @@ -139,12 +139,14 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: def to_nir( - module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch" + module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch", + model_fwd_args=[], ignore_dims=[] ) -> nir.NIRNode: """Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR). """ - nir_graph = extract_nir_graph( + nir_graph = nirtorch.extract_nir_graph( module, _extract_snntorch_module, sample_data, model_name=model_name, - ignore_submodules_of=[snn.RLeaky, snn.RSynaptic] + ignore_submodules_of=[snn.RLeaky, snn.RSynaptic], + model_fwd_args=model_fwd_args, ignore_dims=ignore_dims ) return nir_graph diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py index 74837fcd..82ff6545 100644 --- a/snntorch/import_nirtorch.py +++ b/snntorch/import_nirtorch.py @@ -5,25 +5,66 @@ import snntorch as snn +def _create_rnn_subgraph(graph: nir.NIRGraph, lif_nk: str, w_nk: str) -> nir.NIRGraph: + """Take a NIRGraph plus the node keys for a LIF and a W_rec, and return a new NIRGraph + which has the RNN subgraph replaced with a subgraph (i.e., a single NIRGraph node). + """ + # NOTE: assuming that the LIF and W_rec have keys of form xyz.abc + sg_key = lif_nk.split('.')[0] # TODO: make this more general? + + # create subgraph for RNN + sg_edges = [ + (lif_nk, w_nk), (w_nk, lif_nk), (lif_nk, f'{sg_key}.output'), (f'{sg_key}.input', w_nk) + ] + sg_nodes = { + lif_nk: graph.nodes[lif_nk], + w_nk: graph.nodes[w_nk], + f'{sg_key}.input': nir.Input(graph.nodes[lif_nk].input_type), + f'{sg_key}.output': nir.Output(graph.nodes[lif_nk].output_type), + } + sg = nir.NIRGraph(nodes=sg_nodes, edges=sg_edges) + + # remove subgraph edges from graph + graph.edges = [e for e in graph.edges if e not in [(lif_nk, w_nk), (w_nk, lif_nk)]] + # remove subgraph nodes from graph + graph.nodes = {k: v for k, v in graph.nodes.items() if k not in [lif_nk, w_nk]} + + # change edges of type (x, lif_nk) to (x, sg_key) + graph.edges = [(e[0], sg_key) if e[1] == lif_nk else e for e in graph.edges] + # change edges of type (lif_nk, x) to (sg_key, x) + graph.edges = [(sg_key, e[1]) if e[0] == lif_nk else e for e in graph.edges] + + # insert subgraph into graph and return + graph.nodes[sg_key] = sg + return graph + + def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node.""" - if len([e for e in graph.nodes.values() if isinstance(e, nir.Input)]) > 1: - print('found RNN subgraph, trying to parse') - cand_sg_nk = list(set([e[1] for e in graph.edges if e[1] not in graph.nodes])) - print('detected subgraph! candidates:', cand_sg_nk) - assert len(cand_sg_nk) == 1, 'only one subgraph allowed' - nk = cand_sg_nk[0] - nodes = {k: v for k, v in graph.nodes.items() if k.startswith(f'{nk}.')} - edges = [e for e in graph.edges if e[0].startswith(f'{nk}.') or e[1].startswith(f'{nk}.')] - valid_edges = all([e[0].startswith(f'{nk}.') for e in edges]) - valid_edges = valid_edges and all([e[1].startswith(f'{nk}.') for e in edges]) - assert valid_edges, 'subgraph edges must start with subgraph key' - sg_graph = nir.NIRGraph(nodes=nodes, edges=edges) - for k in nodes.keys(): - graph.nodes.pop(k) - for e in edges: - graph.edges.remove(e) - graph.nodes[nk] = sg_graph + print('replace rnn subgraph with nirgraph') + + if len(set(graph.edges)) != len(graph.edges): + print('[WARNING] duplicate edges found, removing') + graph.edges = list(set(graph.edges)) + + # find cycle of LIF <> Dense nodes + for edge1 in graph.edges: + for edge2 in graph.edges: + if not edge1 == edge2: + if edge1[0] == edge2[1] and edge1[1] == edge2[0]: + lif_nk = edge1[0] + lif_n = graph.nodes[lif_nk] + w_nk = edge1[1] + w_n = graph.nodes[w_nk] + is_lif = isinstance(lif_n, (nir.LIF, nir.CubaLIF)) + is_dense = isinstance(w_n, (nir.Affine, nir.Linear)) + # check if the dense only connects to the LIF + w_out_nk = [e[1] for e in graph.edges if e[0] == w_nk] + w_in_nk = [e[0] for e in graph.edges if e[1] == w_nk] + is_rnn = len(w_out_nk) == 1 and len(w_in_nk) == 1 + # check if we found an RNN - if so, then parse it + if is_rnn and is_lif and is_dense: + graph = _create_rnn_subgraph(graph, edge1[0], edge1[1]) return graph @@ -57,7 +98,9 @@ def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): return lif_node, wrec_node, lif_size -def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Module: +def _nir_to_snntorch_module( + node: nir.NIRNode, hack_w_scale=True, init_hidden=False +) -> torch.nn.Module: if isinstance(node, nir.Input) or isinstance(node, nir.Output): return None @@ -100,7 +143,7 @@ def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Mo beta=beta, threshold=np.unique(vthr)[0], reset_mechanism='zero', - init_hidden=True, + init_hidden=init_hidden, ) elif isinstance(node, nir.CubaLIF): @@ -124,12 +167,17 @@ def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Mo assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' + if np.unique(alpha).size == 1: + alpha = float(np.unique(alpha)[0]) + if np.unique(beta).size == 1: + beta = float(np.unique(beta)[0]) + return snn.Synaptic( alpha=alpha, beta=beta, - threshold=np.unique(vthr)[0], + threshold=float(np.unique(vthr)[0]), reset_mechanism='zero', - init_hidden=True, + init_hidden=init_hidden, ) elif isinstance(node, nir.NIRGraph): @@ -142,7 +190,7 @@ def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Mo beta=1 - (1 / lif_node.tau), threshold=lif_node.v_threshold, reset_mechanism='zero', - init_hidden=True, + init_hidden=init_hidden, all_to_all=True, linear_features=lif_size, ) @@ -174,15 +222,25 @@ def _nir_to_snntorch_module(node: nir.NIRNode, hack_w_scale=True) -> torch.nn.Mo diagonal = np.array_equal(wrec_node.weight, np.diag(np.diag(wrec_node.weight))) + if np.unique(alpha).size == 1: + alpha = float(np.unique(alpha)[0]) + if np.unique(beta).size == 1: + beta = float(np.unique(beta)[0]) + + if diagonal: + V = torch.from_numpy(np.diag(wrec_node.weight)).to(dtype=torch.float32) + else: + V = None + rsynaptic = snn.RSynaptic( alpha=alpha, beta=beta, - threshold=np.unique(vthr)[0], + threshold=float(np.unique(vthr)[0]), reset_mechanism='zero', - init_hidden=True, + init_hidden=init_hidden, all_to_all=not diagonal, linear_features=lif_size, - V=np.diag(wrec_node.weight) if diagonal else None, + V=V, ) rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) From 2ada175550b99a1cdf3de3d998e69e478871625f Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Thu, 19 Oct 2023 21:37:34 +0200 Subject: [PATCH 14/33] rm dead code --- snntorch/export_nirtorch.py | 39 ------------------------------------- snntorch/import_nirtorch.py | 13 ------------- 2 files changed, 52 deletions(-) diff --git a/snntorch/export_nirtorch.py b/snntorch/export_nirtorch.py index a8e3051c..455d0c4d 100644 --- a/snntorch/export_nirtorch.py +++ b/snntorch/export_nirtorch.py @@ -8,20 +8,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: if isinstance(module, snn.Leaky): - # TODO raise NotImplementedError('Leaky not supported') - beta = module.beta.detach().numpy() - vthr = module.threshold.detach().numpy() - - tau = 1 / (1 - module.beta.detach().numpy()) - r = module.beta.detach().numpy() - threshold = module.threshold.detach().numpy() - return nir.LIF( - tau=tau, - v_threshold=threshold, - v_leak=torch.zeros_like(tau), - r=r, - ) elif isinstance(module, torch.nn.Linear): if module.bias is None: @@ -63,32 +50,6 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: elif isinstance(module, snn.RLeaky): raise NotImplementedError('RLeaky not supported') - if module.all_to_all: - w_rec = _extract_snntorch_module(module.recurrent) - n_neurons = w_rec.weight.shape[0] - else: - if len(module.recurrent.V.shape) == 0: - # TODO: handle this better - if V is a scalar, then the weight has wrong shape - raise ValueError('V must be a vector, cannot infer layer size for scalar V') - n_neurons = module.recurrent.V.shape[0] - w = np.diag(module.recurrent.V.data.detach().numpy()) - w_rec = nir.Linear(weight=w) - - # TODO: set the parameters correctly - v_thr = np.ones(n_neurons) * module.threshold.detach().numpy() - beta = np.ones(n_neurons) * module.beta.detach().numpy() - tau = 1 / (1 - beta) - r = beta - v_leak = beta - - return nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=[n_neurons]), - 'lif': nir.LIF(v_threshold=v_thr, tau=tau, r=r, v_leak=v_leak), - 'w_rec': w_rec, - 'output': nir.Output(output_type=[n_neurons]) - }, edges=[ - ('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output') - ]) elif isinstance(module, snn.RSynaptic): if module.all_to_all: diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py index 82ff6545..fb029a4f 100644 --- a/snntorch/import_nirtorch.py +++ b/snntorch/import_nirtorch.py @@ -185,19 +185,6 @@ def _nir_to_snntorch_module( if isinstance(lif_node, nir.LIF): raise NotImplementedError('LIF in subgraph not supported') - # TODO: fix neuron parameters - rleaky = snn.RLeaky( - beta=1 - (1 / lif_node.tau), - threshold=lif_node.v_threshold, - reset_mechanism='zero', - init_hidden=init_hidden, - all_to_all=True, - linear_features=lif_size, - ) - rleaky.recurrent.weight.data = torch.Tensor(wrec_node.weight) - if isinstance(wrec_node, nir.Affine): - rleaky.recurrent.bias.data = torch.Tensor(wrec_node.bias) - return rleaky elif isinstance(lif_node, nir.CubaLIF): dt = 1e-4 From 0d45ad18982c280bb70cddf00198969b9d05f331 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Mon, 23 Oct 2023 19:05:05 +0200 Subject: [PATCH 15/33] adapt NIR-standard thresholding for (r)synaptic --- snntorch/_neurons/rsynaptic.py | 14 ++++++++++++++ snntorch/_neurons/synaptic.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/snntorch/_neurons/rsynaptic.py b/snntorch/_neurons/rsynaptic.py index 5c8b6852..17274e44 100644 --- a/snntorch/_neurons/rsynaptic.py +++ b/snntorch/_neurons/rsynaptic.py @@ -254,6 +254,7 @@ def __init__( reset_mechanism="subtract", state_quant=False, output=False, + reset_after=False, ): super(RSynaptic, self).__init__( beta, @@ -294,6 +295,11 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) + self._reset_after = reset_after + + if reset_after and self.init_hidden: + raise NotImplementedError('reset_after not implemented for init_hidden=False') + if self.init_hidden: self.spk, self.syn, self.mem = self.init_rsynaptic() @@ -324,6 +330,14 @@ def forward(self, input_, spk=False, syn=False, mem=False): else: spk = self.fire(mem) + if self._reset_after: + # reset membrane potential _right_ after spike + do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem -= do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + mem -= do_reset * mem + return spk, syn, mem # intended for truncated-BPTT where instance variables are hidden diff --git a/snntorch/_neurons/synaptic.py b/snntorch/_neurons/synaptic.py index 2f904069..16fe8dcb 100644 --- a/snntorch/_neurons/synaptic.py +++ b/snntorch/_neurons/synaptic.py @@ -168,6 +168,7 @@ def __init__( reset_mechanism="subtract", state_quant=False, output=False, + reset_after=False, ): super(Synaptic, self).__init__( beta, @@ -185,6 +186,11 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) + self._reset_after = reset_after + + if reset_after and self.init_hidden: + raise NotImplementedError('reset_after only supported for init_hidden=False') + if self.init_hidden: self.syn, self.mem = self.init_synaptic() @@ -214,6 +220,14 @@ def forward(self, input_, syn=False, mem=False): else: spk = self.fire(mem) + if self._reset_after: + # reset membrane potential _right_ after spike + do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem -= do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + mem -= do_reset * mem + return spk, syn, mem # intended for truncated-BPTT where instance variables are From cb684c6ae4fdd96d1fbd961d265abdf4997cf8ca Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Mon, 23 Oct 2023 19:34:33 +0200 Subject: [PATCH 16/33] rename reset_after -> reset_delay (+ invert) --- snntorch/_neurons/rsynaptic.py | 10 +++++----- snntorch/_neurons/synaptic.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/snntorch/_neurons/rsynaptic.py b/snntorch/_neurons/rsynaptic.py index 17274e44..d46cc14c 100644 --- a/snntorch/_neurons/rsynaptic.py +++ b/snntorch/_neurons/rsynaptic.py @@ -254,7 +254,7 @@ def __init__( reset_mechanism="subtract", state_quant=False, output=False, - reset_after=False, + reset_delay=True, ): super(RSynaptic, self).__init__( beta, @@ -295,10 +295,10 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) - self._reset_after = reset_after + self.reset_delay = reset_delay - if reset_after and self.init_hidden: - raise NotImplementedError('reset_after not implemented for init_hidden=False') + if not reset_delay and self.init_hidden: + raise NotImplementedError('no reset_delay only supported for init_hidden=False') if self.init_hidden: self.spk, self.syn, self.mem = self.init_rsynaptic() @@ -330,7 +330,7 @@ def forward(self, input_, spk=False, syn=False, mem=False): else: spk = self.fire(mem) - if self._reset_after: + if not self.reset_delay: # reset membrane potential _right_ after spike do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset if self.reset_mechanism_val == 0: # reset by subtraction diff --git a/snntorch/_neurons/synaptic.py b/snntorch/_neurons/synaptic.py index 16fe8dcb..b5ee22c4 100644 --- a/snntorch/_neurons/synaptic.py +++ b/snntorch/_neurons/synaptic.py @@ -168,7 +168,7 @@ def __init__( reset_mechanism="subtract", state_quant=False, output=False, - reset_after=False, + reset_delay=True, ): super(Synaptic, self).__init__( beta, @@ -186,10 +186,10 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) - self._reset_after = reset_after + self.reset_delay = reset_delay - if reset_after and self.init_hidden: - raise NotImplementedError('reset_after only supported for init_hidden=False') + if not reset_delay and self.init_hidden: + raise NotImplementedError('no reset_delay only supported for init_hidden=False') if self.init_hidden: self.syn, self.mem = self.init_synaptic() @@ -220,7 +220,7 @@ def forward(self, input_, syn=False, mem=False): else: spk = self.fire(mem) - if self._reset_after: + if not self.reset_delay: # reset membrane potential _right_ after spike do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset if self.reset_mechanism_val == 0: # reset by subtraction From 4352c6017b49209d8c06b0b93ec0453293f54f36 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Tue, 24 Oct 2023 10:39:26 +0200 Subject: [PATCH 17/33] add conv/if/pool to import --- snntorch/import_nirtorch.py | 42 +++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py index fb029a4f..3787ad46 100644 --- a/snntorch/import_nirtorch.py +++ b/snntorch/import_nirtorch.py @@ -119,6 +119,45 @@ def _nir_to_snntorch_module( return mod + elif isinstance(node, nir.Conv2d): + mod = torch.nn.Conv2d( + node.weight.shape[1], + node.weight.shape[0], + kernel_size=[*node.weight.shape[-2:]], + stride=node.stride, + padding=node.padding, + dilation=node.dilation, + groups=node.groups, + ) + mod.bias.data = torch.Tensor(node.bias) + mod.weight.data = torch.Tensor(node.weight) + return mod + + if isinstance(node, nir.Flatten): + return torch.nn.Flatten(node.start_dim, node.end_dim) + + if isinstance(node, nir.SumPool2d): + return torch.nn.AvgPool2d( + kernel_size=tuple(node.kernel_size), + stride=tuple(node.stride), + padding=tuple(node.padding), + divisor_override=1, # turn AvgPool into SumPool + ) + + elif isinstance(node, nir.IF): + assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' + assert np.unique(node.r).size == 1, 'r must be same for all neurons' + vthr = np.unique(node.v_threshold)[0] + r = np.unique(node.r)[0] + assert r == 1, 'r != 1 not supported' + mod = snn.Leaky( + beta=0.9, + threshold=vthr * r, + init_hidden=False, + reset_delay=False, + ) + return mod + elif isinstance(node, nir.LIF): dt = 1e-4 @@ -144,6 +183,7 @@ def _nir_to_snntorch_module( threshold=np.unique(vthr)[0], reset_mechanism='zero', init_hidden=init_hidden, + reset_delay=False, ) elif isinstance(node, nir.CubaLIF): @@ -178,6 +218,7 @@ def _nir_to_snntorch_module( threshold=float(np.unique(vthr)[0]), reset_mechanism='zero', init_hidden=init_hidden, + reset_delay=False, ) elif isinstance(node, nir.NIRGraph): @@ -228,6 +269,7 @@ def _nir_to_snntorch_module( all_to_all=not diagonal, linear_features=lif_size, V=V, + reset_delay=False, ) rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) From a20445727de7ab91e040cd84d9266a620ce71f3a Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Tue, 24 Oct 2023 10:57:51 +0200 Subject: [PATCH 18/33] fix reset_delay (+ add for (r)leaky) --- snntorch/_neurons/leaky.py | 13 +++++++++++++ snntorch/_neurons/rleaky.py | 13 +++++++++++++ snntorch/_neurons/rsynaptic.py | 5 +++-- snntorch/_neurons/synaptic.py | 4 ++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/snntorch/_neurons/leaky.py b/snntorch/_neurons/leaky.py index 48f8ae4b..d51d5dc1 100644 --- a/snntorch/_neurons/leaky.py +++ b/snntorch/_neurons/leaky.py @@ -141,6 +141,7 @@ def __init__( output=False, graded_spikes_factor=1.0, learn_graded_spikes_factor=False, + reset_delay=True, ): super(Leaky, self).__init__( beta, @@ -158,6 +159,11 @@ def __init__( learn_graded_spikes_factor, ) + self.reset_delay = reset_delay + + if not self.reset_delay and self.init_hidden: + raise NotImplementedError('no reset_delay only supported for init_hidden=False') + if self.init_hidden: self.mem = self.init_leaky() @@ -188,6 +194,13 @@ def forward(self, input_, mem=False): else: spk = self.fire(mem) + if not self.reset_delay: + do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem = mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + mem = mem - do_reset * mem + return spk, mem # intended for truncated-BPTT where instance variables are hidden diff --git a/snntorch/_neurons/rleaky.py b/snntorch/_neurons/rleaky.py index 38523849..6a72aaf0 100644 --- a/snntorch/_neurons/rleaky.py +++ b/snntorch/_neurons/rleaky.py @@ -241,6 +241,7 @@ def __init__( reset_mechanism="subtract", state_quant=False, output=False, + reset_delay=True, ): super(RLeaky, self).__init__( beta, @@ -279,6 +280,11 @@ def __init__( if not learn_recurrent: self._disable_recurrent_grad() + self.reset_delay = reset_delay + + if not self.reset_delay and self.init_hidden: + raise NotImplementedError('no reset_delay only supported for init_hidden=False') + if self.init_hidden: self.spk, self.mem = self.init_rleaky() # self.state_fn = self._build_state_function_hidden @@ -312,6 +318,13 @@ def forward(self, input_, spk=False, mem=False): else: spk = self.fire(mem) + if not self.reset_delay: + do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem = mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + mem = mem - do_reset * mem + return spk, mem # intended for truncated-BPTT where instance variables are hidden diff --git a/snntorch/_neurons/rsynaptic.py b/snntorch/_neurons/rsynaptic.py index d46cc14c..f799118e 100644 --- a/snntorch/_neurons/rsynaptic.py +++ b/snntorch/_neurons/rsynaptic.py @@ -334,9 +334,10 @@ def forward(self, input_, spk=False, syn=False, mem=False): # reset membrane potential _right_ after spike do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset if self.reset_mechanism_val == 0: # reset by subtraction - mem -= do_reset * self.threshold + mem = mem - do_reset * self.threshold elif self.reset_mechanism_val == 1: # reset to zero - mem -= do_reset * mem + # mem -= do_reset * mem + mem = mem - do_reset * mem return spk, syn, mem diff --git a/snntorch/_neurons/synaptic.py b/snntorch/_neurons/synaptic.py index b5ee22c4..707e3ba3 100644 --- a/snntorch/_neurons/synaptic.py +++ b/snntorch/_neurons/synaptic.py @@ -224,9 +224,9 @@ def forward(self, input_, syn=False, mem=False): # reset membrane potential _right_ after spike do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset if self.reset_mechanism_val == 0: # reset by subtraction - mem -= do_reset * self.threshold + mem = mem - do_reset * self.threshold elif self.reset_mechanism_val == 1: # reset to zero - mem -= do_reset * mem + mem = mem - do_reset * mem return spk, syn, mem From f5ef7f4adfc0cb29d8e2eecd7c1ac8dd82aa5658 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Tue, 14 Nov 2023 07:00:27 -0600 Subject: [PATCH 19/33] bias bug fix --- snntorch/import_nirtorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py index 3787ad46..21ef0534 100644 --- a/snntorch/import_nirtorch.py +++ b/snntorch/import_nirtorch.py @@ -275,6 +275,8 @@ def _nir_to_snntorch_module( rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) if isinstance(wrec_node, nir.Affine): rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias) + else: + rsynaptic.recurrent.bias.data = torch.zeros_like(rsynaptic.recurrent.bias) return rsynaptic else: From 8c7c7861fd1917651522b11049f2696ef3a026b7 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Fri, 26 Jan 2024 20:15:12 -0500 Subject: [PATCH 20/33] move to using nirtorch --- snntorch/export_nir.py | 259 ++++++++---------------- snntorch/export_nir_old.py | 204 +++++++++++++++++++ snntorch/export_nirtorch.py | 113 ----------- snntorch/import_nir.py | 390 ++++++++++++++++++++++++------------ snntorch/import_nir_old.py | 177 ++++++++++++++++ snntorch/import_nirtorch.py | 292 --------------------------- 6 files changed, 723 insertions(+), 712 deletions(-) create mode 100644 snntorch/export_nir_old.py delete mode 100644 snntorch/export_nirtorch.py create mode 100644 snntorch/import_nir_old.py delete mode 100644 snntorch/import_nirtorch.py diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index 71a7190a..455d0c4d 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -1,204 +1,113 @@ -from typing import Union, Optional - +from typing import Optional import torch import nir import numpy as np -from nirtorch import extract_nir_graph - -from snntorch import Leaky, Synaptic, RLeaky, RSynaptic +import nirtorch +import snntorch as snn -def _create_rnn_subgraph( - module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLIF], n_neurons=-1 -) -> nir.NIRGraph: - """Create NIR Graph for RNN, from the snnTorch module and the extracted LIF/CubaLIF node.""" - b = None - if module.all_to_all: - lif_shape = module.recurrent.weight.shape[0] - w_rec = module.recurrent.weight.data.detach().numpy() - if module.recurrent.bias is not None: - b = module.recurrent.bias.data.detach().numpy() - else: - if len(module.recurrent.V.shape) == 0 and n_neurons == -1: - lif_shape = None - w_rec = np.eye(1) * module.recurrent.V.data.detach().numpy() - elif n_neurons != -1: - lif_shape = n_neurons - w_rec = np.eye(n_neurons) * module.recurrent.V.data.detach().numpy() - else: - lif_shape = module.recurrent.V.shape[0] - w_rec = np.diag(module.recurrent.V.data.detach().numpy()) - - return nir.NIRGraph( - nodes={ - 'input': nir.Input(input_type=[lif_shape]), - 'lif': lif, - 'w_rec': nir.Linear(weight=w_rec) if b is None else nir.Affine(weight=w_rec, bias=b), - 'output': nir.Output(output_type=[lif_shape]) - }, - edges=[('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output')] - ) - +def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: + if isinstance(module, snn.Leaky): + raise NotImplementedError('Leaky not supported') -def _get_neuron_count(module: torch.nn.Module) -> int: - if isinstance(module, RLeaky) or isinstance(module, RSynaptic): - if module.all_to_all: - return module.linear_features - elif isinstance(module.recurrent.V, torch.Tensor) and len(module.recurrent.V.shape) > 0: - return module.recurrent.V.shape[0] - elif module.init_hidden is True: - return module.mem.shape[0] + elif isinstance(module, torch.nn.Linear): + if module.bias is None: + return nir.Linear( + weight=module.weight.data.detach().numpy() + ) else: - # not implemented - return -1 - else: - # not implemented - return -1 - + return nir.Affine( + weight=module.weight.data.detach().numpy(), + bias=module.bias.data.detach().numpy() + ) -# eqn is assumed to be: v_t+1 = (1-1/tau)*v_t + 1/tau * v_leak + I_in / C -def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: - """ - NOTE: it might leave the NIR node of neurons with an incompatible shape. This must be fixed. - """ - if isinstance(module, Leaky): - # TODO - tau = 1 / (1 - module.beta.detach().numpy()) - r = module.beta.detach().numpy() - threshold = module.threshold.detach().numpy() - return nir.LIF( - tau=tau, - v_threshold=threshold, - v_leak=torch.zeros_like(tau), - r=r, - ) + elif isinstance(module, snn.Synaptic): + dt = 1e-4 - elif isinstance(module, RSynaptic): alpha = module.alpha.detach().numpy() beta = module.beta.detach().numpy() - threshold = module.threshold.detach().numpy() - n_neurons = _get_neuron_count(module) - if len(alpha.shape) == 0 and n_neurons != -1: - alpha = np.ones(n_neurons) * alpha - if len(beta.shape) == 0 and n_neurons != -1: - beta = np.ones(n_neurons) * beta - if len(threshold.shape) == 0 and n_neurons != -1: - threshold = np.ones(n_neurons) * threshold - lif = nir.CubaLIF( - tau_syn=1 / (1 - beta), - tau_mem=1 / (1 - alpha), - v_threshold=threshold, - v_leak=np.zeros_like(beta), - r=beta, - ) - return _create_rnn_subgraph(module, lif, n_neurons=n_neurons) + vthr = module.threshold.detach().numpy() - elif isinstance(module, RLeaky): - beta = module.beta.detach().numpy() - threshold = module.threshold.detach().numpy() - n_neurons = _get_neuron_count(module) - if len(beta.shape) == 0 and n_neurons != -1: - beta = np.ones(n_neurons) * beta - if len(threshold.shape) == 0 and n_neurons != -1: - threshold = np.ones(n_neurons) * threshold - lif = nir.LIF( - tau=1 / (1 - beta), - v_threshold=threshold, - v_leak=np.zeros_like(beta), - r=beta, - ) - return _create_rnn_subgraph(module, lif, n_neurons=n_neurons) + # TODO: make sure alpha, beta, vthr are tensors of same size + alpha = np.ones(7) * alpha + beta = np.ones(7) * beta + vthr = np.ones(7) * vthr + + tau_syn = dt / (1 - alpha) + tau_mem = dt / (1 - beta) + r = tau_mem / dt + v_leak = np.zeros_like(beta) + w_in = tau_syn / dt - elif isinstance(module, Synaptic): - # TODO - n_neurons = module.alpha.shape[0] return nir.CubaLIF( - tau_syn=1 / (1 - module.alpha).detach().numpy(), - tau_mem=1 / (1 - module.beta).detach().numpy(), - v_threshold=np.ones_like(module.alpha) * module.threshold.detach().numpy(), - v_leak=np.zeros_like(module.beta), - r=module.beta.detach().numpy(), # NOTE: is this right? + tau_syn=tau_syn, + tau_mem=tau_mem, + v_threshold=vthr, + v_leak=v_leak, + r=r, + w_in=w_in, ) - elif isinstance(module, torch.nn.Linear): - if module.bias is None: # Add zero bias if none is present - return nir.Affine( - module.weight.detach().numpy(), np.zeros(*module.weight.shape[:-1]) - ) + elif isinstance(module, snn.RLeaky): + raise NotImplementedError('RLeaky not supported') + + elif isinstance(module, snn.RSynaptic): + if module.all_to_all: + w_rec = _extract_snntorch_module(module.recurrent) + n_neurons = w_rec.weight.shape[0] else: - return nir.Affine(module.weight.detach().numpy(), module.bias.detach().numpy()) + if len(module.recurrent.V.shape) == 0: + # TODO: handle this better - if V is a scalar, then the weight has wrong shape + raise ValueError('V must be a vector, cannot infer layer size for scalar V') + n_neurons = module.recurrent.V.shape[0] + w = np.diag(module.recurrent.V.data.detach().numpy()) + w_rec = nir.Linear(weight=w) + + dt = 1e-4 + + alpha = module.alpha.detach().numpy() + beta = module.beta.detach().numpy() + vthr = module.threshold.detach().numpy() + alpha = np.ones(n_neurons) * alpha + beta = np.ones(n_neurons) * beta + vthr = np.ones(n_neurons) * vthr + + tau_syn = dt / (1 - alpha) + tau_mem = dt / (1 - beta) + r = tau_mem / dt + v_leak = np.zeros_like(beta) + w_in = tau_syn / dt + + return nir.NIRGraph(nodes={ + 'input': nir.Input(input_type=[n_neurons]), + 'lif': nir.CubaLIF( + v_threshold=vthr, + tau_mem=tau_mem, + tau_syn=tau_syn, + r=r, + v_leak=v_leak, + w_in=w_in, + ), + 'w_rec': w_rec, + 'output': nir.Output(output_type=[n_neurons]) + }, edges=[ + ('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output') + ]) else: - print(f'[WARNING] unknown module type: {type(module).__name__} (ignored)') + print(f'[WARNING] module not implemented: {module.__class__.__name__}') return None def to_nir( - module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch" + module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch", + model_fwd_args=[], ignore_dims=[] ) -> nir.NIRNode: """Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR). - - Example:: - - import torch, torch.nn as nn - import snntorch as snn - from snntorch import export - - data_path = "untrained-snntorch.pt" - - net = nn.Sequential(nn.Linear(784, 128), - snn.Leaky(beta=0.8, init_hidden=True), - nn.Linear(128, 10), - snn.Leaky(beta=0.8, init_hidden=True, output=True)) - - # save model in pt format - torch.save(net.state_dict(), data_path) - - # load model (does nothing here, but shown for completeness) - net.load_state_dict(torch.load(data_path)) - - # generate input tensor to dynamically construct graph - x = torch.zeros(784) - - # generate NIR graph - nir_net = export.to_nir(net, x) - - - :param module: a saved snnTorch model as a parameter dictionary - :type module: torch.nn.Module - - :param sample_data: sample input data to the model - :type sample_data: torch.Tensor - - :param model_name: name of library used to train model, default: "snntorch" - :type model_name: str, optional - - :return: NIR computational graph where torch modules are represented as NIR nodes - :rtype: NIRGraph - """ - nir_graph = extract_nir_graph( + nir_graph = nirtorch.extract_nir_graph( module, _extract_snntorch_module, sample_data, model_name=model_name, - ignore_submodules_of=[RLeaky, RSynaptic] + ignore_submodules_of=[snn.RLeaky, snn.RSynaptic], + model_fwd_args=model_fwd_args, ignore_dims=ignore_dims ) - - # NOTE: hack to define subgraph I/O types - for node_key, node in nir_graph.nodes.items(): - inp_type = node.input_type.get('input', [None]) - input_undef = len(inp_type) == 0 or inp_type[0] is None - if isinstance(node, nir.Input) and input_undef and '.' in node_key: - print('WARNING: subgraph input type not set, inferring from previous node') - key = '.'.join(node_key.split('.')[:-1]) - prev_keys = [edge[0] for edge in nir_graph.edges if edge[1] == key] - assert len(prev_keys) == 1, 'multiple previous nodes not supported' - prev_node = nir_graph.nodes[prev_keys[0]] - cur_type = prev_node.output_type['output'] - node.input_type['input'] = cur_type - nir_graph.nodes[f'{key}.output'].output_type['output'] = cur_type - - # NOTE: hack to remove recurrent connections of subgraph to itself - for edge in nir_graph.edges: - if edge[0] not in nir_graph.nodes and edge[1] not in nir_graph.nodes: - nir_graph.edges.remove(edge) - return nir_graph diff --git a/snntorch/export_nir_old.py b/snntorch/export_nir_old.py new file mode 100644 index 00000000..71a7190a --- /dev/null +++ b/snntorch/export_nir_old.py @@ -0,0 +1,204 @@ +from typing import Union, Optional + +import torch +import nir +import numpy as np +from nirtorch import extract_nir_graph + +from snntorch import Leaky, Synaptic, RLeaky, RSynaptic + + +def _create_rnn_subgraph( + module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLIF], n_neurons=-1 +) -> nir.NIRGraph: + """Create NIR Graph for RNN, from the snnTorch module and the extracted LIF/CubaLIF node.""" + b = None + if module.all_to_all: + lif_shape = module.recurrent.weight.shape[0] + w_rec = module.recurrent.weight.data.detach().numpy() + if module.recurrent.bias is not None: + b = module.recurrent.bias.data.detach().numpy() + else: + if len(module.recurrent.V.shape) == 0 and n_neurons == -1: + lif_shape = None + w_rec = np.eye(1) * module.recurrent.V.data.detach().numpy() + elif n_neurons != -1: + lif_shape = n_neurons + w_rec = np.eye(n_neurons) * module.recurrent.V.data.detach().numpy() + else: + lif_shape = module.recurrent.V.shape[0] + w_rec = np.diag(module.recurrent.V.data.detach().numpy()) + + return nir.NIRGraph( + nodes={ + 'input': nir.Input(input_type=[lif_shape]), + 'lif': lif, + 'w_rec': nir.Linear(weight=w_rec) if b is None else nir.Affine(weight=w_rec, bias=b), + 'output': nir.Output(output_type=[lif_shape]) + }, + edges=[('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output')] + ) + + +def _get_neuron_count(module: torch.nn.Module) -> int: + if isinstance(module, RLeaky) or isinstance(module, RSynaptic): + if module.all_to_all: + return module.linear_features + elif isinstance(module.recurrent.V, torch.Tensor) and len(module.recurrent.V.shape) > 0: + return module.recurrent.V.shape[0] + elif module.init_hidden is True: + return module.mem.shape[0] + else: + # not implemented + return -1 + else: + # not implemented + return -1 + + +# eqn is assumed to be: v_t+1 = (1-1/tau)*v_t + 1/tau * v_leak + I_in / C +def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: + """ + NOTE: it might leave the NIR node of neurons with an incompatible shape. This must be fixed. + """ + if isinstance(module, Leaky): + # TODO + tau = 1 / (1 - module.beta.detach().numpy()) + r = module.beta.detach().numpy() + threshold = module.threshold.detach().numpy() + return nir.LIF( + tau=tau, + v_threshold=threshold, + v_leak=torch.zeros_like(tau), + r=r, + ) + + elif isinstance(module, RSynaptic): + alpha = module.alpha.detach().numpy() + beta = module.beta.detach().numpy() + threshold = module.threshold.detach().numpy() + n_neurons = _get_neuron_count(module) + if len(alpha.shape) == 0 and n_neurons != -1: + alpha = np.ones(n_neurons) * alpha + if len(beta.shape) == 0 and n_neurons != -1: + beta = np.ones(n_neurons) * beta + if len(threshold.shape) == 0 and n_neurons != -1: + threshold = np.ones(n_neurons) * threshold + lif = nir.CubaLIF( + tau_syn=1 / (1 - beta), + tau_mem=1 / (1 - alpha), + v_threshold=threshold, + v_leak=np.zeros_like(beta), + r=beta, + ) + return _create_rnn_subgraph(module, lif, n_neurons=n_neurons) + + elif isinstance(module, RLeaky): + beta = module.beta.detach().numpy() + threshold = module.threshold.detach().numpy() + n_neurons = _get_neuron_count(module) + if len(beta.shape) == 0 and n_neurons != -1: + beta = np.ones(n_neurons) * beta + if len(threshold.shape) == 0 and n_neurons != -1: + threshold = np.ones(n_neurons) * threshold + lif = nir.LIF( + tau=1 / (1 - beta), + v_threshold=threshold, + v_leak=np.zeros_like(beta), + r=beta, + ) + return _create_rnn_subgraph(module, lif, n_neurons=n_neurons) + + elif isinstance(module, Synaptic): + # TODO + n_neurons = module.alpha.shape[0] + return nir.CubaLIF( + tau_syn=1 / (1 - module.alpha).detach().numpy(), + tau_mem=1 / (1 - module.beta).detach().numpy(), + v_threshold=np.ones_like(module.alpha) * module.threshold.detach().numpy(), + v_leak=np.zeros_like(module.beta), + r=module.beta.detach().numpy(), # NOTE: is this right? + ) + + elif isinstance(module, torch.nn.Linear): + if module.bias is None: # Add zero bias if none is present + return nir.Affine( + module.weight.detach().numpy(), np.zeros(*module.weight.shape[:-1]) + ) + else: + return nir.Affine(module.weight.detach().numpy(), module.bias.detach().numpy()) + + else: + print(f'[WARNING] unknown module type: {type(module).__name__} (ignored)') + return None + + +def to_nir( + module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch" +) -> nir.NIRNode: + """Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR). + + Example:: + + import torch, torch.nn as nn + import snntorch as snn + from snntorch import export + + data_path = "untrained-snntorch.pt" + + net = nn.Sequential(nn.Linear(784, 128), + snn.Leaky(beta=0.8, init_hidden=True), + nn.Linear(128, 10), + snn.Leaky(beta=0.8, init_hidden=True, output=True)) + + # save model in pt format + torch.save(net.state_dict(), data_path) + + # load model (does nothing here, but shown for completeness) + net.load_state_dict(torch.load(data_path)) + + # generate input tensor to dynamically construct graph + x = torch.zeros(784) + + # generate NIR graph + nir_net = export.to_nir(net, x) + + + :param module: a saved snnTorch model as a parameter dictionary + :type module: torch.nn.Module + + :param sample_data: sample input data to the model + :type sample_data: torch.Tensor + + :param model_name: name of library used to train model, default: "snntorch" + :type model_name: str, optional + + :return: NIR computational graph where torch modules are represented as NIR nodes + :rtype: NIRGraph + + """ + nir_graph = extract_nir_graph( + module, _extract_snntorch_module, sample_data, model_name=model_name, + ignore_submodules_of=[RLeaky, RSynaptic] + ) + + # NOTE: hack to define subgraph I/O types + for node_key, node in nir_graph.nodes.items(): + inp_type = node.input_type.get('input', [None]) + input_undef = len(inp_type) == 0 or inp_type[0] is None + if isinstance(node, nir.Input) and input_undef and '.' in node_key: + print('WARNING: subgraph input type not set, inferring from previous node') + key = '.'.join(node_key.split('.')[:-1]) + prev_keys = [edge[0] for edge in nir_graph.edges if edge[1] == key] + assert len(prev_keys) == 1, 'multiple previous nodes not supported' + prev_node = nir_graph.nodes[prev_keys[0]] + cur_type = prev_node.output_type['output'] + node.input_type['input'] = cur_type + nir_graph.nodes[f'{key}.output'].output_type['output'] = cur_type + + # NOTE: hack to remove recurrent connections of subgraph to itself + for edge in nir_graph.edges: + if edge[0] not in nir_graph.nodes and edge[1] not in nir_graph.nodes: + nir_graph.edges.remove(edge) + + return nir_graph diff --git a/snntorch/export_nirtorch.py b/snntorch/export_nirtorch.py deleted file mode 100644 index 455d0c4d..00000000 --- a/snntorch/export_nirtorch.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import Optional -import torch -import nir -import numpy as np -import nirtorch -import snntorch as snn - - -def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: - if isinstance(module, snn.Leaky): - raise NotImplementedError('Leaky not supported') - - elif isinstance(module, torch.nn.Linear): - if module.bias is None: - return nir.Linear( - weight=module.weight.data.detach().numpy() - ) - else: - return nir.Affine( - weight=module.weight.data.detach().numpy(), - bias=module.bias.data.detach().numpy() - ) - - elif isinstance(module, snn.Synaptic): - dt = 1e-4 - - alpha = module.alpha.detach().numpy() - beta = module.beta.detach().numpy() - vthr = module.threshold.detach().numpy() - - # TODO: make sure alpha, beta, vthr are tensors of same size - alpha = np.ones(7) * alpha - beta = np.ones(7) * beta - vthr = np.ones(7) * vthr - - tau_syn = dt / (1 - alpha) - tau_mem = dt / (1 - beta) - r = tau_mem / dt - v_leak = np.zeros_like(beta) - w_in = tau_syn / dt - - return nir.CubaLIF( - tau_syn=tau_syn, - tau_mem=tau_mem, - v_threshold=vthr, - v_leak=v_leak, - r=r, - w_in=w_in, - ) - - elif isinstance(module, snn.RLeaky): - raise NotImplementedError('RLeaky not supported') - - elif isinstance(module, snn.RSynaptic): - if module.all_to_all: - w_rec = _extract_snntorch_module(module.recurrent) - n_neurons = w_rec.weight.shape[0] - else: - if len(module.recurrent.V.shape) == 0: - # TODO: handle this better - if V is a scalar, then the weight has wrong shape - raise ValueError('V must be a vector, cannot infer layer size for scalar V') - n_neurons = module.recurrent.V.shape[0] - w = np.diag(module.recurrent.V.data.detach().numpy()) - w_rec = nir.Linear(weight=w) - - dt = 1e-4 - - alpha = module.alpha.detach().numpy() - beta = module.beta.detach().numpy() - vthr = module.threshold.detach().numpy() - alpha = np.ones(n_neurons) * alpha - beta = np.ones(n_neurons) * beta - vthr = np.ones(n_neurons) * vthr - - tau_syn = dt / (1 - alpha) - tau_mem = dt / (1 - beta) - r = tau_mem / dt - v_leak = np.zeros_like(beta) - w_in = tau_syn / dt - - return nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=[n_neurons]), - 'lif': nir.CubaLIF( - v_threshold=vthr, - tau_mem=tau_mem, - tau_syn=tau_syn, - r=r, - v_leak=v_leak, - w_in=w_in, - ), - 'w_rec': w_rec, - 'output': nir.Output(output_type=[n_neurons]) - }, edges=[ - ('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output') - ]) - - else: - print(f'[WARNING] module not implemented: {module.__class__.__name__}') - return None - - -def to_nir( - module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch", - model_fwd_args=[], ignore_dims=[] -) -> nir.NIRNode: - """Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR). - """ - nir_graph = nirtorch.extract_nir_graph( - module, _extract_snntorch_module, sample_data, model_name=model_name, - ignore_submodules_of=[snn.RLeaky, snn.RSynaptic], - model_fwd_args=model_fwd_args, ignore_dims=ignore_dims - ) - return nir_graph diff --git a/snntorch/import_nir.py b/snntorch/import_nir.py index a18a3a42..21ef0534 100644 --- a/snntorch/import_nir.py +++ b/snntorch/import_nir.py @@ -1,166 +1,292 @@ -import snntorch as snn import numpy as np -import torch import nir -import typing +import nirtorch +import torch +import snntorch as snn + + +def _create_rnn_subgraph(graph: nir.NIRGraph, lif_nk: str, w_nk: str) -> nir.NIRGraph: + """Take a NIRGraph plus the node keys for a LIF and a W_rec, and return a new NIRGraph + which has the RNN subgraph replaced with a subgraph (i.e., a single NIRGraph node). + """ + # NOTE: assuming that the LIF and W_rec have keys of form xyz.abc + sg_key = lif_nk.split('.')[0] # TODO: make this more general? + + # create subgraph for RNN + sg_edges = [ + (lif_nk, w_nk), (w_nk, lif_nk), (lif_nk, f'{sg_key}.output'), (f'{sg_key}.input', w_nk) + ] + sg_nodes = { + lif_nk: graph.nodes[lif_nk], + w_nk: graph.nodes[w_nk], + f'{sg_key}.input': nir.Input(graph.nodes[lif_nk].input_type), + f'{sg_key}.output': nir.Output(graph.nodes[lif_nk].output_type), + } + sg = nir.NIRGraph(nodes=sg_nodes, edges=sg_edges) + # remove subgraph edges from graph + graph.edges = [e for e in graph.edges if e not in [(lif_nk, w_nk), (w_nk, lif_nk)]] + # remove subgraph nodes from graph + graph.nodes = {k: v for k, v in graph.nodes.items() if k not in [lif_nk, w_nk]} -# TODO: implement this? -class ImportedNetwork(torch.nn.Module): - """Wrapper for a snnTorch network. NOTE: not working atm.""" - def __init__(self, module_list): - super().__init__() - self.module_list = module_list + # change edges of type (x, lif_nk) to (x, sg_key) + graph.edges = [(e[0], sg_key) if e[1] == lif_nk else e for e in graph.edges] + # change edges of type (lif_nk, x) to (sg_key, x) + graph.edges = [(sg_key, e[1]) if e[0] == lif_nk else e for e in graph.edges] - def forward(self, x): - for module in self.module_list: - x = module(x) - return x + # insert subgraph into graph and return + graph.nodes[sg_key] = sg + return graph -def create_snntorch_network(module_list): - return torch.nn.Sequential(*module_list) +def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: + """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node.""" + print('replace rnn subgraph with nirgraph') + + if len(set(graph.edges)) != len(graph.edges): + print('[WARNING] duplicate edges found, removing') + graph.edges = list(set(graph.edges)) + + # find cycle of LIF <> Dense nodes + for edge1 in graph.edges: + for edge2 in graph.edges: + if not edge1 == edge2: + if edge1[0] == edge2[1] and edge1[1] == edge2[0]: + lif_nk = edge1[0] + lif_n = graph.nodes[lif_nk] + w_nk = edge1[1] + w_n = graph.nodes[w_nk] + is_lif = isinstance(lif_n, (nir.LIF, nir.CubaLIF)) + is_dense = isinstance(w_n, (nir.Affine, nir.Linear)) + # check if the dense only connects to the LIF + w_out_nk = [e[1] for e in graph.edges if e[0] == w_nk] + w_in_nk = [e[0] for e in graph.edges if e[1] == w_nk] + is_rnn = len(w_out_nk) == 1 and len(w_in_nk) == 1 + # check if we found an RNN - if so, then parse it + if is_rnn and is_lif and is_dense: + graph = _create_rnn_subgraph(graph, edge1[0], edge1[1]) + return graph + + +def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): + """Try parsing the graph as a RNN subgraph. + + Assumes four nodes: Input, Output, LIF | CubaLIF, Affine | Linear + Checks that all nodes have consistent shapes. + Will throw an error if either not all nodes are found or consistent shapes are found. + + Returns: + lif_node: LIF | CubaLIF node + wrec_node: Affine | Linear node + lif_size: int, number of neurons in the RNN + """ + sub_nodes = graph.nodes.values() + assert len(sub_nodes) == 4, 'only 4-node RNN allowed in subgraph' + try: + input_node = [n for n in sub_nodes if isinstance(n, nir.Input)][0] + output_node = [n for n in sub_nodes if isinstance(n, nir.Output)][0] + lif_node = [n for n in sub_nodes if isinstance(n, (nir.LIF, nir.CubaLIF))][0] + wrec_node = [n for n in sub_nodes if isinstance(n, (nir.Affine, nir.Linear))][0] + except IndexError: + raise ValueError('invalid RNN subgraph - could not find all required nodes') + lif_size = list(input_node.input_type.values())[0][0] + assert lif_size == list(output_node.output_type.values())[0][0], 'output size mismatch' + assert lif_size == lif_node.v_threshold.size, 'lif size mismatch (v_threshold)' + assert lif_size == wrec_node.weight.shape[0], 'w_rec shape mismatch' + assert lif_size == wrec_node.weight.shape[1], 'w_rec shape mismatch' + return lif_node, wrec_node, lif_size -def _lif_to_snntorch_module( - lif: typing.Union[nir.LIF, nir.CubaLIF] + +def _nir_to_snntorch_module( + node: nir.NIRNode, hack_w_scale=True, init_hidden=False ) -> torch.nn.Module: - """Parse a LIF node into snnTorch.""" - if isinstance(lif, nir.LIF): - assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' - assert np.alltrue(lif.r == 1. - 1. / lif.tau), 'r not supported' - assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons' - threshold = lif.v_threshold[0] - mod = snn.RLeaky( - beta=1. - 1. / lif.tau, - threshold=threshold, - all_to_all=True, - reset_mechanism='zero', - linear_features=lif.tau.shape[0] if len(lif.tau.shape) == 1 else None, - init_hidden=True, + if isinstance(node, nir.Input) or isinstance(node, nir.Output): + return None + + elif isinstance(node, nir.Affine): + assert node.bias is not None, 'bias must be specified for Affine layer' + + mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0]) + mod.weight.data = torch.Tensor(node.weight) + mod.bias.data = torch.Tensor(node.bias) + + return mod + + elif isinstance(node, nir.Linear): + mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=False) + mod.weight.data = torch.Tensor(node.weight) + + return mod + + elif isinstance(node, nir.Conv2d): + mod = torch.nn.Conv2d( + node.weight.shape[1], + node.weight.shape[0], + kernel_size=[*node.weight.shape[-2:]], + stride=node.stride, + padding=node.padding, + dilation=node.dilation, + groups=node.groups, ) + mod.bias.data = torch.Tensor(node.bias) + mod.weight.data = torch.Tensor(node.weight) return mod - elif isinstance(lif, nir.CubaLIF): - assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' - assert np.alltrue(lif.r == 1. - 1. / lif.tau_mem), 'r not supported' # NOTE: is this right? - assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons' - threshold = lif.v_threshold[0] - mod = snn.RSynaptic( - alpha=1. - 1. / lif.tau_syn, - beta=1. - 1. / lif.tau_mem, - threshold=threshold, - all_to_all=True, - reset_mechanism='zero', - linear_features=lif.tau_mem.shape[0] if len(lif.tau_mem.shape) == 1 else None, - init_hidden=True, + if isinstance(node, nir.Flatten): + return torch.nn.Flatten(node.start_dim, node.end_dim) + + if isinstance(node, nir.SumPool2d): + return torch.nn.AvgPool2d( + kernel_size=tuple(node.kernel_size), + stride=tuple(node.stride), + padding=tuple(node.padding), + divisor_override=1, # turn AvgPool into SumPool + ) + + elif isinstance(node, nir.IF): + assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' + assert np.unique(node.r).size == 1, 'r must be same for all neurons' + vthr = np.unique(node.v_threshold)[0] + r = np.unique(node.r)[0] + assert r == 1, 'r != 1 not supported' + mod = snn.Leaky( + beta=0.9, + threshold=vthr * r, + init_hidden=False, + reset_delay=False, ) return mod - else: - raise ValueError('called _lif_to_snntorch_module on non-LIF node') + elif isinstance(node, nir.LIF): + dt = 1e-4 + assert np.allclose(node.v_leak, 0.), 'v_leak not supported' + assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' -def _to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: - """Convert a NIR node to a snnTorch module. + beta = 1 - (dt / node.tau) + vthr = node.v_threshold + w_scale = node.r * dt / node.tau - Supported NIR nodes: Affine. - """ - if isinstance(node, (nir.LIF, nir.CubaLIF)): - return _lif_to_snntorch_module(node) + if not np.allclose(w_scale, 1.): + if hack_w_scale: + vthr = vthr / np.unique(w_scale)[0] + print('[warning] scaling weights to avoid scaling inputs') + print(f'w_scale: {w_scale}, r: {node.r}, dt: {dt}, tau: {node.tau}') + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') - elif isinstance(node, nir.Affine): - if len(node.weight.shape) != 2: - raise NotImplementedError('only 2D weight matrices are supported') - has_bias = node.bias is not None and not np.alltrue(node.bias == 0) - linear = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=has_bias) - linear.weight.data = torch.Tensor(node.weight) - if has_bias: - linear.bias.data = torch.Tensor(node.bias) - return linear + assert np.unique(vthr).size == 1, 'LIF v_thr must be same for all neurons' - else: - raise NotImplementedError(f'node type {type(node).__name__} not supported') + return snn.Leaky( + beta=beta, + threshold=np.unique(vthr)[0], + reset_mechanism='zero', + init_hidden=init_hidden, + reset_delay=False, + ) + elif isinstance(node, nir.CubaLIF): + dt = 1e-4 -def _rnn_subgraph_to_snntorch_module( - lif: typing.Union[nir.LIF, nir.CubaLIF], w_rec: typing.Union[nir.Affine, nir.Linear] -) -> torch.nn.Module: - """Parse an RNN subgraph consisting of a LIF node and a recurrent weight matrix into snnTorch. + assert np.allclose(node.v_leak, 0), 'v_leak not supported' + assert np.allclose(node.r, node.tau_mem / dt), 'r not supported in CubaLIF' - NOTE: for now always set it as a recurrent linear layer (not RecurrentOneToOne) - """ - assert isinstance(lif, (nir.LIF, nir.CubaLIF)), 'only LIF or CubaLIF nodes supported as RNNs' - mod = _lif_to_snntorch_module(lif) - mod.recurrent.weight.data = torch.Tensor(w_rec.weight) - if isinstance(w_rec, nir.Linear): - mod.recurrent.register_parameter('bias', None) - mod.recurrent.reset_parameters() - else: - mod.recurrent.bias.data = torch.Tensor(w_rec.bias) - return mod + alpha = 1 - (dt / node.tau_syn) + beta = 1 - (dt / node.tau_mem) + vthr = node.v_threshold + w_scale = node.w_in * (dt / node.tau_syn) + if not np.allclose(w_scale, 1.): + if hack_w_scale: + vthr = vthr / w_scale + print('[warning] scaling weights to avoid scaling inputs') + print(f'w_scale: {w_scale}, w_in: {node.w_in}, dt: {dt}, tau_syn: {node.tau_syn}') + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') -def _get_next_node_key(node_key: str, graph: nir.ir.NIRGraph): - """Get the next node key in the NIR graph.""" - possible_next_node_keys = [edge[1] for edge in graph.edges if edge[0] == node_key] - # possible_next_node_keys += [edge[1] + '.input' for edge in graph.edges if edge[0] == node_key] - assert len(possible_next_node_keys) <= 1, 'branching networks are not supported' - if len(possible_next_node_keys) == 0: - return None - else: - return possible_next_node_keys[0] + assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' + + if np.unique(alpha).size == 1: + alpha = float(np.unique(alpha)[0]) + if np.unique(beta).size == 1: + beta = float(np.unique(beta)[0]) + return snn.Synaptic( + alpha=alpha, + beta=beta, + threshold=float(np.unique(vthr)[0]), + reset_mechanism='zero', + init_hidden=init_hidden, + reset_delay=False, + ) -def from_nir(graph: nir.ir.NIRGraph) -> torch.nn.Module: - """Convert NIR graph to snnTorch module. + elif isinstance(node, nir.NIRGraph): + lif_node, wrec_node, lif_size = _parse_rnn_subgraph(node) - :param graph: a saved snnTorch model as a parameter dictionary - :type graph: nir.ir.NIRGraph + if isinstance(lif_node, nir.LIF): + raise NotImplementedError('LIF in subgraph not supported') - :return: snnTorch module - :rtype: torch.nn.Module - """ - node_key = 'input' - visited_node_keys = [node_key] - module_list = [] - - while _get_next_node_key(node_key, graph) is not None: - node_key = _get_next_node_key(node_key, graph) - - assert node_key not in visited_node_keys, 'cyclic NIR graphs not supported' - - if node_key == 'output': - visited_node_keys.append(node_key) - continue - - if node_key in graph.nodes: - visited_node_keys.append(node_key) - node = graph.nodes[node_key] - print(f'simple node {node_key}: {type(node).__name__}') - module = _to_snntorch_module(node) - else: - # check if it's a nested node - print(f'potential subgraph node: {node_key}') - sub_node_keys = [n for n in graph.nodes if n.startswith(f'{node_key}.')] - assert len(sub_node_keys) > 0, f'no nodes found for subgraph {node_key}' - - # parse subgraph - # NOTE: for now only looking for RNN subgraphs - rnn_sub_node_keys = [f'{node_key}.{n}' for n in ['input', 'output', 'lif', 'w_rec']] - if set(sub_node_keys) != set(rnn_sub_node_keys): - raise NotImplementedError('only RNN subgraphs are supported') - print('found RNN subgraph') - module = _rnn_subgraph_to_snntorch_module( - graph.nodes[f'{node_key}.lif'], graph.nodes[f'{node_key}.w_rec'] + elif isinstance(lif_node, nir.CubaLIF): + dt = 1e-4 + + assert np.allclose(lif_node.v_leak, 0), 'v_leak not supported' + assert np.allclose(lif_node.r, lif_node.tau_mem / dt), 'r not supported in CubaLIF' + + alpha = 1 - (dt / lif_node.tau_syn) + beta = 1 - (dt / lif_node.tau_mem) + vthr = lif_node.v_threshold + w_scale = lif_node.w_in * (dt / lif_node.tau_syn) + + if not np.allclose(w_scale, 1.): + if hack_w_scale: + vthr = vthr / w_scale + print(f'[warning] scaling weights to avoid scaling inputs. w_scale: {w_scale}') + print(f'w_in: {lif_node.w_in}, dt: {dt}, tau_syn: {lif_node.tau_syn}') + else: + raise NotImplementedError('w_scale must be 1, or the same for all neurons') + + assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' + + diagonal = np.array_equal(wrec_node.weight, np.diag(np.diag(wrec_node.weight))) + + if np.unique(alpha).size == 1: + alpha = float(np.unique(alpha)[0]) + if np.unique(beta).size == 1: + beta = float(np.unique(beta)[0]) + + if diagonal: + V = torch.from_numpy(np.diag(wrec_node.weight)).to(dtype=torch.float32) + else: + V = None + + rsynaptic = snn.RSynaptic( + alpha=alpha, + beta=beta, + threshold=float(np.unique(vthr)[0]), + reset_mechanism='zero', + init_hidden=init_hidden, + all_to_all=not diagonal, + linear_features=lif_size, + V=V, + reset_delay=False, ) - for nk in sub_node_keys: - visited_node_keys.append(nk) - module_list.append(module) + rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) + if isinstance(wrec_node, nir.Affine): + rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias) + else: + rsynaptic.recurrent.bias.data = torch.zeros_like(rsynaptic.recurrent.bias) + return rsynaptic + + else: + print('[WARNING] could not parse node of type:', node.__class__.__name__) + + return None - if len(visited_node_keys) != len(graph.nodes): - print(graph.nodes.keys(), visited_node_keys) - raise ValueError('not all nodes visited') - return create_snntorch_network(module_list) +def from_nir(graph: nir.NIRGraph) -> torch.nn.Module: + # find valid RNN subgraphs, and replace them with a single NIRGraph node + graph = _replace_rnn_subgraph_with_nirgraph(graph) + # TODO: right now, the subgraph edges seem to not be parsed correctly - fix this + return nirtorch.load(graph, _nir_to_snntorch_module) diff --git a/snntorch/import_nir_old.py b/snntorch/import_nir_old.py new file mode 100644 index 00000000..f77b9df8 --- /dev/null +++ b/snntorch/import_nir_old.py @@ -0,0 +1,177 @@ +import snntorch as snn +import numpy as np +import torch +import nir +import typing + + +# TODO: implement this? +class ImportedNetwork(torch.nn.Module): + """Wrapper for a snnTorch network. NOTE: not working atm.""" + def __init__(self, module_list): + super().__init__() + self.module_list = module_list + + def forward(self, x): + for module in self.module_list: + x = module(x) + return x + + +def create_snntorch_network(module_list): + return torch.nn.Sequential(*module_list) + + +def _lif_to_snntorch_module( + lif: typing.Union[nir.LIF, nir.CubaLIF] +) -> torch.nn.Module: + """Parse a LIF node into snnTorch.""" + if isinstance(lif, nir.LIF): + assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' + assert np.alltrue(lif.r == 1. - 1. / lif.tau), 'r not supported' + assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons' + threshold = lif.v_threshold[0] + mod = snn.RLeaky( + beta=1. - 1. / lif.tau, + threshold=threshold, + all_to_all=True, + reset_mechanism='zero', + linear_features=lif.tau.shape[0] if len(lif.tau.shape) == 1 else None, + init_hidden=True, + ) + return mod + + elif isinstance(lif, nir.CubaLIF): + assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' + assert np.alltrue(lif.r == 1. - 1. / lif.tau_mem), 'r not supported' # NOTE: is this right? + assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons' + threshold = lif.v_threshold[0] + mod = snn.RSynaptic( + alpha=1. - 1. / lif.tau_syn, + beta=1. - 1. / lif.tau_mem, + threshold=threshold, + all_to_all=True, + reset_mechanism='zero', + linear_features=lif.tau_mem.shape[0] if len(lif.tau_mem.shape) == 1 else None, + init_hidden=True, + ) + return mod + + elif isinstance(lif, nir.LI): + assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' + assert np.allclose(lif.r , 1. - 1. / lif.tau), 'r not supported' + mod = snn.Leaky( + beta=1. - 1. / lif.tau, + reset_mechanism='none', + init_hidden=True, + output=True, + ) + return mod + + else: + raise ValueError('called _lif_to_snntorch_module on non-LIF node') + + +def _to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: + """Convert a NIR node to a snnTorch module. + + Supported NIR nodes: Affine. + """ + if isinstance(node, (nir.LIF, nir.CubaLIF)): + return _lif_to_snntorch_module(node) + + elif isinstance(node, nir.Affine): + if len(node.weight.shape) != 2: + raise NotImplementedError('only 2D weight matrices are supported') + has_bias = node.bias is not None and not np.alltrue(node.bias == 0) + linear = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=has_bias) + linear.weight.data = torch.Tensor(node.weight) + if has_bias: + linear.bias.data = torch.Tensor(node.bias) + return linear + + else: + raise NotImplementedError(f'node type {type(node).__name__} not supported') + + +def _rnn_subgraph_to_snntorch_module( + lif: typing.Union[nir.LIF, nir.CubaLIF], w_rec: typing.Union[nir.Affine, nir.Linear] +) -> torch.nn.Module: + """Parse an RNN subgraph consisting of a LIF node and a recurrent weight matrix into snnTorch. + + NOTE: for now always set it as a recurrent linear layer (not RecurrentOneToOne) + """ + assert isinstance(lif, (nir.LIF, nir.CubaLIF)), 'only LIF or CubaLIF nodes supported as RNNs' + mod = _lif_to_snntorch_module(lif) + mod.recurrent.weight.data = torch.Tensor(w_rec.weight) + if isinstance(w_rec, nir.Linear): + mod.recurrent.register_parameter('bias', None) + mod.recurrent.reset_parameters() + else: + mod.recurrent.bias.data = torch.Tensor(w_rec.bias) + return mod + + +def _get_next_node_key(node_key: str, graph: nir.ir.NIRGraph): + """Get the next node key in the NIR graph.""" + possible_next_node_keys = [edge[1] for edge in graph.edges if edge[0] == node_key] + # possible_next_node_keys += [edge[1] + '.input' for edge in graph.edges if edge[0] == node_key] + assert len(possible_next_node_keys) <= 1, 'branching networks are not supported' + if len(possible_next_node_keys) == 0: + return None + else: + return possible_next_node_keys[0] + + +def from_nir(graph: nir.ir.NIRGraph) -> torch.nn.Module: + """Convert NIR graph to snnTorch module. + + :param graph: a saved snnTorch model as a parameter dictionary + :type graph: nir.ir.NIRGraph + + :return: snnTorch module + :rtype: torch.nn.Module + """ + node_key = 'input' + visited_node_keys = [node_key] + module_list = [] + + while _get_next_node_key(node_key, graph) is not None: + node_key = _get_next_node_key(node_key, graph) + + assert node_key not in visited_node_keys, 'cyclic NIR graphs not supported' + + if node_key == 'output': + visited_node_keys.append(node_key) + continue + + if node_key in graph.nodes: + visited_node_keys.append(node_key) + node = graph.nodes[node_key] + print(f'simple node {node_key}: {type(node).__name__}') + module = _to_snntorch_module(node) + else: + # check if it's a nested node + print(f'potential subgraph node: {node_key}') + sub_node_keys = [n for n in graph.nodes if n.startswith(f'{node_key}.')] + assert len(sub_node_keys) > 0, f'no nodes found for subgraph {node_key}' + + # parse subgraph + # NOTE: for now only looking for RNN subgraphs + rnn_sub_node_keys = [f'{node_key}.{n}' for n in ['input', 'output', 'lif', 'w_rec']] + if set(sub_node_keys) != set(rnn_sub_node_keys): + raise NotImplementedError('only RNN subgraphs are supported') + print('found RNN subgraph') + module = _rnn_subgraph_to_snntorch_module( + graph.nodes[f'{node_key}.lif'], graph.nodes[f'{node_key}.w_rec'] + ) + for nk in sub_node_keys: + visited_node_keys.append(nk) + + module_list.append(module) + + if len(visited_node_keys) != len(graph.nodes): + print(graph.nodes.keys(), visited_node_keys) + raise ValueError('not all nodes visited') + + return create_snntorch_network(module_list) diff --git a/snntorch/import_nirtorch.py b/snntorch/import_nirtorch.py deleted file mode 100644 index 21ef0534..00000000 --- a/snntorch/import_nirtorch.py +++ /dev/null @@ -1,292 +0,0 @@ -import numpy as np -import nir -import nirtorch -import torch -import snntorch as snn - - -def _create_rnn_subgraph(graph: nir.NIRGraph, lif_nk: str, w_nk: str) -> nir.NIRGraph: - """Take a NIRGraph plus the node keys for a LIF and a W_rec, and return a new NIRGraph - which has the RNN subgraph replaced with a subgraph (i.e., a single NIRGraph node). - """ - # NOTE: assuming that the LIF and W_rec have keys of form xyz.abc - sg_key = lif_nk.split('.')[0] # TODO: make this more general? - - # create subgraph for RNN - sg_edges = [ - (lif_nk, w_nk), (w_nk, lif_nk), (lif_nk, f'{sg_key}.output'), (f'{sg_key}.input', w_nk) - ] - sg_nodes = { - lif_nk: graph.nodes[lif_nk], - w_nk: graph.nodes[w_nk], - f'{sg_key}.input': nir.Input(graph.nodes[lif_nk].input_type), - f'{sg_key}.output': nir.Output(graph.nodes[lif_nk].output_type), - } - sg = nir.NIRGraph(nodes=sg_nodes, edges=sg_edges) - - # remove subgraph edges from graph - graph.edges = [e for e in graph.edges if e not in [(lif_nk, w_nk), (w_nk, lif_nk)]] - # remove subgraph nodes from graph - graph.nodes = {k: v for k, v in graph.nodes.items() if k not in [lif_nk, w_nk]} - - # change edges of type (x, lif_nk) to (x, sg_key) - graph.edges = [(e[0], sg_key) if e[1] == lif_nk else e for e in graph.edges] - # change edges of type (lif_nk, x) to (sg_key, x) - graph.edges = [(sg_key, e[1]) if e[0] == lif_nk else e for e in graph.edges] - - # insert subgraph into graph and return - graph.nodes[sg_key] = sg - return graph - - -def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: - """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node.""" - print('replace rnn subgraph with nirgraph') - - if len(set(graph.edges)) != len(graph.edges): - print('[WARNING] duplicate edges found, removing') - graph.edges = list(set(graph.edges)) - - # find cycle of LIF <> Dense nodes - for edge1 in graph.edges: - for edge2 in graph.edges: - if not edge1 == edge2: - if edge1[0] == edge2[1] and edge1[1] == edge2[0]: - lif_nk = edge1[0] - lif_n = graph.nodes[lif_nk] - w_nk = edge1[1] - w_n = graph.nodes[w_nk] - is_lif = isinstance(lif_n, (nir.LIF, nir.CubaLIF)) - is_dense = isinstance(w_n, (nir.Affine, nir.Linear)) - # check if the dense only connects to the LIF - w_out_nk = [e[1] for e in graph.edges if e[0] == w_nk] - w_in_nk = [e[0] for e in graph.edges if e[1] == w_nk] - is_rnn = len(w_out_nk) == 1 and len(w_in_nk) == 1 - # check if we found an RNN - if so, then parse it - if is_rnn and is_lif and is_dense: - graph = _create_rnn_subgraph(graph, edge1[0], edge1[1]) - return graph - - -def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): - """Try parsing the graph as a RNN subgraph. - - Assumes four nodes: Input, Output, LIF | CubaLIF, Affine | Linear - Checks that all nodes have consistent shapes. - Will throw an error if either not all nodes are found or consistent shapes are found. - - Returns: - lif_node: LIF | CubaLIF node - wrec_node: Affine | Linear node - lif_size: int, number of neurons in the RNN - """ - sub_nodes = graph.nodes.values() - assert len(sub_nodes) == 4, 'only 4-node RNN allowed in subgraph' - try: - input_node = [n for n in sub_nodes if isinstance(n, nir.Input)][0] - output_node = [n for n in sub_nodes if isinstance(n, nir.Output)][0] - lif_node = [n for n in sub_nodes if isinstance(n, (nir.LIF, nir.CubaLIF))][0] - wrec_node = [n for n in sub_nodes if isinstance(n, (nir.Affine, nir.Linear))][0] - except IndexError: - raise ValueError('invalid RNN subgraph - could not find all required nodes') - lif_size = list(input_node.input_type.values())[0][0] - assert lif_size == list(output_node.output_type.values())[0][0], 'output size mismatch' - assert lif_size == lif_node.v_threshold.size, 'lif size mismatch (v_threshold)' - assert lif_size == wrec_node.weight.shape[0], 'w_rec shape mismatch' - assert lif_size == wrec_node.weight.shape[1], 'w_rec shape mismatch' - - return lif_node, wrec_node, lif_size - - -def _nir_to_snntorch_module( - node: nir.NIRNode, hack_w_scale=True, init_hidden=False -) -> torch.nn.Module: - if isinstance(node, nir.Input) or isinstance(node, nir.Output): - return None - - elif isinstance(node, nir.Affine): - assert node.bias is not None, 'bias must be specified for Affine layer' - - mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0]) - mod.weight.data = torch.Tensor(node.weight) - mod.bias.data = torch.Tensor(node.bias) - - return mod - - elif isinstance(node, nir.Linear): - mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=False) - mod.weight.data = torch.Tensor(node.weight) - - return mod - - elif isinstance(node, nir.Conv2d): - mod = torch.nn.Conv2d( - node.weight.shape[1], - node.weight.shape[0], - kernel_size=[*node.weight.shape[-2:]], - stride=node.stride, - padding=node.padding, - dilation=node.dilation, - groups=node.groups, - ) - mod.bias.data = torch.Tensor(node.bias) - mod.weight.data = torch.Tensor(node.weight) - return mod - - if isinstance(node, nir.Flatten): - return torch.nn.Flatten(node.start_dim, node.end_dim) - - if isinstance(node, nir.SumPool2d): - return torch.nn.AvgPool2d( - kernel_size=tuple(node.kernel_size), - stride=tuple(node.stride), - padding=tuple(node.padding), - divisor_override=1, # turn AvgPool into SumPool - ) - - elif isinstance(node, nir.IF): - assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' - assert np.unique(node.r).size == 1, 'r must be same for all neurons' - vthr = np.unique(node.v_threshold)[0] - r = np.unique(node.r)[0] - assert r == 1, 'r != 1 not supported' - mod = snn.Leaky( - beta=0.9, - threshold=vthr * r, - init_hidden=False, - reset_delay=False, - ) - return mod - - elif isinstance(node, nir.LIF): - dt = 1e-4 - - assert np.allclose(node.v_leak, 0.), 'v_leak not supported' - assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' - - beta = 1 - (dt / node.tau) - vthr = node.v_threshold - w_scale = node.r * dt / node.tau - - if not np.allclose(w_scale, 1.): - if hack_w_scale: - vthr = vthr / np.unique(w_scale)[0] - print('[warning] scaling weights to avoid scaling inputs') - print(f'w_scale: {w_scale}, r: {node.r}, dt: {dt}, tau: {node.tau}') - else: - raise NotImplementedError('w_scale must be 1, or the same for all neurons') - - assert np.unique(vthr).size == 1, 'LIF v_thr must be same for all neurons' - - return snn.Leaky( - beta=beta, - threshold=np.unique(vthr)[0], - reset_mechanism='zero', - init_hidden=init_hidden, - reset_delay=False, - ) - - elif isinstance(node, nir.CubaLIF): - dt = 1e-4 - - assert np.allclose(node.v_leak, 0), 'v_leak not supported' - assert np.allclose(node.r, node.tau_mem / dt), 'r not supported in CubaLIF' - - alpha = 1 - (dt / node.tau_syn) - beta = 1 - (dt / node.tau_mem) - vthr = node.v_threshold - w_scale = node.w_in * (dt / node.tau_syn) - - if not np.allclose(w_scale, 1.): - if hack_w_scale: - vthr = vthr / w_scale - print('[warning] scaling weights to avoid scaling inputs') - print(f'w_scale: {w_scale}, w_in: {node.w_in}, dt: {dt}, tau_syn: {node.tau_syn}') - else: - raise NotImplementedError('w_scale must be 1, or the same for all neurons') - - assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' - - if np.unique(alpha).size == 1: - alpha = float(np.unique(alpha)[0]) - if np.unique(beta).size == 1: - beta = float(np.unique(beta)[0]) - - return snn.Synaptic( - alpha=alpha, - beta=beta, - threshold=float(np.unique(vthr)[0]), - reset_mechanism='zero', - init_hidden=init_hidden, - reset_delay=False, - ) - - elif isinstance(node, nir.NIRGraph): - lif_node, wrec_node, lif_size = _parse_rnn_subgraph(node) - - if isinstance(lif_node, nir.LIF): - raise NotImplementedError('LIF in subgraph not supported') - - elif isinstance(lif_node, nir.CubaLIF): - dt = 1e-4 - - assert np.allclose(lif_node.v_leak, 0), 'v_leak not supported' - assert np.allclose(lif_node.r, lif_node.tau_mem / dt), 'r not supported in CubaLIF' - - alpha = 1 - (dt / lif_node.tau_syn) - beta = 1 - (dt / lif_node.tau_mem) - vthr = lif_node.v_threshold - w_scale = lif_node.w_in * (dt / lif_node.tau_syn) - - if not np.allclose(w_scale, 1.): - if hack_w_scale: - vthr = vthr / w_scale - print(f'[warning] scaling weights to avoid scaling inputs. w_scale: {w_scale}') - print(f'w_in: {lif_node.w_in}, dt: {dt}, tau_syn: {lif_node.tau_syn}') - else: - raise NotImplementedError('w_scale must be 1, or the same for all neurons') - - assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' - - diagonal = np.array_equal(wrec_node.weight, np.diag(np.diag(wrec_node.weight))) - - if np.unique(alpha).size == 1: - alpha = float(np.unique(alpha)[0]) - if np.unique(beta).size == 1: - beta = float(np.unique(beta)[0]) - - if diagonal: - V = torch.from_numpy(np.diag(wrec_node.weight)).to(dtype=torch.float32) - else: - V = None - - rsynaptic = snn.RSynaptic( - alpha=alpha, - beta=beta, - threshold=float(np.unique(vthr)[0]), - reset_mechanism='zero', - init_hidden=init_hidden, - all_to_all=not diagonal, - linear_features=lif_size, - V=V, - reset_delay=False, - ) - - rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) - if isinstance(wrec_node, nir.Affine): - rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias) - else: - rsynaptic.recurrent.bias.data = torch.zeros_like(rsynaptic.recurrent.bias) - return rsynaptic - - else: - print('[WARNING] could not parse node of type:', node.__class__.__name__) - - return None - - -def from_nir(graph: nir.NIRGraph) -> torch.nn.Module: - # find valid RNN subgraphs, and replace them with a single NIRGraph node - graph = _replace_rnn_subgraph_with_nirgraph(graph) - # TODO: right now, the subgraph edges seem to not be parsed correctly - fix this - return nirtorch.load(graph, _nir_to_snntorch_module) From f962f3cfc8a7bce58e1746b984c4971da02bb63b Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Fri, 26 Jan 2024 20:17:56 -0500 Subject: [PATCH 21/33] minor changes --- snntorch/export_nir.py | 6 +----- snntorch/import_nir.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index 455d0c4d..7af6a3d7 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -24,15 +24,11 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: elif isinstance(module, snn.Synaptic): dt = 1e-4 + # TODO: assert that size of the current layer is correct alpha = module.alpha.detach().numpy() beta = module.beta.detach().numpy() vthr = module.threshold.detach().numpy() - # TODO: make sure alpha, beta, vthr are tensors of same size - alpha = np.ones(7) * alpha - beta = np.ones(7) * beta - vthr = np.ones(7) * vthr - tau_syn = dt / (1 - alpha) tau_mem = dt / (1 - beta) r = tau_mem / dt diff --git a/snntorch/import_nir.py b/snntorch/import_nir.py index 21ef0534..2ebe7ef8 100644 --- a/snntorch/import_nir.py +++ b/snntorch/import_nir.py @@ -288,5 +288,4 @@ def _nir_to_snntorch_module( def from_nir(graph: nir.NIRGraph) -> torch.nn.Module: # find valid RNN subgraphs, and replace them with a single NIRGraph node graph = _replace_rnn_subgraph_with_nirgraph(graph) - # TODO: right now, the subgraph edges seem to not be parsed correctly - fix this return nirtorch.load(graph, _nir_to_snntorch_module) From 208c87f63620af297e1d45cec3f85ee374bcbbfc Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Fri, 26 Jan 2024 20:20:58 -0500 Subject: [PATCH 22/33] + Leaky export --- snntorch/export_nir.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index 7af6a3d7..aa7b8956 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -8,7 +8,20 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: if isinstance(module, snn.Leaky): - raise NotImplementedError('Leaky not supported') + dt = 1e-4 + + beta = module.beta.detach().numpy() + vthr = module.threshold.detach().numpy() + tau_mem = dt / (1 - beta) + r = tau_mem / dt + v_leak = np.zeros_like(beta) + + return nir.LIF( + tau=tau_mem, + v_threshold=vthr, + v_leak=v_leak, + r=r, + ) elif isinstance(module, torch.nn.Linear): if module.bias is None: From d831179ba7901932240ff8e912d81ee5d6bbaf28 Mon Sep 17 00:00:00 2001 From: Steve Abreu Date: Fri, 26 Jan 2024 20:28:48 -0500 Subject: [PATCH 23/33] remove old files --- snntorch/export_nir_old.py | 204 ------------------------------------- snntorch/import_nir_old.py | 177 -------------------------------- 2 files changed, 381 deletions(-) delete mode 100644 snntorch/export_nir_old.py delete mode 100644 snntorch/import_nir_old.py diff --git a/snntorch/export_nir_old.py b/snntorch/export_nir_old.py deleted file mode 100644 index 71a7190a..00000000 --- a/snntorch/export_nir_old.py +++ /dev/null @@ -1,204 +0,0 @@ -from typing import Union, Optional - -import torch -import nir -import numpy as np -from nirtorch import extract_nir_graph - -from snntorch import Leaky, Synaptic, RLeaky, RSynaptic - - -def _create_rnn_subgraph( - module: torch.nn.Module, lif: Union[nir.LIF, nir.CubaLIF], n_neurons=-1 -) -> nir.NIRGraph: - """Create NIR Graph for RNN, from the snnTorch module and the extracted LIF/CubaLIF node.""" - b = None - if module.all_to_all: - lif_shape = module.recurrent.weight.shape[0] - w_rec = module.recurrent.weight.data.detach().numpy() - if module.recurrent.bias is not None: - b = module.recurrent.bias.data.detach().numpy() - else: - if len(module.recurrent.V.shape) == 0 and n_neurons == -1: - lif_shape = None - w_rec = np.eye(1) * module.recurrent.V.data.detach().numpy() - elif n_neurons != -1: - lif_shape = n_neurons - w_rec = np.eye(n_neurons) * module.recurrent.V.data.detach().numpy() - else: - lif_shape = module.recurrent.V.shape[0] - w_rec = np.diag(module.recurrent.V.data.detach().numpy()) - - return nir.NIRGraph( - nodes={ - 'input': nir.Input(input_type=[lif_shape]), - 'lif': lif, - 'w_rec': nir.Linear(weight=w_rec) if b is None else nir.Affine(weight=w_rec, bias=b), - 'output': nir.Output(output_type=[lif_shape]) - }, - edges=[('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output')] - ) - - -def _get_neuron_count(module: torch.nn.Module) -> int: - if isinstance(module, RLeaky) or isinstance(module, RSynaptic): - if module.all_to_all: - return module.linear_features - elif isinstance(module.recurrent.V, torch.Tensor) and len(module.recurrent.V.shape) > 0: - return module.recurrent.V.shape[0] - elif module.init_hidden is True: - return module.mem.shape[0] - else: - # not implemented - return -1 - else: - # not implemented - return -1 - - -# eqn is assumed to be: v_t+1 = (1-1/tau)*v_t + 1/tau * v_leak + I_in / C -def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: - """ - NOTE: it might leave the NIR node of neurons with an incompatible shape. This must be fixed. - """ - if isinstance(module, Leaky): - # TODO - tau = 1 / (1 - module.beta.detach().numpy()) - r = module.beta.detach().numpy() - threshold = module.threshold.detach().numpy() - return nir.LIF( - tau=tau, - v_threshold=threshold, - v_leak=torch.zeros_like(tau), - r=r, - ) - - elif isinstance(module, RSynaptic): - alpha = module.alpha.detach().numpy() - beta = module.beta.detach().numpy() - threshold = module.threshold.detach().numpy() - n_neurons = _get_neuron_count(module) - if len(alpha.shape) == 0 and n_neurons != -1: - alpha = np.ones(n_neurons) * alpha - if len(beta.shape) == 0 and n_neurons != -1: - beta = np.ones(n_neurons) * beta - if len(threshold.shape) == 0 and n_neurons != -1: - threshold = np.ones(n_neurons) * threshold - lif = nir.CubaLIF( - tau_syn=1 / (1 - beta), - tau_mem=1 / (1 - alpha), - v_threshold=threshold, - v_leak=np.zeros_like(beta), - r=beta, - ) - return _create_rnn_subgraph(module, lif, n_neurons=n_neurons) - - elif isinstance(module, RLeaky): - beta = module.beta.detach().numpy() - threshold = module.threshold.detach().numpy() - n_neurons = _get_neuron_count(module) - if len(beta.shape) == 0 and n_neurons != -1: - beta = np.ones(n_neurons) * beta - if len(threshold.shape) == 0 and n_neurons != -1: - threshold = np.ones(n_neurons) * threshold - lif = nir.LIF( - tau=1 / (1 - beta), - v_threshold=threshold, - v_leak=np.zeros_like(beta), - r=beta, - ) - return _create_rnn_subgraph(module, lif, n_neurons=n_neurons) - - elif isinstance(module, Synaptic): - # TODO - n_neurons = module.alpha.shape[0] - return nir.CubaLIF( - tau_syn=1 / (1 - module.alpha).detach().numpy(), - tau_mem=1 / (1 - module.beta).detach().numpy(), - v_threshold=np.ones_like(module.alpha) * module.threshold.detach().numpy(), - v_leak=np.zeros_like(module.beta), - r=module.beta.detach().numpy(), # NOTE: is this right? - ) - - elif isinstance(module, torch.nn.Linear): - if module.bias is None: # Add zero bias if none is present - return nir.Affine( - module.weight.detach().numpy(), np.zeros(*module.weight.shape[:-1]) - ) - else: - return nir.Affine(module.weight.detach().numpy(), module.bias.detach().numpy()) - - else: - print(f'[WARNING] unknown module type: {type(module).__name__} (ignored)') - return None - - -def to_nir( - module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch" -) -> nir.NIRNode: - """Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR). - - Example:: - - import torch, torch.nn as nn - import snntorch as snn - from snntorch import export - - data_path = "untrained-snntorch.pt" - - net = nn.Sequential(nn.Linear(784, 128), - snn.Leaky(beta=0.8, init_hidden=True), - nn.Linear(128, 10), - snn.Leaky(beta=0.8, init_hidden=True, output=True)) - - # save model in pt format - torch.save(net.state_dict(), data_path) - - # load model (does nothing here, but shown for completeness) - net.load_state_dict(torch.load(data_path)) - - # generate input tensor to dynamically construct graph - x = torch.zeros(784) - - # generate NIR graph - nir_net = export.to_nir(net, x) - - - :param module: a saved snnTorch model as a parameter dictionary - :type module: torch.nn.Module - - :param sample_data: sample input data to the model - :type sample_data: torch.Tensor - - :param model_name: name of library used to train model, default: "snntorch" - :type model_name: str, optional - - :return: NIR computational graph where torch modules are represented as NIR nodes - :rtype: NIRGraph - - """ - nir_graph = extract_nir_graph( - module, _extract_snntorch_module, sample_data, model_name=model_name, - ignore_submodules_of=[RLeaky, RSynaptic] - ) - - # NOTE: hack to define subgraph I/O types - for node_key, node in nir_graph.nodes.items(): - inp_type = node.input_type.get('input', [None]) - input_undef = len(inp_type) == 0 or inp_type[0] is None - if isinstance(node, nir.Input) and input_undef and '.' in node_key: - print('WARNING: subgraph input type not set, inferring from previous node') - key = '.'.join(node_key.split('.')[:-1]) - prev_keys = [edge[0] for edge in nir_graph.edges if edge[1] == key] - assert len(prev_keys) == 1, 'multiple previous nodes not supported' - prev_node = nir_graph.nodes[prev_keys[0]] - cur_type = prev_node.output_type['output'] - node.input_type['input'] = cur_type - nir_graph.nodes[f'{key}.output'].output_type['output'] = cur_type - - # NOTE: hack to remove recurrent connections of subgraph to itself - for edge in nir_graph.edges: - if edge[0] not in nir_graph.nodes and edge[1] not in nir_graph.nodes: - nir_graph.edges.remove(edge) - - return nir_graph diff --git a/snntorch/import_nir_old.py b/snntorch/import_nir_old.py deleted file mode 100644 index f77b9df8..00000000 --- a/snntorch/import_nir_old.py +++ /dev/null @@ -1,177 +0,0 @@ -import snntorch as snn -import numpy as np -import torch -import nir -import typing - - -# TODO: implement this? -class ImportedNetwork(torch.nn.Module): - """Wrapper for a snnTorch network. NOTE: not working atm.""" - def __init__(self, module_list): - super().__init__() - self.module_list = module_list - - def forward(self, x): - for module in self.module_list: - x = module(x) - return x - - -def create_snntorch_network(module_list): - return torch.nn.Sequential(*module_list) - - -def _lif_to_snntorch_module( - lif: typing.Union[nir.LIF, nir.CubaLIF] -) -> torch.nn.Module: - """Parse a LIF node into snnTorch.""" - if isinstance(lif, nir.LIF): - assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' - assert np.alltrue(lif.r == 1. - 1. / lif.tau), 'r not supported' - assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons' - threshold = lif.v_threshold[0] - mod = snn.RLeaky( - beta=1. - 1. / lif.tau, - threshold=threshold, - all_to_all=True, - reset_mechanism='zero', - linear_features=lif.tau.shape[0] if len(lif.tau.shape) == 1 else None, - init_hidden=True, - ) - return mod - - elif isinstance(lif, nir.CubaLIF): - assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' - assert np.alltrue(lif.r == 1. - 1. / lif.tau_mem), 'r not supported' # NOTE: is this right? - assert np.unique(lif.v_threshold).size == 1, 'v_threshold must be same for all neurons' - threshold = lif.v_threshold[0] - mod = snn.RSynaptic( - alpha=1. - 1. / lif.tau_syn, - beta=1. - 1. / lif.tau_mem, - threshold=threshold, - all_to_all=True, - reset_mechanism='zero', - linear_features=lif.tau_mem.shape[0] if len(lif.tau_mem.shape) == 1 else None, - init_hidden=True, - ) - return mod - - elif isinstance(lif, nir.LI): - assert np.alltrue(lif.v_leak == 0), 'v_leak not supported' - assert np.allclose(lif.r , 1. - 1. / lif.tau), 'r not supported' - mod = snn.Leaky( - beta=1. - 1. / lif.tau, - reset_mechanism='none', - init_hidden=True, - output=True, - ) - return mod - - else: - raise ValueError('called _lif_to_snntorch_module on non-LIF node') - - -def _to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: - """Convert a NIR node to a snnTorch module. - - Supported NIR nodes: Affine. - """ - if isinstance(node, (nir.LIF, nir.CubaLIF)): - return _lif_to_snntorch_module(node) - - elif isinstance(node, nir.Affine): - if len(node.weight.shape) != 2: - raise NotImplementedError('only 2D weight matrices are supported') - has_bias = node.bias is not None and not np.alltrue(node.bias == 0) - linear = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=has_bias) - linear.weight.data = torch.Tensor(node.weight) - if has_bias: - linear.bias.data = torch.Tensor(node.bias) - return linear - - else: - raise NotImplementedError(f'node type {type(node).__name__} not supported') - - -def _rnn_subgraph_to_snntorch_module( - lif: typing.Union[nir.LIF, nir.CubaLIF], w_rec: typing.Union[nir.Affine, nir.Linear] -) -> torch.nn.Module: - """Parse an RNN subgraph consisting of a LIF node and a recurrent weight matrix into snnTorch. - - NOTE: for now always set it as a recurrent linear layer (not RecurrentOneToOne) - """ - assert isinstance(lif, (nir.LIF, nir.CubaLIF)), 'only LIF or CubaLIF nodes supported as RNNs' - mod = _lif_to_snntorch_module(lif) - mod.recurrent.weight.data = torch.Tensor(w_rec.weight) - if isinstance(w_rec, nir.Linear): - mod.recurrent.register_parameter('bias', None) - mod.recurrent.reset_parameters() - else: - mod.recurrent.bias.data = torch.Tensor(w_rec.bias) - return mod - - -def _get_next_node_key(node_key: str, graph: nir.ir.NIRGraph): - """Get the next node key in the NIR graph.""" - possible_next_node_keys = [edge[1] for edge in graph.edges if edge[0] == node_key] - # possible_next_node_keys += [edge[1] + '.input' for edge in graph.edges if edge[0] == node_key] - assert len(possible_next_node_keys) <= 1, 'branching networks are not supported' - if len(possible_next_node_keys) == 0: - return None - else: - return possible_next_node_keys[0] - - -def from_nir(graph: nir.ir.NIRGraph) -> torch.nn.Module: - """Convert NIR graph to snnTorch module. - - :param graph: a saved snnTorch model as a parameter dictionary - :type graph: nir.ir.NIRGraph - - :return: snnTorch module - :rtype: torch.nn.Module - """ - node_key = 'input' - visited_node_keys = [node_key] - module_list = [] - - while _get_next_node_key(node_key, graph) is not None: - node_key = _get_next_node_key(node_key, graph) - - assert node_key not in visited_node_keys, 'cyclic NIR graphs not supported' - - if node_key == 'output': - visited_node_keys.append(node_key) - continue - - if node_key in graph.nodes: - visited_node_keys.append(node_key) - node = graph.nodes[node_key] - print(f'simple node {node_key}: {type(node).__name__}') - module = _to_snntorch_module(node) - else: - # check if it's a nested node - print(f'potential subgraph node: {node_key}') - sub_node_keys = [n for n in graph.nodes if n.startswith(f'{node_key}.')] - assert len(sub_node_keys) > 0, f'no nodes found for subgraph {node_key}' - - # parse subgraph - # NOTE: for now only looking for RNN subgraphs - rnn_sub_node_keys = [f'{node_key}.{n}' for n in ['input', 'output', 'lif', 'w_rec']] - if set(sub_node_keys) != set(rnn_sub_node_keys): - raise NotImplementedError('only RNN subgraphs are supported') - print('found RNN subgraph') - module = _rnn_subgraph_to_snntorch_module( - graph.nodes[f'{node_key}.lif'], graph.nodes[f'{node_key}.w_rec'] - ) - for nk in sub_node_keys: - visited_node_keys.append(nk) - - module_list.append(module) - - if len(visited_node_keys) != len(graph.nodes): - print(graph.nodes.keys(), visited_node_keys) - raise ValueError('not all nodes visited') - - return create_snntorch_network(module_list) From 38f19cd6186cd883788cb13f32cc0328c712516a Mon Sep 17 00:00:00 2001 From: Steven Abreu Date: Sat, 3 Feb 2024 08:37:41 -0800 Subject: [PATCH 24/33] add docstrings --- snntorch/export_nir.py | 64 ++++++++++++++++++++++- snntorch/import_nir.py | 113 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 163 insertions(+), 14 deletions(-) diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index aa7b8956..5ad28de2 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -7,6 +7,21 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: + """Convert a single snnTorch module to the equivalent object in the Neuromorphic + Intermediate Representation (NIR). This function is used internally by the export_to_nir + function to convert each submodule/layer of the network to the NIR. + + Currently supported snnTorch modules: Leaky, Linear, Synaptic, RLeaky, RSynaptic. + + Note that recurrent layers such as RLeaky and RSynaptic will be converted to a NIR graph, + which will then be embedded as a subgraph into the main NIR graph. + + :param module: snnTorch module + :type module: torch.nn.Module + + :return: return the NIR node + :rtype: Optional[nir.NIRNode] + """ if isinstance(module, snn.Leaky): dt = 1e-4 @@ -58,6 +73,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: ) elif isinstance(module, snn.RLeaky): + # TODO(stevenabreu7): implement RLeaky raise NotImplementedError('RLeaky not supported') elif isinstance(module, snn.RSynaptic): @@ -108,11 +124,55 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: return None -def to_nir( +def export_to_nir( module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch", model_fwd_args=[], ignore_dims=[] ) -> nir.NIRNode: - """Convert an snnTorch model to the Neuromorphic Intermediate Representation (NIR). + """Convert an snnTorch module to the Neuromorphic Intermediate Representation (NIR). + This function uses nirtorch to extract the computational graph of the torch module, + and the _extract_snntorch_module method is used to convert each module in the graph + to the corresponding NIR module. + + The NIR is a graph-based representation of a spiking neural network, which can be used to + port the network to different neuromorphic hardware and software platforms. + + Example:: + + import snntorch as snn + import torch + from snntorch import export_to_nir + + lif1 = snn.Leaky(beta=0.9, init_hidden=True) + lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) + + net = torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2 + ) + + sample_data = torch.randn(1, 784) + nir_graph = export_to_nir(net, sample_data) + + :param module: Network model (either wrapped in Sequential container or as a class) + :type module: torch.nn.Module + + :param sample_data: Sample input data to the network + :type sample_data: torch.Tensor + + :param model_name: Name of the model + :type model_name: str, optional + + :param model_fwd_args: Arguments to pass to the forward function of the model + :type model_fwd_args: list, optional + + :param ignore_dims: List of dimensions to ignore when extracting the NIR + :type ignore_dims: list, optional + + :return: return the NIR graph + :rtype: nir.NIRNode """ nir_graph = nirtorch.extract_nir_graph( module, _extract_snntorch_module, sample_data, model_name=model_name, diff --git a/snntorch/import_nir.py b/snntorch/import_nir.py index 2ebe7ef8..2d32c5fa 100644 --- a/snntorch/import_nir.py +++ b/snntorch/import_nir.py @@ -8,6 +8,26 @@ def _create_rnn_subgraph(graph: nir.NIRGraph, lif_nk: str, w_nk: str) -> nir.NIRGraph: """Take a NIRGraph plus the node keys for a LIF and a W_rec, and return a new NIRGraph which has the RNN subgraph replaced with a subgraph (i.e., a single NIRGraph node). + + The subgraph will have the following structure: + ``` + LIF -> W_rec -> LIF + ^ | + | v + Input Output + ``` + + :param graph: NIRGraph + :type graph: nir.NIRGraph + + :param lif_nk: key for the LIF node + :type lif_nk: str + + :param w_nk: key for the W_rec node + :type w_nk: str + + :return: NIRGraph with the RNN subgraph replaced with a single NIRGraph node + :rtype: nir.NIRGraph """ # NOTE: assuming that the LIF and W_rec have keys of form xyz.abc sg_key = lif_nk.split('.')[0] # TODO: make this more general? @@ -40,7 +60,16 @@ def _create_rnn_subgraph(graph: nir.NIRGraph, lif_nk: str, w_nk: str) -> nir.NIR def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: - """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node.""" + """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node. + Goes through the NIRGraph to find any RNN subgraphs, and replaces them with a single NIRGraph node, + using the _create_rnn_subgraph function. + + :param graph: NIRGraph + :type graph: nir.NIRGraph + + :return: NIRGraph with RNN subgraphs replaced with a single NIRGraph node + :rtype: nir.NIRGraph + """ print('replace rnn subgraph with nirgraph') if len(set(graph.edges)) != len(graph.edges): @@ -69,16 +98,28 @@ def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): - """Try parsing the graph as a RNN subgraph. - - Assumes four nodes: Input, Output, LIF | CubaLIF, Affine | Linear - Checks that all nodes have consistent shapes. - Will throw an error if either not all nodes are found or consistent shapes are found. - - Returns: - lif_node: LIF | CubaLIF node - wrec_node: Affine | Linear node - lif_size: int, number of neurons in the RNN + """Try parsing the presented graph as a RNN subgraph. Assumes the graph is a valid RNN subgraph + with four nodes in the following structure: + + ``` + Input -> LIF | CubaLIF -> Output + ^ + | + v + Affine | Linear + ``` + + :param graph: NIRGraph + :type graph: nir.NIRGraph + + :return: LIF | CubaLIF node + :rtype: nir.NIRNode + + :return: Affine | Linear node + :rtype: nir.NIRNode + + :return: int, number of neurons in the RNN + :rtype: int """ sub_nodes = graph.nodes.values() assert len(sub_nodes) == 4, 'only 4-node RNN allowed in subgraph' @@ -101,6 +142,20 @@ def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): def _nir_to_snntorch_module( node: nir.NIRNode, hack_w_scale=True, init_hidden=False ) -> torch.nn.Module: + """Convert a NIR node to a snnTorch module. This function is used by the import_from_nir function. + + :param node: NIR node + :type node: nir.NIRNode + + :param hack_w_scale: if True, then the function will attempt to scale the weights to avoid scaling the inputs + :type hack_w_scale: bool + + :param init_hidden: the init_hidden flag of the snnTorch neuron. + :type init_hidden: bool + + :return: snnTorch module + :rtype: torch.nn.Module + """ if isinstance(node, nir.Input) or isinstance(node, nir.Output): return None @@ -285,7 +340,41 @@ def _nir_to_snntorch_module( return None -def from_nir(graph: nir.NIRGraph) -> torch.nn.Module: +def import_from_nir(graph: nir.NIRGraph) -> torch.nn.Module: + """Convert a NIRGraph to a snnTorch module. This function is the inverse of export_to_nir. + It proceeds by wrapping any recurrent connections into NIR sub-graphs, then converts each + NIR module into the equivalent snnTorch module, and wraps them into a torch.nn.Module + using the generic GraphExecutor from NIRTorch to execute all modules in the right order. + + Example:: + + import snntorch as snn + import torch + from snntorch import export_to_nir, import_from_nir + + lif1 = snn.Leaky(beta=0.9, init_hidden=True) + lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) + + net = torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2 + ) + + sample_data = torch.randn(1, 784) + nir_graph = export_to_nir(net, sample_data, model_name="snntorch") + + net2 = import_from_nir(nir_graph) + + :param graph: NIR graph + :type graph: NIR.NIRGraph + + :return: snnTorch network + :rtype: torch.nn.Module + """ # find valid RNN subgraphs, and replace them with a single NIRGraph node graph = _replace_rnn_subgraph_with_nirgraph(graph) + # convert the NIR graph into a torch.nn.Module using snnTorch modules return nirtorch.load(graph, _nir_to_snntorch_module) From 04b1d024d2cba1570a73edfd4b3e4c7f88589d50 Mon Sep 17 00:00:00 2001 From: Steven Abreu Date: Sat, 3 Feb 2024 09:52:36 -0800 Subject: [PATCH 25/33] rename files --- snntorch/{export_nir.py => export.py} | 0 snntorch/{import_nir.py => import.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename snntorch/{export_nir.py => export.py} (100%) rename snntorch/{import_nir.py => import.py} (100%) diff --git a/snntorch/export_nir.py b/snntorch/export.py similarity index 100% rename from snntorch/export_nir.py rename to snntorch/export.py diff --git a/snntorch/import_nir.py b/snntorch/import.py similarity index 100% rename from snntorch/import_nir.py rename to snntorch/import.py From bdb15a37f68d469cbe4564c07f503bbff67e5a3f Mon Sep 17 00:00:00 2001 From: Steven Abreu Date: Sat, 3 Feb 2024 09:54:51 -0800 Subject: [PATCH 26/33] test suggestions --- tests/test_nir.py | 100 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 tests/test_nir.py diff --git a/tests/test_nir.py b/tests/test_nir.py new file mode 100644 index 00000000..c824501f --- /dev/null +++ b/tests/test_nir.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python + +"""Tests for NIR import and export.""" + +import pytest +import snntorch as snn +import torch + + +@pytest.fixture(scope="module") +def snntorch_sequential(): + lif1 = snn.Leaky(beta=0.9, init_hidden=False) + lif2 = snn.Leaky(beta=0.9, init_hidden=False, output=True) + + return torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2 + ) + + +@pytest.fixture(scope="module") +def snntorch_sequential_hidden(): + lif1 = snn.Leaky(beta=0.9, init_hidden=True) + lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) + + return torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2 + ) + + +@pytest.fixture(scope="module") +def snntorch_recurrent(): + lif1 = snn.RLeaky(beta=0.9, V=1, all_to_all=True, init_hidden=False) + lif2 = snn.Leaky(beta=0.9, init_hidden=False, output=True) + + return torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2 + ) + + +@pytest.fixture(scope="module") +def snntorch_recurrent_hidden(): + lif1 = snn.RLeaky(beta=0.9, V=1, all_to_all=True, init_hidden=True) + lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) + + return torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(784, 500), + lif1, + torch.nn.Linear(500, 10), + lif2 + ) + + +class NIRTestExport: + """Test exporting snnTorch network to NIR.""" + def test_export_sequential(snntorch_sequential): + pass + + def test_export_sequential_hidden(snntorch_sequential_hidden): + pass + + def test_export_recurrent(snntorch_recurrent): + pass + + def test_export_recurrent_hidden(snntorch_recurrent_hidden): + pass + + +class NIRTestImport: + """Test importing NIR graph to snnTorch.""" + def test_import_nir(): + # load a NIR graph from a file? + pass + + +class NIRTestCommute: + """Test that snnTorch -> NIR -> snnTorch doesn't change the network.""" + def test_commute_sequential(snntorch_sequential): + pass + + def test_commute_sequential_hidden(snntorch_sequential_hidden): + pass + + def test_commute_recurrent(snntorch_recurrent): + pass + + def test_commute_recurrent_hidden(snntorch_recurrent_hidden): + pass From 3710dfbf382677ee46722cb1534a5357f2b8a809 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 3 Feb 2024 17:41:51 -0800 Subject: [PATCH 27/33] leaky syntax change --- snntorch/_neurons/leaky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snntorch/_neurons/leaky.py b/snntorch/_neurons/leaky.py index d51d5dc1..3e1371e1 100644 --- a/snntorch/_neurons/leaky.py +++ b/snntorch/_neurons/leaky.py @@ -162,7 +162,7 @@ def __init__( self.reset_delay = reset_delay if not self.reset_delay and self.init_hidden: - raise NotImplementedError('no reset_delay only supported for init_hidden=False') + raise NotImplementedError("`reset_delay=True` is only supported for `init_hidden=False`") if self.init_hidden: self.mem = self.init_leaky() From 9adbcef765a5547bd3a42e5cab113ae99ab4b8aa Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 3 Feb 2024 18:02:34 -0800 Subject: [PATCH 28/33] solve conflicts for updated leaky neuron forward-pass --- snntorch/_neurons/leaky.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/snntorch/_neurons/leaky.py b/snntorch/_neurons/leaky.py index bcc91906..11ac3f8a 100644 --- a/snntorch/_neurons/leaky.py +++ b/snntorch/_neurons/leaky.py @@ -106,6 +106,15 @@ def forward(self, x, mem1, spk1, mem2): returned when neuron is called. Defaults to False :type output: bool, optional + :param graded_spikes_factor: output spikes are scaled this value, if specified. Defaults to 1.0 + :type graded_spikes_factor: float or torch.tensor + + :param learn_graded_spikes_factor: Option to enable learnable graded spikes. Defaults to False + :type learn_graded_spikes_factor: bool, optional + + :param reset_delay: If `True`, a spike is returned with a one-step delay after the threshold is reached. + Defaults to True + :type reset_delay: bool, optional Inputs: \\input_, mem_0 - **input_** of shape `(batch, input_size)`: tensor containing input From eec1045957f1ff1c0c1f1c4470ecc9a404d12eb4 Mon Sep 17 00:00:00 2001 From: Jason Eshraghian Date: Sat, 3 Feb 2024 18:15:14 -0800 Subject: [PATCH 29/33] membrane potential init bug fix --- snntorch/_neurons/leaky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snntorch/_neurons/leaky.py b/snntorch/_neurons/leaky.py index 11ac3f8a..7406fa82 100644 --- a/snntorch/_neurons/leaky.py +++ b/snntorch/_neurons/leaky.py @@ -197,7 +197,7 @@ def init_leaky(self): self.reset_mem() return self.mem - def forward(self, input_, mem=False): + def forward(self, input_, mem=None): if not mem == None: self.mem = mem From b6dc7d998ca9cc78e8bf8923cdf29d35a0abdb52 Mon Sep 17 00:00:00 2001 From: Steven Abreu Date: Sun, 4 Feb 2024 20:07:20 -0800 Subject: [PATCH 30/33] rename again (import.py is invalid) --- snntorch/__init__.py | 2 ++ snntorch/{export.py => export_nir.py} | 0 snntorch/{import.py => import_nir.py} | 0 3 files changed, 2 insertions(+) rename snntorch/{export.py => export_nir.py} (100%) rename snntorch/{import.py => import_nir.py} (100%) diff --git a/snntorch/__init__.py b/snntorch/__init__.py index ccea82c0..0482a113 100644 --- a/snntorch/__init__.py +++ b/snntorch/__init__.py @@ -1,3 +1,5 @@ from ._version import __version__ from ._neurons import * from ._layers import * +from .export_nir import export_to_nir +from .import_nir import import_from_nir \ No newline at end of file diff --git a/snntorch/export.py b/snntorch/export_nir.py similarity index 100% rename from snntorch/export.py rename to snntorch/export_nir.py diff --git a/snntorch/import.py b/snntorch/import_nir.py similarity index 100% rename from snntorch/import.py rename to snntorch/import_nir.py From 11684cf9a3e68721b27add595e0be470e550fd26 Mon Sep 17 00:00:00 2001 From: Steven Abreu Date: Sun, 4 Feb 2024 20:07:27 -0800 Subject: [PATCH 31/33] add lif nir graph --- tests/lif.nir | Bin 0 -> 17584 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/lif.nir diff --git a/tests/lif.nir b/tests/lif.nir new file mode 100644 index 0000000000000000000000000000000000000000..ce5f8e1c230f01b93edf3678479ebcfebfbfaa69 GIT binary patch literal 17584 zcmeHO&2JM&6o2awYSMztDGd^0q+Z;Lf}yB&TM!8(;7Accw1?ahqZn)vB~jv}?g@_O zmSc`N_SiomM-CuyeQ##=?d$b zm&S-wHcMwH7iRJIQPbbkQiC$umoyV7Z7TgMBbk8sNA$JUU)26xzTcnB%#h_o|Dbvq z70yU5iwaLq6#}K1smHAFf@%$WdZTs*Ft#fjo3*u7T3uVNa2@P&5)6IPfJUVc?I<9z zK=eCeEt~TbK@eItdyljU->Bv1awa3yS@nbwbmaIO(Qcm-9xwZEkLITX%0yB(MrC~9 zv8Uyn-XQp^8981#(>_|adC|o$X^xKUoO!p*@yySlzrxBR+%(^;G6H1+IR zZM8ync9n$d=4Ky+72pG%->QcleWCo2i)G;iIo$el(j1Murv1o<2Xi%1LHWK^ihfZ4 zmDWSvn(C$1^!@^Jo_Yj40v-X6fJeY1;1Tc$cmzBG9s!SlMLs@S>aMLK=0QDQJNDt5Ux}pc-P^Z!+OOk0`fDGDW9?jzemExIUgp?+dzqKK z_A=|xJ`Uai*lM3V0v-X6fJeY1;1TG6fV&^YV`4$~qqqdm|C`lv7zp6?m(yhU-~XSS zD{7rS+`l4ypkH@C1G)Zj_&V)pxTl8htI?mk&o(j~HSS0z;?weexVurN6qBQiqZlJ< zBhmm6c9BUXbS)8dhv(eKt}VwbD!L+fI|MJ z+cKQxxLxH#MFW|}BYTCNxfJK*FPbZQ7Szyqmd0NyT;11&k`Wf1Ak_<~fy{}}I37vs zq4-_@`p+3r(AB?rB6@clG-*Bj_b0hhtfpHpzJ6E{aMweuYfX(iLvO5y_3~C&wV@5$ ziz}7#3!?4CdUd0+SzTLM4qGtK80B+G+)rhO$90&O`u~3Cu?~rq03#2=dN>-rS1$bm zxjR#6SNyqiJ8A`0q66oCq8-#|X2HG{=#?-}RB!j Date: Sun, 4 Feb 2024 20:40:50 -0800 Subject: [PATCH 32/33] tests --- snntorch/import_nir.py | 22 ++++--- tests/test_nir.py | 139 ++++++++++++++++++++++------------------- 2 files changed, 88 insertions(+), 73 deletions(-) diff --git a/snntorch/import_nir.py b/snntorch/import_nir.py index 2d32c5fa..5314adf7 100644 --- a/snntorch/import_nir.py +++ b/snntorch/import_nir.py @@ -322,22 +322,28 @@ def _nir_to_snntorch_module( reset_mechanism='zero', init_hidden=init_hidden, all_to_all=not diagonal, - linear_features=lif_size, - V=V, + linear_features=lif_size if not diagonal else None, + V=V if diagonal else None, reset_delay=False, ) - rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) - if isinstance(wrec_node, nir.Affine): - rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias) + if isinstance(rsynaptic.recurrent, torch.nn.Linear): + rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) + if isinstance(wrec_node, nir.Affine): + rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias) + else: + rsynaptic.recurrent.bias.data = torch.zeros_like(rsynaptic.recurrent.bias) else: - rsynaptic.recurrent.bias.data = torch.zeros_like(rsynaptic.recurrent.bias) + rsynaptic.recurrent.V.data = torch.diagonal(torch.Tensor(wrec_node.weight)) + return rsynaptic + elif node is None: + return torch.nn.Identity() + else: print('[WARNING] could not parse node of type:', node.__class__.__name__) - - return None + return None def import_from_nir(graph: nir.NIRGraph) -> torch.nn.Module: diff --git a/tests/test_nir.py b/tests/test_nir.py index c824501f..0fa1f635 100644 --- a/tests/test_nir.py +++ b/tests/test_nir.py @@ -2,99 +2,108 @@ """Tests for NIR import and export.""" +import nir import pytest import snntorch as snn import torch @pytest.fixture(scope="module") -def snntorch_sequential(): - lif1 = snn.Leaky(beta=0.9, init_hidden=False) - lif2 = snn.Leaky(beta=0.9, init_hidden=False, output=True) - - return torch.nn.Sequential( - torch.nn.Flatten(), - torch.nn.Linear(784, 500), - lif1, - torch.nn.Linear(500, 10), - lif2 - ) +def sample_data(): + return torch.ones((4, 784)) @pytest.fixture(scope="module") -def snntorch_sequential_hidden(): +def snntorch_sequential(): lif1 = snn.Leaky(beta=0.9, init_hidden=True) lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) return torch.nn.Sequential( - torch.nn.Flatten(), torch.nn.Linear(784, 500), lif1, torch.nn.Linear(500, 10), - lif2 + lif2, ) @pytest.fixture(scope="module") def snntorch_recurrent(): - lif1 = snn.RLeaky(beta=0.9, V=1, all_to_all=True, init_hidden=False) - lif2 = snn.Leaky(beta=0.9, init_hidden=False, output=True) - - return torch.nn.Sequential( - torch.nn.Flatten(), - torch.nn.Linear(784, 500), - lif1, - torch.nn.Linear(500, 10), - lif2 + v = torch.ones((500,)) + lif1 = snn.RSynaptic( + alpha=0.5, beta=0.9, V=v, all_to_all=False, init_hidden=True ) - - -@pytest.fixture(scope="module") -def snntorch_recurrent_hidden(): - lif1 = snn.RLeaky(beta=0.9, V=1, all_to_all=True, init_hidden=True) lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) return torch.nn.Sequential( - torch.nn.Flatten(), torch.nn.Linear(784, 500), lif1, torch.nn.Linear(500, 10), - lif2 + lif2, ) -class NIRTestExport: - """Test exporting snnTorch network to NIR.""" - def test_export_sequential(snntorch_sequential): - pass - - def test_export_sequential_hidden(snntorch_sequential_hidden): - pass - - def test_export_recurrent(snntorch_recurrent): - pass - - def test_export_recurrent_hidden(snntorch_recurrent_hidden): - pass - - -class NIRTestImport: - """Test importing NIR graph to snnTorch.""" - def test_import_nir(): - # load a NIR graph from a file? - pass - - -class NIRTestCommute: - """Test that snnTorch -> NIR -> snnTorch doesn't change the network.""" - def test_commute_sequential(snntorch_sequential): - pass - - def test_commute_sequential_hidden(snntorch_sequential_hidden): - pass - - def test_commute_recurrent(snntorch_recurrent): - pass - - def test_commute_recurrent_hidden(snntorch_recurrent_hidden): - pass +class TestNIR: + """Test import and export from snnTorch to NIR.""" + + def test_export_sequential(self, snntorch_sequential, sample_data): + nir_graph = snn.export_to_nir(snntorch_sequential, sample_data) + assert nir_graph is not None + assert set(nir_graph.nodes.keys()) == set( + ["input", "output"] + [str(i) for i in range(4)] + ), nir_graph.nodes.keys() + assert set(nir_graph.edges) == set( + [ + ("3", "output"), + ("input", "0"), + ("2", "3"), + ("1", "2"), + ("0", "1"), + ] + ) + assert isinstance(nir_graph.nodes["input"], nir.Input) + assert isinstance(nir_graph.nodes["output"], nir.Output) + assert isinstance(nir_graph.nodes["0"], nir.Affine) + assert isinstance(nir_graph.nodes["1"], nir.LIF) + assert isinstance(nir_graph.nodes["2"], nir.Affine) + assert isinstance(nir_graph.nodes["3"], nir.LIF) + + def test_export_recurrent(self, snntorch_recurrent, sample_data): + nir_graph = snn.export_to_nir(snntorch_recurrent, sample_data) + assert nir_graph is not None + assert set(nir_graph.nodes.keys()) == set( + ["input", "output", "0", "1.lif", "1.w_rec", "2", "3"] + ), nir_graph.nodes.keys() + assert isinstance(nir_graph.nodes["input"], nir.Input) + assert isinstance(nir_graph.nodes["output"], nir.Output) + assert isinstance(nir_graph.nodes["0"], nir.Affine) + assert isinstance(nir_graph.nodes["1.lif"], nir.CubaLIF) + assert isinstance(nir_graph.nodes["1.w_rec"], nir.Linear) + assert isinstance(nir_graph.nodes["2"], nir.Affine) + assert isinstance(nir_graph.nodes["3"], nir.LIF) + assert set(nir_graph.edges) == set( + [ + ("1.lif", "1.w_rec"), + ("1.w_rec", "1.lif"), + ("0", "1.lif"), + ("3", "output"), + ("2", "3"), + ("input", "0"), + ("1.lif", "2"), + ] + ) + + def test_import_nir(self): + graph = nir.read("tests/lif.nir") + net = snn.import_from_nir(graph) + out, _ = net(torch.ones(1, 1)) + assert out.shape == (1, 1), out.shape + + def test_commute_sequential(self, snntorch_sequential, sample_data): + x = torch.rand((4, 784)) + y_snn, state = snntorch_sequential(x) + assert y_snn.shape == (4, 10) + nir_graph = snn.export_to_nir(snntorch_sequential, sample_data) + net = snn.import_from_nir(nir_graph) + y_nir, state = net(x) + assert y_nir.shape == (4, 10), y_nir.shape + assert torch.allclose(y_snn, y_nir) From 40ad1360c644f948889cb9610121397c6ce943c8 Mon Sep 17 00:00:00 2001 From: Steven Abreu Date: Sun, 4 Feb 2024 20:42:54 -0800 Subject: [PATCH 33/33] missing features into docstring --- snntorch/export_nir.py | 3 +++ snntorch/import_nir.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index 5ad28de2..4f0a1614 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -136,6 +136,9 @@ def export_to_nir( The NIR is a graph-based representation of a spiking neural network, which can be used to port the network to different neuromorphic hardware and software platforms. + Missing features: + - RLeaky + Example:: import snntorch as snn diff --git a/snntorch/import_nir.py b/snntorch/import_nir.py index 5314adf7..154a6a17 100644 --- a/snntorch/import_nir.py +++ b/snntorch/import_nir.py @@ -352,6 +352,9 @@ def import_from_nir(graph: nir.NIRGraph) -> torch.nn.Module: NIR module into the equivalent snnTorch module, and wraps them into a torch.nn.Module using the generic GraphExecutor from NIRTorch to execute all modules in the right order. + Missing features: + - RLeaky (LIF inside RNN) + Example:: import snntorch as snn