Skip to content

Commit

Permalink
Fix bug with integrate_odes() for numeric solver (nest#1147)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu authored Dec 16, 2024
1 parent c87b682 commit 0596e8e
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,17 @@ extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %
}
{%- endif %}


{% set numeric_state_variables_to_be_integrated = numeric_state_variables + purely_numeric_state_variables_moved %}
{%- if ast.get_args() | length > 0 %}
{%- set numeric_state_variables_to_be_integrated = utils.filter_variables_list(numeric_state_variables_to_be_integrated, ast.get_args()) %}
{%- endif %}
{%- for variable_name in numeric_state_variables + numeric_state_variables_moved %}
{%- set update_expr = numeric_update_expressions[variable_name] %}
{%- set variable_symbol = variable_symbols[variable_name] %}
{%- if use_gap_junctions %}
f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in utils.integrate_odes_args_strs_from_function_call(ast) + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr)|replace("node.B_." + gap_junction_port + "_grid_sum_", "(node.B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %};
f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in numeric_state_variables_to_be_integrated + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr)|replace("node.B_." + gap_junction_port + "_grid_sum_", "(node.B_." + gap_junction_port + "_grid_sum_ + __I_gap)") }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %};
{%- else %}
f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in utils.integrate_odes_args_strs_from_function_call(ast) + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr) }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %};
f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in numeric_state_variables_to_be_integrated + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer.print(update_expr) }}{% else %}0{% endif %}{% else %}{{ gsl_printer.print(update_expr) }}{% endif %};
{%- endif %}
{%- endfor %}

Expand Down
120 changes: 120 additions & 0 deletions tests/nest_tests/resources/aeif_cond_alpha_alt_neuron.nestml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
aeif_cond_alpha - Conductance based exponential integrate-and-fire neuron model
###############################################################################

Description
+++++++++++

aeif_psc_alpha is the adaptive exponential integrate and fire neuron according to Brette and Gerstner (2005), with post-synaptic conductances in the form of a bi-exponential ("alpha") function.

The membrane potential is given by the following differential equation:

.. math::

C_m \frac{dV_m}{dt} =
-g_L(V_m-E_L)+g_L\Delta_T\exp\left(\frac{V_m-V_{th}}{\Delta_T}\right) -
g_e(t)(V_m-E_e) \\
-g_i(t)(V_m-E_i)-w + I_e

and

.. math::

\tau_w \frac{dw}{dt} = a(V_m-E_L) - w

Note that the membrane potential can diverge to positive infinity due to the exponential term. To avoid numerical instabilities, instead of :math:`V_m`, the value :math:`\min(V_m,V_{peak})` is used in the dynamical equations.


References
++++++++++

.. [1] Brette R and Gerstner W (2005). Adaptive exponential
integrate-and-fire model as an effective description of neuronal
activity. Journal of Neurophysiology. 943637-3642
DOI: https://doi.org/10.1152/jn.00686.2005


See also
++++++++

