Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/djlouie/snntorch
Browse files Browse the repository at this point in the history
  • Loading branch information
djlouie committed Apr 11, 2024
2 parents 292f688 + bdc1b49 commit 764b7e5
Show file tree
Hide file tree
Showing 29 changed files with 748 additions and 962 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-tag.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
python-version: ['3.8', '3.9', '3.10', '3.11']

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion _version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# fmt: off
__version__ = '0.7.0'
__version__ = '0.8.1'
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# fmt: off
__version__ = '0.7.0'
__version__ = '0.8.1'
# fmt: on


Expand Down
24 changes: 12 additions & 12 deletions docs/tutorials/tutorial_2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ The neuron model is now stored in ``lif1``. To use this neuron:

**Inputs**

* ``spk_in``: each element of :math:`I_{\rm in}` is sequentially passed as an input (0 for now)
* ``cur_in``: each element of :math:`I_{\rm in}` is sequentially passed as an input (0 for now)
* ``mem``: the membrane potential, previously :math:`U[t]`, is also passed as input. Initialize it arbitrarily as :math:`U[0] = 0.9~V`.

**Outputs**
Expand All @@ -321,7 +321,7 @@ These all need to be of type ``torch.Tensor``.

# Initialize membrane, input, and output
mem = torch.ones(1) * 0.9 # U=0.9 at t=0
cur_in = torch.zeros(num_steps) # I=0 for all t
cur_in = torch.zeros(num_steps, 1) # I=0 for all t
spk_out = torch.zeros(1) # initialize output spikes

These values are only for the initial time step :math:`t=0`.
Expand Down Expand Up @@ -382,7 +382,7 @@ Let’s visualize what this looks like by triggering a current pulse of
::

# Initialize input current pulse
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.1), 0) # input current turns on at t=10
cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.1), 0) # input current turns on at t=10
# Initialize membrane, output and recordings
mem = torch.zeros(1) # membrane potential of 0 at t=0
Expand Down Expand Up @@ -430,7 +430,7 @@ Now what if the step input was clipped at :math:`t=30ms`?
::

# Initialize current pulse, membrane and outputs
cur_in1 = torch.cat((torch.zeros(10), torch.ones(20)*(0.1), torch.zeros(170)), 0) # input turns on at t=10, off at t=30
cur_in1 = torch.cat((torch.zeros(10, 1), torch.ones(20, 1)*(0.1), torch.zeros(170, 1)), 0) # input turns on at t=10, off at t=30
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec1 = [mem]
Expand Down Expand Up @@ -462,7 +462,7 @@ time window must be decreased.
::

# Increase amplitude of current pulse; half the time.
cur_in2 = torch.cat((torch.zeros(10), torch.ones(10)*0.111, torch.zeros(180)), 0) # input turns on at t=10, off at t=20
cur_in2 = torch.cat((torch.zeros(10, 1), torch.ones(10, 1)*0.111, torch.zeros(180, 1)), 0) # input turns on at t=10, off at t=20
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec2 = [mem]
Expand All @@ -487,7 +487,7 @@ amplitude:
::

# Increase amplitude of current pulse; quarter the time.
cur_in3 = torch.cat((torch.zeros(10), torch.ones(5)*0.147, torch.zeros(185)), 0) # input turns on at t=10, off at t=15
cur_in3 = torch.cat((torch.zeros(10, 1), torch.ones(5, 1)*0.147, torch.zeros(185, 1)), 0) # input turns on at t=10, off at t=15
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec3 = [mem]
Expand Down Expand Up @@ -526,7 +526,7 @@ membrane potential will jump straight up in virtually zero rise time:
::

# Current spike input
cur_in4 = torch.cat((torch.zeros(10), torch.ones(1)*0.5, torch.zeros(189)), 0) # input only on for 1 time step
cur_in4 = torch.cat((torch.zeros(10, 1), torch.ones(1, 1)*0.5, torch.zeros(189, 1)), 0) # input only on for 1 time step
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec4 = [mem]
Expand Down Expand Up @@ -685,7 +685,7 @@ As before, all of that code is condensed by calling the built-in Lapicque neuron
::

# Initialize inputs and outputs
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.2), 0)
cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.2), 0)
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec = [mem]
Expand Down Expand Up @@ -732,7 +732,7 @@ approaches the threshold :math:`U_{\rm thr}` faster:
::

# Initialize inputs and outputs
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) # increased current
cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) # increased current
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec = [mem]
Expand Down Expand Up @@ -766,7 +766,7 @@ rest of the code block is the exact same as above:
lif3 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5)
# Initialize inputs and outputs
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0)
cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0)
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec = [mem]
Expand Down Expand Up @@ -806,7 +806,7 @@ generated input spikes.
::

# Create a 1-D random spike train. Each element has a probability of 40% of firing.
spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)
spk_in = spikegen.rate_conv(torch.ones((num_steps,1)) * 0.40)

Run the following code block to see how many spikes have been generated.

