Skip to content

Commit

Permalink
Merge pull request #308 from Gabo-Tor/fix_tutorial_2
Browse files Browse the repository at this point in the history
Fix tutorial 2
  • Loading branch information
jeshraghian authored Apr 2, 2024
2 parents a1d97b7 + 25f4389 commit bdc1b49
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
24 changes: 12 additions & 12 deletions docs/tutorials/tutorial_2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ The neuron model is now stored in ``lif1``. To use this neuron:

**Inputs**

* ``spk_in``: each element of :math:`I_{\rm in}` is sequentially passed as an input (0 for now)
* ``cur_in``: each element of :math:`I_{\rm in}` is sequentially passed as an input (0 for now)
* ``mem``: the membrane potential, previously :math:`U[t]`, is also passed as input. Initialize it arbitrarily as :math:`U[0] = 0.9~V`.

**Outputs**
Expand All @@ -321,7 +321,7 @@ These all need to be of type ``torch.Tensor``.

# Initialize membrane, input, and output
mem = torch.ones(1) * 0.9 # U=0.9 at t=0
cur_in = torch.zeros(num_steps) # I=0 for all t
cur_in = torch.zeros(num_steps, 1) # I=0 for all t
spk_out = torch.zeros(1) # initialize output spikes

These values are only for the initial time step :math:`t=0`.
Expand Down Expand Up @@ -382,7 +382,7 @@ Let’s visualize what this looks like by triggering a current pulse of
::

# Initialize input current pulse
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.1), 0) # input current turns on at t=10
cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.1), 0) # input current turns on at t=10
# Initialize membrane, output and recordings
mem = torch.zeros(1) # membrane potential of 0 at t=0
Expand Down Expand Up @@ -430,7 +430,7 @@ Now what if the step input was clipped at :math:`t=30ms`?
::

# Initialize current pulse, membrane and outputs
cur_in1 = torch.cat((torch.zeros(10), torch.ones(20)*(0.1), torch.zeros(170)), 0) # input turns on at t=10, off at t=30
cur_in1 = torch.cat((torch.zeros(10, 1), torch.ones(20, 1)*(0.1), torch.zeros(170, 1)), 0) # input turns on at t=10, off at t=30
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec1 = [mem]
Expand Down Expand Up @@ -462,7 +462,7 @@ time window must be decreased.
::

# Increase amplitude of current pulse; half the time.
cur_in2 = torch.cat((torch.zeros(10), torch.ones(10)*0.111, torch.zeros(180)), 0) # input turns on at t=10, off at t=20
cur_in2 = torch.cat((torch.zeros(10, 1), torch.ones(10, 1)*0.111, torch.zeros(180, 1)), 0) # input turns on at t=10, off at t=20
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec2 = [mem]
Expand All @@ -487,7 +487,7 @@ amplitude:
::

# Increase amplitude of current pulse; quarter the time.
cur_in3 = torch.cat((torch.zeros(10), torch.ones(5)*0.147, torch.zeros(185)), 0) # input turns on at t=10, off at t=15
cur_in3 = torch.cat((torch.zeros(10, 1), torch.ones(5, 1)*0.147, torch.zeros(185, 1)), 0) # input turns on at t=10, off at t=15
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec3 = [mem]
Expand Down Expand Up @@ -526,7 +526,7 @@ membrane potential will jump straight up in virtually zero rise time:
::

# Current spike input
cur_in4 = torch.cat((torch.zeros(10), torch.ones(1)*0.5, torch.zeros(189)), 0) # input only on for 1 time step
cur_in4 = torch.cat((torch.zeros(10, 1), torch.ones(1, 1)*0.5, torch.zeros(189, 1)), 0) # input only on for 1 time step
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec4 = [mem]
Expand Down Expand Up @@ -685,7 +685,7 @@ As before, all of that code is condensed by calling the built-in Lapicque neuron
::

# Initialize inputs and outputs
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.2), 0)
cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.2), 0)
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec = [mem]
Expand Down Expand Up @@ -732,7 +732,7 @@ approaches the threshold :math:`U_{\rm thr}` faster:
::

# Initialize inputs and outputs
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) # increased current
cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) # increased current
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec = [mem]
Expand Down Expand Up @@ -766,7 +766,7 @@ rest of the code block is the exact same as above:
lif3 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5)
# Initialize inputs and outputs
cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0)
cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0)
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec = [mem]
Expand Down Expand Up @@ -806,7 +806,7 @@ generated input spikes.
::

# Create a 1-D random spike train. Each element has a probability of 40% of firing.
spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)
spk_in = spikegen.rate_conv(torch.ones((num_steps,1)) * 0.40)

Run the following code block to see how many spikes have been generated.