iaf_psc_alpha, aeif_psc_exp
"""
model aeif_cond_alpha_alt_neuron:

state:
V_m mV = E_L # Membrane potential
w pA = 0 pA # Spike-adaptation current
refr_t ms = 0 ms # Refractory period timer
g_exc nS = 0 nS # AHP conductance
g_exc' nS/ms = 0 nS/ms # AHP conductance
g_inh nS = 0 nS # AHP conductance
g_inh' nS/ms = 0 nS/ms # AHP conductance

equations:
inline V_bounded mV = min(V_m, V_peak) # prevent exponential divergence

g_exc'' = -2 * g_exc' / tau_syn_exc - g_exc / tau_syn_exc**2
g_inh'' = -2 * g_inh' / tau_syn_inh - g_inh / tau_syn_inh**2

# Add inlines to simplify the equation definition of V_m
inline exp_arg real = (V_bounded - V_th) / Delta_T
inline I_spike pA = g_L * Delta_T * exp(exp_arg)

V_m' = (-g_L * (V_bounded - E_L) + I_spike - g_exc * (V_bounded - E_exc) - g_inh * (V_bounded - E_inh) - w + I_e + I_stim) / C_m
w' = (a * (V_bounded - E_L) - w) / tau_w

refr_t' = -1e3 * ms/s # refractoriness is implemented as an ODE, representing a timer counting back down to zero. XXX: TODO: This should simply read ``refr_t' = -1 / s`` (see https://github.com/nest/nestml/issues/984)

parameters:
# membrane parameters
C_m pF = 281.0 pF # Membrane Capacitance
refr_T ms = 2 ms # Duration of refractory period
V_reset mV = -60.0 mV # Reset Potential
g_L nS = 30.0 nS # Leak Conductance
E_L mV = -70.6 mV # Leak reversal Potential (aka resting potential)

# spike adaptation parameters
a nS = 4 nS # Subthreshold adaptation
b pA = 80.5 pA # Spike-triggered adaptation
Delta_T mV = 2.0 mV # Slope factor
tau_w ms = 144.0 ms # Adaptation time constant
V_th mV = -50.4 mV # Threshold Potential
V_peak mV = 0 mV # Spike detection threshold

# synaptic parameters
tau_syn_exc ms = 0.2 ms # Synaptic Time Constant Excitatory Synapse
tau_syn_inh ms = 2.0 ms # Synaptic Time Constant for Inhibitory Synapse
E_exc mV = 0 mV # Excitatory reversal Potential
E_inh mV = -85.0 mV # Inhibitory reversal Potential

# constant external input current
I_e pA = 0 pA

input:
exc_spikes <- excitatory spike
inh_spikes <- inhibitory spike
I_stim pA <- continuous

output:
spike

update:
if refr_t > 0 ms:
# neuron is absolute refractory, do not evolve V_m
integrate_odes(g_exc, g_inh, w, refr_t)
else:
# neuron not refractory
integrate_odes(g_exc, g_inh, V_m, w)

onReceive(exc_spikes):
g_exc' += exc_spikes * (e / tau_syn_exc) * nS * s

onReceive(inh_spikes):
g_inh' += inh_spikes * (e / tau_syn_inh) * nS * s

onCondition(refr_t <= 0 ms and V_m >= V_th):
# threshold crossing
refr_t = refr_T # start of the refractory period
V_m = V_reset
w += b
emit_spike()
85 changes: 73 additions & 12 deletions tests/nest_tests/test_integrate_odes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@

try:
import matplotlib

matplotlib.use("Agg")
import matplotlib.ticker
import matplotlib.pyplot as plt

TEST_PLOTS = True
except Exception:
TEST_PLOTS = False
Expand All @@ -50,22 +52,24 @@ def setUp(self):
r"""Generate the model code"""

generate_nest_target(input_path=[os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))),
os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "integrate_odes_nonlinear_test.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),],
os.path.join("resources", "alpha_function_2nd_order_ode_neuron.nestml"))),
os.path.realpath(os.path.join(os.path.dirname(__file__),
os.path.join("resources", "aeif_cond_alpha_alt_neuron.nestml")))],
logging_level="INFO",
suffix="_nestml")

def test_convolutions_always_integrated(self):
r"""Test that synaptic integration continues for iaf_psc_exp, even when neuron is refractory."""

sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]
spike_interval = 5. # [ms]
sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]
spike_interval = 5. # [ms]

nest.set_verbosity("M_ALL")
nest.ResetKernel()
Expand Down Expand Up @@ -120,8 +124,8 @@ def test_convolutions_always_integrated(self):
def test_integrate_odes(self):
r"""Test the integrate_odes() function, in particular when not all the ODEs are being integrated."""

sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]
sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]

nest.set_verbosity("M_ALL")
nest.ResetKernel()
Expand Down Expand Up @@ -169,8 +173,8 @@ def test_integrate_odes(self):
def test_integrate_odes_nonlinear(self):
r"""Test the integrate_odes() function, in particular when not all the ODEs are being integrated, for nonlinear ODEs."""

sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]
sim_time: float = 100. # [ms]
resolution: float = .1 # [ms]

nest.set_verbosity("M_ALL")
nest.ResetKernel()
Expand Down Expand Up @@ -205,7 +209,6 @@ def test_integrate_odes_nonlinear(self):
for _ax in ax:
_ax.grid(which="major", axis="both")
_ax.grid(which="minor", axis="x", linestyle=":", alpha=.4)
# _ax.minorticks_on()
_ax.set_xlim(0., sim_time)
_ax.legend()

Expand All @@ -218,15 +221,17 @@ def test_integrate_odes_nonlinear(self):
def test_integrate_odes_params(self):
r"""Test the integrate_odes() function, in particular with respect to the parameter types."""

fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml")))
fname = os.path.realpath(
os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml")))
generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")

assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2

def test_integrate_odes_params2(self):
r"""Test the integrate_odes() function, in particular with respect to non-existent parameter variables."""

fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml")))
fname = os.path.realpath(
os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml")))
generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")

assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2
Expand Down Expand Up @@ -276,3 +281,59 @@ def test_integrate_odes_higher_order(self):
# verify
np.testing.assert_allclose(x_actual[-1], 0.10737970490959549)
np.testing.assert_allclose(y_actual[-1], 0.6211608596446752)

def test_integrate_odes_numeric_higher_order(self):
r"""
Tests for higher-order ODEs of the form F(x'',x',x)=0, integrate_odes(x) integrates the full dynamics of x with a numeric solver.
"""
resolution = 0.1
simtime = 800.
params_nestml = {"V_peak": 0.0, "a": 4.0, "b": 80.5, "E_L": -70.6,
"g_L": 300.0, 'E_exc': 20.0, 'E_inh': -85.0,
'tau_syn_exc': 40.0, 'tau_syn_inh': 20.0}

params_nest = {"V_peak": 0.0, "a": 4.0, "b": 80.5, "E_L": -70.6,
"g_L": 300.0, 'E_ex': 20.0, 'E_in': -85.0,
'tau_syn_ex': 40.0, 'tau_syn_in': 20.0}

for model in ["aeif_cond_alpha_alt_neuron_nestml", "aeif_cond_alpha"]:
nest.set_verbosity("M_ALL")
nest.ResetKernel()
nest.SetKernelStatus({"resolution": resolution})
try:
nest.Install("nestmlmodule")
except Exception:
# ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
pass
n = nest.Create(model)
if "_nestml" in model:
nest.SetStatus(n, params_nestml)
else:
nest.SetStatus(n, params_nest)

spike = nest.Create("spike_generator")
spike_times = [10.0, 400.0]
nest.SetStatus(spike, {"spike_times": spike_times})
nest.Connect(spike, n, syn_spec={"weight": 0.1, "delay": 1.0})
nest.Connect(spike, n, syn_spec={"weight": -0.2, "delay": 100.})

mm = nest.Create("multimeter", params={"record_from": ["V_m"]})
nest.Connect(mm, n)

nest.Simulate(simtime)
times = mm.get()["events"]["times"]
if "_nestml" in model:
v_m_nestml = mm.get()["events"]["V_m"]
else:
v_m_nest = mm.get()["events"]["V_m"]

if TEST_PLOTS:
fig, ax = plt.subplots(nrows=1)

ax.plot(times, v_m_nestml, label="NESTML")
ax.plot(times, v_m_nest, label="NEST")
ax.legend()

fig.savefig("/tmp/test_integrate_odes_numeric_higher_order.png")

np.testing.assert_allclose(v_m_nestml, v_m_nest)

0 comments on commit 0596e8e

Please sign in to comment.