Expand Down Expand Up @@ -889,7 +889,7 @@ This can be explicitly overridden by passing the argument
lif4 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5, reset_mechanism="zero")
# Initialize inputs and outputs
spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)
spk_in = spikegen.rate_conv(torch.ones((num_steps, 1)) * 0.40)
mem = torch.ones(1)*0.5
spk_out = torch.zeros(1)
mem_rec0 = [mem]
Expand Down
24 changes: 12 additions & 12 deletions docs/tutorials/tutorial_exoplanet_hunter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,20 @@ Before diving into the code, let's gain an understanding of what Exoplanet Detec
The transit method is a widely used and successful technique for
detecting exoplanets. When an exoplanet transits its host star, it
causes a temporary reduction in the star's light flux (brightness).
Compared to other techniques, the transmit method has has discovered
Compared to other techniques, the transit method has discovered
the largest number of planets.

Astronomers use telescopes equipped with photometers or
spectrophotometers to continuously monitor the brightness of a star over
time. Repeated observations of multiple transits allows astronomers to
time. Repeated observations of multiple transits allow astronomers to
gather more detailed information about the exoplanet, such as its
atmosphere and the presence of moons.

Space telescopes like NASA's Kepler and TESS (Transiting Exoplanet
Survey Satellite) have been instrumental in discovering thousands of
exoplanets using the transit method. Without the Earth's atmosphere in the way,
there is less interference and more precise measurements are possible.
The transit method continues to be a key tool in advancing our understanding of
exoplanets using the transit method. Without the Earth's atmosphere to hinder observations,
there is minimal interference, allowing for more precise measurements.
The transit method remains a key tool in furthering our comprehension of
exoplanetary systems. For more information about transit method, you can
visit `NASA Exoplanet Exploration
Page <https://exoplanets.nasa.gov/alien-worlds/ways-to-find-a-planet/#/2>`__.
Expand All @@ -107,7 +107,7 @@ Page <https://exoplanets.nasa.gov/alien-worlds/ways-to-find-a-planet/#/2>`__.
The drawback of this method is that the angle between the planet's
orbital plane and the direction of the observer's line of sight must be
sufficiently small. Therefore, the chance of this phenomenon occurring is not
high. Thus more time and resources must be spent to detect and confirm
high. Thus, more time and resources must be allocated to detect and confirm
the existence of an exoplanet. These resources include the Kepler
telescope and ESA's CoRoT when they were still operational.

Expand Down Expand Up @@ -202,8 +202,8 @@ datasets <https://pytorch.org/tutorials/beginner/data_loading_tutorial.html>`__
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Given the low chance of detecting exoplanets, this dataset is very imbalanced.
Most samples are negative, i.e., there are very few exoplanets from the observed
light intensity data. If your model was to simply predict 'no exoplanet' for every sample,
Most samples are negative, meaning there are very few exoplanets from the observed
light intensity data. If your model were to simply predict 'no exoplanet' for every sample,
then it would achieve very high accuracy. This indicates that accuracy is a poor metric for success.

Let's first probe our data to gain insight into how imbalanced it is.
Expand Down Expand Up @@ -245,7 +245,7 @@ To deal with the imbalance of our dataset, let's Synthetic Minority
Over-Sampling Technique (SMOTE). SMOTE works by
generating synthetic samples from the minority class to balance the
distribution (typically implemented using the nearest neighbors
strategy). By implementing SMOTE, we attempt to reduce bias towards
strategy). By implementing SMOTE, we attempt to reduce bias toward
stars without exoplanets (the majority class).

.. code:: python
Expand Down Expand Up @@ -368,9 +368,9 @@ After loading the data, let's see what our data looks like.
The code block below follows the same syntax as with the `official
snnTorch
tutorial <https://snntorch.readthedocs.io/en/latest/tutorials/index.html>`__.
In contrast to other tutorials however, this model passes data across the entire sequence in parallel.
In that sense, it is more akin to how attention-based mechanisms take data.
Turning this into a more 'online' method would likely involve pre-processing to downsample the exceedingly long sequence length.
In contrast to other tutorials, however, this model concurrently processes data across the entire sequence.
In that sense, it is more akin to how attention-based mechanisms handle data.
Turning this into a more 'online' method would likely involve preprocessing to downsample the exceedingly long sequence length.

