diff --git a/.github/workflows/build-tag.yml b/.github/workflows/build-tag.yml index 2c9d91c4..a77344e4 100644 --- a/.github/workflows/build-tag.yml +++ b/.github/workflows/build-tag.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11'] + python-version: ['3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v2 @@ -41,4 +41,4 @@ jobs: # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics # - name: Test with pytest # run: | - # pytest \ No newline at end of file + # pytest diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 0f26492f..3e7e0ff6 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -102,7 +102,7 @@ Before you submit a pull request, check that it meets these guidelines: 2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. -3. The pull request should work for Python 3.8, 3.9, 3.10, 3.11 and for PyPy. Check +3. The pull request should work for Python 3.9, 3.10, 3.11 and for PyPy. Check https://github.com/jeshraghian/snntorch/actions and make sure that the tests pass for all supported Python versions. diff --git a/examples/dataloaders/DVS_Gesture.ipynb b/examples/dataloaders/DVS_Gesture.ipynb index c8110818..1a15f9cd 100644 --- a/examples/dataloaders/DVS_Gesture.ipynb +++ b/examples/dataloaders/DVS_Gesture.ipynb @@ -25,41 +25,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "DATADIR = \"/tmp/data\"" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Download Dataset using `spikedata` (deprecated)" - ] - }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "id": "kbHJ827iVcYY" }, - "outputs": [], + "outputs": [ + { + "ename": "AttributeError", + "evalue": "module 'torch' has no attribute '_six'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_2028996/4075440975.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# # note that a default transform is already applied to keep things easy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msnntorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspikevision\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mspikedata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mtrain_ds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspikedata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDVSGesture\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/tmp/data/dvsgesture/\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m500\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# ds: spatial compression; dt: temporal compressiondvs_test\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mtest_ds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspikedata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDVSGesture\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/tmp/data/dvsgesture/\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1800\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/snntorch/spikevision/spikedata/dvs_gesture.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, train, transform, target_transform, download_and_create, num_steps, dt, ds, return_meta, time_shuffle)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0mtarget_transform\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCompose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mRepeat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoOneHot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m11\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m super(DVSGesture, self).__init__(\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mroot\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"/\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhdf5_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/snntorch/spikevision/neuromorphic_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, transforms, transform, target_transform, transform_train, transform_test, target_transform_train, target_transform_test)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mtarget_transform_test\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m ):\n\u001b[0;32m--> 145\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_six\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstring_classes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mroot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpanduser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: module 'torch' has no attribute '_six'" + ] + } + ], "source": [ "# from snntorch.spikevision import spikedata \n", "# # note that a default transform is already applied to keep things easy\n", - "\n", - "# train_ds = spikedata.DVSGesture(\"/tmp/data/dvsgesture\", train=True, dt=1000, num_steps=500, ds=1) # ds: spatial compression; dt: temporal compressiondvs_test\n", - "# test_ds = spikedata.DVSGesture(\"/tmp/data/dvsgesture\", train=False, dt=1000, num_steps=1800, ds=1)\n", - "# test_ds" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Download Dataset using `tonic`" + "from snntorch.spikevision import spikedata\n", + "train_ds = spikedata.DVSGesture(\"/tmp/data/dvsgesture/\", train=True, dt=1000, num_steps=500, ds=1) # ds: spatial compression; dt: temporal compressiondvs_test\n", + "test_ds = spikedata.DVSGesture(\"/tmp/data/dvsgesture/\", train=False, dt=1000, num_steps=1800, ds=1)\n" ] }, { @@ -225,7 +224,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/model_params.json b/examples/model_params.json new file mode 100644 index 00000000..bbf6a736 --- /dev/null +++ b/examples/model_params.json @@ -0,0 +1,11 @@ +{ + "nb_hidden": 55, + "alpha_r": 0.75, + "alpha_out": 0.45, + "beta_r": 0.85, + "beta_out": 0.7, + "lr": 0.001, + "reg_l1": 0.001, + "reg_l2": 0.000001, + "slope": 5 +} \ No newline at end of file diff --git a/examples/scnn_mnist.nir b/examples/scnn_mnist.nir new file mode 100644 index 00000000..89e0a618 Binary files /dev/null and b/examples/scnn_mnist.nir differ diff --git a/examples/test.py b/examples/test.py new file mode 100644 index 00000000..206123e1 --- /dev/null +++ b/examples/test.py @@ -0,0 +1,903 @@ +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import List + +import nir +import numpy as np + +import spinnaker2 +from spinnaker2 import ann2snn_helpers, hardware, snn + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class HardwareConstraintError(Exception): + pass + + +class NotFoundError(Exception): + pass + + +NEURON_NODES = (nir.LIF, nir.IF, nir.CubaLIF) +SPIKE_SOURCE_NODES = (nir.LIF, nir.IF, nir.CubaLIF, nir.Input) +WEIGHT_NODES = (nir.Affine, nir.Conv2d, nir.Conv2d, nir.Linear) + + +class IntegratorMethod(str, Enum): + FORWARD = "Forward-Euler" + EXPONENTIAL = "Exponential-Euler" + + +class ResetMethod(str, Enum): + SUBTRACT = "subtract" # subtract threshold + ZERO = "zero" # reset voltage to zero + + +def get_s2_reset_method(reset: ResetMethod): + """get SpiNNaker2 reset method string.""" + if reset == ResetMethod.SUBTRACT: + return "reset_by_subtraction" + elif reset == ResetMethod.ZERO: + return "reset_to_v_reset" + else: + raise Exception("Unsupported ResetMethod") + + +@dataclass +class ConversionConfig: + """NIR-to-SpiNNaker2-conversion configuration. + + Attributes: + dt: discretization timestep in seconds + output_record: list of variables to record from output populations. + Supported: ["spikes", "v"]. Default: `["spikes"]`. + conn_delay: connection delay in timesteps for creating :spinnaker2.snn.Projection:s. + scale_weights: if True, scale weights to maximum dynamic range [-127, 127], + else don't scale weights + integrator: numerical integration method for SpiNNaker2. This affects how + parameters, especially time constants are translated to the SpiNNaker2 + neuron models. + Available options: + `IntegratorMethod.EXPONENTIAL`: First-order exponential Euler (default) + `IntegratorMethod.FORWARD`: First-order forward Euler + reset: neuron reset mechanism for SpiNNaker2. This affects how the + membrane voltage is reset after a spike is detected. + Available options: + `ResetMethod.SUBTRACT`: subtract threshold from voltage (default) + `ResetMethod.ZERO`: reset voltage to 0.0 + """ + + dt: float = 1.0 + output_record: List[str] = field(default_factory=lambda: ["spikes"]) + conn_delay: int = 1 + scale_weights: bool = False + integrator: IntegratorMethod = IntegratorMethod.EXPONENTIAL + reset: ResetMethod = ResetMethod.SUBTRACT + + +def add_output_to_node(node_name, nir_model, output_name): + assert node_name in nir_model.nodes.keys() + assert output_name not in nir_model.nodes.keys() + node = nir_model.nodes[node_name] + output_node = nir.Output(output_type=node.output_type) + nir_model.nodes[output_name] = output_node + nir_model.edges.append((node_name, output_name)) + + +def recurse_layer_shapes(name, node, nir_model): + targets = get_outgoing_nodes(name, nir_model) + print(name, type(node), node.input_type, node.output_type, "->", [name for name, _ in targets]) + for name, node in targets: + recurse_layer_shapes(name, node, nir_model) + + +def model_summary(nir_model): + inputs = [(name, n) for name, n in nir_model.nodes.items() if isinstance(n, nir.Input)] + assert len(inputs) == 1 + inp = inputs[0] + recurse_layer_shapes(inp[0], inp[1], nir_model) + + +def replace_sumpool2d_by_sumpool2d_if(nir_model): + nodes = nir_model.nodes + edges = nir_model.edges + edges_to_remove = [] + edges_to_add = [] + nodes_to_add = {} + for name, node in nodes.items(): + if isinstance(node, nir.SumPool2d): + outgoing_nodes = get_outgoing_nodes(name, nir_model) + edges = get_outgoing_edges(name, nir_model) + old_edges = [] + for edge_idx, edge in enumerate(edges): + if not isinstance(edge, (nir.LIF, nir.IF)): + old_edges.append(edge) + print("removing edge ", edge) + + if len(old_edges) > 0: + shape = node.output_type["output"] + # new_if_node = nir.IF(r=np.ones(shape), v_threshold=np.ones(node.output_type["output"])) + new_if_node = nir.IF(r=np.ones(shape), v_threshold=np.ones(shape)) + new_if_name = f"{name}_if" + nodes_to_add[new_if_name] = new_if_node + edges_to_add.append((name, new_if_name)) + for edge in old_edges: + print("adding edge", (new_if_name, edge[1])) + edges_to_add.append((new_if_name, edge[1])) + edges_to_remove.extend(old_edges) + + nir_model.nodes.update(nodes_to_add) + for edge in edges_to_remove: + print("really removing edge", edge) + nir_model.edges.remove(edge) + for edge in edges_to_add: + print("really adding edge", edge) + nir_model.edges.append(edge) + return nir_model + + +def replace_avgpool2d_by_avgpool2d_if(nir_model): + nodes = nir_model.nodes + edges = nir_model.edges + edges_to_remove = [] + edges_to_add = [] + nodes_to_add = {} + for name, node in nodes.items(): + if isinstance(node, nir.AvgPool2d): + outgoing_nodes = get_outgoing_nodes(name, nir_model) + edges = get_outgoing_edges(name, nir_model) + old_edges = [] + for edge_idx, edge in enumerate(edges): + if not isinstance(edge, (nir.LIF, nir.IF)): + old_edges.append(edge) + print("removing edge ", edge) + + if len(old_edges) > 0: + shape = node.output_type["output"] + new_if_node = nir.IF(r=np.ones(shape), v_threshold=4 * np.ones(shape)) + new_if_name = f"{name}_if" + nodes_to_add[new_if_name] = new_if_node + edges_to_add.append((name, new_if_name)) + for edge in old_edges: + print("adding edge", (new_if_name, edge[1])) + edges_to_add.append((new_if_name, edge[1])) + edges_to_remove.extend(old_edges) + + nir_model.nodes.update(nodes_to_add) + for edge in edges_to_remove: + print("really removing edge", edge) + nir_model.edges.remove(edge) + for edge in edges_to_add: + print("really adding edge", edge) + nir_model.edges.append(edge) + return nir_model + + +def get_outgoing_edges(node_name, nir_model): + outgoing_edges = [] + for edge in nir_model.edges: + if edge[0] == node_name: + outgoing_edges.append(edge) + return outgoing_edges + + +def get_incoming_nodes(node_name, nir_model): + incoming_nodes = [] + for edge in nir_model.edges: + # print("edge:", edge, " current node:", node_name) + if edge[1] == node_name: + incoming_nodes.append((edge[0], nir_model.nodes[edge[0]])) + return incoming_nodes + + +def get_outgoing_nodes(node_name, nir_model): + """get all outgoing connected nodes of a node.""" + outgoing_nodes = [] + for edge in nir_model.edges: + if edge[0] == node_name: + outgoing_nodes.append((edge[1], nir_model.nodes[edge[1]])) + return outgoing_nodes + + +def get_connected_nodes(node_name, nir_model): + connected_nodes = [] + for edge in nir_model.edges: + if edge[0] == node_name: + connected_nodes.append((edge[1], nir_model.nodes[edge[1]])) + if edge[1] == node_name: + connected_nodes.append((edge[0], nir_model.nodes[edge[0]])) + return connected_nodes + + +def fetch_population_by_name(name: str, populations: list): + for pop in populations: + if pop.name == name: + return pop + raise (NotFoundError(f"Population {name} could not be found!")) + +#for this error: ValueError: non-broadcastable output operand with shape () doesn't match the broadcast shape (11,) +#the shape of the bias array of the LIF node is not compatible with the shape of the bias array of the incomming node -> sol1: tInitialize the bias array with the shape of the incoming bias. +def get_accumulated_bias(node_name: str, nir_model): + """get the accumulated bias for all units in a node. + + This function is applied to neuron nodes such as `nir.LIF` or `nir.CubaLIF` + and looks for incoming nodes with bias (currently: `nir.Affine`, + `nir.Conv1d` and `nir.Conv2d`). From theses incoming nodes the biases are + accumulated and returned. + + Args: + node_name: name of node + nir_model: NIR graph + + Returns: + np.ndarray: Array with accumulated bias with the same shape as the node + """ + node = nir_model.nodes[node_name] + size_bias = node.input_type["input"] + bias = np.zeros(size_bias) + for _, in_node in get_incoming_nodes(node_name, nir_model): + if isinstance(in_node, (nir.Affine, nir.Conv1d, nir.Conv2d)): + # Initialize bias array with the shape of the incoming bias + bias = np.zeros_like(in_node.bias) + bias += in_node.bias + return bias.flatten() + + + +#def get_accumulated_bias(node_name: str, nir_model): + # """get the accumulated bias for all units in a node. + + # This function is applied to neuron nodes such as `nir.LIF` or `nir.CubaLIF` + # and looks for incoming nodes with bias (currently: `nir.Affine`, + # `nir.Conv1d` and `nir.Conv2d`). From theses incoming nodes the biases are + # accumulated and returned. + + # Args: + # node_name: name of node + # nir_model: NIR graph + + # Returns: + # np.ndarray: Array with accumulated bias with the same shape as the node + # """ + # node = nir_model.nodes[node_name] + # print("this is the node:", node) + # size_bias = node.input_type["input"] + # print("size of the bias for LIF node:", size_bias) + # bias = np.zeros(node.input_type["input"]) + # print("this is it's bias:", bias) + # for _, in_node in get_incoming_nodes(node_name, nir_model): + # if isinstance(in_node, (nir.Affine, nir.Conv1d, nir.Conv2d)): + # print("this is the incoming node:", in_node) + # print("this is the bias of the incoming node:", in_node.bias ) + # bias += in_node.bias + # return bias.flatten() + + +def get_max_abs_weight(node_name: str, nir_model): + """get the maximum absolute weight from all synapses + + This function is applied to neuron nodes such as `nir.LIF` or `nir.CubaLIF` + and looks for incoming nodes with weights (currently: `nir.Affine`, + `nir.Linear`, `nir.Conv1d` and `nir.Conv2d`). From theses incoming nodes + the maximum weight is determined and returned. + + Args: + node_name: name of node + nir_model: NIR graph + + Returns: + float: maximum absolute weight + """ + max_weight = 0.0 + for _, in_node in get_incoming_nodes(node_name, nir_model): + if isinstance(in_node, WEIGHT_NODES): + max_weight_node = np.abs(in_node.weight).max() + max_weight = np.max((max_weight, max_weight_node)) + return max_weight + + +def convert_LIF(node: nir.NIRNode, bias: np.ndarray, config: ConversionConfig, w_scale: float): # noqa: N802 + """convert the LIF parameters to SpiNNaker2. + + Args: + node: NIR node + bias: bias array with same shape as node + config: NIR-to-SpiNNaker2-conversion configuration + w_scale: factor by which weights are scaled. Will be applied to + parameters `threshold` and`i_offset`. + + Returns: + tuple: (neuron_params, v_scale) + + `neuron_params` contains the parameters for the spinnaker2 + `lif_no_delay` population, while `v_scale` is the factor by which the + threshold was scaled during translation from NIR to SpiNNaker2. + """ + assert isinstance(node, nir.LIF) + dt = config.dt + tau = node.tau.flatten() + + if config.integrator == IntegratorMethod.FORWARD: + r_factor = (dt / tau) * node.r.flatten() + v_leak_factor = dt / tau + alpha_decay = 1 - dt / node.tau + elif config.integrator == IntegratorMethod.EXPONENTIAL: + r_factor = (1 - np.exp(-dt / tau)) * node.r.flatten() + v_leak_factor = 1 - np.exp(-dt / tau) + alpha_decay = np.exp(-dt / tau) + else: + raise Exception("Unsupported IntegratorMethod") + + v_scale = 1.0 / r_factor # scaling factor from tau_mem+r circuit + scale = v_scale * w_scale # overall scaling factor applied to membrane voltage + + neuron_params = { + "threshold": node.v_threshold.flatten() * scale, + "alpha_decay": alpha_decay, + "i_offset": v_leak_factor * node.v_leak.flatten() * scale + bias.flatten() * w_scale, + "reset": get_s2_reset_method(config.reset), + } + return neuron_params, scale + + +def convert_CubaLIF(node: nir.CubaLIF, bias: np.ndarray, config: ConversionConfig, w_scale: float): # noqa N802 + """convert the CubaLIF parameters to SpiNNaker2. + + Args: + node: NIR node + bias: bias array with same shape as node + config: NIR-to-SpiNNaker2-conversion configuration + w_scale: factor by which weights are scaled. Will be applied to + parameters `threshold` and`i_offset`. + + Returns: + tuple: (neuron_params, v_scale) + + `neuron_params` contains the parameters for the spinnaker2 + `lif_curr_exp_no_delay` population, while `v_scale` is the factor by + which the threshold was scaled during translation from NIR to + SpiNNaker2. + """ + assert isinstance(node, nir.CubaLIF) + dt = config.dt + + if config.integrator == IntegratorMethod.FORWARD: + r_factor = (dt / node.tau_mem) * node.r + w_in_factor = (dt / node.tau_syn) * node.w_in + v_leak_factor = dt / node.tau_mem + syn_decay = 1 - dt / node.tau_syn + alpha_decay = 1 - dt / node.tau_mem + elif config.integrator == IntegratorMethod.EXPONENTIAL: + r_factor = (1 - np.exp(-dt / node.tau_mem)) * node.r + w_in_factor = (1 - np.exp(-dt / node.tau_syn)) * node.w_in + v_leak_factor = 1 - np.exp(-dt / node.tau_mem) + syn_decay = np.exp(-dt / node.tau_syn) + alpha_decay = np.exp(-dt / node.tau_mem) + else: + raise Exception("Unsupported IntegratorMethod") + + v_scale = 1.0 / r_factor # scaling factor from tau_mem+r circuit + I_scale = 1.0 / w_in_factor # scaling factor from the input current circuit + scale = v_scale * I_scale * w_scale # overall scaling factor applied to membrane voltage + + neuron_params = { + "threshold": node.v_threshold * scale, + "alpha_decay": alpha_decay, + "exc_decay": syn_decay, + "inh_decay": syn_decay, + "i_offset": v_leak_factor * node.v_leak * scale + bias * node.w_in * I_scale * w_scale, + "reset": get_s2_reset_method(config.reset), + "t_refrac": 0, + "v_reset": 0.0, # will be ignored for `reset_by_subtraction` + } + return neuron_params, scale + + +def convert_IF(node: nir.NIRNode, bias: np.ndarray, config: ConversionConfig, w_scale: float): # noqa: N802 + """convert the IF parameters to SpiNNaker2. + + Args: + node: NIR node + bias: bias array with same shape as node + config: NIR-to-SpiNNaker2-conversion configuration + w_scale: factor by which weights are scaled. Will be applied to + parameters `threshold` and`i_offset`. + + Returns: + tuple: (neuron_params, v_scale) + + `neuron_params` contains the parameters for the spinnaker2 + `lif_no_delay` population, while `v_scale` is the factor by which the + threshold was scaled during translation from NIR to SpiNNaker2. + """ + assert isinstance(node, nir.IF) + bias_factor = node.r.flatten() + v_scale = 1.0 / bias_factor * w_scale + + neuron_params = { + "threshold": node.v_threshold.flatten() * v_scale, + "alpha_decay": 1.0, + "i_offset": bias * w_scale, + "reset": get_s2_reset_method(config.reset), + } + return neuron_params, v_scale + + +def create_populations(nir_model, config: ConversionConfig): + populations = [] + input_populations = [] + output_populations = [] + for name, node in nir_model.nodes.items(): + print(f"Node '{name}'") + print(f"Got {type(node)}") + if isinstance(node, (nir.LIF, nir.IF, nir.CubaLIF)): + bias = get_accumulated_bias(name, nir_model) + print("this is the accumulated bias:", bias) + w_scale = 1.0 + if config.scale_weights: + max_abs_weight = get_max_abs_weight(name, nir_model) + w_scale = 127.0 / max_abs_weight # TODO: 127 should not be hard-coded + + if any([isinstance(pre_node, nir.Conv2d) for _, pre_node in get_incoming_nodes(name, nir_model)]): + is_conv2d = True + else: + is_conv2d = False + + if isinstance(node, nir.LIF): + neuron_model_name = "lif_conv2d" if is_conv2d else "lif_no_delay" + neuron_params, v_scale = convert_LIF(node, bias, config, w_scale) + elif isinstance(node, nir.IF): + neuron_model_name = "lif_conv2d" if is_conv2d else "lif_no_delay" + neuron_params, v_scale = convert_IF(node, bias, config, w_scale) + else: # CubaLIF + print("Got CubaLIF!") + assert is_conv2d == False, "Conv2d with CubaLIF is currently not supported!" + neuron_model_name = "lif_curr_exp_no_delay" + neuron_params, v_scale = convert_CubaLIF(node, bias, config, w_scale) + + if any([isinstance(out_node, nir.Output) for _, out_node in get_outgoing_nodes(name, nir_model)]): + record = config.output_record + print(record) + has_output = True + else: + record = [] + has_output = False + + print("record: ", record) + input_shape = node.input_type.get("input", []) # Get input shape or empty list + if not input_shape: + # Handle case where input shape is empty + input_shape = [250] # Set input shape to 250 + + print("for node:", node, "this is the input_type:", input_shape) + print("this is the input shape:", input_shape) + print("and this is the size:", np.prod(input_shape)) + print(input_shape, "->", node.output_type.get("output")) + + # Create the population with the adjusted input shape + pop = snn.Population( + size=np.prod(input_shape), # Calculate size based on input shape + neuron_model=neuron_model_name, + params=neuron_params, + name=name, + record=record, + ) + print("this is the created population:", pop) + pop.nir_v_scale = v_scale # save scaling factors to rescale recorded voltages + pop.nir_w_scale = w_scale # weight scale needed later for creation of projections + populations.append(pop) + if has_output: + output_populations.append(pop) + + elif isinstance(node, nir.Input): + # infer shape + input_shape = node.input_type["input"] + print("this is the input shape of your model:", input_shape) + #assert input_shape.size == 1 or input_shape.size == 3, "only 1d or 3d input allowed" + pop = snn.Population(size=np.prod(input_shape), neuron_model="spike_list", params={}, name=name) + populations.append(pop) + input_populations.append(pop) + elif isinstance(node, nir.Output): + pass + elif isinstance(node, WEIGHT_NODES): + # TODO: check for any other combinations? + for _, target_node in get_connected_nodes(name, nir_model): + if isinstance(target_node, WEIGHT_NODES): + raise HardwareConstraintError( + f"Two successive layers with weights are not " + f"supported! Got: {type(node)} and {type(target_node)}" + ) + elif isinstance(node, nir.SumPool2d): + pass + elif isinstance(node, nir.AvgPool2d): + pass + elif isinstance(node, nir.Flatten): + pass + else: + print("Got", type(node)) + raise NotImplementedError(type(node)) + print("input: ", [(e, type(n)) for e, n in get_incoming_nodes(name, nir_model)]) + print("output: ", [(e, type(n)) for e, n in get_outgoing_nodes(name, nir_model)]) + print("") + return populations, input_populations, output_populations + + +def create_projection(origin_name, target_name, affine_node, populations, delay): + assert isinstance(affine_node, (nir.Affine, nir.Linear)) + pre = fetch_population_by_name(origin_name, populations) + post = fetch_population_by_name(target_name, populations) + conns = ann2snn_helpers.connection_list_from_dense_weights(affine_node.weight.T * post.nir_w_scale, delay) + proj = snn.Projection(pre=pre, post=post, connections=conns) + return proj + + +def get_conv2d_params(conv2d_node, post_node): + input_shape = conv2d_node.input_type["input"] + conv_weights = conv2d_node.weight.swapaxes(0, 1) * post_node.nir_w_scale + conv_weights = conv_weights.astype(np.int8) + conv_params = { + "in_height": input_shape[1], + "in_width": input_shape[2], + "stride_x": nir.ir._index_tuple(conv2d_node.stride, 0), + "stride_y": nir.ir._index_tuple(conv2d_node.stride, 1), + "pool_x": 1, + "pool_y": 1, + "pad_top": nir.ir._index_tuple(conv2d_node.padding, 0), + "pad_bottom": nir.ir._index_tuple(conv2d_node.padding, 0), + "pad_left": nir.ir._index_tuple(conv2d_node.padding, 1), + "pad_right": nir.ir._index_tuple(conv2d_node.padding, 1), + } + return conv_params, conv_weights + + +def create_conv2d_projection(origin_name, target_name, conv2d_node, populations, delay): + assert delay == 0, f"Conv2DProjection do not support a delay differnt from 1" + pre = fetch_population_by_name(origin_name, populations) + post = fetch_population_by_name(target_name, populations) + conv_params, conv_weights = get_conv2d_params(conv2d_node, post) + print(conv_params, "weights: ", conv_weights.shape) + + proj = snn.Conv2DProjection(pre, post, conv_weights, conv_params) + + return proj + + +def create_sumpool2d_projection(origin_name, target_name, sumpool2d_node, populations, delay): + # this might be a bit hacky, but should work... + pre = fetch_population_by_name(origin_name, populations) + post = fetch_population_by_name(target_name, populations) + # C_out, C_in, H, W + sumpool2d_conns, sumpool2d_output_shape = ann2snn_helpers.connection_list_for_sumpool2d( + input_shape=sumpool2d_node.input_type["input"], + stride=sumpool2d_node.stride, + kernel_size=sumpool2d_node.kernel_size, + padding=sumpool2d_node.padding, + delay=delay, + data_order="torch", + ) + proj = snn.Projection(pre=pre, post=post, connections=sumpool2d_conns) + return proj + + +def create_avgpool2d_projection(origin_name, target_name, avgpool2d_node, populations, delay): + # this might be a bit hacky, but should work... + pre = fetch_population_by_name(origin_name, populations) + post = fetch_population_by_name(target_name, populations) + # C_out, C_in, H, W + avgpool2d_conns, avgpool2d_output_shape = ann2snn_helpers.connection_list_for_avgpool2d( + input_shape=avgpool2d_node.input_type["input"], + stride=avgpool2d_node.stride, + kernel_size=avgpool2d_node.kernel_size, + padding=avgpool2d_node.padding, + delay=delay, + data_order="torch", + ) + + proj = snn.Projection(pre=pre, post=post, connections=avgpool2d_conns) + return proj + + +def create_sumpool2d_conv2d_projection(origin_name, target_name, sumpool2d_node, conv2d_node, populations, delay): + assert delay == 0, f"Conv2DProjection do not support a delay different from 1" + # this might be a bit hacky, but should work... + pre = fetch_population_by_name(origin_name, populations) + post = fetch_population_by_name(target_name, populations) + + conv_params, conv_weights = get_conv2d_params(conv2d_node, post) + print(conv_params, "weights: ", conv_weights.shape) + print(sumpool2d_node) + assert all(sumpool2d_node.kernel_size == sumpool2d_node.stride) + assert nir.ir._index_tuple(sumpool2d_node.padding, 0) == 0 + assert nir.ir._index_tuple(sumpool2d_node.padding, 1) == 0 + conv_params["pool_x"] = nir.ir._index_tuple(sumpool2d_node.kernel_size, 0) + conv_params["pool_y"] = nir.ir._index_tuple(sumpool2d_node.kernel_size, 1) + # fix input sizes (we need to take the ones from the sumpool2d_node) + input_shape = sumpool2d_node.input_type["input"] + conv_params["in_height"] = input_shape[1] + conv_params["in_width"] = input_shape[2] + + proj = snn.Conv2DProjection(pre, post, conv_weights, conv_params) + + return proj + + +def create_avgpool2d_conv2d_projection( + origin_name, target_name, avgpool2d_node, conv2d_node, populations, delay, config +): + assert delay == 0, f"Conv2DProjection do not support a delay different from 1" + # this might be a bit hacky, but should work... + pre = fetch_population_by_name(origin_name, populations) + post = fetch_population_by_name(target_name, populations) + + conv_params, conv_weights = get_conv2d_params(conv2d_node, post) + print(conv_params, "weights: ", conv_weights.shape) + print(avgpool2d_node) + # Ensure average pooling kernel size and stride are the same + assert all(avgpool2d_node.kernel_size == avgpool2d_node.stride) + # Calculate the scaling factor for weights + kernel_size = avgpool2d_node.kernel_size[0] * avgpool2d_node.kernel_size[1] + scale_factor = 1 / kernel_size + # Scale the convolution weights by the factor of 4 (average pooling) and the scale factor + conv_weights *= scale_factor / 4 + # Check if weight scaling is enabled in the config + if config.scale_weights: + max_abs_weight = np.max(np.abs(conv_weights)) + w_scale = 127.0 / max_abs_weight + conv_weights *= w_scale + # Set convolution parameters + conv_params["pool_x"] = nir.ir._index_tuple(avgpool2d_node.kernel_size, 0) + conv_params["pool_y"] = nir.ir._index_tuple(avgpool2d_node.kernel_size, 1) + # Set input sizes from the avgpool2d_node + input_shape = avgpool2d_node.input_type["input"] + conv_params["in_height"] = input_shape[1] + conv_params["in_width"] = input_shape[2] + + proj = snn.Conv2DProjection(pre, post, conv_weights, conv_params) + + return proj + + +def create_sumpool2d_affine_projection(origin_name, target_name, sumpool2d_node, affine_node, populations, delay): + # this might be a bit hacky, but should work... + pre = fetch_population_by_name(origin_name, populations) + post = fetch_population_by_name(target_name, populations) + # C_out, C_in, H, W + sumpool2d_conns, sumpool2d_output_shape = ann2snn_helpers.connection_list_for_sumpool2d( + input_shape=sumpool2d_node.input_type["input"], + stride=sumpool2d_node.stride, + kernel_size=sumpool2d_node.kernel_size, + padding=sumpool2d_node.padding, + delay=delay, + data_order="torch", + ) + affine_conns = ann2snn_helpers.connection_list_from_dense_weights( + affine_node.weight.T * post.nir_w_scale, delay=delay + ) + conns = ann2snn_helpers.join_conn_lists(sumpool2d_conns, affine_conns) + + proj = snn.Projection(pre=pre, post=post, connections=conns) + return proj + + +def create_avgpool2d_affine_projection(origin_name, target_name, avgpool2d_node, affine_node, populations, delay): + # this might be a bit hacky, but should work... + pre = fetch_population_by_name(origin_name, populations) + post = fetch_population_by_name(target_name, populations) + # C_out, C_in, H, W + avgpool2d_conns, avgpool2d_output_shape = ann2snn_helpers.connection_list_for_avgpool2d( + input_shape=avgpool2d_node.input_type["input"], + stride=avgpool2d_node.stride, + kernel_size=avgpool2d_node.kernel_size, + padding=avgpool2d_node.padding, + delay=delay, + data_order="torch", + ) + affine_conns = ann2snn_helpers.connection_list_from_dense_weights( + affine_node.weight.T * post.nir_w_scale, delay=delay + ) + conns = ann2snn_helpers.join_conn_lists(avgpool2d_conns, affine_conns) + + proj = snn.Projection(pre=pre, post=post, connections=conns) + return proj + + +def create_projections(nir_model, populations, delay=1): + logger.debug("creating projections(): start") + projections = [] + edges = nir_model.edges + + for edge in edges: + logger.debug(f"checking {edge[0]}->{edge[1]}") + origin = nir_model.nodes[edge[0]] + if isinstance(origin, SPIKE_SOURCE_NODES): + target_name = edge[1] + target = nir_model.nodes[target_name] + if isinstance(target, (nir.Affine, nir.Linear)): + logger.debug(f" found {type(target)}, next search for neuron node") + final_targets = get_outgoing_nodes(target_name, nir_model) + for final_target_name, final_target in final_targets: + logger.debug(f" found final target {final_target_name}") + assert isinstance(final_target, NEURON_NODES) + logger.debug(f" create projection between {edge[0]} and {final_target_name}") + proj = create_projection(edge[0], final_target_name, target, populations, delay) + projections.append(proj) + elif isinstance(target, nir.SumPool2d): + logger.debug(" found SumPool2d, next search for path to neurons") + sumpooltargets = get_outgoing_nodes(target_name, nir_model) + for sumpooltarget_name, sumpooltarget in sumpooltargets: + if isinstance(sumpooltarget, nir.Conv2d): + logger.debug(f" found Conv2d: {sumpooltarget_name}") + final_targets = get_outgoing_nodes(sumpooltarget_name, nir_model) + for final_target_name, final_target in final_targets: + logger.debug(f" found final target {final_target_name}") + assert isinstance(final_target, NEURON_NODES) + logger.debug( + f" create sumpool2d_convd_projection between {edge[0]} and {final_target_name}" + ) + proj = create_sumpool2d_conv2d_projection( + edge[0], final_target_name, target, sumpooltarget, populations, delay + ) + projections.append(proj) + elif isinstance(sumpooltarget, NEURON_NODES): + logger.debug(f" found Neuron: {sumpooltarget_name}") + final_targets = get_outgoing_nodes(sumpooltarget_name, nir_model) + nodetype = "IF" if isinstance(sumpooltarget, nir.IF) else "LIF" + logger.debug(f" create sumpool2d_projection between {edge[0]} and {sumpooltarget_name}") + proj = create_sumpool2d_projection(edge[0], sumpooltarget_name, target, populations, delay) + projections.append(proj) + elif isinstance(sumpooltarget, nir.Flatten): + logger.debug(f" found Flatten: {sumpooltarget_name}") + flatten_targets = get_outgoing_nodes(sumpooltarget_name, nir_model) + for flatten_target_name, flatten_target in flatten_targets: + if isinstance(flatten_target, nir.Affine): + final_targets = get_outgoing_nodes(flatten_target_name, nir_model) + for final_target_name, final_target in final_targets: + logger.debug(f" found final target {final_target_name}") + assert isinstance(final_target, NEURON_NODES) + logger.debug( + f" create sumpool2d_affine_projection between {edge[0]} and " + f"{final_target_name}" + ) + proj = create_sumpool2d_affine_projection( + edge[0], + final_target_name, + target, + flatten_target, + populations, + delay, + ) + projections.append(proj) + else: + raise ( + NotImplementedError( + "Currently after SumPool2d->Flatten, Affine has to follow! Other combinations " + "not supported yet!" + ) + ) + else: + raise ( + NotImplementedError( + "Currently SumPool2d can only be connected to Conv2d or Flatten, others are not " + "supported yet!" + ) + ) + + elif isinstance(target, nir.AvgPool2d): + logger.debug(" found AvgPool2d, next search for path to neurons") + avgpooltargets = get_outgoing_nodes(target_name, nir_model) + for avgpooltarget_name, avgpooltarget in avgpooltargets: + if isinstance(avgpooltarget, nir.Conv2d): + logger.debug(f" found Conv2d: {avgpooltarget_name}") + final_targets = get_outgoing_nodes(avgpooltarget_name, nir_model) + for final_target_name, final_target in final_targets: + logger.debug(f" found final target {final_target_name}") + assert isinstance(final_target, NEURON_NODES) + logger.debug( + f" create avgpool2d_convd_projection between {edge[0]} and {final_target_name}" + ) + proj = create_avgpool2d_conv2d_projection( + edge[0], final_target_name, target, avgpooltarget, populations, delay + ) + projections.append(proj) + elif isinstance(avgpooltarget, NEURON_NODES): + logger.debug(f" found Neuron: {avgpooltarget_name}") + final_targets = get_outgoing_nodes(avgpooltarget_name, nir_model) + nodetype = "IF" if isinstance(avgpooltarget, nir.IF) else "LIF" + logger.debug(f" create avgpool2d_projection between {edge[0]} and {avgpooltarget_name}") + proj = create_avgpool2d_projection(edge[0], avgpooltarget_name, target, populations, delay) + projections.append(proj) + elif isinstance(avgpooltarget, nir.Flatten): + logger.debug(f" found Flatten: {avgpooltarget_name}") + flatten_targets = get_outgoing_nodes(avgpooltarget_name, nir_model) + for flatten_target_name, flatten_target in flatten_targets: + if isinstance(flatten_target, nir.Affine): + final_targets = get_outgoing_nodes(flatten_target_name, nir_model) + for final_target_name, final_target in final_targets: + logger.debug(f" found final target {final_target_name}") + assert isinstance(final_target, NEURON_NODES) + logger.debug( + f" create avgpool2d_affine_projection between {edge[0]} and " + f"{final_target_name}" + ) + proj = create_avgpool2d_affine_projection( + edge[0], + final_target_name, + target, + flatten_target, + populations, + delay, + ) + projections.append(proj) + else: + raise ( + NotImplementedError( + "Currently after AvgPool2d->Flatten, Affine has to follow! Other combinations " + "not supported yet!" + ) + ) + else: + raise ( + NotImplementedError( + "Currently AvgPool2d can only be connected to Conv2d or Flatten, others are not " + "supported yet!" + ) + ) + + elif isinstance(target, nir.Conv2d): + logger.debug(" found Conv2d, next search for path to neurons") + final_targets = get_outgoing_nodes(target_name, nir_model) + for final_target_name, final_target in final_targets: + logger.debug(f" found final target {final_target_name}") + print("NEURON_NODES are: (nir.LIF, nir.IF, nir.CubaLIF)") + print("this is the final_target:", final_target) + assert isinstance(final_target, NEURON_NODES) + logger.debug(f" create conv2d_projection between {edge[0]} and {final_target_name}") + proj = create_conv2d_projection(edge[0], final_target_name, target, populations, delay) + projections.append(proj) + elif isinstance(target, NEURON_NODES): + raise ( + NotImplementedError( + f"Direct connections from spike source nodes ({SPIKE_SOURCE_NODES}) to neuron nodes " + f"({NEURON_NODES}) not supported yet!" + ) + ) + else: + logger.debug(" discard edge as target does not represent a real connection.") + else: + logger.debug(" discard edge as source is not spike source node") + return projections + + +def from_nir(nir_model: nir.NIRGraph, config: ConversionConfig = None): + """create SpiNNaker2 network from NIR graph. + + Args: + nir_model: NIR graph + config: NIR-to-SpiNNaker2-conversion configuration + Returns: + tuple of length 3: (net, input_pops, output_pops) + + Details: + net(snn.Network): SpiNNaker2 Network + input_pops(list[snn.Population]): list of input populations + output_pops(list[snn.Population]): list of output populations + """ + assert isinstance(nir_model, nir.NIRGraph) + if config == None: + config = ConversionConfig() + logger.info("from_nir(): create spinnaker2.Network from NIR graph") + populations, input_populations, output_populations = create_populations(nir_model, config) + print("these are the populations:", populations) + # Print the names of the populations + print("Population Names:", [pop.name for pop in populations]) + logger.info(f"Created {len(populations)} populations: {[(_.name) for _ in populations]}") + projections = create_projections(nir_model, populations, delay=config.conn_delay) + logger.info(f"Created {len(projections)} projections: {[(_.name) for _ in projections]}") + + net = snn.Network() + net.add(*populations, *projections) + net.validate() + return net, input_populations, output_populations \ No newline at end of file diff --git a/examples/testconv2d+avgpool.nir b/examples/testconv2d+avgpool.nir new file mode 100644 index 00000000..aa7e25b5 Binary files /dev/null and b/examples/testconv2d+avgpool.nir differ diff --git a/examples/tutorial_sae.ipynb b/examples/tutorial_sae.ipynb old mode 100755 new mode 100644 diff --git a/examples/tutorial_snntorch_to_nir.ipynb b/examples/tutorial_snntorch_to_nir.ipynb new file mode 100644 index 00000000..bbcc383c --- /dev/null +++ b/examples/tutorial_snntorch_to_nir.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import snntorch as snn\n", + "import torch\n", + "import nir\n", + "from snntorch import export_to_nir, import_from_nir" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generating an example NIR Graph from snntorch model:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# 2D sample data \n", + "sample_data = torch.randn(1, 1, 28, 28) #input image size is 28x28 and 1 channel #the dimension here is batch_size*number_of_channels*height*width\n", + "\n", + "class NetWithAvgPool(torch.nn.Module):\n", + " def __init__(self):\n", + " super(NetWithAvgPool, self).__init__()\n", + " self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)\n", + " self.lif1 = snn.Leaky(beta=0.9, init_hidden=True)\n", + " self.fc1 = torch.nn.Linear(28*28*16 // 4, 500) # Adjusting input size after pooling layer\n", + " self.lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True)\n", + "\n", + " def forward(self, x):\n", + " x = torch.nn.functional.avg_pool2d(self.conv1(x), kernel_size=2, stride=2) # kernel_size=2 and stride=2 for avg_pool2d\n", + " x = x.view(-1, 28*28*16 // 4) # Adjusting the view based on the output shape after pooling\n", + " x = self.lif1(x)\n", + " x = self.fc1(x)\n", + " x = self.lif2(x)\n", + " return x\n", + "\n", + "net_with_avgpool = NetWithAvgPool()\n", + "nir_graphtest = export_to_nir(net_with_avgpool, sample_data, ignore_dims=[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nir_graphtest.infer_types()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nir.write(f\"./snnTorch/examples/testconv2d+avgpool.nir\", nir_graphtest)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "nir_graph = nir.read(\"testconv2d+avgpool.nir\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#dict_keys(['conv1', 'fc1', 'input', 'lif1', 'lif2', 'output'])\n", + "nir_graph.nodes['lif1']" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LIF(tau=0.0009999997615814777, r=9.999997615814776, v_leak=0.0, v_threshold=1.0, input_type={'input': array([], dtype=float64)}, output_type={'output': array([], dtype=float64)}, metadata={})" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nir_graph.nodes['lif1']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nir_graph.nodes.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nir_graph.edges" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pytest.txt b/pytest.txt new file mode 100644 index 00000000..0066f647 --- /dev/null +++ b/pytest.txt @@ -0,0 +1,30 @@ +============================= test session starts ============================== +platform linux -- Python 3.11.5, pytest-8.1.1, pluggy-1.4.0 +rootdir: /home/sirine/PHD-work/snntorch-nir/snntorch +configfile: setup.cfg +testpaths: tests +collected 119 items + +tests/test_nir.py ...... [ 5%] +tests/test_snntorch/functional/test_loss.py ..................... [ 22%] +tests/test_snntorch/test_alpha.py ....... [ 28%] +tests/test_snntorch/test_bntt.py ......... [ 36%] +tests/test_snntorch/test_graded_spikes.py ...... [ 41%] +tests/test_snntorch/test_lapicque.py ....... [ 47%] +tests/test_snntorch/test_leaky.py ........ [ 53%] +tests/test_snntorch/test_rleaky.py ....... [ 59%] +tests/test_snntorch/test_rsynaptic.py ....... [ 65%] +tests/test_snntorch/test_sconv2dlstm.py ...... [ 70%] +tests/test_snntorch/test_slstm.py ...... [ 75%] +tests/test_snntorch/test_synaptic.py ....... [ 81%] +tests/test_snntorch.py .. [ 83%] +tests/test_spikegen.py .................... [100%] + +=============================== warnings summary =============================== +tests/test_nir.py::TestNIR::test_import_conv_nir +tests/test_nir.py::TestNIR::test_import_conv_nir + /home/sirine/miniconda3/lib/python3.11/site-packages/nir/ir.py:73: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.) + shapes.append(int(shape)) + +-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html +======================= 119 passed, 2 warnings in 1.31s ======================== diff --git a/setup.py b/setup.py index 41acb7bb..49f1479f 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setup( author="Jason K. Eshraghian", author_email="jeshragh@ucsc.edu", - python_requires=">=3.8", + python_requires=">=3.9", classifiers=[ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", @@ -42,7 +42,6 @@ "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/snntorch/__init__.py b/snntorch/__init__.py index 0482a113..16ec19f5 100644 --- a/snntorch/__init__.py +++ b/snntorch/__init__.py @@ -2,4 +2,4 @@ from ._neurons import * from ._layers import * from .export_nir import export_to_nir -from .import_nir import import_from_nir \ No newline at end of file +from .import_nir import import_from_nir diff --git a/snntorch/_neurons/__init__.py b/snntorch/_neurons/__init__.py index c2ec2ed4..7e411cfe 100644 --- a/snntorch/_neurons/__init__.py +++ b/snntorch/_neurons/__init__.py @@ -33,4 +33,4 @@ from .sconv2dlstm import SConv2dLSTM from .slstm import SLSTM -from .leakyparallel import LeakyParallel \ No newline at end of file +from .leakyparallel import LeakyParallel diff --git a/snntorch/_neurons/leakykernel.py b/snntorch/_neurons/leakykernel.py index 59ef4df8..36325da8 100644 --- a/snntorch/_neurons/leakykernel.py +++ b/snntorch/_neurons/leakykernel.py @@ -1,10 +1,11 @@ import torch import torch.nn as nn + class LeakyKernel(nn.Module): """ A parallel implementation of the Leaky neuron with a fused input linear layer. - All time steps are passed to the input at once. + All time steps are passed to the input at once. This implementation uses `torch.nn.RNN` to accelerate the implementation. First-order leaky integrate-and-fire neuron model. @@ -23,7 +24,7 @@ class LeakyKernel(nn.Module): * :math:`β` - Membrane potential decay rate Several differences between `LeakyParallel` and `Leaky` include: - + * Negative hidden states are clipped due to the forced ReLU operation in RNN * Linear weights are included in addition to recurrent weights * `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise @@ -57,9 +58,9 @@ def __init__(self): def forward(self, x): spk1 = self.lif1(x) spk2 = self.lif2(spk1) - return spk2 + return spk2 + - :param input_size: The number of expected features in the input `x` :type input_size: int @@ -100,25 +101,25 @@ def forward(self, x): to False :type learn_threshold: bool, optional - :param weight_hh_enable: Option to set the hidden matrix to be dense or - diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. - Dense (True) would allow the membrane potential of one LIF neuron to + :param weight_hh_enable: Option to set the hidden matrix to be dense or + diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. + Dense (True) would allow the membrane potential of one LIF neuron to influence all others, and follow the RNN default implementation. Defaults to False :type weight_hh_enable: bool, optional Inputs: \\input_ - - **input_** of shape of shape `(L, H_{in})` for unbatched input, - or `(L, N, H_{in})` containing the features of the input sequence. + - **input_** of shape of shape `(L, H_{in})` for unbatched input, + or `(L, N, H_{in})` containing the features of the input sequence. Outputs: spk - **spk** of shape `(L, batch, input_size)`: tensor containing the output spikes. - + where: `L = sequence length` - + `N = batch size` `H_{in} = input_size` @@ -141,8 +142,6 @@ def __init__( input_size, hidden_size, beta=None, - - bias=True, threshold=1.0, dropout=0.0, @@ -161,19 +160,28 @@ def __init__( # take in a data sample of size T x B x Dims # compile a 1-D Conv Kernel equivalent to the size of the full time-sweep - - # - - self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu', - bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) - + + # + + self.rnn = nn.RNN( + input_size, + hidden_size, + num_layers=1, + nonlinearity="relu", + bias=bias, + batch_first=False, + dropout=dropout, + device=device, + dtype=dtype, + ) + self._beta_buffer(beta, learn_beta) self.hidden_size = hidden_size if self.beta is not None: self.beta = self.beta.clamp(0, 1) - if spike_grad is None: + if spike_grad is None: self.spike_grad = self.ATan.apply else: self.spike_grad = spike_grad @@ -185,7 +193,7 @@ def __init__( # Register a gradient hook to clamp out non-diagonal matrices in backward pass if learn_beta: self.rnn.weight_hh_l0.register_hook(self.grad_hook) - + if not learn_beta: # Make the weights non-learnable self.rnn.weight_hh_l0.requires_grad_(False) @@ -202,11 +210,11 @@ def __init__( def forward(self, input_): mem = self.rnn(input_) # mem[0] contains relu'd outputs, mem[1] contains final hidden state - mem_shift = mem[0] - self.threshold # self.rnn.weight_hh_l0 + mem_shift = mem[0] - self.threshold # self.rnn.weight_hh_l0 spk = self.spike_grad(mem_shift) spk = spk * self.graded_spikes_factor return spk - + @staticmethod def _surrogate_bypass(input_): return (input_ > 0).float() @@ -261,11 +269,11 @@ def backward(ctx, grad_output): * grad_input ) return grad, None - + def weight_hh_enable(self): mask = torch.eye(self.hidden_size, self.hidden_size) self.rnn.weight_hh_l0.data = self.rnn.weight_hh_l0.data * mask - + def grad_hook(self, grad): device = grad.device # Create a mask that is 1 on the diagonal and 0 elsewhere @@ -279,7 +287,9 @@ def _beta_to_weight_hh(self): # Set all weights to the scalar value of self.beta if isinstance(self.beta, float) or isinstance(self.beta, int): self.rnn.weight_hh_l0.fill_(self.beta) - elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor): + elif isinstance(self.beta, torch.Tensor) or isinstance( + self.beta, torch.FloatTensor + ): if len(self.beta) == 1: self.rnn.weight_hh_l0.fill_(self.beta[0]) elif len(self.beta) == self.hidden_size: @@ -287,8 +297,10 @@ def _beta_to_weight_hh(self): for i in range(self.hidden_size): self.rnn.weight_hh_l0.data[i].fill_(self.beta[i]) else: - raise ValueError("Beta must be either a single value or of length 'hidden_size'.") - + raise ValueError( + "Beta must be either a single value or of length 'hidden_size'." + ) + def _beta_buffer(self, beta, learn_beta): if not isinstance(beta, torch.Tensor): if beta is not None: @@ -296,7 +308,7 @@ def _beta_buffer(self, beta, learn_beta): self.register_buffer("beta", beta) def _graded_spikes_buffer( - self, graded_spikes_factor, learn_graded_spikes_factor + self, graded_spikes_factor, learn_graded_spikes_factor ): if not isinstance(graded_spikes_factor, torch.Tensor): graded_spikes_factor = torch.as_tensor(graded_spikes_factor) @@ -311,4 +323,4 @@ def _threshold_buffer(self, threshold, learn_threshold): if learn_threshold: self.threshold = nn.Parameter(threshold) else: - self.register_buffer("threshold", threshold) \ No newline at end of file + self.register_buffer("threshold", threshold) diff --git a/snntorch/_neurons/leakyparallel.py b/snntorch/_neurons/leakyparallel.py index 531671ee..a06bfeb8 100644 --- a/snntorch/_neurons/leakyparallel.py +++ b/snntorch/_neurons/leakyparallel.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + class LeakyParallel(nn.Module): """ A parallel implementation of the Leaky neuron with a fused input linear layer. @@ -172,10 +173,19 @@ def __init__( dtype=None, ): super().__init__() - - self.rnn = nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity='relu', - bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) - + + self.rnn = nn.RNN( + input_size, + hidden_size, + num_layers=1, + nonlinearity="relu", + bias=bias, + batch_first=False, + dropout=dropout, + device=device, + dtype=dtype, + ) + self._beta_buffer(beta, learn_beta) self.hidden_size = hidden_size @@ -194,7 +204,7 @@ def __init__( # Register a gradient hook to clamp out non-diagonal matrices in backward pass if learn_beta: self.rnn.weight_hh_l0.register_hook(self.grad_hook) - + if not learn_beta: # Make the weights non-learnable self.rnn.weight_hh_l0.requires_grad_(False) @@ -211,11 +221,11 @@ def __init__( def forward(self, input_): mem = self.rnn(input_) # mem[0] contains relu'd outputs, mem[1] contains final hidden state - mem_shift = mem[0] - self.threshold # self.rnn.weight_hh_l0 + mem_shift = mem[0] - self.threshold # self.rnn.weight_hh_l0 spk = self.spike_grad(mem_shift) spk = spk * self.graded_spikes_factor return spk - + @staticmethod def _surrogate_bypass(input_): return (input_ > 0).float() @@ -270,11 +280,11 @@ def backward(ctx, grad_output): * grad_input ) return grad, None - + def weight_hh_enable(self): mask = torch.eye(self.hidden_size, self.hidden_size) self.rnn.weight_hh_l0.data = self.rnn.weight_hh_l0.data * mask - + def grad_hook(self, grad): device = grad.device # Create a mask that is 1 on the diagonal and 0 elsewhere @@ -288,7 +298,9 @@ def _beta_to_weight_hh(self): # Set all weights to the scalar value of self.beta if isinstance(self.beta, float) or isinstance(self.beta, int): self.rnn.weight_hh_l0.fill_(self.beta) - elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor): + elif isinstance(self.beta, torch.Tensor) or isinstance( + self.beta, torch.FloatTensor + ): if len(self.beta) == 1: self.rnn.weight_hh_l0.fill_(self.beta[0]) elif len(self.beta) == self.hidden_size: @@ -296,8 +308,10 @@ def _beta_to_weight_hh(self): for i in range(self.hidden_size): self.rnn.weight_hh_l0.data[i].fill_(self.beta[i]) else: - raise ValueError("Beta must be either a single value or of length 'hidden_size'.") - + raise ValueError( + "Beta must be either a single value or of length 'hidden_size'." + ) + def _beta_buffer(self, beta, learn_beta): if not isinstance(beta, torch.Tensor): if beta is not None: @@ -305,7 +319,7 @@ def _beta_buffer(self, beta, learn_beta): self.register_buffer("beta", beta) def _graded_spikes_buffer( - self, graded_spikes_factor, learn_graded_spikes_factor + self, graded_spikes_factor, learn_graded_spikes_factor ): if not isinstance(graded_spikes_factor, torch.Tensor): graded_spikes_factor = torch.as_tensor(graded_spikes_factor) diff --git a/snntorch/_neurons/leakyunroll.py b/snntorch/_neurons/leakyunroll.py index 3178d136..fc9beb1d 100644 --- a/snntorch/_neurons/leakyunroll.py +++ b/snntorch/_neurons/leakyunroll.py @@ -1,10 +1,11 @@ import torch import torch.nn as nn + class LeakyParallel(nn.Module): """ A parallel implementation of the Leaky neuron intended to handle arbitrary input dimensions. - All time steps are passed to the input at once. + All time steps are passed to the input at once. This implementation uses `torch.nn.RNN` to accelerate the implementation. First-order leaky integrate-and-fire neuron model. @@ -23,7 +24,7 @@ class LeakyParallel(nn.Module): * :math:`β` - Membrane potential decay rate Several differences between `LeakyParallel` and `Leaky` include: - + * Negative hidden states are clipped due to the forced ReLU operation in RNN * Linear weights are included in addition to recurrent weights * `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise @@ -57,9 +58,9 @@ def __init__(self): def forward(self, x): spk1 = self.lif1(x) spk2 = self.lif2(spk1) - return spk2 + return spk2 + - :param input_size: The number of expected features in the input `x`. The output of a linear layer should be an int, whereas the output of a 2D-convolution should be a tuple with 3 int values :type input_size: int or tuple @@ -100,25 +101,25 @@ def forward(self, x): to False :type learn_threshold: bool, optional - :param weight_hh_enable: Option to set the hidden matrix to be dense or - diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. - Dense (True) would allow the membrane potential of one LIF neuron to + :param weight_hh_enable: Option to set the hidden matrix to be dense or + diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. + Dense (True) would allow the membrane potential of one LIF neuron to influence all others, and follow the RNN default implementation. Defaults to False :type weight_hh_enable: bool, optional Inputs: \\input_ - - **input_** of shape of shape `(L, H_{in})` for unbatched input, - or `(L, N, H_{in})` containing the features of the input sequence. + - **input_** of shape of shape `(L, H_{in})` for unbatched input, + or `(L, N, H_{in})` containing the features of the input sequence. Outputs: spk - **spk** of shape `(L, batch, input_size)`: tensor containing the output spikes. - + where: `L = sequence length` - + `N = batch size` `H_{in} = input_size` @@ -163,9 +164,18 @@ def __init__( unrolled_input_size = self._process_input() # initialize weights: input linear layer is diagonal filled w/ones to prevent linear from doing anything - self.rnn = nn.RNN(unrolled_input_size, unrolled_input_size, num_layers=1, nonlinearity='relu', - bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) - + self.rnn = nn.RNN( + unrolled_input_size, + unrolled_input_size, + num_layers=1, + nonlinearity="relu", + bias=bias, + batch_first=False, + dropout=dropout, + device=device, + dtype=dtype, + ) + self._beta_buffer(beta, learn_beta) self.hidden_size = unrolled_input_size @@ -177,7 +187,7 @@ def __init__( else: self.spike_grad = spike_grad - self.weight_ih_disable() # disable input linear layer + self.weight_ih_disable() # disable input linear layer self._beta_to_weight_hh() if weight_hh_enable is False: # Initial gradient and weights of w_hh are made diagonal @@ -185,7 +195,7 @@ def __init__( # Register a gradient hook to clamp out non-diagonal matrices in backward pass if learn_beta: self.rnn.weight_hh_l0.register_hook(self.grad_hook) - + if not learn_beta: # Make the weights non-learnable self.rnn.weight_hh_l0.requires_grad_(False) @@ -203,12 +213,12 @@ def forward(self, input_): input_ = self.process_tensor(input_) mem = self.rnn(input_) # mem[0] contains relu'd outputs, mem[1] contains final hidden state - mem_shift = mem[0] - self.threshold # self.rnn.weight_hh_l0 + mem_shift = mem[0] - self.threshold # self.rnn.weight_hh_l0 spk = self.spike_grad(mem_shift) spk = spk * self.graded_spikes_factor spk = self.unprocess_tensor(self, spk) return spk - + @staticmethod def _surrogate_bypass(input_): return (input_ > 0).float() @@ -263,7 +273,7 @@ def backward(ctx, grad_output): * grad_input ) return grad, None - + def _process_input(self): # Check if the input is an integer if isinstance(self.input_size, int): @@ -275,18 +285,19 @@ def _process_input(self): for item in self.input_size: # Ensure each item in the tuple is an integer if not isinstance(item, int): - raise ValueError("All elements in the tuple must be integers") + raise ValueError( + "All elements in the tuple must be integers" + ) product *= item return product else: raise TypeError("Input must be an integer or a tuple of integers") - def weight_hh_enable(self): mask = torch.eye(self.hidden_size, self.hidden_size) self.rnn.weight_hh_l0.data = self.rnn.weight_hh_l0.data * mask - + def weight_ih_disable(self): with torch.no_grad(): mask = torch.eye(self.input_size, self.input_size) @@ -300,7 +311,7 @@ def process_tensor(self, input_): return input_.flatten(2) else: raise ValueError("input_size must be either an int or a tuple") - + def unprocess_tensor(self, input_): if isinstance(self.input_size, int): return input_ @@ -309,7 +320,6 @@ def unprocess_tensor(self, input_): else: raise ValueError("input_size must be either an int or a tuple") - def grad_hook(self, grad): device = grad.device # Create a mask that is 1 on the diagonal and 0 elsewhere @@ -323,7 +333,9 @@ def _beta_to_weight_hh(self): # Set all weights to the scalar value of self.beta if isinstance(self.beta, float) or isinstance(self.beta, int): self.rnn.weight_hh_l0.fill_(self.beta) - elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor): + elif isinstance(self.beta, torch.Tensor) or isinstance( + self.beta, torch.FloatTensor + ): if len(self.beta) == 1: self.rnn.weight_hh_l0.fill_(self.beta[0]) elif len(self.beta) == self.hidden_size: @@ -331,8 +343,10 @@ def _beta_to_weight_hh(self): for i in range(self.hidden_size): self.rnn.weight_hh_l0.data[i].fill_(self.beta[i]) else: - raise ValueError("Beta must be either a single value or of length 'hidden_size'.") - + raise ValueError( + "Beta must be either a single value or of length 'hidden_size'." + ) + def _beta_buffer(self, beta, learn_beta): if not isinstance(beta, torch.Tensor): if beta is not None: @@ -340,7 +354,7 @@ def _beta_buffer(self, beta, learn_beta): self.register_buffer("beta", beta) def _graded_spikes_buffer( - self, graded_spikes_factor, learn_graded_spikes_factor + self, graded_spikes_factor, learn_graded_spikes_factor ): if not isinstance(graded_spikes_factor, torch.Tensor): graded_spikes_factor = torch.as_tensor(graded_spikes_factor) @@ -355,4 +369,4 @@ def _threshold_buffer(self, threshold, learn_threshold): if learn_threshold: self.threshold = nn.Parameter(threshold) else: - self.register_buffer("threshold", threshold) \ No newline at end of file + self.register_buffer("threshold", threshold) diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index 471b9e92..3f96be2c 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -1,11 +1,15 @@ -from typing import Optional +#from typing import Optional +from typing import Optional, Tuple, Union import torch +import os +import sys import nir import numpy as np import nirtorch import snntorch as snn + def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: """Convert a single snnTorch module to the equivalent object in the Neuromorphic Intermediate Representation (NIR). This function is used internally by the export_to_nir @@ -22,7 +26,45 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: :return: return the NIR node :rtype: Optional[nir.NIRNode] """ - if isinstance(module, snn.Leaky): + # Adding Conv2d layer + """ + if isinstance(module, torch.nn.Conv2d): + return nir.Conv2d( + input_shape=None, + weight=module.weight.detach(), + bias=module.bias.detach(), + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + """ + #modifiying bias of the conv2d layer extraction + if isinstance(module, torch.nn.Conv2d): + return nir.Conv2d( + input_shape=None, + weight=module.weight.detach(), + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + #better handle for the bias if it's False + bias=( + module.bias.detach() + if isinstance(module.bias, torch.Tensor) + else torch.zeros((module.weight.shape[0])) + ), + ) + elif isinstance(module, torch.nn.AvgPool2d): + return nir.AvgPool2d( + kernel_size=module.kernel_size, # (Height, Width) + stride=module.kernel_size + if module.stride is None + else module.stride, # (Height, width) + padding=(0, 0), # (Height, width) + ) + + elif isinstance(module, snn.Leaky): dt = 1e-4 beta = module.beta.detach().numpy() @@ -41,13 +83,11 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: elif isinstance(module, torch.nn.Linear): if module.bias is None: - return nir.Linear( - weight=module.weight.data.detach().numpy() - ) + return nir.Linear(weight=module.weight.data.detach().numpy()) else: return nir.Affine( weight=module.weight.data.detach().numpy(), - bias=module.bias.data.detach().numpy() + bias=module.bias.data.detach().numpy(), ) elif isinstance(module, snn.Synaptic): @@ -76,7 +116,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: elif isinstance(module, snn.RLeaky): # TODO(stevenabreu7): implement RLeaky - raise NotImplementedError('RLeaky not supported') + raise NotImplementedError("RLeaky not supported") elif isinstance(module, snn.RSynaptic): if module.all_to_all: @@ -85,7 +125,9 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: else: if len(module.recurrent.V.shape) == 0: # TODO: handle this better - if V is a scalar, then the weight has wrong shape - raise ValueError('V must be a vector, cannot infer layer size for scalar V') + raise ValueError( + "V must be a vector, cannot infer layer size for scalar V" + ) n_neurons = module.recurrent.V.shape[0] w = np.diag(module.recurrent.V.data.detach().numpy()) w_rec = nir.Linear(weight=w) @@ -105,30 +147,50 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: v_leak = np.zeros_like(beta) w_in = tau_syn / dt - return nir.NIRGraph(nodes={ - 'input': nir.Input(input_type=[n_neurons]), - 'lif': nir.CubaLIF( - v_threshold=vthr, - tau_mem=tau_mem, - tau_syn=tau_syn, - r=r, - v_leak=v_leak, - w_in=w_in, - ), - 'w_rec': w_rec, - 'output': nir.Output(output_type=[n_neurons]) - }, edges=[ - ('input', 'lif'), ('lif', 'w_rec'), ('w_rec', 'lif'), ('lif', 'output') - ]) + return nir.NIRGraph( + nodes={ + "input": nir.Input(input_type=[n_neurons]), + "lif": nir.CubaLIF( + v_threshold=vthr, + tau_mem=tau_mem, + tau_syn=tau_syn, + r=r, + v_leak=v_leak, + w_in=w_in, + ), + "w_rec": w_rec, + "output": nir.Output(output_type=[n_neurons]), + }, + edges=[ + ("input", "lif"), + ("lif", "w_rec"), + ("w_rec", "lif"), + ("lif", "output"), + ], + ) + elif isinstance(module, torch.nn.Flatten): + # Getting rid of the batch dimension for NIR + start_dim = ( + module.start_dim - 1 if module.start_dim > 0 else module.start_dim + ) + end_dim = module.end_dim - 1 if module.end_dim > 0 else module.end_dim + return nir.Flatten( + input_type=None, + start_dim=start_dim, + end_dim=end_dim, + ) else: - print(f'[WARNING] module not implemented: {module.__class__.__name__}') + print(f"[WARNING] module not implemented: {module.__class__.__name__}") return None def export_to_nir( - module: torch.nn.Module, sample_data: torch.Tensor, model_name: str = "snntorch", - model_fwd_args=[], ignore_dims=[] + module: torch.nn.Module, + sample_data: torch.Tensor, + model_name: str = "snntorch", + model_fwd_args=[], + ignore_dims=[], ) -> nir.NIRNode: """Convert an snnTorch module to the Neuromorphic Intermediate Representation (NIR). This function uses nirtorch to extract the computational graph of the torch module, @@ -160,7 +222,7 @@ def export_to_nir( sample_data = torch.randn(1, 784) nir_graph = export_to_nir(net, sample_data) - + :param module: Network model (either wrapped in Sequential container or as a class) :type module: torch.nn.Module @@ -180,8 +242,12 @@ def export_to_nir( :rtype: nir.NIRNode """ nir_graph = nirtorch.extract_nir_graph( - module, _extract_snntorch_module, sample_data, model_name=model_name, + module, + _extract_snntorch_module, + sample_data, + model_name=model_name, ignore_submodules_of=[snn.RLeaky, snn.RSynaptic], - model_fwd_args=model_fwd_args, ignore_dims=ignore_dims + model_fwd_args=model_fwd_args, + ignore_dims=ignore_dims, ) return nir_graph diff --git a/snntorch/functional/loss.py b/snntorch/functional/loss.py index f1e19da9..3157604b 100644 --- a/snntorch/functional/loss.py +++ b/snntorch/functional/loss.py @@ -893,4 +893,4 @@ def _ce_temporal_cases(self): raise ValueError( '`inverse` must be of type string containing either "negate" ' 'or "reciprocal".' - ) + ) \ No newline at end of file diff --git a/snntorch/functional/quant.py b/snntorch/functional/quant.py index 5ac31108..cc426c5e 100644 --- a/snntorch/functional/quant.py +++ b/snntorch/functional/quant.py @@ -132,7 +132,6 @@ def state_quant( num_levels, ) - # exponential / non-uniform quantization else: if multiplier is None: @@ -155,8 +154,12 @@ def state_quant( # asymmetric: shifted to threshold if thr_centered: - max_val = threshold + (threshold * upper_limit) # maximum level that can be reached - min_val = -(threshold + (threshold * lower_limit)) # minimum level that can be reached + max_val = threshold + ( + threshold * upper_limit + ) # maximum level that can be reached + min_val = -( + threshold + (threshold * lower_limit) + ) # minimum level that can be reached num_levels = 2 << num_bits - 1 # total number of levels @@ -165,36 +168,46 @@ def state_quant( lower_range = threshold - min_val # levels below the threshold upper_range = max_val - threshold # levels above the threshold - lower_percent = lower_range / overall_range # percent +66 the threshold - upper_percent = upper_range / overall_range # percent above the threshold + lower_percent = ( + lower_range / overall_range + ) # percent +66 the threshold + upper_percent = ( + upper_range / overall_range + ) # percent above the threshold - lower_bits = math.floor(num_levels * lower_percent) # how many levels lower than the threshold - upper_bits = num_levels - lower_bits # how many bits above the threshold + lower_bits = math.floor( + num_levels * lower_percent + ) # how many levels lower than the threshold + upper_bits = ( + num_levels - lower_bits + ) # how many bits above the threshold lower_summation = 0 store_val = [] if lower_bits != 0: - for j in reversed(range(lower_bits)): # figure out how much the summation travels - lower_curr = (multiplier ** j) - lower_summation += (multiplier ** j) + for j in reversed( + range(lower_bits) + ): # figure out how much the summation travels + lower_curr = multiplier ** j + lower_summation += multiplier ** j lower_room = lower_summation / lower_range min_temp_sum = min_val # store_val.append(min_temp_sum) - for j in (range(lower_bits)): + for j in range(lower_bits): lower_curr = multiplier ** j store_val.append(min_temp_sum) - min_temp_sum += (lower_curr / lower_room) + min_temp_sum += lower_curr / lower_room # store_val.append(min_temp_sum) if upper_bits != 0: upper_summation = 0 for j in reversed(range(upper_bits)): - upper_curr = (multiplier ** j) - upper_summation += (multiplier ** j) + upper_curr = multiplier ** j + upper_summation += multiplier ** j upper_room = upper_summation / upper_range @@ -204,7 +217,7 @@ def state_quant( for j in reversed(range(upper_bits)): upper_curr = multiplier ** j # store_val.append(max_temp_sum) - max_temp_sum += (upper_curr / upper_room) + max_temp_sum += upper_curr / upper_room store_val.append(max_temp_sum) # store_val.append(max_temp_sum) @@ -212,8 +225,12 @@ def state_quant( # centered about zero else: - max_val = threshold + (threshold * upper_limit) # maximum level that can be reached - min_val = -(threshold + (threshold * lower_limit)) # minimum level that can be reached + max_val = threshold + ( + threshold * upper_limit + ) # maximum level that can be reached + min_val = -( + threshold + (threshold * lower_limit) + ) # minimum level that can be reached num_levels = 2 << num_bits - 1 # total number of levels @@ -222,30 +239,40 @@ def state_quant( lower_range = 0 - min_val # levels below the threshold upper_range = max_val - 0 # levels above the threshold - lower_percent = lower_range / overall_range # percent +66 the threshold - upper_percent = upper_range / overall_range # percent above the threshold + lower_percent = ( + lower_range / overall_range + ) # percent +66 the threshold + upper_percent = ( + upper_range / overall_range + ) # percent above the threshold - lower_bits = math.floor(num_levels * lower_percent) # how many levels lower than the threshold - upper_bits = num_levels - lower_bits # how many bits above the threshold + lower_bits = math.floor( + num_levels * lower_percent + ) # how many levels lower than the threshold + upper_bits = ( + num_levels - lower_bits + ) # how many bits above the threshold lower_summation = 0 store_val = [] if lower_bits != 0: - for j in reversed(range(lower_bits)): # figure out how much the summation travels - lower_curr = (multiplier ** j) - lower_summation += (multiplier ** j) + for j in reversed( + range(lower_bits) + ): # figure out how much the summation travels + lower_curr = multiplier ** j + lower_summation += multiplier ** j lower_room = lower_summation / lower_range min_temp_sum = min_val # store_val.append(min_temp_sum) - for j in (range(lower_bits)): + for j in range(lower_bits): lower_curr = multiplier ** j store_val.append(min_temp_sum) - min_temp_sum += (lower_curr / lower_room) + min_temp_sum += lower_curr / lower_room # store_val.append(min_temp_sum) @@ -253,8 +280,8 @@ def state_quant( upper_summation = 0 for j in reversed(range(upper_bits)): - upper_curr = (multiplier ** j) - upper_summation += (multiplier ** j) + upper_curr = multiplier ** j + upper_summation += multiplier ** j upper_room = upper_summation / upper_range @@ -263,7 +290,7 @@ def state_quant( # store_val.append(max_temp_sum) for j in reversed(range(upper_bits)): upper_curr = multiplier ** j - max_temp_sum += (upper_curr / upper_room) + max_temp_sum += upper_curr / upper_room store_val.append(max_temp_sum) # store_val.append(max_temp_sum) @@ -273,4 +300,4 @@ def state_quant( def inner(x): return StateQuant.apply(x, levels) - return inner \ No newline at end of file + return inner diff --git a/snntorch/functional/stdp_learner.py b/snntorch/functional/stdp_learner.py index 7d9304dc..a2960542 100644 --- a/snntorch/functional/stdp_learner.py +++ b/snntorch/functional/stdp_learner.py @@ -10,86 +10,118 @@ def stdp_linear_single_step( - fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor, + fc: nn.Linear, + in_spike: torch.Tensor, + out_spike: torch.Tensor, trace_pre: Union[float, torch.Tensor, None], trace_post: Union[float, torch.Tensor, None], - tau_pre: float, tau_post: float, - f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x + tau_pre: float, + tau_post: float, + f_pre: Callable = lambda x: x, + f_post: Callable = lambda x: x, ): if trace_pre is None: - trace_pre = 0. + trace_pre = 0.0 if trace_post is None: - trace_post = 0. + trace_post = 0.0 weight = fc.weight.data - trace_pre = trace_pre - trace_pre / tau_pre + in_spike # shape = [batch_size, N_in] - trace_post = trace_post - trace_post / tau_post + out_spike # shape = [batch_size, N_out] + trace_pre = ( + trace_pre - trace_pre / tau_pre + in_spike + ) # shape = [batch_size, N_in] + trace_post = ( + trace_post - trace_post / tau_post + out_spike + ) # shape = [batch_size, N_out] # [batch_size, N_out, N_in] -> [N_out, N_in] - delta_w_pre = -f_pre(weight) * (trace_post.unsqueeze(2) * in_spike.unsqueeze(1)).sum(0) - delta_w_post = f_post(weight) * (trace_pre.unsqueeze(1) * out_spike.unsqueeze(2)).sum(0) + delta_w_pre = -f_pre(weight) * ( + trace_post.unsqueeze(2) * in_spike.unsqueeze(1) + ).sum(0) + delta_w_post = f_post(weight) * ( + trace_pre.unsqueeze(1) * out_spike.unsqueeze(2) + ).sum(0) return trace_pre, trace_post, delta_w_pre + delta_w_post def mstdp_linear_single_step( - fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor, + fc: nn.Linear, + in_spike: torch.Tensor, + out_spike: torch.Tensor, trace_pre: Union[float, torch.Tensor, None], trace_post: Union[float, torch.Tensor, None], - tau_pre: float, tau_post: float, - f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x + tau_pre: float, + tau_post: float, + f_pre: Callable = lambda x: x, + f_post: Callable = lambda x: x, ): if trace_pre is None: - trace_pre = 0. + trace_pre = 0.0 if trace_post is None: - trace_post = 0. + trace_post = 0.0 weight = fc.weight.data - trace_pre = trace_pre * math.exp(-1 / tau_pre) + in_spike # shape = [batch_size, C_in] - trace_post = trace_post * math.exp(-1 / tau_post) + out_spike # shape = [batch_size, C_out] + trace_pre = ( + trace_pre * math.exp(-1 / tau_pre) + in_spike + ) # shape = [batch_size, C_in] + trace_post = ( + trace_post * math.exp(-1 / tau_post) + out_spike + ) # shape = [batch_size, C_out] # [batch_size, N_out, N_in] - eligibility = f_post(weight) * (trace_pre.unsqueeze(1) * out_spike.unsqueeze(2)) - \ - f_pre(weight) * (trace_post.unsqueeze(2) * in_spike.unsqueeze(1)) + eligibility = f_post(weight) * ( + trace_pre.unsqueeze(1) * out_spike.unsqueeze(2) + ) - f_pre(weight) * (trace_post.unsqueeze(2) * in_spike.unsqueeze(1)) return trace_pre, trace_post, eligibility def mstdpet_linear_single_step( - fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor, + fc: nn.Linear, + in_spike: torch.Tensor, + out_spike: torch.Tensor, trace_pre: Union[float, torch.Tensor, None], trace_post: Union[float, torch.Tensor, None], - tau_pre: float, tau_post: float, tau_trace: float, - f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x + tau_pre: float, + tau_post: float, + tau_trace: float, + f_pre: Callable = lambda x: x, + f_post: Callable = lambda x: x, ): if trace_pre is None: - trace_pre = 0. + trace_pre = 0.0 if trace_post is None: - trace_post = 0. + trace_post = 0.0 weight = fc.weight.data trace_pre = trace_pre * math.exp(-1 / tau_pre) + in_spike trace_post = trace_post * math.exp(-1 / tau_post) + out_spike - eligibility = f_post(weight) * torch.outer(out_spike, trace_pre) - \ - f_pre(weight) * torch.outer(trace_post, in_spike) + eligibility = f_post(weight) * torch.outer(out_spike, trace_pre) - f_pre( + weight + ) * torch.outer(trace_post, in_spike) return trace_pre, trace_post, eligibility def stdp_conv2d_single_step( - conv: nn.Conv2d, in_spike: torch.Tensor, out_spike: torch.Tensor, - trace_pre: Union[torch.Tensor, None], trace_post: Union[torch.Tensor, None], - tau_pre: float, tau_post: float, - f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x + conv: nn.Conv2d, + in_spike: torch.Tensor, + out_spike: torch.Tensor, + trace_pre: Union[torch.Tensor, None], + trace_post: Union[torch.Tensor, None], + tau_pre: float, + tau_post: float, + f_pre: Callable = lambda x: x, + f_post: Callable = lambda x: x, ): if conv.dilation != (1, 1): raise NotImplementedError( - 'STDP with dilation != 1 for Conv2d has not been implemented!' + "STDP with dilation != 1 for Conv2d has not been implemented!" ) if conv.groups != 1: raise NotImplementedError( - 'STDP with groups != 1 for Conv2d has not been implemented!' + "STDP with groups != 1 for Conv2d has not been implemented!" ) stride_h = conv.stride[0] @@ -100,10 +132,11 @@ def stdp_conv2d_single_step( else: pH = conv.padding[0] pW = conv.padding[1] - if conv.padding_mode != 'zeros': + if conv.padding_mode != "zeros": in_spike = F.pad( - in_spike, conv._reversed_padding_repeated_twice, - mode=conv.padding_mode + in_spike, + conv._reversed_padding_repeated_twice, + mode=conv.padding_mode, ) else: in_spike = F.pad(in_spike, pad=(pW, pW, pH, pH)) @@ -127,37 +160,51 @@ def stdp_conv2d_single_step( h_end = in_spike.shape[2] - conv.weight.shape[2] + 1 + h w_end = in_spike.shape[3] - conv.weight.shape[3] + 1 + w - pre_spike = in_spike[:, :, h:h_end:stride_h, w:w_end:stride_w] # shape = [batch_size, C_in, h_out, w_out] + pre_spike = in_spike[ + :, :, h:h_end:stride_h, w:w_end:stride_w + ] # shape = [batch_size, C_in, h_out, w_out] post_spike = out_spike # shape = [batch_size, C_out, h_out, h_out] - weight = conv.weight.data[:, :, h, w] # shape = [batch_size_out, C_in] + weight = conv.weight.data[ + :, :, h, w + ] # shape = [batch_size_out, C_in] - tr_pre = trace_pre[:, :, h:h_end:stride_h, w:w_end:stride_w] # shape = [batch_size, C_in, h_out, w_out] + tr_pre = trace_pre[ + :, :, h:h_end:stride_h, w:w_end:stride_w + ] # shape = [batch_size, C_in, h_out, w_out] tr_post = trace_post # shape = [batch_size, C_out, h_out, w_out] - delta_w_pre = - (f_pre(weight) * - (tr_post.unsqueeze(2) * pre_spike.unsqueeze(1)) - .permute([1, 2, 0, 3, 4]).sum(dim=[2, 3, 4])) - delta_w_post = f_post(weight) * \ - (tr_pre.unsqueeze(1) * post_spike.unsqueeze(2)) \ - .permute([1, 2, 0, 3, 4]).sum(dim=[2, 3, 4]) + delta_w_pre = -( + f_pre(weight) + * (tr_post.unsqueeze(2) * pre_spike.unsqueeze(1)) + .permute([1, 2, 0, 3, 4]) + .sum(dim=[2, 3, 4]) + ) + delta_w_post = f_post(weight) * ( + tr_pre.unsqueeze(1) * post_spike.unsqueeze(2) + ).permute([1, 2, 0, 3, 4]).sum(dim=[2, 3, 4]) delta_w[:, :, h, w] += delta_w_pre + delta_w_post return trace_pre, trace_post, delta_w def stdp_conv1d_single_step( - conv: nn.Conv1d, in_spike: torch.Tensor, out_spike: torch.Tensor, - trace_pre: Union[torch.Tensor, None], trace_post: Union[torch.Tensor, None], - tau_pre: float, tau_post: float, - f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x + conv: nn.Conv1d, + in_spike: torch.Tensor, + out_spike: torch.Tensor, + trace_pre: Union[torch.Tensor, None], + trace_post: Union[torch.Tensor, None], + tau_pre: float, + tau_post: float, + f_pre: Callable = lambda x: x, + f_post: Callable = lambda x: x, ): if conv.dilation != (1,): raise NotImplementedError( - 'STDP with dilation != 1 for Conv1d has not been implemented!' + "STDP with dilation != 1 for Conv1d has not been implemented!" ) if conv.groups != 1: raise NotImplementedError( - 'STDP with groups != 1 for Conv1d has not been implemented!' + "STDP with groups != 1 for Conv1d has not been implemented!" ) stride_l = conv.stride[0] @@ -166,10 +213,11 @@ def stdp_conv1d_single_step( pass else: pL = conv.padding[0] - if conv.padding_mode != 'zeros': + if conv.padding_mode != "zeros": in_spike = F.pad( - in_spike, conv._reversed_padding_repeated_twice, - mode=conv.padding_mode + in_spike, + conv._reversed_padding_repeated_twice, + mode=conv.padding_mode, ) else: in_spike = F.pad(in_spike, pad=(pL, pL)) @@ -190,29 +238,40 @@ def stdp_conv1d_single_step( delta_w = torch.zeros_like(conv.weight.data) for l in range(conv.weight.shape[2]): l_end = in_spike.shape[2] - conv.weight.shape[2] + 1 + l - pre_spike = in_spike[:, :, l:l_end:stride_l] # shape = [batch_size, C_in, l_out] + pre_spike = in_spike[ + :, :, l:l_end:stride_l + ] # shape = [batch_size, C_in, l_out] post_spike = out_spike # shape = [batch_size, C_out, l_out] weight = conv.weight.data[:, :, l] # shape = [batch_size_out, C_in] - tr_pre = trace_pre[:, :, l:l_end:stride_l] # shape = [batch_size, C_in, l_out] + tr_pre = trace_pre[ + :, :, l:l_end:stride_l + ] # shape = [batch_size, C_in, l_out] tr_post = trace_post # shape = [batch_size, C_out, l_out] - delta_w_pre = - (f_pre(weight) * - (tr_post.unsqueeze(2) * pre_spike.unsqueeze(1)) - .permute([1, 2, 0, 3]).sum(dim=[2, 3])) - delta_w_post = f_post(weight) * \ - (tr_pre.unsqueeze(1) * post_spike.unsqueeze(2)) \ - .permute([1, 2, 0, 3]).sum(dim=[2, 3]) + delta_w_pre = -( + f_pre(weight) + * (tr_post.unsqueeze(2) * pre_spike.unsqueeze(1)) + .permute([1, 2, 0, 3]) + .sum(dim=[2, 3]) + ) + delta_w_post = f_post(weight) * ( + tr_pre.unsqueeze(1) * post_spike.unsqueeze(2) + ).permute([1, 2, 0, 3]).sum(dim=[2, 3]) delta_w[:, :, l] += delta_w_pre + delta_w_post return trace_pre, trace_post, delta_w + class STDPLearner(nn.Module): def __init__( self, - synapse: Union[nn.Conv2d, nn.Linear], sn, - tau_pre: float, tau_post: float, - f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x + synapse: Union[nn.Conv2d, nn.Linear], + sn, + tau_pre: float, + tau_post: float, + f_pre: Callable = lambda x: x, + f_post: Callable = lambda x: x, ): super().__init__() self.tau_pre = tau_pre @@ -238,10 +297,10 @@ def enable(self): self.in_spike_monitor.enable() self.out_spike_monitor.enable() - def step(self, on_grad: bool = True, scale: float = 1.): + def step(self, on_grad: bool = True, scale: float = 1.0): length = self.in_spike_monitor.records.__len__() delta_w = None - + if isinstance(self.synapse, nn.Linear): stdp_f = stdp_linear_single_step elif isinstance(self.synapse, nn.Conv2d): @@ -252,16 +311,25 @@ def step(self, on_grad: bool = True, scale: float = 1.): raise NotImplementedError(self.synapse) for _ in range(length): - in_spike = self.in_spike_monitor.records.pop(0) # [batch_size, N_in] - out_spike = self.out_spike_monitor.records.pop(0) # [batch_size, N_out] + in_spike = self.in_spike_monitor.records.pop( + 0 + ) # [batch_size, N_in] + out_spike = self.out_spike_monitor.records.pop( + 0 + ) # [batch_size, N_out] self.trace_pre, self.trace_post, dw = stdp_f( - self.synapse, in_spike, out_spike, - self.trace_pre, self.trace_post, - self.tau_pre, self.tau_post, - self.f_pre, self.f_post + self.synapse, + in_spike, + out_spike, + self.trace_pre, + self.trace_post, + self.tau_pre, + self.tau_post, + self.f_pre, + self.f_post, ) - if scale != 1.: + if scale != 1.0: dw *= scale delta_w = dw if (delta_w is None) else (delta_w + dw) @@ -273,4 +341,3 @@ def step(self, on_grad: bool = True, scale: float = 1.): self.synapse.weight.grad = self.synapse.weight.grad - delta_w else: return delta_w - diff --git a/snntorch/import_nir.py b/snntorch/import_nir.py index 154a6a17..1a3c664e 100644 --- a/snntorch/import_nir.py +++ b/snntorch/import_nir.py @@ -4,8 +4,9 @@ import torch import snntorch as snn - -def _create_rnn_subgraph(graph: nir.NIRGraph, lif_nk: str, w_nk: str) -> nir.NIRGraph: +def _create_rnn_subgraph( + graph: nir.NIRGraph, lif_nk: str, w_nk: str +) -> nir.NIRGraph: """Take a NIRGraph plus the node keys for a LIF and a W_rec, and return a new NIRGraph which has the RNN subgraph replaced with a subgraph (i.e., a single NIRGraph node). @@ -30,29 +31,40 @@ def _create_rnn_subgraph(graph: nir.NIRGraph, lif_nk: str, w_nk: str) -> nir.NIR :rtype: nir.NIRGraph """ # NOTE: assuming that the LIF and W_rec have keys of form xyz.abc - sg_key = lif_nk.split('.')[0] # TODO: make this more general? + sg_key = lif_nk.split(".")[0] # TODO: make this more general? # create subgraph for RNN sg_edges = [ - (lif_nk, w_nk), (w_nk, lif_nk), (lif_nk, f'{sg_key}.output'), (f'{sg_key}.input', w_nk) + (lif_nk, w_nk), + (w_nk, lif_nk), + (lif_nk, f"{sg_key}.output"), + (f"{sg_key}.input", w_nk), ] sg_nodes = { lif_nk: graph.nodes[lif_nk], w_nk: graph.nodes[w_nk], - f'{sg_key}.input': nir.Input(graph.nodes[lif_nk].input_type), - f'{sg_key}.output': nir.Output(graph.nodes[lif_nk].output_type), + f"{sg_key}.input": nir.Input(graph.nodes[lif_nk].input_type), + f"{sg_key}.output": nir.Output(graph.nodes[lif_nk].output_type), } sg = nir.NIRGraph(nodes=sg_nodes, edges=sg_edges) # remove subgraph edges from graph - graph.edges = [e for e in graph.edges if e not in [(lif_nk, w_nk), (w_nk, lif_nk)]] + graph.edges = [ + e for e in graph.edges if e not in [(lif_nk, w_nk), (w_nk, lif_nk)] + ] # remove subgraph nodes from graph - graph.nodes = {k: v for k, v in graph.nodes.items() if k not in [lif_nk, w_nk]} + graph.nodes = { + k: v for k, v in graph.nodes.items() if k not in [lif_nk, w_nk] + } # change edges of type (x, lif_nk) to (x, sg_key) - graph.edges = [(e[0], sg_key) if e[1] == lif_nk else e for e in graph.edges] + graph.edges = [ + (e[0], sg_key) if e[1] == lif_nk else e for e in graph.edges + ] # change edges of type (lif_nk, x) to (sg_key, x) - graph.edges = [(sg_key, e[1]) if e[0] == lif_nk else e for e in graph.edges] + graph.edges = [ + (sg_key, e[1]) if e[0] == lif_nk else e for e in graph.edges + ] # insert subgraph into graph and return graph.nodes[sg_key] = sg @@ -61,7 +73,7 @@ def _create_rnn_subgraph(graph: nir.NIRGraph, lif_nk: str, w_nk: str) -> nir.NIR def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: """Take a NIRGraph and replace any RNN subgraphs with a single NIRGraph node. - Goes through the NIRGraph to find any RNN subgraphs, and replaces them with a single NIRGraph node, + Goes through the NIRGraph to find any RNN subgraphs, and replaces them with a single NIRGraph node, using the _create_rnn_subgraph function. :param graph: NIRGraph @@ -70,10 +82,10 @@ def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: :return: NIRGraph with RNN subgraphs replaced with a single NIRGraph node :rtype: nir.NIRGraph """ - print('replace rnn subgraph with nirgraph') + print("replace rnn subgraph with nirgraph") if len(set(graph.edges)) != len(graph.edges): - print('[WARNING] duplicate edges found, removing') + print("[WARNING] duplicate edges found, removing") graph.edges = list(set(graph.edges)) # find cycle of LIF <> Dense nodes @@ -97,10 +109,10 @@ def _replace_rnn_subgraph_with_nirgraph(graph: nir.NIRGraph) -> nir.NIRGraph: return graph -def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): +def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): # type: ignore """Try parsing the presented graph as a RNN subgraph. Assumes the graph is a valid RNN subgraph with four nodes in the following structure: - + ``` Input -> LIF | CubaLIF -> Output ^ @@ -122,28 +134,38 @@ def _parse_rnn_subgraph(graph: nir.NIRGraph) -> (nir.NIRNode, nir.NIRNode, int): :rtype: int """ sub_nodes = graph.nodes.values() - assert len(sub_nodes) == 4, 'only 4-node RNN allowed in subgraph' + assert len(sub_nodes) == 4, "only 4-node RNN allowed in subgraph" try: input_node = [n for n in sub_nodes if isinstance(n, nir.Input)][0] output_node = [n for n in sub_nodes if isinstance(n, nir.Output)][0] - lif_node = [n for n in sub_nodes if isinstance(n, (nir.LIF, nir.CubaLIF))][0] - wrec_node = [n for n in sub_nodes if isinstance(n, (nir.Affine, nir.Linear))][0] + lif_node = [ + n for n in sub_nodes if isinstance(n, (nir.LIF, nir.CubaLIF)) + ][0] + wrec_node = [ + n for n in sub_nodes if isinstance(n, (nir.Affine, nir.Linear)) + ][0] except IndexError: - raise ValueError('invalid RNN subgraph - could not find all required nodes') + raise ValueError( + "invalid RNN subgraph - could not find all required nodes" + ) lif_size = list(input_node.input_type.values())[0][0] - assert lif_size == list(output_node.output_type.values())[0][0], 'output size mismatch' - assert lif_size == lif_node.v_threshold.size, 'lif size mismatch (v_threshold)' - assert lif_size == wrec_node.weight.shape[0], 'w_rec shape mismatch' - assert lif_size == wrec_node.weight.shape[1], 'w_rec shape mismatch' + assert ( + lif_size == list(output_node.output_type.values())[0][0] + ), "output size mismatch" + assert ( + lif_size == lif_node.v_threshold.size + ), "lif size mismatch (v_threshold)" + assert lif_size == wrec_node.weight.shape[0], "w_rec shape mismatch" + assert lif_size == wrec_node.weight.shape[1], "w_rec shape mismatch" return lif_node, wrec_node, lif_size def _nir_to_snntorch_module( - node: nir.NIRNode, hack_w_scale=True, init_hidden=False + node: nir.NIRNode, hack_w_scale=True, init_hidden=False ) -> torch.nn.Module: """Convert a NIR node to a snnTorch module. This function is used by the import_from_nir function. - + :param node: NIR node :type node: nir.NIRNode @@ -160,7 +182,7 @@ def _nir_to_snntorch_module( return None elif isinstance(node, nir.Affine): - assert node.bias is not None, 'bias must be specified for Affine layer' + assert node.bias is not None, "bias must be specified for Affine layer" mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0]) mod.weight.data = torch.Tensor(node.weight) @@ -169,7 +191,9 @@ def _nir_to_snntorch_module( return mod elif isinstance(node, nir.Linear): - mod = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=False) + mod = torch.nn.Linear( + node.weight.shape[1], node.weight.shape[0], bias=False + ) mod.weight.data = torch.Tensor(node.weight) return mod @@ -191,20 +215,22 @@ def _nir_to_snntorch_module( if isinstance(node, nir.Flatten): return torch.nn.Flatten(node.start_dim, node.end_dim) - if isinstance(node, nir.SumPool2d): + if isinstance(node, nir.AvgPool2d): return torch.nn.AvgPool2d( kernel_size=tuple(node.kernel_size), stride=tuple(node.stride), padding=tuple(node.padding), - divisor_override=1, # turn AvgPool into SumPool + # divisor_override=1, ) elif isinstance(node, nir.IF): - assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' - assert np.unique(node.r).size == 1, 'r must be same for all neurons' + assert ( + np.unique(node.v_threshold).size == 1 + ), "v_threshold must be same for all neurons" + assert np.unique(node.r).size == 1, "r must be same for all neurons" vthr = np.unique(node.v_threshold)[0] r = np.unique(node.r)[0] - assert r == 1, 'r != 1 not supported' + assert r == 1, "r != 1 not supported" mod = snn.Leaky( beta=0.9, threshold=vthr * r, @@ -216,27 +242,35 @@ def _nir_to_snntorch_module( elif isinstance(node, nir.LIF): dt = 1e-4 - assert np.allclose(node.v_leak, 0.), 'v_leak not supported' - assert np.unique(node.v_threshold).size == 1, 'v_threshold must be same for all neurons' + assert np.allclose(node.v_leak, 0.0), "v_leak not supported" + assert ( + np.unique(node.v_threshold).size == 1 + ), "v_threshold must be same for all neurons" beta = 1 - (dt / node.tau) vthr = node.v_threshold w_scale = node.r * dt / node.tau - if not np.allclose(w_scale, 1.): + if not np.allclose(w_scale, 1.0): if hack_w_scale: vthr = vthr / np.unique(w_scale)[0] - print('[warning] scaling weights to avoid scaling inputs') - print(f'w_scale: {w_scale}, r: {node.r}, dt: {dt}, tau: {node.tau}') + print("[warning] scaling weights to avoid scaling inputs") + print( + f"w_scale: {w_scale}, r: {node.r}, dt: {dt}, tau: {node.tau}" + ) else: - raise NotImplementedError('w_scale must be 1, or the same for all neurons') + raise NotImplementedError( + "w_scale must be 1, or the same for all neurons" + ) - assert np.unique(vthr).size == 1, 'LIF v_thr must be same for all neurons' + assert ( + np.unique(vthr).size == 1 + ), "LIF v_thr must be same for all neurons" return snn.Leaky( beta=beta, threshold=np.unique(vthr)[0], - reset_mechanism='zero', + reset_mechanism="zero", init_hidden=init_hidden, reset_delay=False, ) @@ -244,23 +278,31 @@ def _nir_to_snntorch_module( elif isinstance(node, nir.CubaLIF): dt = 1e-4 - assert np.allclose(node.v_leak, 0), 'v_leak not supported' - assert np.allclose(node.r, node.tau_mem / dt), 'r not supported in CubaLIF' + assert np.allclose(node.v_leak, 0), "v_leak not supported" + assert np.allclose( + node.r, node.tau_mem / dt + ), "r not supported in CubaLIF" alpha = 1 - (dt / node.tau_syn) beta = 1 - (dt / node.tau_mem) vthr = node.v_threshold w_scale = node.w_in * (dt / node.tau_syn) - if not np.allclose(w_scale, 1.): + if not np.allclose(w_scale, 1.0): if hack_w_scale: vthr = vthr / w_scale - print('[warning] scaling weights to avoid scaling inputs') - print(f'w_scale: {w_scale}, w_in: {node.w_in}, dt: {dt}, tau_syn: {node.tau_syn}') + print("[warning] scaling weights to avoid scaling inputs") + print( + f"w_scale: {w_scale}, w_in: {node.w_in}, dt: {dt}, tau_syn: {node.tau_syn}" + ) else: - raise NotImplementedError('w_scale must be 1, or the same for all neurons') + raise NotImplementedError( + "w_scale must be 1, or the same for all neurons" + ) - assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' + assert ( + np.unique(vthr).size == 1 + ), "CubaLIF v_thr must be same for all neurons" if np.unique(alpha).size == 1: alpha = float(np.unique(alpha)[0]) @@ -271,7 +313,7 @@ def _nir_to_snntorch_module( alpha=alpha, beta=beta, threshold=float(np.unique(vthr)[0]), - reset_mechanism='zero', + reset_mechanism="zero", init_hidden=init_hidden, reset_delay=False, ) @@ -280,30 +322,42 @@ def _nir_to_snntorch_module( lif_node, wrec_node, lif_size = _parse_rnn_subgraph(node) if isinstance(lif_node, nir.LIF): - raise NotImplementedError('LIF in subgraph not supported') + raise NotImplementedError("LIF in subgraph not supported") elif isinstance(lif_node, nir.CubaLIF): dt = 1e-4 - assert np.allclose(lif_node.v_leak, 0), 'v_leak not supported' - assert np.allclose(lif_node.r, lif_node.tau_mem / dt), 'r not supported in CubaLIF' + assert np.allclose(lif_node.v_leak, 0), "v_leak not supported" + assert np.allclose( + lif_node.r, lif_node.tau_mem / dt + ), "r not supported in CubaLIF" alpha = 1 - (dt / lif_node.tau_syn) beta = 1 - (dt / lif_node.tau_mem) vthr = lif_node.v_threshold w_scale = lif_node.w_in * (dt / lif_node.tau_syn) - if not np.allclose(w_scale, 1.): + if not np.allclose(w_scale, 1.0): if hack_w_scale: vthr = vthr / w_scale - print(f'[warning] scaling weights to avoid scaling inputs. w_scale: {w_scale}') - print(f'w_in: {lif_node.w_in}, dt: {dt}, tau_syn: {lif_node.tau_syn}') + print( + f"[warning] scaling weights to avoid scaling inputs. w_scale: {w_scale}" + ) + print( + f"w_in: {lif_node.w_in}, dt: {dt}, tau_syn: {lif_node.tau_syn}" + ) else: - raise NotImplementedError('w_scale must be 1, or the same for all neurons') + raise NotImplementedError( + "w_scale must be 1, or the same for all neurons" + ) - assert np.unique(vthr).size == 1, 'CubaLIF v_thr must be same for all neurons' + assert ( + np.unique(vthr).size == 1 + ), "CubaLIF v_thr must be same for all neurons" - diagonal = np.array_equal(wrec_node.weight, np.diag(np.diag(wrec_node.weight))) + diagonal = np.array_equal( + wrec_node.weight, np.diag(np.diag(wrec_node.weight)) + ) if np.unique(alpha).size == 1: alpha = float(np.unique(alpha)[0]) @@ -311,7 +365,9 @@ def _nir_to_snntorch_module( beta = float(np.unique(beta)[0]) if diagonal: - V = torch.from_numpy(np.diag(wrec_node.weight)).to(dtype=torch.float32) + V = torch.from_numpy(np.diag(wrec_node.weight)).to( + dtype=torch.float32 + ) else: V = None @@ -319,7 +375,7 @@ def _nir_to_snntorch_module( alpha=alpha, beta=beta, threshold=float(np.unique(vthr)[0]), - reset_mechanism='zero', + reset_mechanism="zero", init_hidden=init_hidden, all_to_all=not diagonal, linear_features=lif_size if not diagonal else None, @@ -328,13 +384,21 @@ def _nir_to_snntorch_module( ) if isinstance(rsynaptic.recurrent, torch.nn.Linear): - rsynaptic.recurrent.weight.data = torch.Tensor(wrec_node.weight) + rsynaptic.recurrent.weight.data = torch.Tensor( + wrec_node.weight + ) if isinstance(wrec_node, nir.Affine): - rsynaptic.recurrent.bias.data = torch.Tensor(wrec_node.bias) + rsynaptic.recurrent.bias.data = torch.Tensor( + wrec_node.bias + ) else: - rsynaptic.recurrent.bias.data = torch.zeros_like(rsynaptic.recurrent.bias) + rsynaptic.recurrent.bias.data = torch.zeros_like( + rsynaptic.recurrent.bias + ) else: - rsynaptic.recurrent.V.data = torch.diagonal(torch.Tensor(wrec_node.weight)) + rsynaptic.recurrent.V.data = torch.diagonal( + torch.Tensor(wrec_node.weight) + ) return rsynaptic @@ -342,14 +406,16 @@ def _nir_to_snntorch_module( return torch.nn.Identity() else: - print('[WARNING] could not parse node of type:', node.__class__.__name__) + print( + "[WARNING] could not parse node of type:", node.__class__.__name__ + ) return None def import_from_nir(graph: nir.NIRGraph) -> torch.nn.Module: """Convert a NIRGraph to a snnTorch module. This function is the inverse of export_to_nir. - It proceeds by wrapping any recurrent connections into NIR sub-graphs, then converts each - NIR module into the equivalent snnTorch module, and wraps them into a torch.nn.Module + It proceeds by wrapping any recurrent connections into NIR sub-graphs, then converts each + NIR module into the equivalent snnTorch module, and wraps them into a torch.nn.Module using the generic GraphExecutor from NIRTorch to execute all modules in the right order. Missing features: @@ -376,7 +442,7 @@ def import_from_nir(graph: nir.NIRGraph) -> torch.nn.Module: nir_graph = export_to_nir(net, sample_data, model_name="snntorch") net2 = import_from_nir(nir_graph) - + :param graph: NIR graph :type graph: NIR.NIRGraph diff --git a/snntorch/spikeplot.py b/snntorch/spikeplot.py index e0b945cc..8f92d471 100644 --- a/snntorch/spikeplot.py +++ b/snntorch/spikeplot.py @@ -78,7 +78,9 @@ def animator(data, fig, ax, num_steps=False, interval=40, cmap="plasma"): for step in range( num_steps ): # im appears unused but is required by camera.snap() - im = ax.imshow(data[step], cmap=cmap, vmin=data.min(), vmax=data.max()) # noqa: F841 + im = ax.imshow( + data[step], cmap=cmap, vmin=data.min(), vmax=data.max() + ) # noqa: F841 camera.snap() anim = camera.animate(interval=interval) diff --git a/snntorch/spikevision/_utils.py b/snntorch/spikevision/_utils.py index 6747b25f..1ab1eaf5 100644 --- a/snntorch/spikevision/_utils.py +++ b/snntorch/spikevision/_utils.py @@ -28,7 +28,7 @@ def load_ATIS_bin(filename): ) # Process time stamp overflow events - time_increment = 2**13 + time_increment = 2 ** 13 overflow_indices = np.where(all_y == 240)[0] for overflow_index in overflow_indices: all_ts[overflow_index:] += time_increment @@ -166,7 +166,7 @@ def load_jaer( "read %i (~ %.2fM) AE events, duration= %.2fs" % ( len(timestamps), - len(timestamps) / float(10**6), + len(timestamps) / float(10 ** 6), (timestamps[-1] - timestamps[0]) * td, ) ) diff --git a/snntorch/surrogate.py b/snntorch/surrogate.py index 31914e93..148c5eeb 100644 --- a/snntorch/surrogate.py +++ b/snntorch/surrogate.py @@ -641,6 +641,7 @@ def custom_fast_sigmoid(input_, grad_input, spikes): ).to(device) """ + @staticmethod def forward(ctx, input_, custom_surrogate_function): out = (input_ > 0).float() diff --git a/snntorch/utils.py b/snntorch/utils.py index 7d2c35bc..cd12cc00 100644 --- a/snntorch/utils.py +++ b/snntorch/utils.py @@ -249,4 +249,3 @@ def _final_layer_check(net): return 4 else: # if not from snn, assume from nn with 1 return return 1 - \ No newline at end of file diff --git a/tests/test_nir.py b/tests/test_nir.py index 0fa1f635..dd2b5ff8 100644 --- a/tests/test_nir.py +++ b/tests/test_nir.py @@ -7,12 +7,43 @@ import snntorch as snn import torch - +# sample data for snntorch_sequential @pytest.fixture(scope="module") def sample_data(): return torch.ones((4, 784)) +# sample data for snntorch with conv2d_avgpool +@pytest.fixture(scope="module") +def sample_data2(): + return torch.randn(1, 1, 28, 28) + + +class NetWithAvgPool(torch.nn.Module): + def __init__(self): + super(NetWithAvgPool, self).__init__() + self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) + self.lif1 = snn.Leaky(beta=0.9, init_hidden=True) + self.fc1 = torch.nn.Linear(28 * 28 * 16 // 4, 500) + self.lif2 = snn.Leaky(beta=0.9, init_hidden=True, output=True) + + def forward(self, x): + x = torch.nn.functional.avg_pool2d( + self.conv1(x), kernel_size=2, stride=2 + ) + x = x.view(-1, 28 * 28 * 16 // 4) + x = self.lif1(x) + x = self.fc1(x) + x = self.lif2(x) + return x + + +@pytest.fixture(scope="module") +def net_with_avg_pool(): + net = NetWithAvgPool() + return net + + @pytest.fixture(scope="module") def snntorch_sequential(): lif1 = snn.Leaky(beta=0.9, init_hidden=True) @@ -67,6 +98,30 @@ def test_export_sequential(self, snntorch_sequential, sample_data): assert isinstance(nir_graph.nodes["2"], nir.Affine) assert isinstance(nir_graph.nodes["3"], nir.LIF) + def test_export_NetWithAvgPool(self, net_with_avg_pool, sample_data2): + nir_graph = snn.export_to_nir(net_with_avg_pool, sample_data2) + assert nir_graph is not None + # dict_keys(['conv1', 'fc1', 'input', 'lif1', 'lif2', 'output']) + assert set(nir_graph.nodes.keys()) == set( + ["input", "output"] + + ["conv1", "fc1", "input", "lif1", "lif2", "output"] + ), nir_graph.nodes.keys() + assert set(nir_graph.edges) == set( + [ + ("lif2", "output"), + ("lif1", "fc1"), + ("fc1", "lif2"), + ("input", "conv1"), + ("conv1", "output"), + ] + ) + assert isinstance(nir_graph.nodes["input"], nir.Input) + assert isinstance(nir_graph.nodes["output"], nir.Output) + assert isinstance(nir_graph.nodes["conv1"], nir.Conv2d) + assert isinstance(nir_graph.nodes["lif1"], nir.LIF) + assert isinstance(nir_graph.nodes["fc1"], nir.Affine) + assert isinstance(nir_graph.nodes["lif2"], nir.LIF) + def test_export_recurrent(self, snntorch_recurrent, sample_data): nir_graph = snn.export_to_nir(snntorch_recurrent, sample_data) assert nir_graph is not None @@ -98,6 +153,13 @@ def test_import_nir(self): out, _ = net(torch.ones(1, 1)) assert out.shape == (1, 1), out.shape + def test_import_conv_nir(self): + graph = nir.read("examples/testconv2d+avgpool.nir") + net = snn.import_from_nir(graph) + assert net is not None + out, _ = net(torch.randn(1, 1, 1, 1)) + assert out.shape == (1, 16, 1, 1), out.shape + def test_commute_sequential(self, snntorch_sequential, sample_data): x = torch.rand((4, 784)) y_snn, state = snntorch_sequential(x) diff --git a/tests/test_snntorch/functional/test_loss.py b/tests/test_snntorch/functional/test_loss.py index c80b93b3..f6c3c2dc 100644 --- a/tests/test_snntorch/functional/test_loss.py +++ b/tests/test_snntorch/functional/test_loss.py @@ -42,201 +42,261 @@ def test_ce_rate_loss_base(self, spike_predicted_, targets_labels_): loss_fn = sf.ce_rate_loss() assert loss_fn.weight is None - assert loss_fn.reduction == 'mean' + assert loss_fn.reduction == "mean" def test_ce_rate_loss_unreduced(self, spike_predicted_, targets_labels_): - unreduced_loss_fn = sf.ce_rate_loss(reduction='none') + unreduced_loss_fn = sf.ce_rate_loss(reduction="none") unreduced_loss = unreduced_loss_fn(spike_predicted_, targets_labels_) reduced_loss_fn = sf.ce_rate_loss() reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_) - assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item()) + assert_approximate_equality( + unreduced_loss.mean().item(), reduced_loss.item() + ) - def test_ce_rate_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_): + def test_ce_rate_loss_weighted( + self, spike_predicted_, targets_labels_, class_weights_ + ): weighted_loss_fn = sf.ce_rate_loss(weight=class_weights_) weighted_loss = weighted_loss_fn(spike_predicted_, targets_labels_) # unreduced, unweighted loss - vanilla_loss_fn = sf.ce_rate_loss(reduction='none') + vanilla_loss_fn = sf.ce_rate_loss(reduction="none") vanilla_loss = vanilla_loss_fn(spike_predicted_, targets_labels_) # weight multiplier weight_multiplier = class_weights_[targets_labels_] # expectation - expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean()) + expected_weighted_loss = (vanilla_loss * weight_multiplier).mean() - assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item()) + assert_approximate_equality( + weighted_loss.item(), expected_weighted_loss.item() + ) def test_ce_count_loss_base(self, spike_predicted_, targets_labels_): loss_fn = sf.ce_count_loss() assert loss_fn.weight is None - assert loss_fn.reduction == 'mean' + assert loss_fn.reduction == "mean" def test_ce_count_loss_unreduced(self, spike_predicted_, targets_labels_): - unreduced_loss_fn = sf.ce_count_loss(reduction='none') + unreduced_loss_fn = sf.ce_count_loss(reduction="none") unreduced_loss = unreduced_loss_fn(spike_predicted_, targets_labels_) reduced_loss_fn = sf.ce_count_loss() reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_) - assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item()) + assert_approximate_equality( + unreduced_loss.mean().item(), reduced_loss.item() + ) - def test_ce_count_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_): + def test_ce_count_loss_weighted( + self, spike_predicted_, targets_labels_, class_weights_ + ): weighted_loss_fn = sf.ce_count_loss(weight=class_weights_) weighted_loss = weighted_loss_fn(spike_predicted_, targets_labels_) # unreduced, unweighted loss - vanilla_loss_fn = sf.ce_count_loss(reduction='none') + vanilla_loss_fn = sf.ce_count_loss(reduction="none") vanilla_loss = vanilla_loss_fn(spike_predicted_, targets_labels_) # weight multiplier weight_multiplier = class_weights_[targets_labels_] # expectation - expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean()) + expected_weighted_loss = (vanilla_loss * weight_multiplier).mean() - assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item()) + assert_approximate_equality( + weighted_loss.item(), expected_weighted_loss.item() + ) - def test_ce_max_membrane_loss_base(self, membrane_predicted_, targets_labels_): + def test_ce_max_membrane_loss_base( + self, membrane_predicted_, targets_labels_ + ): loss_fn = sf.ce_max_membrane_loss() assert loss_fn.weight is None - assert loss_fn.reduction == 'mean' + assert loss_fn.reduction == "mean" - def test_ce_max_membrane_loss_unreduced(self, membrane_predicted_, targets_labels_): - unreduced_loss_fn = sf.ce_max_membrane_loss(reduction='none') - unreduced_loss = unreduced_loss_fn(membrane_predicted_, targets_labels_) + def test_ce_max_membrane_loss_unreduced( + self, membrane_predicted_, targets_labels_ + ): + unreduced_loss_fn = sf.ce_max_membrane_loss(reduction="none") + unreduced_loss = unreduced_loss_fn( + membrane_predicted_, targets_labels_ + ) reduced_loss_fn = sf.ce_max_membrane_loss() reduced_loss = reduced_loss_fn(membrane_predicted_, targets_labels_) - assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item()) + assert_approximate_equality( + unreduced_loss.mean().item(), reduced_loss.item() + ) - def test_ce_max_membrane_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_): + def test_ce_max_membrane_loss_weighted( + self, spike_predicted_, targets_labels_, class_weights_ + ): weighted_loss_fn = sf.ce_max_membrane_loss(weight=class_weights_) weighted_loss = weighted_loss_fn(spike_predicted_, targets_labels_) # unreduced, unweighted loss - vanilla_loss_fn = sf.ce_max_membrane_loss(reduction='none') + vanilla_loss_fn = sf.ce_max_membrane_loss(reduction="none") vanilla_loss = vanilla_loss_fn(spike_predicted_, targets_labels_) # weight multiplier weight_multiplier = class_weights_[targets_labels_] # expectation - expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean()) + expected_weighted_loss = (vanilla_loss * weight_multiplier).mean() - assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item()) + assert_approximate_equality( + weighted_loss.item(), expected_weighted_loss.item() + ) def test_mse_count_loss_base(self, spike_predicted_, targets_labels_): loss_fn = sf.mse_count_loss() assert loss_fn.weight is None - assert loss_fn.reduction == 'mean' + assert loss_fn.reduction == "mean" def test_mse_count_loss_unreduced(self, spike_predicted_, targets_labels_): - unreduced_loss_fn = sf.mse_count_loss(reduction='none') + unreduced_loss_fn = sf.mse_count_loss(reduction="none") unreduced_loss = unreduced_loss_fn(spike_predicted_, targets_labels_) reduced_loss_fn = sf.mse_count_loss() reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_) - assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item()) + assert_approximate_equality( + unreduced_loss.mean().item(), reduced_loss.item() + ) - def test_mse_count_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_): + def test_mse_count_loss_weighted( + self, spike_predicted_, targets_labels_, class_weights_ + ): weighted_loss_fn = sf.mse_count_loss(weight=class_weights_) weighted_loss = weighted_loss_fn(spike_predicted_, targets_labels_) # unreduced, unweighted loss - vanilla_loss_fn = sf.mse_count_loss(reduction='none') + vanilla_loss_fn = sf.mse_count_loss(reduction="none") vanilla_loss = vanilla_loss_fn(spike_predicted_, targets_labels_) # weight multiplier weight_multiplier = class_weights_[targets_labels_] # expectation - expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean()) + expected_weighted_loss = (vanilla_loss * weight_multiplier).mean() - assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item()) + assert_approximate_equality( + weighted_loss.item(), expected_weighted_loss.item() + ) - def test_mse_membrane_loss_base(self, membrane_predicted_, targets_labels_): + def test_mse_membrane_loss_base( + self, membrane_predicted_, targets_labels_ + ): loss_fn = sf.mse_membrane_loss() assert loss_fn.weight is None - assert loss_fn.reduction == 'mean' + assert loss_fn.reduction == "mean" - def test_mse_membrane_loss_unreduced(self, membrane_predicted_, targets_labels_): - unreduced_loss_fn = sf.mse_membrane_loss(reduction='none') - unreduced_loss = unreduced_loss_fn(membrane_predicted_, targets_labels_) + def test_mse_membrane_loss_unreduced( + self, membrane_predicted_, targets_labels_ + ): + unreduced_loss_fn = sf.mse_membrane_loss(reduction="none") + unreduced_loss = unreduced_loss_fn( + membrane_predicted_, targets_labels_ + ) reduced_loss_fn = sf.mse_membrane_loss() reduced_loss = reduced_loss_fn(membrane_predicted_, targets_labels_) - assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item()) + assert_approximate_equality( + unreduced_loss.mean().item(), reduced_loss.item() + ) - def test_mse_membrane_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_): + def test_mse_membrane_loss_weighted( + self, spike_predicted_, targets_labels_, class_weights_ + ): weighted_loss_fn = sf.mse_membrane_loss(weight=class_weights_) weighted_loss = weighted_loss_fn(spike_predicted_, targets_labels_) # unreduced, unweighted loss - vanilla_loss_fn = sf.mse_membrane_loss(reduction='none') + vanilla_loss_fn = sf.mse_membrane_loss(reduction="none") vanilla_loss = vanilla_loss_fn(spike_predicted_, targets_labels_) # weight multiplier weight_multiplier = class_weights_[targets_labels_] # expectation - expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean()) + expected_weighted_loss = (vanilla_loss * weight_multiplier).mean() - assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item()) + assert_approximate_equality( + weighted_loss.item(), expected_weighted_loss.item() + ) def test_mse_temporal_loss_base(self, spike_predicted_, targets_labels_): loss_fn = sf.mse_temporal_loss(on_target=1, off_target=0) assert loss_fn.weight is None - assert loss_fn.reduction == 'mean' + assert loss_fn.reduction == "mean" - def test_mse_temporal_loss_unreduced(self, spike_predicted_, targets_labels_): - unreduced_loss_fn = sf.mse_temporal_loss(reduction='none') + def test_mse_temporal_loss_unreduced( + self, spike_predicted_, targets_labels_ + ): + unreduced_loss_fn = sf.mse_temporal_loss(reduction="none") unreduced_loss = unreduced_loss_fn(spike_predicted_, targets_labels_) reduced_loss_fn = sf.mse_temporal_loss() reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_) - assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item()) + assert_approximate_equality( + unreduced_loss.mean().item(), reduced_loss.item() + ) - def test_mse_temporal_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_): + def test_mse_temporal_loss_weighted( + self, spike_predicted_, targets_labels_, class_weights_ + ): weighted_loss_fn = sf.mse_temporal_loss(weight=class_weights_) weighted_loss = weighted_loss_fn(spike_predicted_, targets_labels_) # unreduced, unweighted loss - vanilla_loss_fn = sf.mse_temporal_loss(reduction='none') + vanilla_loss_fn = sf.mse_temporal_loss(reduction="none") vanilla_loss = vanilla_loss_fn(spike_predicted_, targets_labels_) # weight multiplier weight_multiplier = class_weights_[targets_labels_] # expectation - expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean()) + expected_weighted_loss = (vanilla_loss * weight_multiplier).mean() - assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item()) + assert_approximate_equality( + weighted_loss.item(), expected_weighted_loss.item() + ) def test_ce_temporal_loss_base(self, spike_predicted_, targets_labels_): loss_fn = sf.ce_temporal_loss() assert loss_fn.weight is None - assert loss_fn.reduction == 'mean' + assert loss_fn.reduction == "mean" - def test_ce_temporal_loss_unreduced(self, spike_predicted_, targets_labels_): - unreduced_loss_fn = sf.ce_temporal_loss(reduction='none') + def test_ce_temporal_loss_unreduced( + self, spike_predicted_, targets_labels_ + ): + unreduced_loss_fn = sf.ce_temporal_loss(reduction="none") unreduced_loss = unreduced_loss_fn(spike_predicted_, targets_labels_) reduced_loss_fn = sf.ce_temporal_loss() reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_) - assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item()) + assert_approximate_equality( + unreduced_loss.mean().item(), reduced_loss.item() + ) - def test_ce_temporal_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_): + def test_ce_temporal_loss_weighted( + self, spike_predicted_, targets_labels_, class_weights_ + ): weighted_loss_fn = sf.ce_temporal_loss(weight=class_weights_) weighted_loss = weighted_loss_fn(spike_predicted_, targets_labels_) # unreduced, unweighted loss - vanilla_loss_fn = sf.ce_temporal_loss(reduction='none') + vanilla_loss_fn = sf.ce_temporal_loss(reduction="none") vanilla_loss = vanilla_loss_fn(spike_predicted_, targets_labels_) # weight multiplier weight_multiplier = class_weights_[targets_labels_] # expectation - expected_weighted_loss = ((vanilla_loss * weight_multiplier).sum() / weight_multiplier.sum()) + expected_weighted_loss = ( + vanilla_loss * weight_multiplier + ).sum() / weight_multiplier.sum() - assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item()) + assert_approximate_equality( + weighted_loss.item(), expected_weighted_loss.item() + ) diff --git a/tests/test_snntorch/test_alpha.py b/tests/test_snntorch/test_alpha.py index 29e59935..cc6f3dde 100644 --- a/tests/test_snntorch/test_alpha.py +++ b/tests/test_snntorch/test_alpha.py @@ -17,9 +17,12 @@ def input_(): def alpha_instance(): return snn.Alpha(alpha=0.6, beta=0.5, reset_mechanism="subtract") + @pytest.fixture(scope="module") def alpha_instance_surrogate(): - return snn.Alpha(alpha=0.6, beta=0.5, reset_mechanism="subtract", surrogate_disable=True) + return snn.Alpha( + alpha=0.6, beta=0.5, reset_mechanism="subtract", surrogate_disable=True + ) @pytest.fixture(scope="module") @@ -142,8 +145,7 @@ def test_alpha_cases(self, alpha_hidden_instance, input_): with pytest.raises(TypeError): alpha_hidden_instance(input_, input_) - def test_alpha_compile_fullgraph(self, alpha_instance_surrogate, input_): explanation = dynamo.explain(alpha_instance_surrogate)(input_[0]) - assert explanation.graph_break_count == 0 \ No newline at end of file + assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_bntt.py b/tests/test_snntorch/test_bntt.py index bc839541..97d22b92 100644 --- a/tests/test_snntorch/test_bntt.py +++ b/tests/test_snntorch/test_bntt.py @@ -31,12 +31,10 @@ def batchnormtt1d_instance(): class TestBatchNormTT1d: - @pytest.mark.parametrize("time_steps, num_features", ([1, 1], [3, 2], [6, 3])) - def test_batchnormtt1d_init( - self, - time_steps, - num_features - ): + @pytest.mark.parametrize( + "time_steps, num_features", ([1, 1], [3, 2], [6, 3]) + ) + def test_batchnormtt1d_init(self, time_steps, num_features): batchnormtt1d_instance = snn.BatchNormTT1d(num_features, time_steps) assert len(batchnormtt1d_instance) == time_steps @@ -48,9 +46,7 @@ def test_batchnormtt1d_init( assert module.bias is None def test_batchnormtt1d_with_2d_input( - self, - batchnormtt1d_instance, - input2d_ + self, batchnormtt1d_instance, input2d_ ): for step, batchnormtt1d_module in enumerate(batchnormtt1d_instance): out = batchnormtt1d_module(input2d_[step]) @@ -58,9 +54,7 @@ def test_batchnormtt1d_with_2d_input( assert out.shape == input2d_[step].shape def test_batchnormtt1d_with_3d_input( - self, - batchnormtt1d_instance, - input3d_ + self, batchnormtt1d_instance, input3d_ ): for step, batchnormtt1d_module in enumerate(batchnormtt1d_instance): out = batchnormtt1d_module(input3d_[step]) @@ -74,12 +68,10 @@ def batchnormtt2d_instance(): class TestBatchNormTT2d: - @pytest.mark.parametrize("time_steps, num_features", ([1, 1], [3, 2], [6, 3])) - def test_batchnormtt2d_init( - self, - time_steps, - num_features - ): + @pytest.mark.parametrize( + "time_steps, num_features", ([1, 1], [3, 2], [6, 3]) + ) + def test_batchnormtt2d_init(self, time_steps, num_features): batchnormtt2d_instance = snn.BatchNormTT2d(num_features, time_steps) assert len(batchnormtt2d_instance) == time_steps @@ -91,9 +83,7 @@ def test_batchnormtt2d_init( assert module.bias is None def test_batchnormtt2d_with_4d_input( - self, - batchnormtt2d_instance, - input4d_ + self, batchnormtt2d_instance, input4d_ ): for step, batchnormtt2d_module in enumerate(batchnormtt2d_instance): out = batchnormtt2d_module(input4d_[step]) diff --git a/tests/test_snntorch/test_lapicque.py b/tests/test_snntorch/test_lapicque.py index f76dd232..a72e49dc 100644 --- a/tests/test_snntorch/test_lapicque.py +++ b/tests/test_snntorch/test_lapicque.py @@ -135,7 +135,9 @@ def test_lapicque_cases(self, lapicque_hidden_instance, input_): with pytest.raises(TypeError): lapicque_hidden_instance(input_, input_) - def test_lapicque_compile_fullgraph(self, lapicque_instance_surrogate, input_): + def test_lapicque_compile_fullgraph( + self, lapicque_instance_surrogate, input_ + ): explanation = dynamo.explain(lapicque_instance_surrogate)(input_[0]) assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_rleaky.py b/tests/test_snntorch/test_rleaky.py index 617488d4..de315e94 100644 --- a/tests/test_snntorch/test_rleaky.py +++ b/tests/test_snntorch/test_rleaky.py @@ -142,9 +142,7 @@ def test_lreaky_cases(self, rleaky_hidden_instance, input_): with pytest.raises(TypeError): rleaky_hidden_instance(input_, input_, input_) - def test_rleaky_compile_fullgraph( - self, rleaky_instance_surrogate, input_ - ): + def test_rleaky_compile_fullgraph(self, rleaky_instance_surrogate, input_): explanation = dynamo.explain(rleaky_instance_surrogate)(input_[0]) assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_slstm.py b/tests/test_snntorch/test_slstm.py index 8aa46b79..fd1ab5fd 100644 --- a/tests/test_snntorch/test_slstm.py +++ b/tests/test_snntorch/test_slstm.py @@ -127,9 +127,7 @@ def test_slstm_init_hidden_reset_subtract( assert spk_rec[0].size() == (1, 2) - def test_slstm_compile_fullgraph( - self, slstm_instance_surrogate, input_ - ): + def test_slstm_compile_fullgraph(self, slstm_instance_surrogate, input_): explanation = dynamo.explain(slstm_instance_surrogate)(input_[0]) assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_synaptic.py b/tests/test_snntorch/test_synaptic.py index 262f5ca8..8f6a8cbb 100644 --- a/tests/test_snntorch/test_synaptic.py +++ b/tests/test_snntorch/test_synaptic.py @@ -17,6 +17,7 @@ def input_(): def synaptic_instance(): return snn.Synaptic(alpha=0.5, beta=0.5) + @pytest.fixture(scope="module") def synaptic_instance_surrogate(): return snn.Synaptic(alpha=0.5, beta=0.5, surrogate_disable=True) @@ -129,7 +130,9 @@ def test_synaptic_cases(self, synaptic_hidden_instance, input_): with pytest.raises(TypeError): synaptic_hidden_instance(input_, input_) - def test_synaptic_compile_fullgraph(self, synaptic_instance_surrogate, input_): + def test_synaptic_compile_fullgraph( + self, synaptic_instance_surrogate, input_ + ): explanation = dynamo.explain(synaptic_instance_surrogate)(input_[0]) - assert explanation.graph_break_count == 0 \ No newline at end of file + assert explanation.graph_break_count == 0 diff --git a/tests/testconv2d+avgpool.nir b/tests/testconv2d+avgpool.nir new file mode 100644 index 00000000..3dbca46a Binary files /dev/null and b/tests/testconv2d+avgpool.nir differ