Expand Down Expand Up @@ -889,7 +889,7 @@ This can be explicitly overridden by passing the argument
lif4 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5, reset_mechanism="zero")
# Initialize inputs and outputs
spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)
spk_in = spikegen.rate_conv(torch.ones((num_steps, 1)) * 0.40)
mem = torch.ones(1)*0.5
spk_out = torch.zeros(1)
mem_rec0 = [mem]
Expand Down
24 changes: 12 additions & 12 deletions examples/tutorial_2_lif_neuron.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@
"source": [
"# Initialize membrane, input, and output\n",
"mem = torch.ones(1) * 0.9 # U=0.9 at t=0\n",
"cur_in = torch.zeros(num_steps) # I=0 for all t \n",
"cur_in = torch.zeros(num_steps, 1) # I=0 for all t \n",
"spk_out = torch.zeros(1) # initialize output spikes"
]
},
Expand Down Expand Up @@ -688,7 +688,7 @@
"outputs": [],
"source": [
"# Initialize input current pulse\n",
"cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.1), 0) # input current turns on at t=10\n",
"cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.1), 0) # input current turns on at t=10\n",
"\n",
"# Initialize membrane, output and recordings\n",
"mem = torch.zeros(1) # membrane potential of 0 at t=0\n",
Expand Down Expand Up @@ -776,7 +776,7 @@
"outputs": [],
"source": [
"# Initialize current pulse, membrane and outputs\n",
"cur_in1 = torch.cat((torch.zeros(10), torch.ones(20)*(0.1), torch.zeros(170)), 0) # input turns on at t=10, off at t=30\n",
"cur_in1 = torch.cat((torch.zeros(10, 1), torch.ones(20, 1)*(0.1), torch.zeros(170, 1)), 0) # input turns on at t=10, off at t=30\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1)\n",
"mem_rec1 = [mem]"
Expand Down Expand Up @@ -820,7 +820,7 @@
"outputs": [],
"source": [
"# Increase amplitude of current pulse; half the time.\n",
"cur_in2 = torch.cat((torch.zeros(10), torch.ones(10)*0.111, torch.zeros(180)), 0) # input turns on at t=10, off at t=20\n",
"cur_in2 = torch.cat((torch.zeros(10, 1), torch.ones(10, 1)*0.111, torch.zeros(180, 1)), 0) # input turns on at t=10, off at t=20\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1)\n",
"mem_rec2 = [mem]\n",
Expand Down Expand Up @@ -853,7 +853,7 @@
"outputs": [],
"source": [
"# Increase amplitude of current pulse; quarter the time.\n",
"cur_in3 = torch.cat((torch.zeros(10), torch.ones(5)*0.147, torch.zeros(185)), 0) # input turns on at t=10, off at t=15\n",
"cur_in3 = torch.cat((torch.zeros(10, 1), torch.ones(5, 1)*0.147, torch.zeros(185, 1)), 0) # input turns on at t=10, off at t=15\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1)\n",
"mem_rec3 = [mem]\n",
Expand Down Expand Up @@ -907,7 +907,7 @@
"outputs": [],
"source": [
"# Current spike input\n",
"cur_in4 = torch.cat((torch.zeros(10), torch.ones(1)*0.5, torch.zeros(189)), 0) # input only on for 1 time step\n",
"cur_in4 = torch.cat((torch.zeros(10, 1), torch.ones(1, 1)*0.5, torch.zeros(189, 1)), 0) # input only on for 1 time step\n",
"mem = torch.zeros(1) \n",
"spk_out = torch.zeros(1)\n",
"mem_rec4 = [mem]\n",
Expand Down Expand Up @@ -1120,7 +1120,7 @@
"outputs": [],
"source": [
"# Initialize inputs and outputs\n",
"cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.2), 0)\n",
"cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.2), 0)\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1) \n",
"mem_rec = [mem]\n",
Expand Down Expand Up @@ -1180,7 +1180,7 @@
"outputs": [],
"source": [
"# Initialize inputs and outputs\n",
"cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) # increased current\n",
"cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) # increased current\n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1) \n",
"mem_rec = [mem]\n",
Expand Down Expand Up @@ -1222,7 +1222,7 @@
"lif3 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5)\n",
"\n",
"# Initialize inputs and outputs\n",
"cur_in = torch.cat((torch.zeros(10), torch.ones(190)*0.3), 0) \n",
"cur_in = torch.cat((torch.zeros(10, 1), torch.ones(190, 1)*0.3), 0) \n",
"mem = torch.zeros(1)\n",
"spk_out = torch.zeros(1) \n",
"mem_rec = [mem]\n",
Expand Down Expand Up @@ -1278,7 +1278,7 @@
"outputs": [],
"source": [
"# Create a 1-D random spike train. Each element has a probability of 40% of firing.\n",
"spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)"
"spk_in = spikegen.rate_conv(torch.ones((num_steps,1)) * 0.40)"
]
},
{
Expand Down Expand Up @@ -1372,7 +1372,7 @@
"lif4 = snn.Lapicque(R=5.1, C=5e-3, time_step=1e-3, threshold=0.5, reset_mechanism=\"zero\")\n",
"\n",
"# Initialize inputs and outputs\n",
"spk_in = spikegen.rate_conv(torch.ones((num_steps)) * 0.40)\n",
"spk_in = spikegen.rate_conv(torch.ones((num_steps, 1)) * 0.40)\n",
"mem = torch.ones(1)*0.5\n",
"spk_out = torch.zeros(1)\n",
"mem_rec0 = [mem]\n",
Expand Down Expand Up @@ -1466,7 +1466,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.11.8"
},
"vscode": {
"interpreter": {
Expand Down

0 comments on commit bdc1b49

Please sign in to comment.