.. code:: python
Expand Down
File renamed without changes.
File renamed without changes.
24 changes: 12 additions & 12 deletions examples/tutorial_2_lif_neuron.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@
"source": [
"# Initialize membrane, input, and output\n",
"mem = torch.ones(1) * 0.9 # U=0.9 at t=0\n",
"cur_in = torch.zeros(num_steps) # I=0 for all t \n",
"cur_in = torch.zeros(num_steps, 1) # I=0 for all t \n",
"spk_out = torch.zeros(1) # initialize output spikes"
]
},
Expand Down Expand Up @@ -688,7 +688,7 @@
"outputs": [],
"source": [
"# Initialize input current pulse\n",
"cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.1), 0) # input current turns on at t=10\n",
"cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.1), 0) # input current turns on at t=10\n",
"\n",
"# Initialize membrane, output and recordings\n",
"mem = torch.zeros(1) # membrane potential of 0 at t=0\n",
Expand Down Expand Up @@ -776,7 +776,7 @@
"outputs": [],
"source": [
"# Initialize current pulse, membrane and outputs\n",
"cur_in1 = torch.cat((torch.zeros(10), torch.ones(20)*(0.1), torch.zeros(170)), 0) # input turns on at t=10, off at t=30\n",
"cur_in1 = torch.cat((torch.zeros(10, 1), torch.ones(20, 1)*(0.1), torch.zeros(170, 1)), 0) # input turns on at t=10, off at t=30\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1)\n",
"mem_rec1 = [mem]"
Expand Down Expand Up @@ -820,7 +820,7 @@
"outputs": [],
"source": [
"# Increase amplitude of current pulse; half the time.\n",
"cur_in2 = torch.cat((torch.zeros(10), torch.ones(10)*0.111, torch.zeros(180)), 0) # input turns on at t=10, off at t=20\n",
"cur_in2 = torch.cat((torch.zeros(10, 1), torch.ones(10, 1)*0.111, torch.zeros(180, 1)), 0) # input turns on at t=10, off at t=20\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1)\n",
"mem_rec2 = [mem]\n",
Expand Down Expand Up @@ -853,7 +853,7 @@
"outputs": [],
"source": [
"# Increase amplitude of current pulse; quarter the time.\n",
"cur_in3 = torch.cat((torch.zeros(10), torch.ones(5)*0.147, torch.zeros(185)), 0) # input turns on at t=10, off at t=15\n",
"cur_in3 = torch.cat((torch.zeros(10, 1), torch.ones(5, 1)*0.147, torch.zeros(185, 1)), 0) # input turns on at t=10, off at t=15\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1)\n",
"mem_rec3 = [mem]\n",
Expand Down Expand Up @@ -907,7 +907,7 @@
"outputs": [],
"source": [
"# Current spike input\n",
"cur_in4 = torch.cat((torch.zeros(10), torch.ones(1)*0.5, torch.zeros(189)), 0) # input only on for 1 time step\n",
"cur_in4 = torch.cat((torch.zeros(10, 1), torch.ones(1, 1)*0.5, torch.zeros(189, 1)), 0) # input only on for 1 time step\n",
"mem = torch.zeros(1) \n",
"spk_out = torch.zeros(1)\n",
"mem_rec4 = [mem]\n",
Expand Down Expand Up @@ -1120,7 +1120,7 @@
"outputs": [],
"source": [
"# Initialize inputs and outputs\n",
"cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.2), 0)\n",
"cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.2), 0)\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1) \n",
"mem_rec = [mem]\n",
Expand Down Expand Up @@ -1180,7 +1180,7 @@
"outputs": [],
"source": [
"# Initialize inputs and outputs\n",
"cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) # increased current\n",
"cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) # increased current\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1) \n",
"mem_rec = [mem]\n",
Expand Down Expand Up @@ -1222,7 +1222,7 @@
"lif3 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5)\n",
"\n",
"# Initialize inputs and outputs\n",
"cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) \n",
"cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) \n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1) \n",
"mem_rec = [mem]\n",
Expand Down Expand Up @@ -1278,7 +1278,7 @@
"outputs": [],
"source": [
"# Create a 1-D random spike train. Each element has a probability of 40% of firing.\n",
"spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)"
"spk_in = spikegen.rate_conv(torch.ones((num_steps,1)) * 0.40)"
]
},
{
Expand Down Expand Up @@ -1372,7 +1372,7 @@
"lif4 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5, reset_mechanism=\"zero\")\n",
"\n",
"# Initialize inputs and outputs\n",
"spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)\n",
"spk_in = spikegen.rate_conv(torch.ones((num_steps, 1)) * 0.40)\n",
"mem = torch.ones(1)*0.5\n",
"spk_out = torch.zeros(1)\n",
"mem_rec0 = [mem]\n",
Expand Down Expand Up @@ -1466,7 +1466,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.11.8"
},
"vscode": {
"interpreter": {
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.7.0
current_version = 0.8.1
commit = True
tag = True

Expand Down Expand Up @@ -31,3 +31,4 @@ test = pytest
[tool:pytest]
testpaths = tests
addopts = --ignore=setup.py

5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# history = history_file.read()

# fmt: off
__version__ = '0.7.0'
__version__ = '0.8.1'
# fmt: on

requirements = [
Expand All @@ -31,7 +31,7 @@
setup(
author="Jason K. Eshraghian",
author_email="[email protected]",
python_requires=">=3.7",
python_requires=">=3.8",
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
Expand All @@ -42,7 +42,6 @@
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
Expand Down
Loading

0 comments on commit 764b7e5

Please sign in to comment.